CommonPlacesLearner.cpp
Go to the documentation of this file.
1/*
2* This file is part of ArmarX.
3*
4* ArmarX is free software; you can redistribute it and/or modify
5* it under the terms of the GNU General Public License version 2 as
6* published by the Free Software Foundation.
7*
8* ArmarX is distributed in the hope that it will be useful, but
9* WITHOUT ANY WARRANTY; without even the implied warranty of
10* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11* GNU General Public License for more details.
12*
13* You should have received a copy of the GNU General Public License
14* along with this program. If not, see <http://www.gnu.org/licenses/>.
15*
16* @package MemoryX::CommonPlacesLearner
17* @author Alexey Kozlov ( kozlov at kit dot edu)
18* @date 2013
19* @copyright http://www.gnu.org/licenses/gpl-2.0.txt
20* GNU General Public License
21*/
22
23// memoryx interface
24#include <MemoryX/interface/core/EntityBase.h>
25#include <MemoryX/interface/memorytypes/MemoryEntities.h>
26
27// memoryx helpers
34
35// Object Factories
38
39#include "CommonPlacesLearner.h"
44
45namespace memoryx
46{
47 std::string
49 {
50 return "CommonPlacesLearner";
51 }
52
53 void
55 {
56 usingProxy("LongtermMemory");
57 usingProxy("PriorKnowledge");
58 }
59
60 void
62 {
63 Ice::CommunicatorPtr ic = getIceManager()->getCommunicator();
64
65
66 longtermMemoryPrx = getProxy<LongtermMemoryInterfacePrx>("LongtermMemory");
67 priorKnowledgePrx = getProxy<PriorKnowledgeInterfacePrx>("PriorKnowledge");
68
69 std::string segmentName = getProperty<std::string>("LTMSegment").getValue();
70 setLTMSegmentName(segmentName);
71
72 if (!ltmInstancesSegmentPrx)
73 {
74 ARMARX_FATAL << "LTM segment not found or has an invalid type: " << segmentName;
75 }
76
77 const float agingFactor = getProperty<float>("AgingFactor").getValue();
78 const float pruningThreshold = getProperty<float>("PruningThreshold").getValue();
79
80 const float distanceThreshold = getProperty<float>("MergingThreshold").getValue();
81 const std::string distanceName = getProperty<std::string>("MergingDistanceType").getValue();
83
84 if (distanceName == "Mahalanobis")
85 {
86 distance.reset(new MahalanobisDistance());
87 }
88 else if (distanceName == "KL")
89 {
90 distance.reset(new RunnallsKLDistance());
91 }
92 else if (distanceName == "ISD")
93 {
94 distance.reset(new ISDDistance);
95 }
96 else
97 {
98 ARMARX_WARNING << "Unknown MergingDistanceType: " << distanceName
99 << ". Will use Kullback-Leibler instead.";
100 distance.reset(new RunnallsKLDistance());
101 }
102
103 assMethod = new GaussianMixtureAssociationMethod(distanceThreshold, distance);
104 fusionMethod = new GaussianMixturePositionFusion(agingFactor, pruningThreshold, assMethod);
105
106 const std::string reducerName = getProperty<std::string>("GMMReducerAlgorithm").getValue();
107 ARMARX_INFO << "Using GMM reduction algorithm: " << reducerName;
108
109 if (reducerName == "West")
110 {
111 gmmReducer.reset(new WestGMMReducer());
112 }
113 else if (reducerName == "Runnalls")
114 {
115 gmmReducer.reset(new RunnallsGMMReducer());
116 }
117 else if (reducerName == "Williams")
118 {
119 gmmReducer.reset(new WilliamsGMMReducer());
120 }
121 else
122 {
123 ARMARX_WARNING << "Unknown GMMReducerAlgorithm: " << reducerName
124 << ". Will use WestReducer instead.";
125 gmmReducer.reset(new WestGMMReducer());
126 }
127 }
128
129 void
133
134 void
135 CommonPlacesLearner::setLTMSegmentName(const ::std::string& segmentName,
136 const ::Ice::Current& c)
137 {
138 if (!ltmInstancesSegmentPrx || ltmInstancesSegmentPrx->getSegmentName() != segmentName)
139 {
140 ltmInstancesSegmentPrx =
141 longtermMemoryPrx->getCustomInstancesSegment(segmentName, true);
142 }
143 }
144
145 void
146 CommonPlacesLearner::setAgingFactor(float factor, const ::Ice::Current& c)
147 {
148 fusionMethod->setAgingFactor(factor);
149 }
150
151 void
152 CommonPlacesLearner::setMergingThreshold(float threshold, const ::Ice::Current& c)
153 {
154 assMethod->setThreshold(threshold);
155 }
156
157 void
158 CommonPlacesLearner::learnFromObjectMCA(const ::std::string& objectName,
159 const MultivariateNormalDistributionBasePtr& posDist,
160 const ::Ice::Current& c)
161 {
162 ObjectInstancePtr instance = new ObjectInstance(objectName);
163 FloatVector mean = posDist->getMean();
165 posMean->x = mean[0];
166 posMean->y = mean[1];
167 posMean->z = mean[2];
168 posMean->frame = "world";
169
170 instance->setPosition(posMean);
171 instance->setPositionUncertainty(posDist);
172 instance->setClass(objectName, 1.);
173 learnFromObject(instance, c);
174 }
175
176 void
177 CommonPlacesLearner::learnFromObject(const ObjectInstanceBasePtr& newObject,
178 const ::Ice::Current& c)
179 {
180 const std::string clsName = newObject->getMostProbableClass();
181 ObjectInstanceList ltmObjects = ltmInstancesSegmentPrx->getObjectInstancesByClass(clsName);
182
183 if (ltmObjects.size() > 0)
184 {
185 for (ObjectInstanceList::const_iterator it = ltmObjects.begin(); it != ltmObjects.end();
186 ++it)
187 {
188 // update existing
189 EntityBasePtr fusedEntity = fusionMethod->fuseEntity(*it, newObject);
190 ltmInstancesSegmentPrx->updateEntity((*it)->getId(), fusedEntity);
191 ARMARX_INFO << "Updated existing object: " << newObject->getName();
192 }
193 }
194 else
195 {
196 // add new
197 EntityBasePtr fusedEntity = fusionMethod->initEntity(newObject);
198 ltmInstancesSegmentPrx->addEntity(fusedEntity);
199 ARMARX_INFO << "Added new object: " << newObject->getName();
200 }
201 }
202
203 void
204 CommonPlacesLearner::learnFromSnapshot(const ::std::string& snapshotName,
205 const ::Ice::Current& c)
206 {
207 WorkingMemorySnapshotInterfacePrx snapshot =
208 longtermMemoryPrx->openWorkingMemorySnapshot(snapshotName);
209
210 if (!snapshot)
211 {
212 ARMARX_ERROR << "Snapshot not found in LTM: " << snapshotName;
213 return;
214 }
215
216 const std::string segmentName = "objectInstances";
217 PersistentEntitySegmentBasePrx segInstances = snapshot->getSegment(segmentName);
218
219 if (!segInstances)
220 {
221 ARMARX_ERROR << "Segment not found in snapshot: " << segmentName;
222 return;
223 }
224
225 EntityIdList ids = segInstances->getAllEntityIds();
226
227 // sort ids
228 struct EntityIdComparator
229 {
230 static bool
231 compare(const std::string& i, const std::string& j)
232 {
233 return (std::stoi(i) < std::stoi(j));
234 }
235 };
236
237 std::sort(ids.begin(), ids.end(), EntityIdComparator::compare);
238
239 for (EntityIdList::const_iterator it = ids.begin(); it != ids.end(); ++it)
240 {
241 EntityBasePtr snapEntity = segInstances->getEntityById(*it);
242 ObjectInstancePtr snapInstance = ObjectInstancePtr::dynamicCast(snapEntity);
243 learnFromObject(snapInstance, c);
244 }
245
246 ARMARX_INFO << "Processing complete! Learned objects: " << ids.size();
247 }
248
249 GaussianMixtureDistributionBasePtr
250 CommonPlacesLearner::getPositionFull(const ::std::string& className, const ::Ice::Current&)
251 {
252 GaussianMixtureDistributionBasePtr result = GaussianMixtureDistributionBasePtr();
253
254 // get subclassses
255 NameList clsList;
256 clsList.push_back(className);
257 getChildClasses(className, clsList);
258
259 // get all objects with the class specified or one of its subclasses
260 ObjectInstanceList ltmObjects =
261 ltmInstancesSegmentPrx->getObjectInstancesByClassList(clsList);
262 ObjectInstanceList::const_iterator it = ltmObjects.begin();
263
264 if (it != ltmObjects.end())
265 {
266 ObjectInstancePtr inst = ObjectInstancePtr::dynamicCast(*it);
268 inst->getPositionAttribute()->getUncertainty());
269
270 for (++it; it != ltmObjects.end(); ++it)
271 {
272 inst = ObjectInstancePtr::dynamicCast(*it);
273 ARMARX_VERBOSE_S << "Processing " << inst->getName() << "..." << std::endl;
275 inst->getPositionAttribute()->getUncertainty()));
276 }
277 }
278
279 return result;
280 }
281
282 GaussianMixtureDistributionBasePtr
284 Ice::Int compCount,
285 const ::Ice::Current&)
286 {
287 GaussianMixtureDistributionBasePtr fullGMM = getPositionFull(objectName);
288
289 if (fullGMM)
290 {
291 return gmmReducer->reduceByComponentCount(fullGMM, compCount);
292 }
293 else
294 {
295 return GaussianMixtureDistributionBasePtr();
296 }
297 }
298
299 GaussianMixtureDistributionBasePtr
301 Ice::Float maxDeviation,
302 DeviationType devType,
303 const ::Ice::Current&)
304 {
305 GaussianMixtureDistributionBasePtr fullGMM = getPositionFull(objectName);
306
307 if (fullGMM)
308 {
309 return gmmReducer->reduceByMaxDeviation(
310 fullGMM, maxDeviation, convertDeviationType(devType));
311 }
312 else
313 {
314 return GaussianMixtureDistributionBasePtr();
315 }
316 }
317
318 NormalDistributionBasePtr
319 CommonPlacesLearner::getPositionAsGaussian(const ::std::string& objectName,
320 const ::Ice::Current&)
321 {
322 GaussianMixtureDistributionBasePtr result =
324
325 if (result)
326 {
327 return result->getModalComponent().gaussian;
328 }
329 else
330 {
331 return NormalDistributionBasePtr();
332 }
333 }
334
335 Cluster3DList
337 Ice::Int compCount,
338 const ::Ice::Current& c)
339 {
340 GaussianMixtureDistributionBasePtr reducedGMM =
341 getPositionReducedByComponentCount(objectName, compCount, c);
342 return gmmToClusterList(reducedGMM);
343 }
344
345 Cluster3DList
347 Ice::Float maxDeviation,
348 DeviationType devType,
349 const ::Ice::Current& c)
350 {
351 GaussianMixtureDistributionBasePtr reducedGMM =
352 getPositionReducedByMaxDeviation(objectName, maxDeviation, devType, c);
353 return gmmToClusterList(reducedGMM);
354 }
355
356 // the only purpose of this dummy function is to avoid making GaussianMixtureHelpers an Ice component or
357 // making it dependent from CommonPlacesLearner
359 CommonPlacesLearner::convertDeviationType(DeviationType devType)
360 {
361 switch (devType)
362 {
363 case eDevAABB:
364 return eAABB;
365
366 case eDevOrientedBBox:
367 return eOrientedBBox;
368
369 case eDevEqualSphere:
370 return eEqualSphere;
371
372 default:
373 return eAABB;
374 }
375 }
376
377 Cluster3DList
378 CommonPlacesLearner::gmmToClusterList(const GaussianMixtureDistributionBasePtr& gmm)
379 {
380 Cluster3DList result;
381
382 if (gmm)
383 {
384 for (int i = 0; i < gmm->size(); ++i)
385 {
386 Cluster3D cluster;
387 FloatVector compMean = gmm->getComponent(i).gaussian->getMean();
388 cluster.center.x = compMean[0];
389 cluster.center.y = compMean[1];
390 cluster.center.z = compMean[2];
391 cluster.weight = gmm->getComponent(i).weight;
392 result.push_back(cluster);
393 }
394 }
395
396 // sort cluster with descending weights
397 struct Cluster3DComparator
398 {
399 static bool
400 compare(const Cluster3D& i, const Cluster3D& j)
401 {
402 return (i.weight > j.weight);
403 }
404 };
405
406 std::sort(result.begin(), result.end(), Cluster3DComparator::compare);
407
408 return result;
409 }
410
411 void
412 CommonPlacesLearner::getChildClasses(std::string className, NameList& result)
413 {
414 ObjectClassList children =
415 priorKnowledgePrx->getObjectClassesSegment()->getChildClasses(className);
416 ARMARX_VERBOSE_S << "Found " << children.size() << " subclasses for " << className
417 << std::endl;
418
419 for (ObjectClassList::const_iterator it = children.begin(); it != children.end(); ++it)
420 {
421 result.push_back((*it)->getName());
422 }
423 }
424} // namespace memoryx
425
#define ARMARX_REGISTER_COMPONENT_EXECUTABLE(ComponentT, applicationName)
Definition Decoupled.h:29
constexpr T c
Property< PropertyType > getProperty(const std::string &name)
The FramedPosition class.
Definition FramedPose.h:158
bool usingProxy(const std::string &name, const std::string &endpoints="")
Registers a proxy for retrieval after initialization and adds it to the dependency list.
IceManagerPtr getIceManager() const
Returns the IceManager.
Ice::ObjectPrx getProxy(long timeoutMs=0, bool waitForScheduler=true) const
Returns the proxy of this object (optionally it waits for the proxy)
GaussianMixtureDistributionBasePtr getPositionFull(const ::std::string &objectName, const ::Ice::Current &=Ice::emptyCurrent) override
void onInitComponent() override
Pure virtual hook for the subclass.
Cluster3DList getPositionClustersByMaxDeviation(const ::std::string &objectName, Ice::Float maxDeviation, DeviationType devType, const ::Ice::Current &=Ice::emptyCurrent) override
void setLTMSegmentName(const ::std::string &segmentName, const ::Ice::Current &=Ice::emptyCurrent) override
void setMergingThreshold(float threshold, const ::Ice::Current &=Ice::emptyCurrent) override
void learnFromObjectMCA(const ::std::string &objectName, const MultivariateNormalDistributionBasePtr &posDist, const ::Ice::Current &=Ice::emptyCurrent) override
GaussianMixtureDistributionBasePtr getPositionReducedByComponentCount(const ::std::string &objectName, Ice::Int compCount, const ::Ice::Current &=Ice::emptyCurrent) override
void onConnectComponent() override
Pure virtual hook for the subclass.
void learnFromObject(const ObjectInstanceBasePtr &newObject, const ::Ice::Current &=Ice::emptyCurrent) override
GaussianMixtureDistributionBasePtr getPositionReducedByMaxDeviation(const ::std::string &objectName, Ice::Float maxDeviation, DeviationType devType, const ::Ice::Current &=Ice::emptyCurrent) override
void setAgingFactor(float factor, const ::Ice::Current &=Ice::emptyCurrent) override
void onExitComponent() override
Hook for subclass.
NormalDistributionBasePtr getPositionAsGaussian(const ::std::string &objectName, const ::Ice::Current &=Ice::emptyCurrent) override
void learnFromSnapshot(const ::std::string &snapshotName, const ::Ice::Current &=Ice::emptyCurrent) override
Cluster3DList getPositionClustersByComponentCount(const ::std::string &objectName, Ice::Int compCount, const ::Ice::Current &=Ice::emptyCurrent) override
static GaussianMixtureDistributionPtr FromProbabilityMeasure(const ProbabilityMeasureBasePtr &probMeasure)
Convert or approximate given ProbabilityMeasure to a gaussian mixture.
#define ARMARX_INFO
The normal logging level.
Definition Logging.h:181
#define ARMARX_FATAL
The logging level for unexpected behaviour, that will lead to a seriously malfunctioning program and ...
Definition Logging.h:199
#define ARMARX_ERROR
The logging level for unexpected behaviour, that must be fixed.
Definition Logging.h:196
#define ARMARX_VERBOSE_S
Definition Logging.h:207
#define ARMARX_WARNING
The logging level for unexpected behaviour, but not a serious problem.
Definition Logging.h:193
::IceInternal::Handle<::Ice::Communicator > CommunicatorPtr
Definition IceManager.h:49
int compare(const T &lhs, const T &rhs)
::std::vector<::Ice::Float > FloatVector
IceInternal::Handle< FramedPosition > FramedPositionPtr
Definition FramedPose.h:149
VirtualRobot headers.
std::shared_ptr< GMMDistance > GMMDistancePtr
Definition GMMDistance.h:62
IceInternal::Handle< ObjectInstance > ObjectInstancePtr
DeviationMeasure
Definition GMMReducer.h:38
@ eOrientedBBox
Definition GMMReducer.h:40
@ eEqualSphere
Definition GMMReducer.h:41
double distance(const Point &a, const Point &b)
Definition point.hpp:95