@@ -47,6 +47,7 @@ def rho_evol_riemann_fn(
47
47
wall_mask_j ,
48
48
n_w_j ,
49
49
g_ext_i ,
50
+ u_tilde_j ,
50
51
** kwargs ,
51
52
):
52
53
# Compute unit vector, above eq. (6), Zhang (2017)
@@ -56,18 +57,22 @@ def rho_evol_riemann_fn(
56
57
kernel_grad = kernel_fn .grad_w (d_ij ) * (e_ij )
57
58
58
59
# Compute average states eq. (6)/(12)/(13), Zhang (2017)
59
- u_L = jnp .where (wall_mask_j == 1 , jnp .dot (u_i , - n_w_j ), jnp .dot (u_i , - e_ij ))
60
+ u_L = jnp .where (
61
+ jnp .isin (wall_mask_j , wall_tags ), jnp .dot (u_i , - n_w_j ), jnp .dot (u_i , - e_ij )
62
+ )
60
63
p_L = p_i
61
64
rho_L = rho_i
62
65
63
66
# u_w from eq. (15), Yang (2020)
64
67
u_R = jnp .where (
65
- wall_mask_j == 1 ,
68
+ jnp . isin ( wall_mask_j , wall_tags ) ,
66
69
- u_L + 2 * jnp .dot (u_j , n_w_j ),
67
70
jnp .dot (u_j , - e_ij ),
68
71
)
69
- p_R = jnp .where (wall_mask_j == 1 , p_L + rho_L * jnp .dot (g_ext_i , - r_ij ), p_j )
70
- rho_R = jnp .where (wall_mask_j == 1 , eos .rho_fn (p_R ), rho_j )
72
+ p_R = jnp .where (
73
+ jnp .isin (wall_mask_j , wall_tags ), p_L + rho_L * jnp .dot (g_ext_i , - r_ij ), p_j
74
+ )
75
+ rho_R = jnp .where (jnp .isin (wall_mask_j , wall_tags ), eos .rho_fn (p_R ), rho_j )
71
76
72
77
U_avg = (u_L + u_R ) / 2
73
78
v_avg = (u_i + u_j ) / 2
@@ -197,6 +202,7 @@ def acceleration_fn_riemann(
197
202
mask ,
198
203
n_w_j ,
199
204
g_ext_i ,
205
+ u_tilde_j ,
200
206
):
201
207
# Compute unit vector, above eq. (6), Zhang (2017)
202
208
e_ij = e_s
@@ -206,18 +212,22 @@ def acceleration_fn_riemann(
206
212
kernel_grad = kernel_part_diff * (e_ij )
207
213
208
214
# Compute average states eq. (6)/(12)/(13), Zhang (2017)
209
- u_L = jnp .where (wall_mask_j == 1 , jnp .dot (u_i , - n_w_j ), jnp .dot (u_i , - e_ij ))
215
+ u_L = jnp .where (
216
+ jnp .isin (wall_mask_j , wall_tags ), jnp .dot (u_i , - n_w_j ), jnp .dot (u_i , - e_ij )
217
+ )
210
218
p_L = p_i
211
219
rho_L = rho_i
212
220
213
- # u_w from eq. (15), Yang (2020)
221
+ # u_w from eq. (15), Yang (2020)
214
222
u_R = jnp .where (
215
- wall_mask_j == 1 ,
223
+ jnp . isin ( wall_mask_j , wall_tags ) ,
216
224
- u_L + 2 * jnp .dot (u_j , n_w_j ),
217
225
jnp .dot (u_j , - e_ij ),
218
226
)
219
- p_R = jnp .where (wall_mask_j == 1 , p_L + rho_L * jnp .dot (g_ext_i , - r_ij ), p_j )
220
- rho_R = jnp .where (wall_mask_j == 1 , eos .rho_fn (p_R ), rho_j )
227
+ p_R = jnp .where (
228
+ jnp .isin (wall_mask_j , wall_tags ), p_L + rho_L * jnp .dot (g_ext_i , - r_ij ), p_j
229
+ )
230
+ rho_R = jnp .where (jnp .isin (wall_mask_j , wall_tags ), eos .rho_fn (p_R ), rho_j )
221
231
222
232
P_avg = (p_L + p_R ) / 2
223
233
rho_avg = (rho_L + rho_R ) / 2
@@ -227,16 +237,18 @@ def acceleration_fn_riemann(
227
237
eta_ij = 2 * eta_i * eta_j / (eta_i + eta_j + EPS )
228
238
229
239
# Compute Riemann states eq. (7) and (10), Zhang (2017)
230
- # u_R = jnp.where(
231
- # wall_mask_j == 1, -u_L - 2 * jnp.dot(v_j, -n_w_j), jnp.dot(v_j, -e_ij)
232
- # )
233
240
P_star = P_avg + 0.5 * rho_avg * (u_L - u_R ) * beta_fn (u_L , u_R , eta_limiter )
234
241
235
242
# pressure term with linear Riemann solver eq. (9), Zhang (2017)
236
243
eq_9 = - 2 * m_j * (P_star / (rho_i * rho_j )) * kernel_grad
237
244
238
245
# viscosity term eq. (6), Zhang (2019)
239
- v_ij = u_i - u_j
246
+ u_d = 2 * u_j - u_tilde_j
247
+ v_ij = jnp .where (
248
+ jnp .isin (wall_mask_j , wall_tags ),
249
+ u_i - u_d ,
250
+ u_i - u_j ,
251
+ )
240
252
eq_6 = 2 * m_j * eta_ij / (rho_i * rho_j ) * v_ij / (d_ij + EPS )
241
253
eq_6 *= kernel_part_diff * mask
242
254
@@ -388,11 +400,21 @@ def gwbc_fn_riemann_wrapper(is_free_slip, is_heat_conduction):
388
400
389
401
def free_weight (fluid_mask_i , tag_i ):
390
402
return fluid_mask_i
403
+
404
+ def riemann_velocities (u , w_dist , fluid_mask , i_s , j_s , N ):
405
+ return u
391
406
else :
392
407
393
408
def free_weight (fluid_mask_i , tag_i ):
394
409
return jnp .ones_like (tag_i )
395
410
411
+ def riemann_velocities (u , w_dist , fluid_mask , i_s , j_s , N ):
412
+ w_dist_fluid = w_dist * fluid_mask [j_s ]
413
+ u_wall_nom = ops .segment_sum (w_dist_fluid [:, None ] * u [j_s ], i_s , N )
414
+ u_wall_denom = ops .segment_sum (w_dist_fluid , i_s , N )
415
+ u_tilde = u_wall_nom / (u_wall_denom [:, None ] + EPS )
416
+ return u_tilde
417
+
396
418
if is_heat_conduction :
397
419
398
420
def heat_bc (mask_j_s_fluid , w_dist , temperature , i_s , j_s , tag , N ):
@@ -410,7 +432,7 @@ def heat_bc(mask_j_s_fluid, w_dist, temperature, i_s, j_s, tag, N):
410
432
def heat_bc (mask_j_s_fluid , w_dist , temperature , i_s , j_s , tag , N ):
411
433
return temperature
412
434
413
- return free_weight , heat_bc
435
+ return free_weight , riemann_velocities , heat_bc
414
436
415
437
416
438
def limiter_fn_wrapper (eta_limiter , c_ref ):
@@ -503,9 +525,11 @@ def __init__(
503
525
self ._kernel_fn = SuperGaussianKernel (h = dx , dim = dim )
504
526
505
527
self ._gwbc_fn = gwbc_fn_wrapper (is_free_slip , is_heat_conduction , eos )
506
- self ._free_weight , self ._heat_bc = gwbc_fn_riemann_wrapper (
507
- is_free_slip , is_heat_conduction
508
- )
528
+ (
529
+ self ._free_weight ,
530
+ self ._riemann_velocities ,
531
+ self ._heat_bc ,
532
+ ) = gwbc_fn_riemann_wrapper (is_free_slip , is_heat_conduction )
509
533
self ._acceleration_tvf_fn = acceleration_tvf_fn_wrapper (self ._kernel_fn )
510
534
self ._acceleration_riemann_fn = acceleration_riemann_fn_wrapper (
511
535
self ._kernel_fn , eos , _beta_fn , eta_limiter
@@ -572,6 +596,10 @@ def forward(state, neighbors):
572
596
)
573
597
n_w = jnp .where (jnp .absolute (n_w ) < EPS , 0.0 , n_w )
574
598
599
+ ##### Riemann velocity BCs
600
+ if self .is_bc_trick and (self .solver == "RIE" ):
601
+ u_tilde = self ._riemann_velocities (u , w_dist , fluid_mask , i_s , j_s , N )
602
+
575
603
##### Density summation or evolution
576
604
577
605
# update evolution
@@ -598,6 +626,7 @@ def forward(state, neighbors):
598
626
wall_mask [j_s ],
599
627
n_w [j_s ],
600
628
g_ext [i_s ],
629
+ u_tilde [j_s ],
601
630
)
602
631
drhodt = ops .segment_sum (temp , i_s , N ) * fluid_mask
603
632
rho = rho + self .dt * drhodt
@@ -687,6 +716,7 @@ def forward(state, neighbors):
687
716
mask ,
688
717
n_w [j_s ],
689
718
g_ext [i_s ],
719
+ u_tilde [j_s ],
690
720
)
691
721
dudt = ops .segment_sum (out , i_s , N )
692
722
0 commit comments