24 #include <MemoryX/interface/core/EntityBase.h>
25 #include <MemoryX/interface/memorytypes/MemoryEntities.h>
59 longtermMemoryPrx = getProxy<LongtermMemoryInterfacePrx>(
"LongtermMemory");
60 priorKnowledgePrx = getProxy<PriorKnowledgeInterfacePrx>(
"PriorKnowledge");
62 std::string segmentName = getProperty<std::string>(
"LTMSegment").getValue();
65 if (!ltmInstancesSegmentPrx)
67 ARMARX_FATAL <<
"LTM segment not found or has an invalid type: " << segmentName;
70 const float agingFactor = getProperty<float>(
"AgingFactor").getValue();
71 const float pruningThreshold = getProperty<float>(
"PruningThreshold").getValue();
73 const float distanceThreshold = getProperty<float>(
"MergingThreshold").getValue();
74 const std::string distanceName = getProperty<std::string>(
"MergingDistanceType").getValue();
77 if (distanceName ==
"Mahalanobis")
81 else if (distanceName ==
"KL")
85 else if (distanceName ==
"ISD")
92 <<
". Will use Kullback-Leibler instead.";
99 const std::string reducerName = getProperty<std::string>(
"GMMReducerAlgorithm").getValue();
100 ARMARX_INFO <<
"Using GMM reduction algorithm: " << reducerName;
102 if (reducerName ==
"West")
106 else if (reducerName ==
"Runnalls")
110 else if (reducerName ==
"Williams")
117 <<
". Will use WestReducer instead.";
129 const ::Ice::Current&
c)
131 if (!ltmInstancesSegmentPrx || ltmInstancesSegmentPrx->getSegmentName() != segmentName)
133 ltmInstancesSegmentPrx =
134 longtermMemoryPrx->getCustomInstancesSegment(segmentName,
true);
141 fusionMethod->setAgingFactor(factor);
147 assMethod->setThreshold(threshold);
152 const MultivariateNormalDistributionBasePtr& posDist,
153 const ::Ice::Current&
c)
158 posMean->x =
mean[0];
159 posMean->y =
mean[1];
160 posMean->z =
mean[2];
161 posMean->frame =
"world";
163 instance->setPosition(posMean);
164 instance->setPositionUncertainty(posDist);
165 instance->setClass(objectName, 1.);
171 const ::Ice::Current&
c)
173 const std::string clsName = newObject->getMostProbableClass();
174 ObjectInstanceList ltmObjects = ltmInstancesSegmentPrx->getObjectInstancesByClass(clsName);
176 if (ltmObjects.size() > 0)
178 for (ObjectInstanceList::const_iterator it = ltmObjects.begin(); it != ltmObjects.end();
182 EntityBasePtr fusedEntity = fusionMethod->fuseEntity(*it, newObject);
183 ltmInstancesSegmentPrx->updateEntity((*it)->getId(), fusedEntity);
184 ARMARX_INFO <<
"Updated existing object: " << newObject->getName();
190 EntityBasePtr fusedEntity = fusionMethod->initEntity(newObject);
191 ltmInstancesSegmentPrx->addEntity(fusedEntity);
192 ARMARX_INFO <<
"Added new object: " << newObject->getName();
198 const ::Ice::Current&
c)
200 WorkingMemorySnapshotInterfacePrx snapshot =
201 longtermMemoryPrx->openWorkingMemorySnapshot(snapshotName);
205 ARMARX_ERROR <<
"Snapshot not found in LTM: " << snapshotName;
209 const std::string segmentName =
"objectInstances";
210 PersistentEntitySegmentBasePrx segInstances = snapshot->getSegment(segmentName);
214 ARMARX_ERROR <<
"Segment not found in snapshot: " << segmentName;
218 EntityIdList ids = segInstances->getAllEntityIds();
221 struct EntityIdComparator
224 compare(
const std::string& i,
const std::string& j)
226 return (std::stoi(i) < std::stoi(j));
232 for (EntityIdList::const_iterator it = ids.begin(); it != ids.end(); ++it)
234 EntityBasePtr snapEntity = segInstances->getEntityById(*it);
239 ARMARX_INFO <<
"Processing complete! Learned objects: " << ids.size();
242 GaussianMixtureDistributionBasePtr
245 GaussianMixtureDistributionBasePtr result = GaussianMixtureDistributionBasePtr();
249 clsList.push_back(className);
250 getChildClasses(className, clsList);
253 ObjectInstanceList ltmObjects =
254 ltmInstancesSegmentPrx->getObjectInstancesByClassList(clsList);
255 ObjectInstanceList::const_iterator it = ltmObjects.begin();
257 if (it != ltmObjects.end())
261 inst->getPositionAttribute()->getUncertainty());
263 for (++it; it != ltmObjects.end(); ++it)
265 inst = ObjectInstancePtr::dynamicCast(*it);
268 inst->getPositionAttribute()->getUncertainty()));
275 GaussianMixtureDistributionBasePtr
278 const ::Ice::Current&)
280 GaussianMixtureDistributionBasePtr fullGMM =
getPositionFull(objectName);
284 return gmmReducer->reduceByComponentCount(fullGMM, compCount);
288 return GaussianMixtureDistributionBasePtr();
292 GaussianMixtureDistributionBasePtr
295 DeviationType devType,
296 const ::Ice::Current&)
298 GaussianMixtureDistributionBasePtr fullGMM =
getPositionFull(objectName);
302 return gmmReducer->reduceByMaxDeviation(
303 fullGMM, maxDeviation, convertDeviationType(devType));
307 return GaussianMixtureDistributionBasePtr();
311 NormalDistributionBasePtr
313 const ::Ice::Current&)
315 GaussianMixtureDistributionBasePtr result =
320 return result->getModalComponent().gaussian;
324 return NormalDistributionBasePtr();
331 const ::Ice::Current&
c)
333 GaussianMixtureDistributionBasePtr reducedGMM =
335 return gmmToClusterList(reducedGMM);
341 DeviationType devType,
342 const ::Ice::Current&
c)
344 GaussianMixtureDistributionBasePtr reducedGMM =
346 return gmmToClusterList(reducedGMM);
352 CommonPlacesLearner::convertDeviationType(DeviationType devType)
359 case eDevOrientedBBox:
362 case eDevEqualSphere:
371 CommonPlacesLearner::gmmToClusterList(
const GaussianMixtureDistributionBasePtr& gmm)
373 Cluster3DList result;
377 for (
int i = 0; i < gmm->size(); ++i)
380 FloatVector compMean = gmm->getComponent(i).gaussian->getMean();
381 cluster.center.x = compMean[0];
382 cluster.center.y = compMean[1];
383 cluster.center.z = compMean[2];
384 cluster.weight = gmm->getComponent(i).weight;
385 result.push_back(cluster);
390 struct Cluster3DComparator
393 compare(
const Cluster3D& i,
const Cluster3D& j)
395 return (i.weight > j.weight);
405 CommonPlacesLearner::getChildClasses(std::string className, NameList& result)
407 ObjectClassList children =
408 priorKnowledgePrx->getObjectClassesSegment()->getChildClasses(className);
409 ARMARX_VERBOSE_S <<
"Found " << children.size() <<
" subclasses for " << className
412 for (ObjectClassList::const_iterator it = children.begin(); it != children.end(); ++it)
414 result.push_back((*it)->getName());