Prediction.h
Go to the documentation of this file.
1/*
2 * This file is part of ArmarX.
3 *
4 * ArmarX is free software; you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License version 2 as
6 * published by the Free Software Foundation.
7 *
8 * ArmarX is distributed in the hope that it will be useful, but
9 * WITHOUT ANY WARRANTY; without even the flied warranty of
10 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 * GNU General Public License for more details.
12 *
13 * You should have received a copy of the GNU General Public License
14 * along with this program. If not, see <http://www.gnu.org/licenses/>.
15 *
16 * @package RobotAPI::armem::core::base::detail
17 * @author phesch ( phesch at student dot kit dot edu )
18 * @date 2022
19 * @copyright http://www.gnu.org/licenses/gpl-2.0.txt
20 * GNU General Public License
21 */
22
23#pragma once
24
25#include <functional>
26
27#include <SimoxUtility/algorithm/get_map_keys_values.h>
28
30
37
39{
40
41 using Predictor = std::function<PredictionResult(const PredictionRequest&)>;
42
43 /**
44 * Can do predictions, but has no children it could delegate predictions to.
45 *
46 * This class is integrated with `armem::base::detail::Predictive`:
47 * If `DerivedT` is also a `Predictive`, the setters of this class also update the
48 * `Predictive` part.
49 */
50 template <class DerivedT>
52 {
54
55 Predictive*
56 _asPredictive()
57 {
58 return dynamic_cast<Predictive*>(&base::detail::derived<DerivedT>(this));
59 }
60
61 public:
62 explicit Prediction(const std::map<PredictionEngine, Predictor>& predictors = {})
63 {
64 this->setPredictors(predictors);
65 }
66
67 void
68 addPredictor(const PredictionEngine& engine, Predictor&& predictor)
69 {
70 _predictors.emplace(engine.engineID, predictor);
71
72 if (Predictive* predictive = this->_asPredictive())
73 {
74 predictive->addPredictionEngine(engine);
75 }
76 }
77
78 void
79 setPredictors(const std::map<PredictionEngine, Predictor>& predictors)
80 {
81 this->_predictors.clear();
82 for (const auto& [engine, predictor] : predictors)
83 {
84 _predictors.emplace(engine.engineID, predictor);
85 }
86
87 if (Predictive* predictive = this->_asPredictive())
88 {
89 predictive->setPredictionEngines(simox::alg::get_keys(predictors));
90 }
91 }
92
93 /**
94 * Resolves mapping of requests to predictors and dispatches them.
95 *
96 * In this case, the resolution is basically no-op because there are no children.
97 */
98 std::vector<PredictionResult>
99 dispatchPredictions(const std::vector<PredictionRequest>& requests)
100 {
101 const MemoryID ownID = base::detail::derived<DerivedT>(this).id();
102 std::vector<PredictionResult> results;
103 for (const auto& request : requests)
104 {
105 results.push_back(dispatchTargetedPrediction(request, ownID));
106 }
107 return results;
108 }
109
110 /**
111 * Dispatches a single prediction request (assuming resolution was done by the caller).
112 */
115 {
116 PredictionResult result;
117 result.snapshotID = request.snapshotID;
118
119 MemoryID ownID = base::detail::derived<DerivedT>(this).id();
120 if (ownID == target)
121 {
122 auto it = _predictors.find(request.predictionSettings.predictionEngineID);
123 if (it != _predictors.end())
124 {
125 const Predictor& predictor = it->second;
126 result = predictor(request);
127 }
128 else
129 {
130 result.success = false;
131 std::stringstream sstream;
132 sstream << "Could not dispatch prediction request for " << request.snapshotID
133 << " with engine '" << request.predictionSettings.predictionEngineID
134 << "' in " << ownID << ": Engine not registered.";
135 result.errorMessage = sstream.str();
136 }
137 }
138 else
139 {
140 result.success = false;
141 std::stringstream sstream;
142 sstream << "Could not dispatch prediction request for " << request.snapshotID
143 << " to " << target << " from " << ownID;
144 result.errorMessage = sstream.str();
145 }
146 return result;
147 }
148
149 private:
150 std::map<std::string, Predictor> _predictors; // NOLINT
151 };
152
153 /**
154 * Can do predictions itself and has children it could delegate predictions to.
155 */
156 template <class DerivedT>
157 class PredictionContainer : public Prediction<DerivedT>
158 {
159 public:
160 using Prediction<DerivedT>::Prediction;
161
162 explicit PredictionContainer(const std::map<PredictionEngine, Predictor>& predictors = {}) :
163 Prediction<DerivedT>(predictors)
164 {
165 }
166
167 std::vector<PredictionResult>
168 dispatchPredictions(const std::vector<PredictionRequest>& requests)
169 {
170 const auto& derivedThis = base::detail::derived<DerivedT>(this);
171 const std::map<MemoryID, std::vector<PredictionEngine>> engines =
172 derivedThis.getAllPredictionEngines();
173
174 std::vector<PredictionResult> results;
175 for (const PredictionRequest& request : requests)
176 {
177 PredictionResult& result = results.emplace_back();
178 result.snapshotID = request.snapshotID;
179
180 auto iter =
182 engines,
183 request.snapshotID,
184 [&request](const MemoryID& /*unused*/,
185 const std::vector<PredictionEngine>& supported) -> bool
186 {
187 for (const PredictionEngine& engine : supported)
188 {
189 if (engine.engineID ==
190 request.predictionSettings.predictionEngineID)
191 {
192 return true;
193 }
194 }
195 return false;
196 });
197
198 if (iter != engines.end())
199 {
200 const MemoryID& responsibleID = iter->first;
201
202 result = dispatchTargetedPrediction(request, responsibleID);
203 }
204 else
205 {
206 result.success = false;
207 std::stringstream sstream;
208 sstream << "Could not find segment offering prediction engine '"
209 << request.predictionSettings.predictionEngineID << "' for memory ID "
210 << request.snapshotID << ".";
211 result.errorMessage = sstream.str();
212 }
213 }
214 return results;
215 }
216
217 /**
218 * Semantics: This container or one of its children (target) is responsible
219 * for performing the prediction.
220 */
223 {
224 PredictionResult result;
225 result.snapshotID = request.snapshotID;
226
227 const auto& derivedThis = base::detail::derived<DerivedT>(this);
228 MemoryID ownID = derivedThis.id();
229 if (ownID == target)
230 {
231 // Delegate to base class.
232 result = Prediction<DerivedT>::dispatchTargetedPrediction(request, target);
233 }
234 // Check if of this' children is really responsible for the request.
235 else if (contains(ownID, target))
236 {
237 std::string childName = _getChildName(ownID, target);
238
239 // TODO(phesch): Looping over all the children just to find the one
240 // with the right name isn't nice, but it's the interface we've got.
241 // TODO(RainerKartmann): Try to add findChild() to loopup mixins.
242 typename DerivedT::ChildT* child = nullptr;
243 derivedThis.forEachChild(
244 [&child, &childName](auto& otherChild)
245 {
246 if (otherChild.name() == childName)
247 {
248 child = &otherChild;
249 }
250 });
251 if (child)
252 {
253 result = child->dispatchTargetedPrediction(request, target);
254 }
255 else
256 {
257 result.success = false;
258 std::stringstream sstream;
259 sstream << "Could not find memory item with ID " << target << ".";
260 result.errorMessage = sstream.str();
261 }
262 }
263 else
264 {
265 result.success = false;
266 std::stringstream sstream;
267 sstream << "Could not dispatch prediction request for " << request.snapshotID
268 << " to " << target << " from " << ownID << ".";
269 result.errorMessage = sstream.str();
270 }
271 return result;
272 }
273
274 private:
275 std::string
276 _getChildName(const MemoryID& parent, const MemoryID& child)
277 {
278 ARMARX_CHECK(armem::contains(parent, child));
279 ARMARX_CHECK(parent != child);
280
281 size_t parentLength = parent.getItems().size();
282
283 // Get iterator to first entry of child ID (memory).
284 std::vector<std::string> childItems = child.getItems();
285
286 int index = parentLength;
287 ARMARX_CHECK_FITS_SIZE(index, childItems.size());
288
289 return childItems[index];
290 }
291 };
292
293} // namespace armarx::armem::server::wm::detail
std::vector< std::string > getItems(bool escapeDelimiters=false) const
Get the levels from root to first not defined level (excluding).
Definition MemoryID.cpp:233
Something that supports a set of prediction engines.
Definition Predictive.h:38
PredictionResult dispatchTargetedPrediction(const PredictionRequest &request, const MemoryID &target)
Semantics: This container or one of its children (target) is responsible for performing the predictio...
Definition Prediction.h:222
std::vector< PredictionResult > dispatchPredictions(const std::vector< PredictionRequest > &requests)
Definition Prediction.h:168
PredictionContainer(const std::map< PredictionEngine, Predictor > &predictors={})
Definition Prediction.h:162
Can do predictions, but has no children it could delegate predictions to.
Definition Prediction.h:52
void setPredictors(const std::map< PredictionEngine, Predictor > &predictors)
Definition Prediction.h:79
void addPredictor(const PredictionEngine &engine, Predictor &&predictor)
Definition Prediction.h:68
PredictionResult dispatchTargetedPrediction(const PredictionRequest &request, const MemoryID &target)
Dispatches a single prediction request (assuming resolution was done by the caller).
Definition Prediction.h:114
std::vector< PredictionResult > dispatchPredictions(const std::vector< PredictionRequest > &requests)
Resolves mapping of requests to predictors and dispatches them.
Definition Prediction.h:99
Prediction(const std::map< PredictionEngine, Predictor > &predictors={})
Definition Prediction.h:62
#define ARMARX_CHECK_FITS_SIZE(number, size)
Check whether number is nonnegative (>= 0) and less than size.
#define ARMARX_CHECK(expression)
Shortcut for ARMARX_CHECK_EXPRESSION.
DerivedT & derived(ThisT *t)
Definition derived.h:8
std::function< PredictionResult(const PredictionRequest &)> Predictor
Definition Prediction.h:41
bool contains(const MemoryID &general, const MemoryID &specific)
Indicates whether general is "less specific" than, or equal to, specific, i.e.
Definition MemoryID.cpp:563
std::map< MemoryID, ValueT >::const_iterator findMostSpecificEntryContainingIDAnd(const std::map< MemoryID, ValueT > &idMap, const MemoryID &id, const std::function< bool(const MemoryID &, const ValueT &)> &predicate)
Find the entry with the most specific key that contains the given ID and satisfies the predicate,...
PredictionSettings predictionSettings
Definition Prediction.h:51