3 #include <VirtualRobot/Robot.h>
10 #include <dmp/representation/dmp/umidmp.h>
13 #include <armarx/control/deprecated_njoint_mp_controller/joint_space/ControllerInterface.h>
24 return "NJointJointSpaceDMPController";
29 const armarx::NJointControllerConfigPtr& config,
32 NJointJointSpaceDMPControllerConfigPtr cfg =
33 NJointJointSpaceDMPControllerConfigPtr::dynamicCast(config);
34 ARMARX_CHECK_EXPRESSION_W_HINT(cfg,
"Needed type: NJointJointSpaceDMPControllerConfigPtr");
36 for (std::string jointName : cfg->jointNames)
40 targets.insert(std::make_pair(jointName, ct->
asA<ControlTarget1DoFActuatorVelocity>()));
41 positionSensors.insert(
42 std::make_pair(jointName, sv->
asA<SensorValue1DoFActuatorPosition>()));
44 std::make_pair(jointName, sv->
asA<SensorValue1DoFActuatorTorque>()));
45 gravityTorqueSensors.insert(
46 std::make_pair(jointName, sv->
asA<SensorValue1DoFGravityTorque>()));
47 velocitySensors.insert(
48 std::make_pair(jointName, sv->
asA<SensorValue1DoFActuatorVelocity>()));
50 if (cfg->jointNames.size() == 0)
55 dmpPtr.reset(
new DMP::UMIDMP(cfg->kernelSize, cfg->DMPKd, cfg->baseMode, cfg->tau));
56 timeDuration = cfg->timeDuration;
57 canVal = timeDuration;
61 phaseDist0 = cfg->phaseDist0;
62 phaseDist1 = cfg->phaseDist1;
63 phaseKp = cfg->phaseKp;
65 isDisturbance =
false;
82 std::vector<double> currentPosition;
83 std::vector<double> currentVelocity;
84 for (
size_t i = 0; i < dimNames.size(); i++)
86 const auto& jointName = dimNames.at(i);
87 DMP::DMPState currentPos;
88 currentPos.pos = (positionSensors.count(jointName) == 1)
89 ? positionSensors[jointName]->position
91 currentPos.vel = (velocitySensors.count(jointName) == 1)
92 ? velocitySensors[jointName]->velocity
94 currentPos.vel *= timeDuration;
95 currentState.push_back(currentPos);
96 currentPosition.push_back(currentPos.pos);
97 currentVelocity.push_back(currentPos.vel);
99 error += pow(currentPos.pos - targetState[i], 2);
106 phaseDist = phaseDist1;
110 phaseDist = phaseDist0;
114 phaseStop = phaseL / (1 + exp(-phaseK * (error - phaseDist)));
115 mpcFactor = 1 - (phaseStop / phaseL);
119 isDisturbance =
true;
124 isDisturbance =
false;
128 double deltaT = timeSinceLastIteration.toSecondsDouble();
129 canVal -= 1 / tau * deltaT * 1 / (1 + phaseStop);
130 double dmpDeltaT = deltaT / timeDuration;
131 dmpPtr->setTemporalFactor(tau);
133 currentState = dmpPtr->calculateDirectlyVelocity(
134 currentState, canVal / timeDuration, dmpDeltaT, targetState);
141 for (
size_t i = 0; i < dimNames.size(); ++i)
143 const auto& jointName = dimNames.at(i);
144 if (targets.count(jointName) == 1)
146 double vel0 = currentState[i].vel / timeDuration;
147 double vel1 = phaseKp * (targetState[i] - currentPosition[i]);
148 double vel = mpcFactor * vel0 + (1 - mpcFactor) * vel1;
149 targets[jointName]->velocity = finished ? 0.0f : vel;
151 std::string targetVelstr = jointName +
"_targetvel";
152 std::string targetStatestr = jointName +
"_dmpTarget";
153 debugOutputData.
getWriteBuffer().latestTargetVelocities[jointName] = vel;
154 debugOutputData.
getWriteBuffer().dmpTargetState[jointName] = targetState[i];
164 for (
size_t i = 0; i < dimNames.size(); ++i)
166 const auto& jointName = dimNames.at(i);
167 if (targets.count(jointName) == 1)
169 targets[jointName]->velocity = 0.0f;
179 DMP::Vec<DMP::SampledTrajectoryV2> trajs;
182 for (
size_t i = 0; i < fileNames.size(); ++i)
184 DMP::SampledTrajectoryV2 traj;
185 traj.readFromCSVFile(fileNames.at(i));
186 dimNames = traj.getDimensionNames();
188 trajs.push_back(traj);
192 ratios.push_back(1.0);
196 ratios.push_back(0.0);
199 dmpPtr->learnFromTrajectories(trajs);
200 dmpPtr->setOneStepMPC(
true);
201 dmpPtr->styleParas = dmpPtr->getStyleParasWithRatio(ratios);
211 currentState.clear();
213 for (
size_t i = 0; i < dimNames.size(); i++)
215 const auto& jointName = dimNames.at(i);
216 DMP::DMPState currentPos;
217 currentPos.pos = (positionSensors.count(jointName) == 1)
218 ? positionSensors[jointName]->position
220 currentPos.vel = (velocitySensors.count(jointName) == 1)
221 ? velocitySensors[jointName]->velocity
223 currentState.push_back(currentPos);
224 targetState.push_back(currentPos.pos);
226 dmpPtr->prepareExecution(goals, currentState, 1, tau);
272 datafields[pair.first] =
new Variant(pair.second);
276 for (
auto& pair : valuesst)
278 datafields[pair.first] =
new Variant(pair.second);
283 debugObs->setDebugChannel(
"latestDMPTargetVelocities", datafields);