Skip to content

Commit c19f5bb

Browse files
Speed up masked sums using multi-resolution
1 parent 3d81e74 commit c19f5bb

File tree

4 files changed

+206
-76
lines changed

4 files changed

+206
-76
lines changed

examples/cpp/edit/crop_map.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ int main(int, char**) {
2020
// Crop the map
2121
const Point3D t_W_center{-2.2, -1.4, 0.0};
2222
const FloatingPoint radius = 3.0;
23+
const Sphere cropping_sphere{t_W_center, radius};
2324
auto thread_pool = std::make_shared<ThreadPool>(); // Optional
24-
edit::crop_to_sphere(*map, t_W_center, radius, 0, thread_pool);
25+
edit::crop(*map, cropping_sphere, 0, thread_pool);
2526

2627
// Save the map
2728
const std::filesystem::path output_map_path =

examples/cpp/edit/sum_map.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ int main(int, char**) {
2424
// Crop the map
2525
const Point3D t_W_center{-2.2, -1.4, 0.0};
2626
const FloatingPoint radius = 3.0;
27+
const Sphere cropping_sphere{t_W_center, radius};
2728
auto thread_pool = std::make_shared<ThreadPool>(); // Optional
28-
edit::crop_to_sphere(*map, t_W_center, radius, 0, thread_pool);
29+
edit::crop(*map, cropping_sphere, 0, thread_pool);
2930

3031
// Create a translated copy
3132
Transformation3D T_AB;

library/cpp/include/wavemap/core/utils/edit/impl/sum_inl.h

Lines changed: 175 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,32 @@
88
#include "wavemap/core/indexing/index_conversions.h"
99
#include "wavemap/core/utils/iterate/grid_iterator.h"
1010
#include "wavemap/core/utils/shape/aabb.h"
11+
#include "wavemap/core/utils/shape/intersection_tests.h"
1112

1213
namespace wavemap::edit {
1314
namespace detail {
15+
template <typename MapT>
16+
void sumNodeRecursive(
17+
typename MapT::Block::OctreeType::NodeRefType node_A,
18+
typename MapT::Block::OctreeType::NodeConstRefType node_B) {
19+
using NodeRefType = decltype(node_A);
20+
using NodeConstPtrType = typename MapT::Block::OctreeType::NodeConstPtrType;
21+
22+
// Sum
23+
node_A.data() += node_B.data();
24+
25+
// Recursively handle all child nodes
26+
for (NdtreeIndexRelativeChild child_idx = 0;
27+
child_idx < OctreeIndex::kNumChildren; ++child_idx) {
28+
NodeConstPtrType child_node_B = node_B.getChild(child_idx);
29+
if (!child_node_B) {
30+
continue;
31+
}
32+
NodeRefType child_node_A = node_A.getOrAllocateChild(child_idx);
33+
sumNodeRecursive<MapT>(child_node_A, *child_node_B);
34+
}
35+
}
36+
1437
template <typename MapT, typename SamplingFn>
1538
void sumLeavesBatch(typename MapT::Block::OctreeType::NodeRefType node,
1639
const OctreeIndex& node_index, FloatingPoint& node_value,
@@ -73,29 +96,125 @@ void sumNodeRecursive(typename MapT::Block::OctreeType::NodeRefType node,
7396
node_value = new_value;
7497
}
7598

76-
template <typename MapT>
77-
void sumNodeRecursive(
78-
typename MapT::Block::OctreeType::NodeRefType node_A,
79-
typename MapT::Block::OctreeType::NodeConstRefType node_B) {
80-
using NodeRefType = decltype(node_A);
81-
using NodeConstPtrType = typename MapT::Block::OctreeType::NodeConstPtrType;
99+
template <typename MapT, typename ShapeT>
100+
void sumLeavesBatch(typename MapT::Block::OctreeType::NodeRefType node,
101+
const OctreeIndex& node_index, FloatingPoint& node_value,
102+
ShapeT&& mask, FloatingPoint summand,
103+
FloatingPoint min_cell_width) {
104+
// Decompress child values
105+
using Transform = typename MapT::Block::Transform;
106+
auto& node_details = node.data();
107+
auto child_values = Transform::backward({node_value, {node_details}});
82108

83-
// Sum
84-
node_A.data() += node_B.data();
109+
// Sum all children
110+
for (NdtreeIndexRelativeChild child_idx = 0;
111+
child_idx < OctreeIndex::kNumChildren; ++child_idx) {
112+
const OctreeIndex child_index = node_index.computeChildIndex(child_idx);
113+
const Point3D t_W_child =
114+
convert::nodeIndexToCenterPoint(child_index, min_cell_width);
115+
if (shape::is_inside(t_W_child, mask)) {
116+
child_values[child_idx] += summand;
117+
}
118+
}
85119

86-
// Recursively handle all child nodes
120+
// Compress
121+
const auto [new_value, new_details] =
122+
MapT::Block::Transform::forward(child_values);
123+
node_details = new_details;
124+
node_value = new_value;
125+
}
126+
127+
template <typename MapT, typename ShapeT>
128+
void sumNodeRecursive(typename MapT::Block::OctreeType::NodeRefType node,
129+
const OctreeIndex& node_index, FloatingPoint& node_value,
130+
ShapeT&& mask, FloatingPoint summand,
131+
FloatingPoint min_cell_width,
132+
IndexElement termination_height) {
133+
using NodeRefType = decltype(node);
134+
135+
// Decompress child values
136+
using Transform = typename MapT::Block::Transform;
137+
auto& node_details = node.data();
138+
auto child_values = Transform::backward({node_value, {node_details}});
139+
140+
// Handle each child
87141
for (NdtreeIndexRelativeChild child_idx = 0;
88142
child_idx < OctreeIndex::kNumChildren; ++child_idx) {
89-
NodeConstPtrType child_node_B = node_B.getChild(child_idx);
90-
if (!child_node_B) {
143+
// If the node is fully outside the shape, skip it
144+
const OctreeIndex child_index = node_index.computeChildIndex(child_idx);
145+
const AABB<Point3D> child_aabb =
146+
convert::nodeIndexToAABB(child_index, min_cell_width);
147+
if (!shape::overlaps(child_aabb, mask)) {
91148
continue;
92149
}
93-
NodeRefType child_node_A = node_A.getOrAllocateChild(child_idx);
94-
sumNodeRecursive<MapT>(child_node_A, *child_node_B);
150+
151+
// If the node is fully inside the shape, sum at the current resolution
152+
auto& child_value = child_values[child_idx];
153+
if (shape::is_inside(child_aabb, mask)) {
154+
child_value += summand;
155+
continue;
156+
}
157+
158+
// Otherwise, continue at a higher resolution
159+
NodeRefType child_node = node.getOrAllocateChild(child_idx);
160+
if (child_index.height <= termination_height + 1) {
161+
sumLeavesBatch<MapT>(child_node, child_index, child_value,
162+
std::forward<ShapeT>(mask), summand, min_cell_width);
163+
} else {
164+
sumNodeRecursive<MapT>(child_node, child_index, child_value,
165+
std::forward<ShapeT>(mask), summand,
166+
min_cell_width, termination_height);
167+
}
95168
}
169+
170+
// Compress
171+
const auto [new_value, new_details] = Transform::forward(child_values);
172+
node_details = new_details;
173+
node_value = new_value;
96174
}
97175
} // namespace detail
98176

177+
template <typename MapT>
178+
void sum(MapT& map_A, const MapT& map_B,
179+
const std::shared_ptr<ThreadPool>& thread_pool) {
180+
CHECK_EQ(map_A.getTreeHeight(), map_B.getTreeHeight());
181+
CHECK_EQ(map_A.getMinCellWidth(), map_B.getMinCellWidth());
182+
using NodePtrType = typename MapT::Block::OctreeType::NodePtrType;
183+
using NodeConstPtrType = typename MapT::Block::OctreeType::NodeConstPtrType;
184+
185+
// Process all blocks
186+
map_B.forEachBlock(
187+
[&map_A, &thread_pool](const Index3D& block_index, const auto& block_B) {
188+
auto& block_A = map_A.getOrAllocateBlock(block_index);
189+
190+
// Indicate that the block has changed
191+
block_A.setLastUpdatedStamp();
192+
block_A.setNeedsPruning();
193+
194+
// Sum the blocks' average values (wavelet scale coefficient)
195+
block_A.getRootScale() += block_B.getRootScale();
196+
197+
// Recursively sum all node values (wavelet detail coefficients)
198+
NodePtrType root_node_ptr_A = &block_A.getRootNode();
199+
NodeConstPtrType root_node_ptr_B = &block_B.getRootNode();
200+
if (thread_pool) {
201+
thread_pool->add_task([root_node_ptr_A, root_node_ptr_B,
202+
block_ptr_A = &block_A]() {
203+
detail::sumNodeRecursive<MapT>(*root_node_ptr_A, *root_node_ptr_B);
204+
block_ptr_A->prune();
205+
});
206+
} else {
207+
detail::sumNodeRecursive<MapT>(*root_node_ptr_A, *root_node_ptr_B);
208+
block_A.prune();
209+
}
210+
});
211+
212+
// Wait for all parallel jobs to finish
213+
if (thread_pool) {
214+
thread_pool->wait_all();
215+
}
216+
}
217+
99218
template <typename MapT, typename SamplingFn>
100219
void sum(MapT& map, SamplingFn sampling_function,
101220
const std::unordered_set<Index3D, IndexHash<3>>& block_indices,
@@ -122,7 +241,7 @@ void sum(MapT& map, SamplingFn sampling_function,
122241
NodePtrType root_node_ptr = &block.getRootNode();
123242
const OctreeIndex root_node_index{tree_height, block_index};
124243

125-
// Recursively crop all nodes
244+
// Recursively sum all nodes
126245
if (thread_pool) {
127246
thread_pool->add_task([root_node_ptr, root_node_index, root_value_ptr,
128247
block_ptr = &block,
@@ -147,55 +266,14 @@ void sum(MapT& map, SamplingFn sampling_function,
147266
}
148267
}
149268

150-
template <typename MapT>
151-
void sum(MapT& map_A, const MapT& map_B,
152-
const std::shared_ptr<ThreadPool>& thread_pool) {
153-
CHECK_EQ(map_A.getTreeHeight(), map_B.getTreeHeight());
154-
CHECK_EQ(map_A.getMinCellWidth(), map_B.getMinCellWidth());
155-
using NodePtrType = typename MapT::Block::OctreeType::NodePtrType;
156-
using NodeConstPtrType = typename MapT::Block::OctreeType::NodeConstPtrType;
157-
158-
// Process all blocks
159-
map_B.forEachBlock(
160-
[&map_A, &thread_pool](const Index3D& block_index, const auto& block_B) {
161-
auto& block_A = map_A.getOrAllocateBlock(block_index);
162-
163-
// Indicate that the block has changed
164-
block_A.setLastUpdatedStamp();
165-
block_A.setNeedsPruning();
166-
167-
// Sum the blocks' average values (wavelet scale coefficient)
168-
block_A.getRootScale() += block_B.getRootScale();
169-
170-
// Recursively sum all node values (wavelet detail coefficients)
171-
NodePtrType root_node_ptr_A = &block_A.getRootNode();
172-
NodeConstPtrType root_node_ptr_B = &block_B.getRootNode();
173-
if (thread_pool) {
174-
thread_pool->add_task([root_node_ptr_A, root_node_ptr_B,
175-
block_ptr_A = &block_A]() {
176-
detail::sumNodeRecursive<MapT>(*root_node_ptr_A, *root_node_ptr_B);
177-
block_ptr_A->prune();
178-
});
179-
} else {
180-
detail::sumNodeRecursive<MapT>(*root_node_ptr_A, *root_node_ptr_B);
181-
block_A.prune();
182-
}
183-
});
184-
185-
// Wait for all parallel jobs to finish
186-
if (thread_pool) {
187-
thread_pool->wait_all();
188-
}
189-
}
190-
191269
template <typename MapT, typename ShapeT>
192-
void sum(MapT& map, ShapeT shape, FloatingPoint update,
270+
void sum(MapT& map, ShapeT mask, FloatingPoint summand,
193271
const std::shared_ptr<ThreadPool>& thread_pool) {
194272
// Find the blocks that overlap with the shape
195273
const FloatingPoint block_width =
196274
convert::heightToCellWidth(map.getMinCellWidth(), map.getTreeHeight());
197275
const FloatingPoint block_width_inv = 1.f / block_width;
198-
const auto aabb = static_cast<AABB<Point3D>>(shape);
276+
const auto aabb = static_cast<AABB<Point3D>>(mask);
199277
const Index3D block_index_min =
200278
convert::pointToFloorIndex(aabb.min, block_width_inv);
201279
const Index3D block_index_max =
@@ -205,14 +283,47 @@ void sum(MapT& map, ShapeT shape, FloatingPoint update,
205283
block_indices.emplace(block_index);
206284
}
207285

208-
// Add the update to all cells whose centers lie inside the shape
209-
auto sampling_function = [&shape, update](const Point3D& t_A_p) {
210-
if (shape.contains(t_A_p)) {
211-
return update;
286+
// Make sure all overlapping blocks have been allocated
287+
for (const Index3D& block_index : block_indices) {
288+
map.getOrAllocateBlock(block_index);
289+
}
290+
291+
// Apply the sum to each overlapping block
292+
using NodePtrType = typename MapT::Block::OctreeType::NodePtrType;
293+
const IndexElement tree_height = map.getTreeHeight();
294+
const FloatingPoint min_cell_width = map.getMinCellWidth();
295+
for (const Index3D& block_index : block_indices) {
296+
// Indicate that the block has changed
297+
auto& block = *CHECK_NOTNULL(map.getBlock(block_index));
298+
block.setLastUpdatedStamp();
299+
block.setNeedsPruning();
300+
301+
// Get pointers to the root value and node, which contain the wavelet
302+
// scale and detail coefficients, respectively
303+
FloatingPoint* root_value_ptr = &block.getRootScale();
304+
NodePtrType root_node_ptr = &block.getRootNode();
305+
const OctreeIndex root_node_index{tree_height, block_index};
306+
307+
// Recursively sum all nodes
308+
if (thread_pool) {
309+
thread_pool->add_task([root_node_ptr, root_node_index, root_value_ptr,
310+
block_ptr = &block, &mask, summand,
311+
min_cell_width]() mutable {
312+
detail::sumNodeRecursive<MapT>(*root_node_ptr, root_node_index,
313+
*root_value_ptr, mask, summand,
314+
min_cell_width, 0);
315+
});
316+
} else {
317+
detail::sumNodeRecursive<MapT>(*root_node_ptr, root_node_index,
318+
*root_value_ptr, mask, summand,
319+
min_cell_width, 0);
212320
}
213-
return 0.f;
214-
};
215-
sum(map, sampling_function, block_indices, 0, thread_pool);
321+
}
322+
323+
// Wait for all parallel jobs to finish
324+
if (thread_pool) {
325+
thread_pool->wait_all();
326+
}
216327
}
217328
} // namespace wavemap::edit
218329

library/cpp/include/wavemap/core/utils/edit/sum.h

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,37 +11,54 @@
1111

1212
namespace wavemap::edit {
1313
namespace detail {
14+
// Recursively sum two maps together
15+
template <typename MapT>
16+
void sumNodeRecursive(
17+
typename MapT::Block::OctreeType::NodeRefType node_A,
18+
typename MapT::Block::OctreeType::NodeConstRefType node_B);
19+
20+
// Recursively add a sampled value
1421
template <typename MapT, typename SamplingFn>
1522
void sumLeavesBatch(typename MapT::Block::OctreeType::NodeRefType node,
1623
const OctreeIndex& node_index, FloatingPoint& node_value,
1724
SamplingFn&& sampling_function,
1825
FloatingPoint min_cell_width);
19-
2026
template <typename MapT, typename SamplingFn>
2127
void sumNodeRecursive(typename MapT::Block::OctreeType::NodeRefType node,
2228
const OctreeIndex& node_index, FloatingPoint& node_value,
2329
SamplingFn&& sampling_function,
2430
FloatingPoint min_cell_width,
2531
IndexElement termination_height = 0);
2632

27-
template <typename MapT>
28-
void sumNodeRecursive(
29-
typename MapT::Block::OctreeType::NodeRefType node_A,
30-
typename MapT::Block::OctreeType::NodeConstRefType node_B);
33+
// Recursively add a scalar value to all cells within a given shape
34+
template <typename MapT, typename ShapeT>
35+
void sumLeavesBatch(typename MapT::Block::OctreeType::NodeRefType node,
36+
const OctreeIndex& node_index, FloatingPoint& node_value,
37+
ShapeT&& mask, FloatingPoint summand,
38+
FloatingPoint min_cell_width);
39+
template <typename MapT, typename ShapeT>
40+
void sumNodeRecursive(typename MapT::Block::OctreeType::NodeRefType node,
41+
const OctreeIndex& node_index, FloatingPoint& node_value,
42+
ShapeT&& mask, FloatingPoint summand,
43+
FloatingPoint min_cell_width,
44+
IndexElement termination_height = 0);
3145
} // namespace detail
3246

47+
// Sum two maps together
48+
template <typename MapT>
49+
void sum(MapT& map_A, const MapT& map_B,
50+
const std::shared_ptr<ThreadPool>& thread_pool = nullptr);
51+
52+
// Add a sampled value to all cells within a given list of blocks
3353
template <typename MapT, typename SamplingFn>
3454
void sum(MapT& map, SamplingFn sampling_function,
3555
const std::unordered_set<Index3D, IndexHash<3>>& block_indices,
3656
IndexElement termination_height = 0,
3757
const std::shared_ptr<ThreadPool>& thread_pool = nullptr);
3858

39-
template <typename MapT>
40-
void sum(MapT& map_A, const MapT& map_B,
41-
const std::shared_ptr<ThreadPool>& thread_pool = nullptr);
42-
59+
// Add a scalar value to all cells within a given shape
4360
template <typename MapT, typename ShapeT>
44-
void sum(MapT& map, ShapeT shape, FloatingPoint update,
61+
void sum(MapT& map, ShapeT mask, FloatingPoint summand,
4562
const std::shared_ptr<ThreadPool>& thread_pool = nullptr);
4663
} // namespace wavemap::edit
4764

0 commit comments

Comments
 (0)