Skip to content

Don't merge: static args and hacks #731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 143 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 140 commits
Commits
Show all changes
143 commits
Select commit Hold shift + click to select a range
f346ef1
Use new itir.Program everywhere
tehrengruber Nov 14, 2024
a80a1e0
Merge remote-tracking branch 'origin/main' into update_to_gtir
tehrengruber Dec 6, 2024
61e97e4
Use gt4py main again
tehrengruber Dec 6, 2024
416d7e7
fix connectivities
havogt Dec 9, 2024
38a162c
switch gt4py branch
havogt Dec 9, 2024
21c495d
fix more connectivities
edopao Dec 20, 2024
0e8fcba
fix more connectivities (1)
edopao Dec 20, 2024
d784012
Merge remote-tracking branch 'origin/main' into update_to_gtir
edopao Dec 20, 2024
7917c51
update versions (temporarily)
havogt Jan 8, 2025
d3933a4
Update pyproject.toml
havogt Jan 16, 2025
804e110
Merge remote-tracking branch 'upstream/main' into update_to_gtir
havogt Jan 27, 2025
b02a659
lock
havogt Jan 27, 2025
7a532be
point to new branch, fix table access
havogt Jan 27, 2025
2bc3a08
Merge origin/main
tehrengruber Feb 4, 2025
81eeebe
Merge remote-tracking branch 'refs/remotes/origin/update_to_gtir' int…
tehrengruber Feb 4, 2025
afd6c94
Small fixes
tehrengruber Feb 5, 2025
376274a
Merge remote-tracking branch 'origin/main' into update_to_gtir
tehrengruber Feb 14, 2025
b36330b
Small fixes
tehrengruber Feb 14, 2025
424c3ff
Upgrade ubuntu image used in CI (required for gt4py)
tehrengruber Feb 14, 2025
87dc687
Use staging branch
tehrengruber Feb 14, 2025
40c6523
Fix dependency for ubuntu 22.04
tehrengruber Feb 14, 2025
0855efc
Small fies
tehrengruber Feb 14, 2025
de4d301
Small fixes
tehrengruber Feb 14, 2025
230f7f4
Small fixes
tehrengruber Feb 14, 2025
7f0900b
Small fixes
tehrengruber Feb 14, 2025
414154a
Fix uv.lock
tehrengruber Feb 14, 2025
fc3d94e
Use dummy (or evil?) values for symbolic domain sizes
tehrengruber Feb 14, 2025
660d924
Fix dycore_tests.test_solve_nonhydro.test_run_solve_nonhydro_multi_st…
tehrengruber Feb 27, 2025
6abe1b7
add concat_where from gtir-dace branch
edopao Feb 27, 2025
54ca640
add concat_where from gtir-dace branch (1)
edopao Feb 27, 2025
ca5207f
change concat condition to range interval
edopao Feb 27, 2025
2b84bbf
change concat condition to range interval (1)
edopao Feb 27, 2025
c4a34c3
Use concat_where branch
tehrengruber Mar 2, 2025
40c15b2
Merge remote-tracking branch 'origin/concat_where' into concat_where
tehrengruber Mar 2, 2025
b841d6d
update uv lock
edopao Mar 3, 2025
723f0d3
apply concat_where in all places
edopao Mar 3, 2025
e3f3907
update testdata path to d126
edopao Mar 3, 2025
31442c1
fix previous commit
edopao Mar 3, 2025
3e70e08
fix previous commit
edopao Mar 3, 2025
85379dd
update uv lock
edopao Mar 3, 2025
a4ce301
fix previous commit
edopao Mar 3, 2025
a4761b9
update uv lock
edopao Mar 3, 2025
52a8895
Update
tehrengruber Mar 3, 2025
740c931
Update uv.lock
tehrengruber Mar 3, 2025
da70bf2
disable advection stencil tests
edopao Mar 4, 2025
b98355d
remove unsupported cuda device type CUDA_MANAGED
edopao Mar 5, 2025
ad41a53
build: enhance nox and CI configuration (#652)
egparedes Feb 14, 2025
04e1d35
CI: Fix for nox environment variables (#671)
edopao Feb 17, 2025
4e22608
CI: read env variable from os environment, not from nox (#672)
edopao Feb 17, 2025
039b0fe
fix ci config
edopao Mar 6, 2025
2d2399d
Merge remote-tracking branch 'origin/main' into update_to_gtir
edopao Mar 6, 2025
f49c977
remove unsupported cuda device type CUDA_MANAGED
edopao Mar 5, 2025
c3ec0d1
switch to gt4py main
edopao Mar 6, 2025
3111c92
Merge remote-tracking branch 'origin/update_to_gtir' into concat_where
edopao Mar 6, 2025
ddab65f
remove test marker requires_concat_where
edopao Mar 6, 2025
0a7dc1e
disable advection stencil tests
edopao Mar 6, 2025
e822d93
Merge remote-tracking branch 'origin/update_to_gtir' into concat_where
edopao Mar 6, 2025
7bf830a
Trial
tehrengruber Mar 11, 2025
725c54e
Simplify velocity advection
havogt Mar 18, 2025
de5d5be
re-introduce halo protection
havogt Mar 19, 2025
3507815
Merge branch 'main' into update_to_gtir
tehrengruber Mar 19, 2025
730eaf5
Merge remote-tracking branch 'origin/update_to_gtir' into concat_where
edopao Mar 19, 2025
12fb0e9
Merge remote-tracking branch 'origin/simplify_vel_adv' into concat_where
edopao Mar 19, 2025
16e18b6
use concat_where in new combined stencils
edopao Mar 19, 2025
98f50e9
Merge remote-tracking branch 'origin/main' into update_to_gtir
edopao Mar 21, 2025
a3d9eb6
update gt4py version in uv lock
edopao Mar 21, 2025
6d115df
Merge remote-tracking branch 'origin/update_to_gtir' into concat_where
edopao Mar 21, 2025
2d73307
update gt4py version in uv lock
edopao Mar 21, 2025
5b20325
Remove additional horizontal bounds
muellch Mar 21, 2025
bb36bdd
Merge remote-tracking branch 'origin/simplify_vel_adv' into concat_where
edopao Mar 21, 2025
e8af1a6
fix concat_where
edopao Mar 21, 2025
b6e5499
Revert "Remove additional horizontal bounds"
edopao Mar 21, 2025
cf7facf
Revert "re-introduce halo protection"
edopao Mar 21, 2025
f50f9b4
Revert "Simplify velocity advection"
edopao Mar 21, 2025
b5b19df
switch to gt4py branch icon4py_staging
edopao Mar 21, 2025
fc389ab
split concat_where on horizontal and vertical dimension
edopao Mar 21, 2025
5e31996
add gt4py_cache to gitignore
edopao Mar 21, 2025
fd62612
formatting
edopao Mar 21, 2025
a54d201
switch to gt4py branch icon4py_staging
edopao Mar 21, 2025
30c4964
review comments
edopao Mar 21, 2025
fe32197
add gt4py_cache to gitignore
edopao Mar 21, 2025
3a5724d
Dim -> field for where
havogt Mar 21, 2025
74fac87
add concat_where to horizontal
havogt Mar 21, 2025
cd7a077
remove unused args
halungge Mar 21, 2025
8007095
Merge branch 'concat_where' of github.com:C2SM/icon4py into concat_where
halungge Mar 21, 2025
477462d
fix velocity advection data tests
halungge Mar 21, 2025
6a258f0
fix infinite domain
havogt Mar 21, 2025
29e4d8d
Merge branch 'concat_where_horizontal' into concat_where
havogt Mar 21, 2025
5200394
cleanup horizontal unused index fields
havogt Mar 21, 2025
a6d90d1
cleanup more index fields
havogt Mar 21, 2025
9aa4ba0
fix more tests
havogt Mar 22, 2025
a6acc25
missing file
havogt Mar 22, 2025
5ac079a
Merge branch 'update_to_gtir' into concat_where
havogt Mar 22, 2025
050aa9f
gt4py update
havogt Mar 22, 2025
ee95d84
fix a few tests
havogt Mar 23, 2025
763e56d
excluded embedded diffusion tests
havogt Mar 23, 2025
2a0d122
Apply suggestions from code review
havogt Mar 24, 2025
f5ca897
format and remove requires_concat_where
havogt Mar 24, 2025
86b6a5c
switch to icon4py staging tag
havogt Mar 24, 2025
f626b3b
Update pyproject.toml
havogt Mar 24, 2025
584a848
update uv lock
havogt Mar 24, 2025
3025e35
freeze
havogt Mar 25, 2025
8b49059
add connectivities
havogt Mar 25, 2025
d63d9de
fix typos
havogt Mar 25, 2025
c3b5d2a
fix connectivities and diffusion init
havogt Mar 25, 2025
8eecaf1
fix typo
havogt Mar 25, 2025
8082dfa
fix offset providers
havogt Mar 25, 2025
ea0f291
fix typos
havogt Mar 25, 2025
c61cffe
fix connectivities
havogt Mar 25, 2025
e421629
freeze diffusion vn
havogt Mar 25, 2025
6da4239
compile VelocityAdvection first
havogt Mar 25, 2025
da2e4c6
Merge remote-tracking branch 'upstream/main' into update_to_gtir_froz…
havogt Mar 29, 2025
3d64d9d
change offset providers
havogt Apr 4, 2025
3d85eea
compile() diffusion and vel advection
havogt Apr 4, 2025
13d685e
solve_nonhydro compile
havogt Apr 4, 2025
7d093fb
Merge remote-tracking branch 'upstream/main' into update_to_gtir_froz…
havogt Apr 4, 2025
74d8914
fixes
havogt Apr 4, 2025
6ddda55
add back scale_k
havogt Apr 7, 2025
acccf0b
config
havogt Oct 29, 2024
c02a458
add more timers
havogt Oct 30, 2024
c6dd86c
remove first timer
havogt Nov 12, 2024
6c081ab
use variants
havogt Apr 7, 2025
0e66cf7
update gt4py branch
havogt Apr 8, 2025
0637e56
update gt4py branch
havogt Apr 9, 2025
f9f953a
Merge remote-tracking branch 'upstream/main' into use_precompile_perf…
havogt Apr 28, 2025
2b06253
point to gt4py branch
havogt Apr 28, 2025
69c0153
fix rename of nrdmax
havogt Apr 28, 2025
d66e270
fix precompile syntax
havogt Apr 28, 2025
8ddb330
fix list in vel adv
havogt Apr 28, 2025
a72622b
hack -1s
havogt Apr 28, 2025
d0d927f
update gt4py commit
havogt Apr 28, 2025
6c995e5
Merge branch 'use_precompile_perf_measurements' of github.com:C2SM/ic…
havogt Apr 28, 2025
027f938
Merge branch 'use_precompile_perf_measurements' of github.com:C2SM/ic…
havogt Apr 28, 2025
35ad4f3
use program in init_nabla2_factor_in_upper_damping_zone
havogt Apr 28, 2025
14e7a3a
as gtx.int32
havogt Apr 28, 2025
32ed276
make vn_incr optional
havogt Apr 28, 2025
2657d9b
make vn_incr optional
havogt Apr 28, 2025
9a66b96
update uv.lock
havogt Apr 28, 2025
7a925ab
update gt4py branch
havogt Apr 29, 2025
61a7982
less static args
havogt Apr 29, 2025
68877ac
Merge remote-tracking branch 'origin/main' into use_precompile_perf_m…
edopao May 5, 2025
f7a2b25
rename offset_provider_type -> offset_provider
edopao May 5, 2025
002d43c
Merge remote-tracking branch 'origin/main' into use_precompile_perf_m…
edopao May 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,10 @@ def __init__(
self._edge_params = edge_params
self._cell_params = cell_params

import cupy as cp
# assert (m := cp.amin(self._grid.offset_providers["E2C"].ndarray)) >= 0, m
# assert (m := cp.amin(self._grid.offset_providers["E2C2E"].ndarray)) >= 0, m

self.halo_exchange_wait = decomposition.create_halo_exchange_wait(
self._exchange
) # wait on a communication handle
Expand All @@ -396,38 +400,80 @@ def __init__(

self.mo_intp_rbf_rbf_vec_interpol_vertex = mo_intp_rbf_rbf_vec_interpol_vertex.with_backend(
self._backend
).compile(
vertical_start=[0],
vertical_end=[self._grid.num_levels],
offset_provider_type=self._grid.offset_providers,
)
self.calculate_nabla2_and_smag_coefficients_for_vn = (
calculate_nabla2_and_smag_coefficients_for_vn.with_backend(self._backend)
calculate_nabla2_and_smag_coefficients_for_vn.with_backend(self._backend).compile(
vertical_start=[0],
vertical_end=[self._grid.num_levels],
offset_provider_type=self._grid.offset_providers,
)
)

self.calculate_diagnostic_quantities_for_turbulence = (
calculate_diagnostic_quantities_for_turbulence.with_backend(self._backend)
calculate_diagnostic_quantities_for_turbulence.with_backend(self._backend).compile(
vertical_start=[1],
vertical_end=[self._grid.num_levels],
offset_provider_type=self._grid.offset_providers,
)
)
self.apply_diffusion_to_vn = apply_diffusion_to_vn.with_backend(self._backend).compile(
limited_area=[self._grid.limited_area],
vertical_start=[0],
vertical_end=[self._grid.num_levels],
offset_provider_type=self._grid.offset_providers,
)
self.apply_diffusion_to_vn = apply_diffusion_to_vn.with_backend(self._backend)
self.apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence = (
apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence.with_backend(
self._backend
).compile(
type_shear=[int32(self.config.shear_type.value)],
nrdmax=[int32(self._vertical_grid.end_index_of_damping_layer + 1)],
vertical_start=[0],
vertical_end=[self._grid.num_levels],
offset_provider_type=self._grid.offset_providers,
)
)
self.calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools = (
calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools.with_backend(
self._backend
).compile(
vertical_start=[(self._grid.num_levels - 2)],
vertical_end=[self._grid.num_levels],
offset_provider_type=self._grid.offset_providers,
)
)
self.calculate_nabla2_for_theta = calculate_nabla2_for_theta.with_backend(self._backend)
self.calculate_nabla2_for_theta = calculate_nabla2_for_theta.with_backend(
self._backend
).compile(
vertical_start=[0],
vertical_end=[self._grid.num_levels],
offset_provider_type=self._grid.offset_providers,
)
self.truly_horizontal_diffusion_nabla_of_theta_over_steep_points = (
truly_horizontal_diffusion_nabla_of_theta_over_steep_points.with_backend(self._backend)
)
self.update_theta_and_exner = update_theta_and_exner.with_backend(self._backend)
self.copy_field = copy_field.with_backend(self._backend)
self.scale_k = scale_k.with_backend(self._backend)
).compile(
vertical_start=[0],
vertical_end=[self._grid.num_levels],
offset_provider_type=self._grid.offset_providers,
)
self.update_theta_and_exner = update_theta_and_exner.with_backend(self._backend).compile(
vertical_start=[0],
vertical_end=[self._grid.num_levels],
offset_provider_type={},
)
self.copy_field = copy_field.with_backend(self._backend).compile(offset_provider_type={})
self.scale_k = scale_k.with_backend(self._backend).compile(offset_provider_type={})
self.setup_fields_for_initial_step = setup_fields_for_initial_step.with_backend(
self._backend
)
).compile(offset_provider_type={})

self.init_diffusion_local_fields_for_regular_timestep = (
init_diffusion_local_fields_for_regular_timestep.with_backend(self._backend)
)
).compile(offset_provider_type={"Koff": dims.KDim})

self._allocate_temporary_fields()

Expand All @@ -452,16 +498,17 @@ def __init__(
offset_provider={"Koff": dims.KDim},
)

diffusion_utils._init_nabla2_factor_in_upper_damping_zone.with_backend(self._backend)(
diffusion_utils.init_nabla2_factor_in_upper_damping_zone.with_backend(self._backend)(
physical_heights=self._vertical_grid.interface_physical_height,
diff_multfac_n2w=self.diff_multfac_n2w,
end_index_of_damping_layer=self._vertical_grid.end_index_of_damping_layer,
nshift=0,
heights_nrd_shift=self._vertical_grid.interface_physical_height.ndarray[
self._vertical_grid.end_index_of_damping_layer + 1
].item(),
heights_1=self._vertical_grid.interface_physical_height.ndarray[1].item(),
domain={dims.KDim: (1, self._vertical_grid.end_index_of_damping_layer + 1)},
out=self.diff_multfac_n2w,
vertical_start=gtx.int32(1),
vertical_end=gtx.int32(self._vertical_grid.end_index_of_damping_layer + 1),
offset_provider={},
)

Expand Down Expand Up @@ -635,15 +682,10 @@ def _do_diffusion_step(
smag_offset:

"""
# dtime dependent: enh_smag_factor,
self.scale_k.with_connectivities(self.compile_time_connectivities)(
self.enh_smag_fac, dtime, self.diff_multfac_smag, offset_provider={}
)
self.scale_k(self.enh_smag_fac, dtime, self.diff_multfac_smag, offset_provider={})

log.debug("rbf interpolation 1: start")
self.mo_intp_rbf_rbf_vec_interpol_vertex.with_connectivities(
self.compile_time_connectivities
)(
self.mo_intp_rbf_rbf_vec_interpol_vertex(
p_e_in=prognostic_state.vn,
ptr_coeff_1=self._interpolation_state.rbf_coeff_1,
ptr_coeff_2=self._interpolation_state.rbf_coeff_2,
Expand All @@ -668,9 +710,7 @@ def _do_diffusion_step(
log.debug("communication rbf extrapolation of vn - end")

log.debug("running stencil 01(calculate_nabla2_and_smag_coefficients_for_vn): start")
self.calculate_nabla2_and_smag_coefficients_for_vn.with_connectivities(
self.compile_time_connectivities
)(
self.calculate_nabla2_and_smag_coefficients_for_vn(
diff_multfac_smag=self.diff_multfac_smag,
tangent_orientation=self._edge_params.tangent_orientation,
inv_primal_edge_length=self._edge_params.inverse_primal_edge_lengths,
Expand Down Expand Up @@ -702,9 +742,7 @@ def _do_diffusion_step(
log.debug(
"running stencils 02 03 (calculate_diagnostic_quantities_for_turbulence): start"
)
self.calculate_diagnostic_quantities_for_turbulence.with_connectivities(
self.compile_time_connectivities
)(
self.calculate_diagnostic_quantities_for_turbulence(
kh_smag_ec=self.kh_smag_ec,
vn=prognostic_state.vn,
e_bln_c_s=self._interpolation_state.e_bln_c_s,
Expand All @@ -731,9 +769,7 @@ def _do_diffusion_step(
log.debug("communication rbf extrapolation of z_nable2_e - end")

log.debug("2nd rbf interpolation: start")
self.mo_intp_rbf_rbf_vec_interpol_vertex.with_connectivities(
self.compile_time_connectivities
)(
self.mo_intp_rbf_rbf_vec_interpol_vertex(
p_e_in=self.z_nabla2_e,
ptr_coeff_1=self._interpolation_state.rbf_coeff_1,
ptr_coeff_2=self._interpolation_state.rbf_coeff_2,
Expand All @@ -758,7 +794,7 @@ def _do_diffusion_step(
log.debug("communication rbf extrapolation of z_nable2_e - end")

log.debug("running stencils 04 05 06 (apply_diffusion_to_vn): start")
self.apply_diffusion_to_vn.with_connectivities(self.compile_time_connectivities)(
self.apply_diffusion_to_vn(
u_vert=self.u_vert,
v_vert=self.v_vert,
primal_normal_vert_v1=self._edge_params.primal_normal_vert[0],
Expand Down Expand Up @@ -790,13 +826,9 @@ def _do_diffusion_step(
"running stencils 07 08 09 10 (apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence): start"
)
# TODO (magdalena) get rid of this copying. So far passing an empty buffer instead did not verify?
self.copy_field.with_connectivities(self.compile_time_connectivities)(
prognostic_state.w, self.w_tmp, offset_provider={}
)
self.copy_field(prognostic_state.w, self.w_tmp, offset_provider={})

self.apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence.with_connectivities(
self.compile_time_connectivities
)(
self.apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence(
area=self._cell_params.area,
geofac_n2s=self._interpolation_state.geofac_n2s,
geofac_grg_x=self._interpolation_state.geofac_grg_x,
Expand Down Expand Up @@ -830,9 +862,7 @@ def _do_diffusion_step(
"running fused stencils 11 12 (calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools): start"
)

self.calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools.with_connectivities(
self.compile_time_connectivities
)(
self.calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools(
theta_v=prognostic_state.theta_v,
theta_ref_mc=self._metric_state.theta_ref_mc,
thresh_tdiff=self.thresh_tdiff,
Expand All @@ -849,7 +879,7 @@ def _do_diffusion_step(
)

log.debug("running stencils 13 14 (calculate_nabla2_for_theta): start")
self.calculate_nabla2_for_theta.with_connectivities(self.compile_time_connectivities)(
self.calculate_nabla2_for_theta(
kh_smag_e=self.kh_smag_e,
inv_dual_edge_length=self._edge_params.inverse_dual_edge_lengths,
theta_v=prognostic_state.theta_v,
Expand All @@ -866,9 +896,7 @@ def _do_diffusion_step(
"running stencil 15 (truly_horizontal_diffusion_nabla_of_theta_over_steep_points): start"
)
if self.config.apply_zdiffusion_t:
self.truly_horizontal_diffusion_nabla_of_theta_over_steep_points.with_connectivities(
self.compile_time_connectivities
)(
self.truly_horizontal_diffusion_nabla_of_theta_over_steep_points(
mask=self._metric_state.mask_hdiff,
zd_vertoffset=self._metric_state.zd_vertoffset,
zd_diffcoef=self._metric_state.zd_diffcoef,
Expand All @@ -888,7 +916,7 @@ def _do_diffusion_step(
"running fused stencil 15 (truly_horizontal_diffusion_nabla_of_theta_over_steep_points): end"
)
log.debug("running stencil 16 (update_theta_and_exner): start")
self.update_theta_and_exner.with_connectivities(self.compile_time_connectivities)(
self.update_theta_and_exner(
z_temp=self.z_temp,
area=self._cell_params.area,
theta_v=prognostic_state.theta_v,
Expand Down
Loading