Skip to content

Commit 76da4ff

Browse files
committed
direct transition setup
1 parent e739511 commit 76da4ff

File tree

1 file changed

+326
-30
lines changed

1 file changed

+326
-30
lines changed

projects/CUDA/utils/Primitives.cpp

+326-30
Original file line numberDiff line numberDiff line change
@@ -1316,6 +1316,7 @@ struct ParticleSegmentation : zeno::INode {
13161316
exclusive_scan(pol, std::begin(numNeighbors), std::end(numNeighbors), std::begin(spmat._ptrs));
13171317

13181318
auto numEntries = spmat._ptrs[pos.size()];
1319+
13191320
spmat._inds.resize(numEntries);
13201321

13211322
pol(range(pos.size()),
@@ -2373,32 +2374,223 @@ struct AssociateParticles : INode {
23732374
auto &dstIndices = srcPrim->add_attr<int>(indexTag);
23742375

23752376
auto n = srcPrim->size();
2376-
auto m = dstPrim->size();
23772377

2378-
auto dims = compute_dimensions(*srcPrim, *dstPrim);
2379-
auto furthestDistance = std::sqrt(dims[0] * dims[0] + dims[1] * dims[1] + dims[2] * dims[2]);
2380-
auto N = std::max(n, m);
2378+
if (n) {
2379+
auto m = dstPrim->size();
2380+
2381+
auto dims = compute_dimensions(*srcPrim, *dstPrim);
2382+
auto furthestDistance = std::sqrt(dims[0] * dims[0] + dims[1] * dims[1] + dims[2] * dims[2]) * 1.1f;
2383+
auto N = std::max(n, m);
2384+
2385+
const auto &src = srcPrim->attr<vec3f>("pos");
2386+
const auto &dst = dstPrim->attr<vec3f>("pos");
2387+
2388+
KuhnMunkres km{(int)N, [&src, &dst, n, m, v = -furthestDistance](int i, int j) {
2389+
if (i < n && j < m)
2390+
return -length(src[i] - dst[j]);
2391+
else
2392+
return v;
2393+
}};
2394+
km.solve();
2395+
2396+
float refSum = 0.f;
2397+
for (int i = 0; i != n; ++i)
2398+
refSum += length(src[i] - dst[i]);
2399+
float curSum = 0.f;
2400+
for (int i = 0; i != n; ++i)
2401+
curSum += length(src[i] - dst[km.find_l[i]]);
2402+
fmt::print(fg(fmt::color::red), "ref: {}, calc: {}\n", refSum, curSum);
2403+
2404+
auto pol = zs::omp_exec();
2405+
pol(zs::range(n), [&](int i) {
2406+
int id = km.find_l[i];
2407+
if (id < m) {
2408+
dstIndices[i] = id;
2409+
dstPos[i] = dst[id];
2410+
} else {
2411+
dstIndices[i] = -1;
2412+
dstPos[i] = src[i];
2413+
}
2414+
});
2415+
}
2416+
set_output("srcPrim", std::move(srcPrim));
2417+
}
2418+
};
2419+
ZENDEFNODE(AssociateParticles, {
2420+
{{"PrimitiveObject", "srcPrim"},
2421+
{"string", "target_pos_tag", "target_pos"},
2422+
{"string", "target_index_tag", "target_index"},
2423+
{"PrimitiveObject", "dstPrim"}},
2424+
{{"PrimitiveObject", "srcPrim"}},
2425+
{},
2426+
{"zs_geom"},
2427+
});
23812428

2382-
const auto &src = srcPrim->attr<vec3f>("pos");
2383-
const auto &dst = dstPrim->attr<vec3f>("pos");
2429+
#if 0
2430+
struct SetupParticleTransition : INode {
2431+
void apply() override {
2432+
auto srcPars = get_input2<PrimitiveObject>("src_particles");
2433+
auto srcClusters = get_input2<PrimitiveObject>("src_clusters");
23842434

2385-
KuhnMunkres km{(int)N, [&src, &dst, n, m, v = -furthestDistance](int i, int j) {
2386-
if (i < n && j < m)
2387-
return -length(src[i] - dst[j]);
2388-
else
2389-
return v;
2390-
}};
2391-
km.solve();
2435+
auto dstPars = get_input2<PrimitiveObject>("dst_particles");
2436+
auto dstClusters = get_input2<PrimitiveObject>("dst_clusters");
23922437

2393-
float refSum = 0.f;
2394-
for (int i = 0; i != n; ++i)
2395-
refSum += length(src[i] - dst[i]);
2396-
float curSum = 0.f;
2397-
for (int i = 0; i != n; ++i)
2398-
curSum += length(src[i] - dst[km.find_l[i]]);
2399-
fmt::print(fg(fmt::color::red), "ref: {}, calc: {}\n", refSum, curSum);
2438+
auto prim = get_input2<PrimitiveObject>("anim_particles");
2439+
2440+
auto particleClusterIndexTag = get_input2<std::string>("particle_cluster_index_tag");
2441+
auto clusterTargetIndexTag = get_input2<std::string>("cluster_target_index_tag");
2442+
auto transTag = get_input2<std::string>("per_frame_translation_tag");
2443+
2444+
auto numTransFrames = get_input2<int>("num_transition_frames");
2445+
auto numFrames = get_input2<int>("num_animating_frames");
2446+
2447+
// sizes
2448+
auto nSrcPars = srcPars->size();
2449+
auto nSrcClusters = srcClusters->size();
2450+
auto nDstPars = dstPars->size();
2451+
auto nDstClusters = dstClusters->size();
2452+
2453+
auto nPars = std::max(nSrcPars, nDstPars);
2454+
auto nClusters = std::max(nSrcClusters, nDstClusters);
2455+
2456+
// attribs
2457+
// prim->resize(nPars);
2458+
// auto &pos = prim->attr<vec3f>("pos");
2459+
// auto &trans = prim->add_attr<vec3f>(transTag);
2460+
2461+
const auto &srcParClusterIds = srcPars->attr<float>(particleClusterIndexTag);
2462+
const auto &dstParClusterIds = dstPars->attr<float>(particleClusterIndexTag);
2463+
const auto &srcParPos = srcPars->attr<vec3f>("pos");
2464+
const auto &dstParPos = dstPars->attr<vec3f>("pos");
2465+
2466+
const auto &targetClusterIds = srcClusters->attr<float>(particleClusterIndexTag);
2467+
// const auto &srcClusterPos = srcClusters->attr<vec3f>("pos");
2468+
// const auto &dstClusterPos = dstClusters->attr<vec3f>("pos");
2469+
2470+
auto dims = compute_dimensions(*srcPars, *dstPars);
2471+
auto furthestDistance = std::sqrt(dims[0] * dims[0] + dims[1] * dims[1] + dims[2] * dims[2]) * 1.1f;
24002472

24012473
auto pol = zs::omp_exec();
2474+
2475+
struct P {
2476+
vec3f pos{0, 0, 0};
2477+
int dstPar{-1};
2478+
float rad{5}; // 30 is visible, 5 is barely visible
2479+
vec3f deltaP{0, 0, 0};
2480+
};
2481+
std::vector<std::vector<P>> parGrps(nSrcClusters); // for constructing result
2482+
2483+
using namespace zs;
2484+
std::vector<int> missingDstClusters(nDstClusters + 1, 1); // default no cover
2485+
pol(range(nSrcClusters), [&](int ci) {
2486+
int dstClusterId = targetClusterIds[ci];
2487+
if (dstClusterId >= 0)
2488+
missingDstClusters[dstClusterId] = 0;
2489+
});
2490+
std::vector<int> missingDstClusterOffsets(nDstClusters + 1);
2491+
exclusive_scan(pol, std::begin(missingDstClusters), std::end(missingDstClusters),
2492+
std::begin(missingDstClusterOffsets));
2493+
auto numTotalMissingClusters = missingDstClusterOffsets.back();
2494+
2495+
std::vector<int> parGrpSizes(nSrcClusters + numTotalMissingClusters);
2496+
std::vector<int> dstClusterSizes(nDstClusters);
2497+
2498+
std::vector<std::vector<int>> srcClusterIndices(nSrcClusters), dstClusterIndices(nDstClusters);
2499+
2500+
// prepare first half of [parGrps]
2501+
pol(range(nSrcPars), [&](int i) {
2502+
int ci = srcParClusterIds[i];
2503+
atomic_add(exec_omp, &parGrpSizes[ci], 1);
2504+
});
2505+
for (int i = 0; i < nSrcClusters; ++i) {
2506+
parGrps[i].resize(parGrpSizes[i]);
2507+
srcClusterIndices[i].resize(parGrpSizes[i]);
2508+
}
2509+
2510+
// prepare second half of [parGrps]
2511+
pol(range(nDstPars), [&](int i) {
2512+
int ci = dstParClusterIds[i];
2513+
2514+
atomic_add(exec_omp, &dstClusterSizes[ci], 1);
2515+
2516+
if (missingDstClusters[ci]) {
2517+
int id = missingDstClusterOffsets[ci] + nSrcClusters;
2518+
atomic_add(exec_omp, &parGrpSizes[id], 1);
2519+
}
2520+
});
2521+
for (int i = 0; i < numTotalMissingClusters; ++i)
2522+
parGrps[nSrcClusters + i].resize(parGrpSizes[nSrcClusters + i]);
2523+
for (int i = 0; i < nDstClusters; ++i)
2524+
dstClusterIndices[i].resize(dstClusterSizes[i]);
2525+
2526+
// init particle data
2527+
std::memset(parGrpSizes.data(), 0, sizeof(int) * parGrpSizes.size());
2528+
pol(range(nSrcPars), [&](int i) {
2529+
int id = srcParClusterIds[i];
2530+
auto offset = atomic_add(exec_omp, &parGrpSizes[id], 1);
2531+
parGrps[id][offset] = P{srcParPos[i], -1};
2532+
srcClusterIndices[id][offset] = i;
2533+
});
2534+
std::memset(dstClusterSizes.data(), 0, sizeof(int) * dstClusterSizes.size());
2535+
pol(range(nDstPars), [&](int i) {
2536+
int ci = dstParClusterIds[i];
2537+
auto offset = atomic_add(exec_omp, &dstClusterSizes[ci], 1);
2538+
dstClusterIndices[ci][offset] = i;
2539+
if (missingDstClusters[ci]) {
2540+
auto id = missingDstClusterOffsets[ci] + nSrcClusters;
2541+
// auto offset = atomic_add(exec_omp, &parGrpSizes[id], 1);
2542+
parGrps[id][offset] = P{dstParPos[i], -1};
2543+
}
2544+
});
2545+
2546+
/// compute first half
2547+
pol(range(nSrcClusters), [&](int ci) {
2548+
int dstClusterId = targetClusterIds[ci];
2549+
int n = parGrpSizes[ci]; // srcClusterIndices.size()
2550+
auto &grp = parGrps[ci];
2551+
if (dstClusterId >= 0) {
2552+
const auto &srcIndices = srcClusterIndices[ci];
2553+
const auto &dstIndices = dstClusterIndices[dstClusterId];
2554+
int m = dstClusterSizes[dstClusterId]; // dstClusterIndices.size()
2555+
int N = std::max(m, n);
2556+
KuhnMunkres km{(int)N, [&, v = -furthestDistance](int i, int j) {
2557+
if (i < n && j < m)
2558+
return -length(srcParPos[srcIndices[i]] - dstParPos[dstIndices[j]]);
2559+
else
2560+
return v;
2561+
}};
2562+
km.solve();
2563+
2564+
std::vector<int> dstPicked(m);
2565+
for (int i = 0; i != n; ++i) {
2566+
int j = km.find_l[i];
2567+
if (j < m) {
2568+
grp[i].deltaP = (dstParPos[dstIndices[j]] - srcParPos[srcIndices[i]]) / numTransFrames;
2569+
dstPicked[j] = 1;
2570+
} else {
2571+
// no longer required, to be removed when transition is done
2572+
grp[i].pos = srcParPos[srcIndices[i]];
2573+
grp[i].rad = 5;
2574+
}
2575+
}
2576+
for (int j = 0; j != m; ++j) {
2577+
if (!dstPicked[j])
2578+
// directly emerge at the destination
2579+
grp.push_back(P{dstParPos[dstIndices[j]], -1, 30});
2580+
}
2581+
} else {
2582+
for (int i = 0; i != n; ++i) {
2583+
// no longer required, to be removed when transition is done
2584+
grp[i].pos = srcParPos[srcIndices[i]];
2585+
grp[i].rad = 5;
2586+
}
2587+
}
2588+
});
2589+
2590+
#if 0
2591+
// update grps to prim
2592+
const auto &dst = dstPrim->attr<vec3f>("pos");
2593+
24022594
pol(zs::range(n), [&](int i) {
24032595
int id = km.find_l[i];
24042596
if (id < m) {
@@ -2409,18 +2601,122 @@ struct AssociateParticles : INode {
24092601
dstPos[i] = src[i];
24102602
}
24112603
});
2412-
set_output("srcPrim", std::move(srcPrim));
2604+
#endif
2605+
set_output("anim_particles", std::move(prim));
24132606
}
24142607
};
2415-
ZENDEFNODE(AssociateParticles, {
2416-
{{"PrimitiveObject", "srcPrim"},
2417-
{"string", "target_pos_tag", "target_pos"},
2418-
{"string", "target_index_tag", "target_index"},
2419-
{"PrimitiveObject", "dstPrim"}},
2420-
{{"PrimitiveObject", "srcPrim"}},
2421-
{},
2422-
{"zs_geom"},
2423-
});
2608+
ZENDEFNODE(SetupParticleTransition, {
2609+
{
2610+
{"PrimitiveObject", "src_particles"},
2611+
{"PrimitiveObject", "src_clusters"},
2612+
{"PrimitiveObject", "dst_particles"},
2613+
{"PrimitiveObject", "dst_clusters"},
2614+
{"string", "particle_cluster_index_tag", "segment_index"}, // for pars
2615+
{"string", "cluster_target_index_tag", "target_index"}, // for clusters
2616+
{"string", "per_frame_translation_tag", "frame_translation"},
2617+
{"int", "num_transition_frames", "20"},
2618+
{"int", "num_animating_frames", "100"},
2619+
{"PrimitiveObject", "anim_particles"},
2620+
},
2621+
{{"PrimitiveObject", "anim_particles"}},
2622+
{},
2623+
{"zs_geom"},
2624+
});
2625+
#endif
2626+
2627+
struct SetupParticleTransitionDirect : INode {
2628+
void apply() override {
2629+
auto srcPars = get_input2<PrimitiveObject>("src_particles");
2630+
2631+
auto dstPars = get_input2<PrimitiveObject>("dst_particles");
2632+
2633+
auto prim = get_input2<PrimitiveObject>("anim_particles");
2634+
2635+
auto indexTag = get_input2<std::string>("target_index_tag");
2636+
auto transTag = get_input2<std::string>("per_frame_translation_tag");
2637+
auto clrTransTag = get_input2<std::string>("per_frame_clr_trans_tag");
2638+
2639+
auto numTransFrames = get_input2<int>("num_transition_frames");
2640+
auto radius = get_input2<float>("rad");
2641+
2642+
// sizes
2643+
auto nSrcPars = srcPars->size();
2644+
auto nDstPars = dstPars->size();
2645+
2646+
auto nPars = std::max(nSrcPars, nDstPars);
2647+
2648+
// attribs
2649+
prim->resize(nPars);
2650+
auto &pos = prim->attr<vec3f>("pos");
2651+
auto &rads = prim->add_attr<float>("rad");
2652+
auto &clrs = prim->add_attr<vec3f>("clr");
2653+
auto &trans = prim->add_attr<vec3f>(transTag);
2654+
auto &clrTrans = prim->add_attr<vec3f>(clrTransTag);
2655+
2656+
const auto &dstIndices = srcPars->attr<int>(indexTag);
2657+
const auto &srcParPos = srcPars->attr<vec3f>("pos");
2658+
const auto &dstParPos = dstPars->attr<vec3f>("pos");
2659+
2660+
std::memcpy(pos.data(), srcParPos.data(), sizeof(vec3f) * srcParPos.size());
2661+
2662+
auto pol = zs::omp_exec();
2663+
2664+
using namespace zs;
2665+
2666+
std::fill(std::begin(rads), std::end(rads), radius);
2667+
2668+
const vec3f onColor{0, 1, 0};
2669+
const vec3f offColor{1, 0, 0};
2670+
2671+
std::vector<int> missingDstPars(nDstPars + 1, 1); // default no cover
2672+
pol(range(nSrcPars), [&](int i) {
2673+
int j = dstIndices[i];
2674+
if (j >= 0) {
2675+
auto xi = srcParPos[i];
2676+
auto xj = dstParPos[j];
2677+
trans[i] = (xj - xi) / numTransFrames;
2678+
clrs[i] = onColor;
2679+
clrTrans[i] = vec3f{0, 0, 0};
2680+
missingDstPars[j] = 0;
2681+
} else {
2682+
trans[i] = vec3f{0, 0, 0};
2683+
clrs[i] = offColor;
2684+
clrTrans[i] = offColor / (-numTransFrames); // towards full black
2685+
}
2686+
});
2687+
2688+
std::vector<int> missingDstParOffsets(nDstPars + 1);
2689+
exclusive_scan(pol, std::begin(missingDstPars), std::end(missingDstPars), std::begin(missingDstParOffsets));
2690+
auto numTotalMissingPars = missingDstParOffsets.back();
2691+
2692+
pol(range(nDstPars), [&](int j) {
2693+
if (missingDstPars[j]) {
2694+
int i = missingDstParOffsets[j] + nSrcPars;
2695+
pos[i] = dstParPos[j];
2696+
trans[i] = vec3f{0, 0, 0};
2697+
clrs[i] = vec3f{0, 0, 0};
2698+
clrTrans[i] = onColor / numTransFrames; // towards full on (green)
2699+
}
2700+
});
2701+
2702+
set_output("anim_particles", std::move(prim));
2703+
}
2704+
};
2705+
ZENDEFNODE(SetupParticleTransitionDirect, {
2706+
{
2707+
{"PrimitiveObject", "src_particles"},
2708+
{"PrimitiveObject", "dst_particles"},
2709+
{"string", "target_index_tag", "target_index"},
2710+
{"string", "per_frame_translation_tag", "frame_translation"},
2711+
{"string", "per_frame_clr_trans_tag", "trans_clr"},
2712+
{"float", "rad", "2"},
2713+
{"int", "num_transition_frames", "20"},
2714+
{"PrimitiveObject", "anim_particles"},
2715+
},
2716+
{{"PrimitiveObject", "anim_particles"}},
2717+
{},
2718+
{"zs_geom"},
2719+
});
24242720

24252721
struct AssociateParticlesFast : INode {
24262722
void apply() override {

0 commit comments

Comments
 (0)