3 #include <boost/archive/text_iarchive.hpp>
4 #include <boost/archive/text_oarchive.hpp>
5 #include <boost/archive/xml_iarchive.hpp>
6 #include <boost/archive/xml_oarchive.hpp>
11 #include <mplib/core/SystemState.h>
21 mplib::factories::VMPFactory mpfactory;
22 mpfactory.addConfig(
"kernelSize",
c.kernelSize);
24 resetVMPType(
c.mpTypeString);
26 std::shared_ptr<mplib::representation::AbstractMovementPrimitive> mp =
28 vmp = std::dynamic_pointer_cast<mplib::representation::vmp::PrincipalComponentVMP>(mp);
30 if (
cfg.viaPoints.size() > 0)
34 for (
const auto& vp :
cfg.viaPoints)
36 ARMARX_INFO <<
"---- (" <<
index <<
") canonical value: " << vp.canonicalValue
61 return cfg.nodeSetName;
73 start(std::vector<double>());
79 start(goals, std::vector<double>());
85 cfg.durationSec = timeDuration;
92 cfg.durationSec = timeDuration;
98 cfg.durationSec = timeDuration;
120 auto vps =
vmp->getViaPoints();
123 for (
const auto&
v : vps)
127 mplib::core::SystemState::convertStatesToArray(
v.second, 0));
160 ARMARX_INFO <<
"cannot reset a running mp, please stop it first.";
201 ARMARX_INFO <<
"-- Train MP '" <<
cfg.name <<
"' from trajectories";
205 ARMARX_INFO <<
"---- using default trajectories from the MP configuration";
206 trajList =
cfg.trajectoryList;
214 std::vector<mplib::core::SampledTrajectory> trajs;
215 for (
const auto& mpTraj : trajList)
217 std::map<double, std::vector<double>> trajMap;
218 for (
int i = 0; i < mpTraj.time.size(); ++i)
220 auto& row_vector = mpTraj.traj.row(i);
221 std::vector<double> vec(row_vector.data(), row_vector.data() + row_vector.size());
222 trajMap.emplace(mpTraj.time(i), vec);
224 mplib::core::SampledTrajectory traj(trajMap);
225 trajs.push_back(traj);
235 std::vector<std::string> fileList;
236 if (fileNames.empty())
238 ARMARX_INFO <<
"---- using default trajectory list: " <<
cfg.fileList[0];
239 fileList =
cfg.fileList;
243 fileList = fileNames;
246 std::vector<mplib::core::SampledTrajectory> trajs;
247 for (
auto& file : fileList)
249 mplib::core::SampledTrajectory traj;
250 traj.readFromCSVFile(file);
251 trajs.push_back(traj);
259 if (
cfg.regressionModel ==
"gpr")
261 vmp->setBaseFunctionApproximator(
262 std::make_unique<mplib::math::function_approximation::GaussianProcessRegression<
263 mplib::math::kernel::SquaredExponentialCovarianceFunction>>(
264 0.01, 0.1, 0.000001));
265 for (
auto& traj : trajs)
267 traj = mplib::core::SampledTrajectory::downSample(traj, 100);
271 vmp->learnFromTrajectories(trajs);
278 for (
size_t i = 0; i < trajs[0].dim(); i++)
280 mplib::representation::MPState state(trajs[0].begin()->getPosition(i),
281 trajs[0].begin()->getDeriv(i, 1));
293 if (not
cfg.fileList.empty())
297 else if (not
cfg.trajectoryList.empty())
304 <<
"You don't provide any trajectory files (path on the robot) nor trajectories in "
305 "type std::vector<armarx::control::common::mp::arondto::MPTraj> in your "
306 "configuration file. \n"
307 <<
"If you intended train MPs by providing files or "
308 "trajectories on the fly, consider using \n"
309 <<
"ctrl.learnFromCSV(Ice::StringSeq& fileNames) or \n"
310 <<
"ctrl.learnFromCSVlearnFromTrajs(armarx::aron::data::dto::DictPtr& dict)";
326 <<
"---- user specified empty goal, try to use goals in your configuration file:\n"
332 <<
"---- user specified empty goal, try to use goals learned from trajectory:\n"
349 <<
": size of starts and learned mp are not consistent";
358 std::vector<double> learnedStart;
363 ARMARX_INFO <<
"---- user doesn't define start point, fall back to start point "
364 "learned from trajectory: "
372 const auto& startState =
vmp->getStartState();
373 std::vector<double> mpStartVec;
374 for (
const auto& state : startState)
376 mpStartVec.push_back(state.pos);
394 if (canVal <= vmp->getUmin())
398 else if (canVal >= 1.0 -
vmp->getUmin())
404 vmp->setViaPoint(canVal, viapoint);
409 ARMARX_ERROR <<
" -- mp is not trained yet, potential memory allocation error! "
410 "Please train mp before setting any via points";
418 vmp->removeViaPoints();
431 vmp->setWeights(weights);
437 return vmp->getWeights();
479 return vmp->getGoals();