CommonPlacesLearner.cpp
Go to the documentation of this file.
1/*
2* This file is part of ArmarX.
3*
4* ArmarX is free software; you can redistribute it and/or modify
5* it under the terms of the GNU General Public License version 2 as
6* published by the Free Software Foundation.
7*
8* ArmarX is distributed in the hope that it will be useful, but
9* WITHOUT ANY WARRANTY; without even the implied warranty of
10* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11* GNU General Public License for more details.
12*
13* You should have received a copy of the GNU General Public License
14* along with this program. If not, see <http://www.gnu.org/licenses/>.
15*
16* @package MemoryX::CommonPlacesLearner
17* @author Alexey Kozlov ( kozlov at kit dot edu)
18* @date 2013
19* @copyright http://www.gnu.org/licenses/gpl-2.0.txt
20* GNU General Public License
21*/
22
23// memoryx interface
24#include <MemoryX/interface/core/EntityBase.h>
25#include <MemoryX/interface/memorytypes/MemoryEntities.h>
26
27// memoryx helpers
34
35// Object Factories
37
38#include "CommonPlacesLearner.h"
43
44namespace memoryx
45{
46 void
48 {
49 usingProxy("LongtermMemory");
50 usingProxy("PriorKnowledge");
51 }
52
53 void
55 {
56 Ice::CommunicatorPtr ic = getIceManager()->getCommunicator();
57
58
59 longtermMemoryPrx = getProxy<LongtermMemoryInterfacePrx>("LongtermMemory");
60 priorKnowledgePrx = getProxy<PriorKnowledgeInterfacePrx>("PriorKnowledge");
61
62 std::string segmentName = getProperty<std::string>("LTMSegment").getValue();
63 setLTMSegmentName(segmentName);
64
65 if (!ltmInstancesSegmentPrx)
66 {
67 ARMARX_FATAL << "LTM segment not found or has an invalid type: " << segmentName;
68 }
69
70 const float agingFactor = getProperty<float>("AgingFactor").getValue();
71 const float pruningThreshold = getProperty<float>("PruningThreshold").getValue();
72
73 const float distanceThreshold = getProperty<float>("MergingThreshold").getValue();
74 const std::string distanceName = getProperty<std::string>("MergingDistanceType").getValue();
76
77 if (distanceName == "Mahalanobis")
78 {
79 distance.reset(new MahalanobisDistance());
80 }
81 else if (distanceName == "KL")
82 {
83 distance.reset(new RunnallsKLDistance());
84 }
85 else if (distanceName == "ISD")
86 {
87 distance.reset(new ISDDistance);
88 }
89 else
90 {
91 ARMARX_WARNING << "Unknown MergingDistanceType: " << distanceName
92 << ". Will use Kullback-Leibler instead.";
93 distance.reset(new RunnallsKLDistance());
94 }
95
96 assMethod = new GaussianMixtureAssociationMethod(distanceThreshold, distance);
97 fusionMethod = new GaussianMixturePositionFusion(agingFactor, pruningThreshold, assMethod);
98
99 const std::string reducerName = getProperty<std::string>("GMMReducerAlgorithm").getValue();
100 ARMARX_INFO << "Using GMM reduction algorithm: " << reducerName;
101
102 if (reducerName == "West")
103 {
104 gmmReducer.reset(new WestGMMReducer());
105 }
106 else if (reducerName == "Runnalls")
107 {
108 gmmReducer.reset(new RunnallsGMMReducer());
109 }
110 else if (reducerName == "Williams")
111 {
112 gmmReducer.reset(new WilliamsGMMReducer());
113 }
114 else
115 {
116 ARMARX_WARNING << "Unknown GMMReducerAlgorithm: " << reducerName
117 << ". Will use WestReducer instead.";
118 gmmReducer.reset(new WestGMMReducer());
119 }
120 }
121
122 void
126
127 void
128 CommonPlacesLearner::setLTMSegmentName(const ::std::string& segmentName,
129 const ::Ice::Current& c)
130 {
131 if (!ltmInstancesSegmentPrx || ltmInstancesSegmentPrx->getSegmentName() != segmentName)
132 {
133 ltmInstancesSegmentPrx =
134 longtermMemoryPrx->getCustomInstancesSegment(segmentName, true);
135 }
136 }
137
138 void
139 CommonPlacesLearner::setAgingFactor(float factor, const ::Ice::Current& c)
140 {
141 fusionMethod->setAgingFactor(factor);
142 }
143
144 void
145 CommonPlacesLearner::setMergingThreshold(float threshold, const ::Ice::Current& c)
146 {
147 assMethod->setThreshold(threshold);
148 }
149
150 void
151 CommonPlacesLearner::learnFromObjectMCA(const ::std::string& objectName,
152 const MultivariateNormalDistributionBasePtr& posDist,
153 const ::Ice::Current& c)
154 {
155 ObjectInstancePtr instance = new ObjectInstance(objectName);
156 FloatVector mean = posDist->getMean();
158 posMean->x = mean[0];
159 posMean->y = mean[1];
160 posMean->z = mean[2];
161 posMean->frame = "world";
162
163 instance->setPosition(posMean);
164 instance->setPositionUncertainty(posDist);
165 instance->setClass(objectName, 1.);
166 learnFromObject(instance, c);
167 }
168
169 void
170 CommonPlacesLearner::learnFromObject(const ObjectInstanceBasePtr& newObject,
171 const ::Ice::Current& c)
172 {
173 const std::string clsName = newObject->getMostProbableClass();
174 ObjectInstanceList ltmObjects = ltmInstancesSegmentPrx->getObjectInstancesByClass(clsName);
175
176 if (ltmObjects.size() > 0)
177 {
178 for (ObjectInstanceList::const_iterator it = ltmObjects.begin(); it != ltmObjects.end();
179 ++it)
180 {
181 // update existing
182 EntityBasePtr fusedEntity = fusionMethod->fuseEntity(*it, newObject);
183 ltmInstancesSegmentPrx->updateEntity((*it)->getId(), fusedEntity);
184 ARMARX_INFO << "Updated existing object: " << newObject->getName();
185 }
186 }
187 else
188 {
189 // add new
190 EntityBasePtr fusedEntity = fusionMethod->initEntity(newObject);
191 ltmInstancesSegmentPrx->addEntity(fusedEntity);
192 ARMARX_INFO << "Added new object: " << newObject->getName();
193 }
194 }
195
196 void
197 CommonPlacesLearner::learnFromSnapshot(const ::std::string& snapshotName,
198 const ::Ice::Current& c)
199 {
200 WorkingMemorySnapshotInterfacePrx snapshot =
201 longtermMemoryPrx->openWorkingMemorySnapshot(snapshotName);
202
203 if (!snapshot)
204 {
205 ARMARX_ERROR << "Snapshot not found in LTM: " << snapshotName;
206 return;
207 }
208
209 const std::string segmentName = "objectInstances";
210 PersistentEntitySegmentBasePrx segInstances = snapshot->getSegment(segmentName);
211
212 if (!segInstances)
213 {
214 ARMARX_ERROR << "Segment not found in snapshot: " << segmentName;
215 return;
216 }
217
218 EntityIdList ids = segInstances->getAllEntityIds();
219
220 // sort ids
221 struct EntityIdComparator
222 {
223 static bool
224 compare(const std::string& i, const std::string& j)
225 {
226 return (std::stoi(i) < std::stoi(j));
227 }
228 };
229
230 std::sort(ids.begin(), ids.end(), EntityIdComparator::compare);
231
232 for (EntityIdList::const_iterator it = ids.begin(); it != ids.end(); ++it)
233 {
234 EntityBasePtr snapEntity = segInstances->getEntityById(*it);
235 ObjectInstancePtr snapInstance = ObjectInstancePtr::dynamicCast(snapEntity);
236 learnFromObject(snapInstance, c);
237 }
238
239 ARMARX_INFO << "Processing complete! Learned objects: " << ids.size();
240 }
241
242 GaussianMixtureDistributionBasePtr
243 CommonPlacesLearner::getPositionFull(const ::std::string& className, const ::Ice::Current&)
244 {
245 GaussianMixtureDistributionBasePtr result = GaussianMixtureDistributionBasePtr();
246
247 // get subclassses
248 NameList clsList;
249 clsList.push_back(className);
250 getChildClasses(className, clsList);
251
252 // get all objects with the class specified or one of its subclasses
253 ObjectInstanceList ltmObjects =
254 ltmInstancesSegmentPrx->getObjectInstancesByClassList(clsList);
255 ObjectInstanceList::const_iterator it = ltmObjects.begin();
256
257 if (it != ltmObjects.end())
258 {
259 ObjectInstancePtr inst = ObjectInstancePtr::dynamicCast(*it);
261 inst->getPositionAttribute()->getUncertainty());
262
263 for (++it; it != ltmObjects.end(); ++it)
264 {
265 inst = ObjectInstancePtr::dynamicCast(*it);
266 ARMARX_VERBOSE_S << "Processing " << inst->getName() << "..." << std::endl;
268 inst->getPositionAttribute()->getUncertainty()));
269 }
270 }
271
272 return result;
273 }
274
275 GaussianMixtureDistributionBasePtr
277 Ice::Int compCount,
278 const ::Ice::Current&)
279 {
280 GaussianMixtureDistributionBasePtr fullGMM = getPositionFull(objectName);
281
282 if (fullGMM)
283 {
284 return gmmReducer->reduceByComponentCount(fullGMM, compCount);
285 }
286 else
287 {
288 return GaussianMixtureDistributionBasePtr();
289 }
290 }
291
292 GaussianMixtureDistributionBasePtr
294 Ice::Float maxDeviation,
295 DeviationType devType,
296 const ::Ice::Current&)
297 {
298 GaussianMixtureDistributionBasePtr fullGMM = getPositionFull(objectName);
299
300 if (fullGMM)
301 {
302 return gmmReducer->reduceByMaxDeviation(
303 fullGMM, maxDeviation, convertDeviationType(devType));
304 }
305 else
306 {
307 return GaussianMixtureDistributionBasePtr();
308 }
309 }
310
311 NormalDistributionBasePtr
312 CommonPlacesLearner::getPositionAsGaussian(const ::std::string& objectName,
313 const ::Ice::Current&)
314 {
315 GaussianMixtureDistributionBasePtr result =
317
318 if (result)
319 {
320 return result->getModalComponent().gaussian;
321 }
322 else
323 {
324 return NormalDistributionBasePtr();
325 }
326 }
327
328 Cluster3DList
330 Ice::Int compCount,
331 const ::Ice::Current& c)
332 {
333 GaussianMixtureDistributionBasePtr reducedGMM =
334 getPositionReducedByComponentCount(objectName, compCount, c);
335 return gmmToClusterList(reducedGMM);
336 }
337
338 Cluster3DList
340 Ice::Float maxDeviation,
341 DeviationType devType,
342 const ::Ice::Current& c)
343 {
344 GaussianMixtureDistributionBasePtr reducedGMM =
345 getPositionReducedByMaxDeviation(objectName, maxDeviation, devType, c);
346 return gmmToClusterList(reducedGMM);
347 }
348
349 // the only purpose of this dummy function is to avoid making GaussianMixtureHelpers an Ice component or
350 // making it dependent from CommonPlacesLearner
352 CommonPlacesLearner::convertDeviationType(DeviationType devType)
353 {
354 switch (devType)
355 {
356 case eDevAABB:
357 return eAABB;
358
359 case eDevOrientedBBox:
360 return eOrientedBBox;
361
362 case eDevEqualSphere:
363 return eEqualSphere;
364
365 default:
366 return eAABB;
367 }
368 }
369
370 Cluster3DList
371 CommonPlacesLearner::gmmToClusterList(const GaussianMixtureDistributionBasePtr& gmm)
372 {
373 Cluster3DList result;
374
375 if (gmm)
376 {
377 for (int i = 0; i < gmm->size(); ++i)
378 {
379 Cluster3D cluster;
380 FloatVector compMean = gmm->getComponent(i).gaussian->getMean();
381 cluster.center.x = compMean[0];
382 cluster.center.y = compMean[1];
383 cluster.center.z = compMean[2];
384 cluster.weight = gmm->getComponent(i).weight;
385 result.push_back(cluster);
386 }
387 }
388
389 // sort cluster with descending weights
390 struct Cluster3DComparator
391 {
392 static bool
393 compare(const Cluster3D& i, const Cluster3D& j)
394 {
395 return (i.weight > j.weight);
396 }
397 };
398
399 std::sort(result.begin(), result.end(), Cluster3DComparator::compare);
400
401 return result;
402 }
403
404 void
405 CommonPlacesLearner::getChildClasses(std::string className, NameList& result)
406 {
407 ObjectClassList children =
408 priorKnowledgePrx->getObjectClassesSegment()->getChildClasses(className);
409 ARMARX_VERBOSE_S << "Found " << children.size() << " subclasses for " << className
410 << std::endl;
411
412 for (ObjectClassList::const_iterator it = children.begin(); it != children.end(); ++it)
413 {
414 result.push_back((*it)->getName());
415 }
416 }
417} // namespace memoryx
constexpr T c
Property< PropertyType > getProperty(const std::string &name)
The FramedPosition class.
Definition FramedPose.h:158
bool usingProxy(const std::string &name, const std::string &endpoints="")
Registers a proxy for retrieval after initialization and adds it to the dependency list.
IceManagerPtr getIceManager() const
Returns the IceManager.
Ice::ObjectPrx getProxy(long timeoutMs=0, bool waitForScheduler=true) const
Returns the proxy of this object (optionally it waits for the proxy)
GaussianMixtureDistributionBasePtr getPositionFull(const ::std::string &objectName, const ::Ice::Current &=Ice::emptyCurrent) override
void onInitComponent() override
Pure virtual hook for the subclass.
Cluster3DList getPositionClustersByMaxDeviation(const ::std::string &objectName, Ice::Float maxDeviation, DeviationType devType, const ::Ice::Current &=Ice::emptyCurrent) override
void setLTMSegmentName(const ::std::string &segmentName, const ::Ice::Current &=Ice::emptyCurrent) override
void setMergingThreshold(float threshold, const ::Ice::Current &=Ice::emptyCurrent) override
void learnFromObjectMCA(const ::std::string &objectName, const MultivariateNormalDistributionBasePtr &posDist, const ::Ice::Current &=Ice::emptyCurrent) override
GaussianMixtureDistributionBasePtr getPositionReducedByComponentCount(const ::std::string &objectName, Ice::Int compCount, const ::Ice::Current &=Ice::emptyCurrent) override
void onConnectComponent() override
Pure virtual hook for the subclass.
void learnFromObject(const ObjectInstanceBasePtr &newObject, const ::Ice::Current &=Ice::emptyCurrent) override
GaussianMixtureDistributionBasePtr getPositionReducedByMaxDeviation(const ::std::string &objectName, Ice::Float maxDeviation, DeviationType devType, const ::Ice::Current &=Ice::emptyCurrent) override
void setAgingFactor(float factor, const ::Ice::Current &=Ice::emptyCurrent) override
void onExitComponent() override
Hook for subclass.
NormalDistributionBasePtr getPositionAsGaussian(const ::std::string &objectName, const ::Ice::Current &=Ice::emptyCurrent) override
void learnFromSnapshot(const ::std::string &snapshotName, const ::Ice::Current &=Ice::emptyCurrent) override
Cluster3DList getPositionClustersByComponentCount(const ::std::string &objectName, Ice::Int compCount, const ::Ice::Current &=Ice::emptyCurrent) override
static GaussianMixtureDistributionPtr FromProbabilityMeasure(const ProbabilityMeasureBasePtr &probMeasure)
Convert or approximate given ProbabilityMeasure to a gaussian mixture.
#define ARMARX_INFO
The normal logging level.
Definition Logging.h:181
#define ARMARX_FATAL
The logging level for unexpected behaviour, that will lead to a seriously malfunctioning program and ...
Definition Logging.h:199
#define ARMARX_ERROR
The logging level for unexpected behaviour, that must be fixed.
Definition Logging.h:196
#define ARMARX_VERBOSE_S
Definition Logging.h:207
#define ARMARX_WARNING
The logging level for unexpected behaviour, but not a serious problem.
Definition Logging.h:193
::IceInternal::Handle<::Ice::Communicator > CommunicatorPtr
Definition IceManager.h:49
int compare(const T &lhs, const T &rhs)
::std::vector<::Ice::Float > FloatVector
IceInternal::Handle< FramedPosition > FramedPositionPtr
Definition FramedPose.h:149
VirtualRobot headers.
std::shared_ptr< GMMDistance > GMMDistancePtr
Definition GMMDistance.h:62
IceInternal::Handle< ObjectInstance > ObjectInstancePtr
DeviationMeasure
Definition GMMReducer.h:38
@ eOrientedBBox
Definition GMMReducer.h:40
@ eEqualSphere
Definition GMMReducer.h:41
double distance(const Point &a, const Point &b)
Definition point.hpp:95