Skip to content

Commit e1b93c1

Browse files
linpeizePeizeLin
andauthored
Update namespace ModuleGint (#6169)
* Update namespace ModuleGint * Update class BlasConnector * Update namespace ModuleGint --------- Co-authored-by: linpz <[email protected]>
1 parent 8fc7c42 commit e1b93c1

15 files changed

+586
-259
lines changed

source/module_base/blas_connector.cpp

+264-44
Large diffs are not rendered by default.

source/module_base/blas_connector.h

+44-5
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ extern "C"
6565
void cgemv_(const char *trans, const int *m, const int *n, const std::complex<float> *alpha,
6666
const std::complex<float> *a, const int *lda, const std::complex<float> *x, const int *incx,
6767
const std::complex<float> *beta, std::complex<float> *y, const int *incy);
68-
68+
6969
void zgemv_(const char *trans, const int *m, const int *n, const std::complex<double> *alpha,
7070
const std::complex<double> *a, const int *lda, const std::complex<double> *x, const int *incx,
7171
const std::complex<double> *beta, std::complex<double> *y, const int *incy);
@@ -180,11 +180,36 @@ class BlasConnector
180180
// Peize Lin add 2017-10-27
181181
// d=x*y
182182
static
183-
float dot( const int n, const float *X, const int incX, const float *Y, const int incY, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
183+
float dot( const int n, const float*const X, const int incX, const float*const Y, const int incY, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
184+
185+
static
186+
double dot( const int n, const double*const X, const int incX, const double*const Y, const int incY, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
187+
188+
// d=x*y
189+
static
190+
float dotu( const int n, const float*const X, const int incX, const float*const Y, const int incY, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
191+
192+
static
193+
double dotu( const int n, const double*const X, const int incX, const double*const Y, const int incY, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
194+
195+
static
196+
std::complex<float> dotu( const int n, const std::complex<float>*const X, const int incX, const std::complex<float>*const Y, const int incY, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
197+
198+
static
199+
std::complex<double> dotu( const int n, const std::complex<double>*const X, const int incX, const std::complex<double>*const Y, const int incY, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
200+
201+
// d=x.conj()*y
202+
static
203+
float dotc( const int n, const float*const X, const int incX, const float*const Y, const int incY, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
204+
205+
static
206+
double dotc( const int n, const double*const X, const int incX, const double*const Y, const int incY, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
184207

185208
static
186-
double dot( const int n, const double *X, const int incX, const double *Y, const int incY, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
209+
std::complex<float> dotc( const int n, const std::complex<float>*const X, const int incX, const std::complex<float>*const Y, const int incY, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
187210

211+
static
212+
std::complex<double> dotc( const int n, const std::complex<double>*const X, const int incX, const std::complex<double>*const Y, const int incY, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
188213

189214
// Peize Lin add 2017-10-27, fix bug trans 2019-01-17
190215
// C = a * A.? * B.? + b * C
@@ -231,6 +256,9 @@ class BlasConnector
231256
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
232257
const std::complex<double> beta, std::complex<double> *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
233258

259+
// side=='L': C = a * A * B + b * C.
260+
// side=='R': C = a * B * A + b * C.
261+
// A == A^T
234262
// Because you cannot pack symm or hemm into a row-major kernel by exchanging parameters, so only col-major functions are provided.
235263
static
236264
void symm_cm(const char side, const char uplo, const int m, const int n,
@@ -252,6 +280,19 @@ class BlasConnector
252280
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
253281
const std::complex<double> beta, std::complex<double> *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
254282

283+
// side=='L': C = a * A * B + b * C.
284+
// side=='R': C = a * B * A + b * C.
285+
// A == A^H
286+
static
287+
void hemm_cm(const char side, const char uplo, const int m, const int n,
288+
const float alpha, const float *a, const int lda, const float *b, const int ldb,
289+
const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
290+
291+
static
292+
void hemm_cm(const char side, const char uplo, const int m, const int n,
293+
const double alpha, const double *a, const int lda, const double *b, const int ldb,
294+
const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
295+
255296
static
256297
void hemm_cm(char side, char uplo, int m, int n,
257298
std::complex<float> alpha, std::complex<float> *a, int lda, std::complex<float> *b, int ldb,
@@ -263,7 +304,6 @@ class BlasConnector
263304
std::complex<double> beta, std::complex<double> *c, int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
264305

265306
// y = A*x + beta*y
266-
267307
static
268308
void gemv(const char trans, const int m, const int n,
269309
const float alpha, const float* A, const int lda, const float* X, const int incx,
@@ -283,7 +323,6 @@ class BlasConnector
283323
void gemv(const char trans, const int m, const int n,
284324
const std::complex<double> alpha, const std::complex<double> *A, const int lda, const std::complex<double> *X, const int incx,
285325
const std::complex<double> beta, std::complex<double> *Y, const int incy, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
286-
287326

288327
// Peize Lin add 2018-06-12
289328
// out = ||x||_2

source/module_hamilt_lcao/module_gint/temp_gint/gint_atom.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -203,5 +203,6 @@ void GintAtom::set_phi_dphi(
203203

204204
// explicit instantiation
205205
template void GintAtom::set_phi(const std::vector<Vec3d>& coords, const int stride, double* phi) const;
206+
template void GintAtom::set_phi(const std::vector<Vec3d>& coords, const int stride, std::complex<double>* phi) const;
206207
template void GintAtom::set_phi_dphi(const std::vector<Vec3d>& coords, const int stride, double* phi, double* dphi_x, double* dphi_y, double* dphi_z) const;
207208
}

source/module_hamilt_lcao/module_gint/temp_gint/gint_common.cpp

+25-12
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ void compose_hr_gint(std::shared_ptr<HContainer<double>> hr_gint)
2323
assert(upper_ap != nullptr);
2424
#endif
2525
for (int ir = 0; ir < ap.get_R_size(); ir++)
26-
{
26+
{
2727
auto R_index = ap.get_R_index(ir);
2828
auto upper_mat = upper_ap->find_matrix(-R_index);
2929
auto lower_mat = lower_ap->find_matrix(R_index);
@@ -36,7 +36,7 @@ void compose_hr_gint(std::shared_ptr<HContainer<double>> hr_gint)
3636
}
3737
}
3838
}
39-
}
39+
}
4040
}
4141

4242
void compose_hr_gint(std::vector<std::shared_ptr<HContainer<double>>> hr_gint_part,
@@ -54,7 +54,7 @@ void compose_hr_gint(std::vector<std::shared_ptr<HContainer<double>>> hr_gint_pa
5454
const hamilt::AtomPair<double>* ap_nspin_0 = hr_gint_part[0]->find_pair(iat1, iat2);
5555
const hamilt::AtomPair<double>* ap_nspin_3 = hr_gint_part[3]->find_pair(iat1, iat2);
5656
for (int ir = 0; ir < upper_ap->get_R_size(); ir++)
57-
{
57+
{
5858
const auto R_index = upper_ap->get_R_index(ir);
5959
auto upper_mat = upper_ap->find_matrix(R_index);
6060
auto mat_nspin_0 = ap_nspin_0->find_matrix(R_index);
@@ -124,10 +124,11 @@ void transfer_hr_gint_to_hR(std::shared_ptr<const HContainer<T>> hr_gint, HConta
124124

125125
// gint_info should not have been a parameter, but it was added to initialize dm_gint_full
126126
// In the future, we might try to remove the gint_info parameter
127+
template<typename T>
127128
void transfer_dm_2d_to_gint(
128129
std::shared_ptr<const GintInfo> gint_info,
129-
std::vector<HContainer<double>*> dm,
130-
std::vector<std::shared_ptr<HContainer<double>>> dm_gint)
130+
std::vector<HContainer<T>*> dm,
131+
std::vector<std::shared_ptr<HContainer<T>>> dm_gint)
131132
{
132133
// To check whether input parameter dm_2d has been initialized
133134
#ifdef __DEBUG
@@ -150,12 +151,12 @@ void transfer_dm_2d_to_gint(
150151
{
151152
#ifdef __MPI
152153
const int npol = 2;
153-
std::shared_ptr<HContainer<double>> dm_full = gint_info->get_hr<double>(npol);
154+
std::shared_ptr<HContainer<T>> dm_full = gint_info->get_hr<T>(npol);
154155
hamilt::transferParallels2Serials(*dm[0], dm_full.get());
155156
#else
156-
HContainer<double>* dm_full = dm[0];
157+
HContainer<T>* dm_full = dm[0];
157158
#endif
158-
std::vector<double*> tmp_pointer(4, nullptr);
159+
std::vector<T*> tmp_pointer(4, nullptr);
159160
for (int iap = 0; iap < dm_full->size_atom_pairs(); iap++)
160161
{
161162
auto& ap = dm_full->get_atom_pair(iap);
@@ -166,10 +167,10 @@ void transfer_dm_2d_to_gint(
166167
const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
167168
for (int is = 0; is < 4; is++)
168169
{
169-
tmp_pointer[is] =
170+
tmp_pointer[is] =
170171
dm_gint[is]->find_matrix(iat1, iat2, r_index)->get_pointer();
171172
}
172-
double* data_full = ap.get_pointer(ir);
173+
T* data_full = ap.get_pointer(ir);
173174
for (int irow = 0; irow < ap.get_row_size(); irow += 2)
174175
{
175176
for (int icol = 0; icol < ap.get_col_size(); icol += 2)
@@ -191,6 +192,18 @@ void transfer_dm_2d_to_gint(
191192
}
192193

193194

194-
template void transfer_hr_gint_to_hR(std::shared_ptr<const HContainer<double>> hr_gint, HContainer<double>* hR);
195-
template void transfer_hr_gint_to_hR(std::shared_ptr<const HContainer<std::complex<double>>> hr_gint, HContainer<std::complex<double>>* hR);
195+
template void transfer_hr_gint_to_hR(
196+
std::shared_ptr<const HContainer<double>> hr_gint,
197+
HContainer<double>* hR);
198+
template void transfer_hr_gint_to_hR(
199+
std::shared_ptr<const HContainer<std::complex<double>>> hr_gint,
200+
HContainer<std::complex<double>>* hR);
201+
template void transfer_dm_2d_to_gint(
202+
std::shared_ptr<const GintInfo> gint_info,
203+
std::vector<HContainer<double>*> dm,
204+
std::vector<std::shared_ptr<HContainer<double>>> dm_gint);
205+
template void transfer_dm_2d_to_gint(
206+
std::shared_ptr<const GintInfo> gint_info,
207+
std::vector<HContainer<std::complex<double>>*> dm,
208+
std::vector<std::shared_ptr<HContainer<std::complex<double>>>> dm_gint);
196209
}

source/module_hamilt_lcao/module_gint/temp_gint/gint_common.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ namespace ModuleGint
1313
template <typename T>
1414
void transfer_hr_gint_to_hR(std::shared_ptr<const HContainer<T>> hr_gint, HContainer<T>* hR);
1515

16+
template<typename T>
1617
void transfer_dm_2d_to_gint(
1718
std::shared_ptr<const GintInfo> gint_info,
18-
std::vector<HContainer<double>*> dm,
19-
std::vector<std::shared_ptr<HContainer<double>>> dm_gint);
19+
std::vector<HContainer<T>*> dm,
20+
std::vector<std::shared_ptr<HContainer<T>>> dm_gint);
2021

2122
}

source/module_hamilt_lcao/module_gint/temp_gint/gint_rho.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ void Gint_rho::cal_rho_()
4444
for (int is = 0; is < nspin_; is++)
4545
{
4646
phi_op.phi_mul_dm(phi.data(), *dm_gint_vec_[is], true, phi_dm.data());
47-
phi_op.phi_dot_phi_dm(phi.data(), phi_dm.data(), rho_[is]);
47+
phi_op.phi_dot_phi(phi.data(), phi_dm.data(), rho_[is]);
4848
}
4949
}
5050
}

source/module_hamilt_lcao/module_gint/temp_gint/gint_tau.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ void Gint_tau::cal_tau_()
5454
phi_op.phi_mul_dm(dphi_x.data(), *dm_gint_vec_[is], true, dphi_x_dm.data());
5555
phi_op.phi_mul_dm(dphi_y.data(), *dm_gint_vec_[is], true, dphi_y_dm.data());
5656
phi_op.phi_mul_dm(dphi_z.data(), *dm_gint_vec_[is], true, dphi_z_dm.data());
57-
phi_op.phi_dot_phi_dm(dphi_x.data(), dphi_x_dm.data(), kin_[is]);
58-
phi_op.phi_dot_phi_dm(dphi_y.data(), dphi_y_dm.data(), kin_[is]);
59-
phi_op.phi_dot_phi_dm(dphi_z.data(), dphi_z_dm.data(), kin_[is]);
57+
phi_op.phi_dot_phi(dphi_x.data(), dphi_x_dm.data(), kin_[is]);
58+
phi_op.phi_dot_phi(dphi_y.data(), dphi_y_dm.data(), kin_[is]);
59+
phi_op.phi_dot_phi(dphi_z.data(), dphi_z_dm.data(), kin_[is]);
6060
}
6161
}
6262
}

source/module_hamilt_lcao/module_gint/temp_gint/gint_vl.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ void Gint_vl::cal_hr_gint_()
3636
#pragma omp for schedule(dynamic)
3737
for(const auto& biggrid: gint_info_->get_biggrids())
3838
{
39-
if(biggrid->get_atoms().size() == 0)
39+
if(biggrid->get_atoms().empty())
4040
{
4141
continue;
4242
}
@@ -46,7 +46,7 @@ void Gint_vl::cal_hr_gint_()
4646
phi_vldr3.resize(phi_len);
4747
phi_op.set_phi(phi.data());
4848
phi_op.phi_mul_vldr3(vr_eff_, dr3_, phi.data(), phi_vldr3.data());
49-
phi_op.phi_mul_phi_vldr3(phi.data(), phi_vldr3.data(), *hr_gint_);
49+
phi_op.phi_mul_phi(phi.data(), phi_vldr3.data(), *hr_gint_, PhiOperator::Triangular_Matrix::Upper);
5050
}
5151
}
5252
}

source/module_hamilt_lcao/module_gint/temp_gint/gint_vl_metagga.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ void Gint_vl_metagga::cal_hr_gint_()
6161
phi_op.phi_mul_vldr3(vofk_, dr3_, dphi_x.data(), dphi_x_vldr3.data());
6262
phi_op.phi_mul_vldr3(vofk_, dr3_, dphi_y.data(), dphi_y_vldr3.data());
6363
phi_op.phi_mul_vldr3(vofk_, dr3_, dphi_z.data(), dphi_z_vldr3.data());
64-
phi_op.phi_mul_phi_vldr3(phi.data(), phi_vldr3.data(), *hr_gint_);
65-
phi_op.phi_mul_phi_vldr3(dphi_x.data(), dphi_x_vldr3.data(), *hr_gint_);
66-
phi_op.phi_mul_phi_vldr3(dphi_y.data(), dphi_y_vldr3.data(), *hr_gint_);
67-
phi_op.phi_mul_phi_vldr3(dphi_z.data(), dphi_z_vldr3.data(), *hr_gint_);
64+
phi_op.phi_mul_phi(phi.data(), phi_vldr3.data(), *hr_gint_, PhiOperator::Triangular_Matrix::Upper);
65+
phi_op.phi_mul_phi(dphi_x.data(), dphi_x_vldr3.data(), *hr_gint_, PhiOperator::Triangular_Matrix::Upper);
66+
phi_op.phi_mul_phi(dphi_y.data(), dphi_y_vldr3.data(), *hr_gint_, PhiOperator::Triangular_Matrix::Upper);
67+
phi_op.phi_mul_phi(dphi_z.data(), dphi_z_vldr3.data(), *hr_gint_, PhiOperator::Triangular_Matrix::Upper);
6868
}
6969
}
7070
}

source/module_hamilt_lcao/module_gint/temp_gint/gint_vl_metagga_nspin4.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,10 @@ void Gint_vl_metagga_nspin4::cal_hr_gint_()
6565
phi_op.phi_mul_vldr3(vofk_[is], dr3_, dphi_x.data(), dphi_x_vldr3.data());
6666
phi_op.phi_mul_vldr3(vofk_[is], dr3_, dphi_y.data(), dphi_y_vldr3.data());
6767
phi_op.phi_mul_vldr3(vofk_[is], dr3_, dphi_z.data(), dphi_z_vldr3.data());
68-
phi_op.phi_mul_phi_vldr3(phi.data(), phi_vldr3.data(), *hr_gint_part_[is]);
69-
phi_op.phi_mul_phi_vldr3(dphi_x.data(), dphi_x_vldr3.data(), *hr_gint_part_[is]);
70-
phi_op.phi_mul_phi_vldr3(dphi_y.data(), dphi_y_vldr3.data(), *hr_gint_part_[is]);
71-
phi_op.phi_mul_phi_vldr3(dphi_z.data(), dphi_z_vldr3.data(), *hr_gint_part_[is]);
68+
phi_op.phi_mul_phi(phi.data(), phi_vldr3.data(), *hr_gint_part_[is], PhiOperator::Triangular_Matrix::Upper);
69+
phi_op.phi_mul_phi(dphi_x.data(), dphi_x_vldr3.data(), *hr_gint_part_[is], PhiOperator::Triangular_Matrix::Upper);
70+
phi_op.phi_mul_phi(dphi_y.data(), dphi_y_vldr3.data(), *hr_gint_part_[is], PhiOperator::Triangular_Matrix::Upper);
71+
phi_op.phi_mul_phi(dphi_z.data(), dphi_z_vldr3.data(), *hr_gint_part_[is], PhiOperator::Triangular_Matrix::Upper);
7272
}
7373
}
7474
}

source/module_hamilt_lcao/module_gint/temp_gint/gint_vl_nspin4.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ void Gint_vl_nspin4::cal_hr_gint_()
4949
for(int is = 0; is < nspin_; is++)
5050
{
5151
phi_op.phi_mul_vldr3(vr_eff_[is], dr3_, phi.data(), phi_vldr3.data());
52-
phi_op.phi_mul_phi_vldr3(phi.data(), phi_vldr3.data(), *hr_gint_part_[is]);
52+
phi_op.phi_mul_phi(phi.data(), phi_vldr3.data(), *hr_gint_part_[is], PhiOperator::Triangular_Matrix::Upper);
5353
}
5454
}
5555
}

0 commit comments

Comments
 (0)