3 #include <boost/archive/text_iarchive.hpp>
4 #include <boost/archive/text_oarchive.hpp>
5 #include <boost/archive/xml_iarchive.hpp>
6 #include <boost/archive/xml_oarchive.hpp>
11 #include <mplib/core/SystemState.h>
21 mplib::factories::VMPFactory mpfactory;
22 mpfactory.addConfig(
"kernelSize",
c.kernelSize);
24 resetVMPType(
c.mpTypeString);
26 std::shared_ptr<mplib::representation::AbstractMovementPrimitive> mp =
28 vmp = std::dynamic_pointer_cast<mplib::representation::vmp::PrincipalComponentVMP>(mp);
30 if (
cfg.viaPoints.size() > 0)
34 for (
const auto& vp :
cfg.viaPoints)
36 ARMARX_INFO <<
"---- (" <<
index <<
") canonical value: " << vp.canonicalValue
61 return cfg.nodeSetName;
73 start(std::vector<double>());
79 start(goals, std::vector<double>());
85 cfg.durationSec = timeDuration;
92 cfg.durationSec = timeDuration;
98 cfg.durationSec = timeDuration;
120 auto vps =
vmp->getViaPoints();
123 for (
const auto&
v : vps)
127 mplib::core::SystemState::convertStatesToArray(
v.second, 0));
160 ARMARX_INFO <<
"cannot reset a running mp, please stop it first.";
202 ARMARX_INFO <<
"-- Train MP '" <<
cfg.name <<
"' from trajectories";
206 ARMARX_INFO <<
"---- using default trajectories from the MP configuration";
207 trajList =
cfg.trajectoryList;
215 std::vector<mplib::core::SampledTrajectory> trajs;
216 for (
const auto& mpTraj : trajList)
218 std::map<double, std::vector<double>> trajMap;
219 for (
int i = 0; i < mpTraj.time.size(); ++i)
221 auto& row_vector = mpTraj.traj.row(i);
222 std::vector<double> vec(row_vector.data(), row_vector.data() + row_vector.size());
223 trajMap.emplace(mpTraj.time(i), vec);
225 mplib::core::SampledTrajectory traj(trajMap);
226 trajs.push_back(traj);
236 std::vector<std::string> fileList;
237 if (fileNames.empty())
239 ARMARX_INFO <<
"---- using default trajectory list: " <<
cfg.fileList[0];
240 fileList =
cfg.fileList;
244 fileList = fileNames;
247 std::vector<mplib::core::SampledTrajectory> trajs;
248 for (
auto& file : fileList)
250 mplib::core::SampledTrajectory traj;
251 traj.readFromCSVFile(file);
252 trajs.push_back(traj);
260 if (
cfg.regressionModel ==
"gpr")
262 vmp->setBaseFunctionApproximator(
263 std::make_unique<mplib::math::function_approximation::GaussianProcessRegression<
264 mplib::math::kernel::SquaredExponentialCovarianceFunction>>(
265 0.01, 0.1, 0.000001));
266 for (
auto& traj : trajs)
268 traj = mplib::core::SampledTrajectory::downSample(traj, 100);
272 vmp->learnFromTrajectories(trajs);
279 for (
size_t i = 0; i < trajs[0].dim(); i++)
281 mplib::representation::MPState state(trajs[0].begin()->getPosition(i),
282 trajs[0].begin()->getDeriv(i, 1));
294 if (not
cfg.fileList.empty())
298 else if (not
cfg.trajectoryList.empty())
306 <<
": You don't provide any trajectory files (path on the robot) nor trajectories "
308 "type std::vector<armarx::control::common::mp::arondto::MPTraj> in your "
309 "configuration file. \n"
310 <<
"If you intended train MPs by providing files or "
311 "trajectories on the fly, consider using \n"
312 <<
"ctrl.learnFromCSV(Ice::StringSeq& fileNames) or \n"
313 <<
"ctrl.learnFromCSVlearnFromTrajs(armarx::aron::data::dto::DictPtr& dict)";
329 <<
"---- user specified empty goal, try to use goals in your configuration file:\n"
335 <<
"---- user specified empty goal, try to use goals learned from trajectory:\n"
352 <<
": size of starts and learned mp are not consistent";
361 std::vector<double> learnedStart;
366 ARMARX_INFO <<
"---- user doesn't define start point, fall back to start point "
367 "learned from trajectory: "
375 const auto& startState =
vmp->getStartState();
376 std::vector<double> mpStartVec;
377 for (
const auto& state : startState)
379 mpStartVec.push_back(state.pos);
387 return cfg.startFromPrevTarget;
405 if (canVal <= vmp->getUmin())
409 else if (canVal >= 1.0 -
vmp->getUmin())
420 ARMARX_ERROR <<
" -- mp is not trained yet, potential memory allocation error! "
421 "Please train mp before setting any via points";
429 vmp->removeViaPoints();
442 vmp->setWeights(weights);
448 return vmp->getWeights();
490 return vmp->getGoals();