Skip to content

Commit ed0043c

Browse files
committed
fix large beam tilt scaling problem
1 parent b74140e commit ed0043c

File tree

4 files changed

+15
-13
lines changed

4 files changed

+15
-13
lines changed

src/cgpu_fcns.cuh

+3-2
Original file line numberDiff line numberDiff line change
@@ -1867,12 +1867,13 @@ namespace mt
18671867
template <class TGrid, class TVector_c>
18681868
DEVICE_CALLABLE FORCE_INLINE
18691869
void exp_r_factor_2d(const int &ix, const int &iy, const TGrid &grid_2d,
1870-
const Value_type<TGrid> &gx, const Value_type<TGrid> &gy, TVector_c &psi_i, TVector_c &psi_o)
1870+
const Value_type<TGrid> &gx, const Value_type<TGrid> &gy, TVector_c &psi_i, TVector_c &psi_o, bool scaling)
18711871
{
18721872
const int ixy = grid_2d.ind_col(ix, iy);
18731873
const auto Rx = grid_2d.Rx_shift(ix)-grid_2d.Rx_c();
18741874
const auto Ry = grid_2d.Ry_shift(iy)-grid_2d.Ry_c();
1875-
psi_o[ixy] = psi_i[ixy]*euler(gx*Rx + gy*Ry)/grid_2d.nxy_r();
1875+
const auto scale = (scaling)?grid_2d.nxy_r():1;
1876+
psi_o[ixy] = psi_i[ixy]*euler(gx*Rx + gy*Ry)/scale;
18761877
}
18771878

18781879
template <class TGrid, class TVector_r, class TVector_c>

src/cpu_fcns.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -921,11 +921,11 @@ namespace mt
921921
template <class TGrid, class TVector_c>
922922
enable_if_host_vector<TVector_c, void>
923923
exp_r_factor_2d(Stream<e_host> &stream, TGrid &grid_2d, Value_type<TGrid> gx, Value_type<TGrid> gy,
924-
TVector_c &fPsi_i, TVector_c &fPsi_o)
924+
TVector_c &fPsi_i, TVector_c &fPsi_o, bool scaling=true)
925925
{
926926
stream.set_n_act_stream(grid_2d.nx);
927927
stream.set_grid(grid_2d.nx, grid_2d.ny);
928-
stream.exec_matrix(host_device_detail::exp_r_factor_2d<TGrid, TVector_c>, grid_2d, gx, gy, fPsi_i, fPsi_o);
928+
stream.exec_matrix(host_device_detail::exp_r_factor_2d<TGrid, TVector_c>, grid_2d, gx, gy, fPsi_i, fPsi_o, scaling);
929929
}
930930

931931
template <class TGrid, class TVector_c>

src/gpu_fcns.cuh

+4-4
Original file line numberDiff line numberDiff line change
@@ -649,14 +649,14 @@ namespace mt
649649
// phase factor 2d
650650
template <class TGrid, class T>
651651
__global__ void exp_r_factor_2d(TGrid grid_2d, Value_type<TGrid> gx,
652-
Value_type<TGrid> gy, rVector<T> psi_i, rVector<T> psi_o)
652+
Value_type<TGrid> gy, rVector<T> psi_i, rVector<T> psi_o, scaling)
653653
{
654654
int iy = threadIdx.x + blockIdx.x*blockDim.x;
655655
int ix = threadIdx.y + blockIdx.y*blockDim.y;
656656

657657
if((ix < grid_2d.nx) && (iy < grid_2d.ny))
658658
{
659-
host_device_detail::exp_r_factor_2d(ix, iy, grid_2d, gx, gy, psi_i, psi_o);
659+
host_device_detail::exp_r_factor_2d(ix, iy, grid_2d, gx, gy, psi_i, psi_o, scaling);
660660
}
661661
}
662662

@@ -1325,11 +1325,11 @@ namespace mt
13251325

13261326
template <class TGrid, class TVector_c>
13271327
enable_if_device_vector<TVector_c, void>
1328-
exp_r_factor_2d(Stream<e_device> &stream, TGrid &grid_2d, Value_type<TGrid> gx, Value_type<TGrid> gy, TVector_c &fPsi_i, TVector_c &fPsi_o)
1328+
exp_r_factor_2d(Stream<e_device> &stream, TGrid &grid_2d, Value_type<TGrid> gx, Value_type<TGrid> gy, TVector_c &fPsi_i, TVector_c &fPsi_o, bool scaling=true)
13291329
{
13301330
auto grid_bt = grid_2d.cuda_grid();
13311331

1332-
device_detail::exp_r_factor_2d<TGrid, typename TVector_c::value_type><<<grid_bt.Blk, grid_bt.Thr>>>(grid_2d, gx, gy, fPsi_i, fPsi_o);
1332+
device_detail::exp_r_factor_2d<TGrid, typename TVector_c::value_type><<<grid_bt.Blk, grid_bt.Thr>>>(grid_2d, gx, gy, fPsi_i, fPsi_o, scaling);
13331333
}
13341334

13351335
template <class TGrid, class TVector_c>

src/wave_function.cuh

+6-5
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ namespace mt
7474
Transmission_Function<T, dev>::set_input_data(input_multislice_i, stream_i, fft2_i);
7575
}
7676

77-
void phase_multiplication(const T_r &gxu, const T_r &gyu, TVector_c &psi_i, TVector_c &psi_o)
77+
void phase_multiplication(const T_r &gxu, const T_r &gyu, TVector_c &psi_i, TVector_c &psi_o, bool scaling=true)
7878
{
7979
if(this->input_multislice->dp_Shift || isZero(gxu, gyu))
8080
{
@@ -85,19 +85,20 @@ namespace mt
8585
return;
8686
}
8787

88-
mt::exp_r_factor_2d(*(this->stream), this->input_multislice->grid_2d, c_2Pi*gxu, c_2Pi*gyu, psi_i, psi_o);
88+
mt::exp_r_factor_2d(*(this->stream), this->input_multislice->grid_2d, c_2Pi*gxu, c_2Pi*gyu, psi_i, psi_o, scaling);
8989
}
9090

91-
void phase_multiplication(const T_r &gxu, const T_r &gyu, TVector_c &psi_io)
91+
void phase_multiplication(const T_r &gxu, const T_r &gyu, TVector_c &psi_io, bool scaling=true)
9292
{
93-
phase_multiplication(gxu, gyu, psi_io, psi_io);
93+
phase_multiplication(gxu, gyu, psi_io, psi_io, scaling);
9494
}
9595

9696
TVector_c* get_psi(const eSpace &space, const T_r &gxu, const T_r &gyu,
9797
T_r z, TVector_c &psi_i)
9898
{
9999
TVector_c *psi_o = &(this->trans_0);
100-
phase_multiplication(gxu, gyu, psi_i, *psi_o);
100+
// real space not need to include scaling and phase multiplication oposite sign as the propagation
101+
phase_multiplication(-gxu, -gyu, psi_i, *psi_o, false);
101102
propagator(space, gxu, gyu, z, *psi_o);
102103

103104
return psi_o;

0 commit comments

Comments
 (0)