8
8
#include " wavemap/core/indexing/index_conversions.h"
9
9
#include " wavemap/core/utils/iterate/grid_iterator.h"
10
10
#include " wavemap/core/utils/shape/aabb.h"
11
+ #include " wavemap/core/utils/shape/intersection_tests.h"
11
12
12
13
namespace wavemap ::edit {
13
14
namespace detail {
15
+ template <typename MapT>
16
+ void sumNodeRecursive (
17
+ typename MapT::Block::OctreeType::NodeRefType node_A,
18
+ typename MapT::Block::OctreeType::NodeConstRefType node_B) {
19
+ using NodeRefType = decltype (node_A);
20
+ using NodeConstPtrType = typename MapT::Block::OctreeType::NodeConstPtrType;
21
+
22
+ // Sum
23
+ node_A.data () += node_B.data ();
24
+
25
+ // Recursively handle all child nodes
26
+ for (NdtreeIndexRelativeChild child_idx = 0 ;
27
+ child_idx < OctreeIndex::kNumChildren ; ++child_idx) {
28
+ NodeConstPtrType child_node_B = node_B.getChild (child_idx);
29
+ if (!child_node_B) {
30
+ continue ;
31
+ }
32
+ NodeRefType child_node_A = node_A.getOrAllocateChild (child_idx);
33
+ sumNodeRecursive<MapT>(child_node_A, *child_node_B);
34
+ }
35
+ }
36
+
14
37
template <typename MapT, typename SamplingFn>
15
38
void sumLeavesBatch (typename MapT::Block::OctreeType::NodeRefType node,
16
39
const OctreeIndex& node_index, FloatingPoint& node_value,
@@ -73,29 +96,125 @@ void sumNodeRecursive(typename MapT::Block::OctreeType::NodeRefType node,
73
96
node_value = new_value;
74
97
}
75
98
76
- template <typename MapT>
77
- void sumNodeRecursive (
78
- typename MapT::Block::OctreeType::NodeRefType node_A,
79
- typename MapT::Block::OctreeType::NodeConstRefType node_B) {
80
- using NodeRefType = decltype (node_A);
81
- using NodeConstPtrType = typename MapT::Block::OctreeType::NodeConstPtrType;
99
+ template <typename MapT, typename ShapeT>
100
+ void sumLeavesBatch (typename MapT::Block::OctreeType::NodeRefType node,
101
+ const OctreeIndex& node_index, FloatingPoint& node_value,
102
+ ShapeT&& mask, FloatingPoint summand,
103
+ FloatingPoint min_cell_width) {
104
+ // Decompress child values
105
+ using Transform = typename MapT::Block::Transform;
106
+ auto & node_details = node.data ();
107
+ auto child_values = Transform::backward ({node_value, {node_details}});
82
108
83
- // Sum
84
- node_A.data () += node_B.data ();
109
+ // Sum all children
110
+ for (NdtreeIndexRelativeChild child_idx = 0 ;
111
+ child_idx < OctreeIndex::kNumChildren ; ++child_idx) {
112
+ const OctreeIndex child_index = node_index.computeChildIndex (child_idx);
113
+ const Point3D t_W_child =
114
+ convert::nodeIndexToCenterPoint (child_index, min_cell_width);
115
+ if (shape::is_inside (t_W_child, mask)) {
116
+ child_values[child_idx] += summand;
117
+ }
118
+ }
85
119
86
- // Recursively handle all child nodes
120
+ // Compress
121
+ const auto [new_value, new_details] =
122
+ MapT::Block::Transform::forward (child_values);
123
+ node_details = new_details;
124
+ node_value = new_value;
125
+ }
126
+
127
+ template <typename MapT, typename ShapeT>
128
+ void sumNodeRecursive (typename MapT::Block::OctreeType::NodeRefType node,
129
+ const OctreeIndex& node_index, FloatingPoint& node_value,
130
+ ShapeT&& mask, FloatingPoint summand,
131
+ FloatingPoint min_cell_width,
132
+ IndexElement termination_height) {
133
+ using NodeRefType = decltype (node);
134
+
135
+ // Decompress child values
136
+ using Transform = typename MapT::Block::Transform;
137
+ auto & node_details = node.data ();
138
+ auto child_values = Transform::backward ({node_value, {node_details}});
139
+
140
+ // Handle each child
87
141
for (NdtreeIndexRelativeChild child_idx = 0 ;
88
142
child_idx < OctreeIndex::kNumChildren ; ++child_idx) {
89
- NodeConstPtrType child_node_B = node_B.getChild (child_idx);
90
- if (!child_node_B) {
143
+ // If the node is fully outside the shape, skip it
144
+ const OctreeIndex child_index = node_index.computeChildIndex (child_idx);
145
+ const AABB<Point3D> child_aabb =
146
+ convert::nodeIndexToAABB (child_index, min_cell_width);
147
+ if (!shape::overlaps (child_aabb, mask)) {
91
148
continue ;
92
149
}
93
- NodeRefType child_node_A = node_A.getOrAllocateChild (child_idx);
94
- sumNodeRecursive<MapT>(child_node_A, *child_node_B);
150
+
151
+ // If the node is fully inside the shape, sum at the current resolution
152
+ auto & child_value = child_values[child_idx];
153
+ if (shape::is_inside (child_aabb, mask)) {
154
+ child_value += summand;
155
+ continue ;
156
+ }
157
+
158
+ // Otherwise, continue at a higher resolution
159
+ NodeRefType child_node = node.getOrAllocateChild (child_idx);
160
+ if (child_index.height <= termination_height + 1 ) {
161
+ sumLeavesBatch<MapT>(child_node, child_index, child_value,
162
+ std::forward<ShapeT>(mask), summand, min_cell_width);
163
+ } else {
164
+ sumNodeRecursive<MapT>(child_node, child_index, child_value,
165
+ std::forward<ShapeT>(mask), summand,
166
+ min_cell_width, termination_height);
167
+ }
95
168
}
169
+
170
+ // Compress
171
+ const auto [new_value, new_details] = Transform::forward (child_values);
172
+ node_details = new_details;
173
+ node_value = new_value;
96
174
}
97
175
} // namespace detail
98
176
177
+ template <typename MapT>
178
+ void sum (MapT& map_A, const MapT& map_B,
179
+ const std::shared_ptr<ThreadPool>& thread_pool) {
180
+ CHECK_EQ (map_A.getTreeHeight (), map_B.getTreeHeight ());
181
+ CHECK_EQ (map_A.getMinCellWidth (), map_B.getMinCellWidth ());
182
+ using NodePtrType = typename MapT::Block::OctreeType::NodePtrType;
183
+ using NodeConstPtrType = typename MapT::Block::OctreeType::NodeConstPtrType;
184
+
185
+ // Process all blocks
186
+ map_B.forEachBlock (
187
+ [&map_A, &thread_pool](const Index3D& block_index, const auto & block_B) {
188
+ auto & block_A = map_A.getOrAllocateBlock (block_index);
189
+
190
+ // Indicate that the block has changed
191
+ block_A.setLastUpdatedStamp ();
192
+ block_A.setNeedsPruning ();
193
+
194
+ // Sum the blocks' average values (wavelet scale coefficient)
195
+ block_A.getRootScale () += block_B.getRootScale ();
196
+
197
+ // Recursively sum all node values (wavelet detail coefficients)
198
+ NodePtrType root_node_ptr_A = &block_A.getRootNode ();
199
+ NodeConstPtrType root_node_ptr_B = &block_B.getRootNode ();
200
+ if (thread_pool) {
201
+ thread_pool->add_task ([root_node_ptr_A, root_node_ptr_B,
202
+ block_ptr_A = &block_A]() {
203
+ detail::sumNodeRecursive<MapT>(*root_node_ptr_A, *root_node_ptr_B);
204
+ block_ptr_A->prune ();
205
+ });
206
+ } else {
207
+ detail::sumNodeRecursive<MapT>(*root_node_ptr_A, *root_node_ptr_B);
208
+ block_A.prune ();
209
+ }
210
+ });
211
+
212
+ // Wait for all parallel jobs to finish
213
+ if (thread_pool) {
214
+ thread_pool->wait_all ();
215
+ }
216
+ }
217
+
99
218
template <typename MapT, typename SamplingFn>
100
219
void sum (MapT& map, SamplingFn sampling_function,
101
220
const std::unordered_set<Index3D, IndexHash<3 >>& block_indices,
@@ -122,7 +241,7 @@ void sum(MapT& map, SamplingFn sampling_function,
122
241
NodePtrType root_node_ptr = &block.getRootNode ();
123
242
const OctreeIndex root_node_index{tree_height, block_index};
124
243
125
- // Recursively crop all nodes
244
+ // Recursively sum all nodes
126
245
if (thread_pool) {
127
246
thread_pool->add_task ([root_node_ptr, root_node_index, root_value_ptr,
128
247
block_ptr = &block,
@@ -147,55 +266,14 @@ void sum(MapT& map, SamplingFn sampling_function,
147
266
}
148
267
}
149
268
150
- template <typename MapT>
151
- void sum (MapT& map_A, const MapT& map_B,
152
- const std::shared_ptr<ThreadPool>& thread_pool) {
153
- CHECK_EQ (map_A.getTreeHeight (), map_B.getTreeHeight ());
154
- CHECK_EQ (map_A.getMinCellWidth (), map_B.getMinCellWidth ());
155
- using NodePtrType = typename MapT::Block::OctreeType::NodePtrType;
156
- using NodeConstPtrType = typename MapT::Block::OctreeType::NodeConstPtrType;
157
-
158
- // Process all blocks
159
- map_B.forEachBlock (
160
- [&map_A, &thread_pool](const Index3D& block_index, const auto & block_B) {
161
- auto & block_A = map_A.getOrAllocateBlock (block_index);
162
-
163
- // Indicate that the block has changed
164
- block_A.setLastUpdatedStamp ();
165
- block_A.setNeedsPruning ();
166
-
167
- // Sum the blocks' average values (wavelet scale coefficient)
168
- block_A.getRootScale () += block_B.getRootScale ();
169
-
170
- // Recursively sum all node values (wavelet detail coefficients)
171
- NodePtrType root_node_ptr_A = &block_A.getRootNode ();
172
- NodeConstPtrType root_node_ptr_B = &block_B.getRootNode ();
173
- if (thread_pool) {
174
- thread_pool->add_task ([root_node_ptr_A, root_node_ptr_B,
175
- block_ptr_A = &block_A]() {
176
- detail::sumNodeRecursive<MapT>(*root_node_ptr_A, *root_node_ptr_B);
177
- block_ptr_A->prune ();
178
- });
179
- } else {
180
- detail::sumNodeRecursive<MapT>(*root_node_ptr_A, *root_node_ptr_B);
181
- block_A.prune ();
182
- }
183
- });
184
-
185
- // Wait for all parallel jobs to finish
186
- if (thread_pool) {
187
- thread_pool->wait_all ();
188
- }
189
- }
190
-
191
269
template <typename MapT, typename ShapeT>
192
- void sum (MapT& map, ShapeT shape , FloatingPoint update ,
270
+ void sum (MapT& map, ShapeT mask , FloatingPoint summand ,
193
271
const std::shared_ptr<ThreadPool>& thread_pool) {
194
272
// Find the blocks that overlap with the shape
195
273
const FloatingPoint block_width =
196
274
convert::heightToCellWidth (map.getMinCellWidth (), map.getTreeHeight ());
197
275
const FloatingPoint block_width_inv = 1 .f / block_width;
198
- const auto aabb = static_cast <AABB<Point3D>>(shape );
276
+ const auto aabb = static_cast <AABB<Point3D>>(mask );
199
277
const Index3D block_index_min =
200
278
convert::pointToFloorIndex (aabb.min , block_width_inv);
201
279
const Index3D block_index_max =
@@ -205,14 +283,47 @@ void sum(MapT& map, ShapeT shape, FloatingPoint update,
205
283
block_indices.emplace (block_index);
206
284
}
207
285
208
- // Add the update to all cells whose centers lie inside the shape
209
- auto sampling_function = [&shape, update](const Point3D& t_A_p) {
210
- if (shape.contains (t_A_p)) {
211
- return update;
286
+ // Make sure all overlapping blocks have been allocated
287
+ for (const Index3D& block_index : block_indices) {
288
+ map.getOrAllocateBlock (block_index);
289
+ }
290
+
291
+ // Apply the sum to each overlapping block
292
+ using NodePtrType = typename MapT::Block::OctreeType::NodePtrType;
293
+ const IndexElement tree_height = map.getTreeHeight ();
294
+ const FloatingPoint min_cell_width = map.getMinCellWidth ();
295
+ for (const Index3D& block_index : block_indices) {
296
+ // Indicate that the block has changed
297
+ auto & block = *CHECK_NOTNULL (map.getBlock (block_index));
298
+ block.setLastUpdatedStamp ();
299
+ block.setNeedsPruning ();
300
+
301
+ // Get pointers to the root value and node, which contain the wavelet
302
+ // scale and detail coefficients, respectively
303
+ FloatingPoint* root_value_ptr = &block.getRootScale ();
304
+ NodePtrType root_node_ptr = &block.getRootNode ();
305
+ const OctreeIndex root_node_index{tree_height, block_index};
306
+
307
+ // Recursively sum all nodes
308
+ if (thread_pool) {
309
+ thread_pool->add_task ([root_node_ptr, root_node_index, root_value_ptr,
310
+ block_ptr = &block, &mask, summand,
311
+ min_cell_width]() mutable {
312
+ detail::sumNodeRecursive<MapT>(*root_node_ptr, root_node_index,
313
+ *root_value_ptr, mask, summand,
314
+ min_cell_width, 0 );
315
+ });
316
+ } else {
317
+ detail::sumNodeRecursive<MapT>(*root_node_ptr, root_node_index,
318
+ *root_value_ptr, mask, summand,
319
+ min_cell_width, 0 );
212
320
}
213
- return 0 .f ;
214
- };
215
- sum (map, sampling_function, block_indices, 0 , thread_pool);
321
+ }
322
+
323
+ // Wait for all parallel jobs to finish
324
+ if (thread_pool) {
325
+ thread_pool->wait_all ();
326
+ }
216
327
}
217
328
} // namespace wavemap::edit
218
329
0 commit comments