Skip to content

Commit e25db6e

Browse files
authored
Refactor: Remove global dependence of descriptor, orbital_precalc, v_delta_precalc in DeePKS. (deepmodeling#5812)
* Remove functions related to v_delta in LCAO_Deepks; Remove some redundent variables. * Remove some temporary variables for force/stress calculation in DeePKS and separate force&stress calculations. Remove global dependence of descriptor. * Use accessor to accelerate the manipulation of torch::Tensor variables in DeePKS. * Remove LCAO_deepks_mpi.cpp. * Update Unittest for DeePKS. * Clang-format change. * Update cal_gdmx and cal_gdmepsl. * Fix check_gvx() bug when using mpirun. * Move functions for calculating descriptor from LCAO_deepks to DeePKS_domain. * Add UT for cal_gdmepsl and modify the ref files to suit the new data structure.
1 parent 8905ddf commit e25db6e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+3767
-2500
lines changed

source/Makefile.Objects

+5-5
Original file line numberDiff line numberDiff line change
@@ -192,21 +192,21 @@ OBJS_CELL=atom_pseudo.o\
192192

193193
OBJS_DEEPKS=LCAO_deepks.o\
194194
deepks_force.o\
195+
deepks_descriptor.o\
195196
deepks_orbital.o\
197+
deepks_orbpre.o\
198+
deepks_vdpre.o\
199+
deepks_hmat.o\
196200
LCAO_deepks_io.o\
197-
LCAO_deepks_mpi.o\
198201
LCAO_deepks_pdm.o\
199202
LCAO_deepks_phialpha.o\
200203
LCAO_deepks_torch.o\
201204
LCAO_deepks_vdelta.o\
202-
deepks_hmat.o\
203205
LCAO_deepks_interface.o\
204-
deepks_orbpre.o\
205206
cal_gdmx.o\
207+
cal_gdmepsl.o\
206208
cal_gedm.o\
207209
cal_gvx.o\
208-
cal_descriptor.o\
209-
v_delta_precalc.o\
210210

211211

212212
OBJS_ELECSTAT=elecstate.o\

source/module_esolver/esolver_ks.h

+10-10
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class ESolver_KS : public ESolver_FP
5555
virtual void after_scf(UnitCell& ucell, const int istep) override;
5656

5757
//! <Temporary> It should be replaced by a function in Hamilt Class
58-
virtual void update_pot(UnitCell& ucell, const int istep, const int iter) {};
58+
virtual void update_pot(UnitCell& ucell, const int istep, const int iter){};
5959

6060
//! Hamiltonian
6161
hamilt::Hamilt<T, Device>* p_hamilt = nullptr;
@@ -72,7 +72,7 @@ class ESolver_KS : public ESolver_FP
7272
//! Electronic wavefunctions
7373
psi::Psi<T>* psi = nullptr;
7474

75-
//! plane wave or LCAO
75+
//! plane wave or LCAO
7676
std::string basisname;
7777

7878
//! number of electrons
@@ -83,18 +83,18 @@ class ESolver_KS : public ESolver_FP
8383

8484
//! the start time of scf iteration
8585
#ifdef __MPI
86-
double iter_time;
86+
double iter_time;
8787
#else
8888
std::chrono::system_clock::time_point iter_time;
8989
#endif
9090

91-
double diag_ethr; //! the threshold for diagonalization
92-
double scf_thr; //! scf density threshold
93-
double scf_ene_thr; //! scf energy threshold
94-
double drho; //! the difference between rho_in (before HSolver) and rho_out (After HSolver)
95-
double hsolver_error; //! the error of HSolver
96-
int maxniter; //! maximum iter steps for scf
97-
int niter; //! iter steps actually used in scf
91+
double diag_ethr; //! the threshold for diagonalization
92+
double scf_thr; //! scf density threshold
93+
double scf_ene_thr; //! scf energy threshold
94+
double drho; //! the difference between rho_in (before HSolver) and rho_out (After HSolver)
95+
double hsolver_error; //! the error of HSolver
96+
int maxniter; //! maximum iter steps for scf
97+
int niter; //! iter steps actually used in scf
9898
};
9999
} // namespace ModuleESolver
100100
#endif

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp

+56-40
Original file line numberDiff line numberDiff line change
@@ -513,18 +513,14 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
513513

514514
if (!PARAM.inp.deepks_equiv) // training with force label not supported by equivariant version now
515515
{
516+
torch::Tensor gdmx;
516517
if (PARAM.globalv.gamma_only_local)
517518
{
518519
const std::vector<std::vector<double>>& dm_gamma
519520
= dynamic_cast<const elecstate::ElecStateLCAO<double>*>(pelec)->get_DM()->get_DMK_vector();
520-
GlobalC::ld.cal_gdmx(dm_gamma,
521-
ucell,
522-
orb,
523-
gd,
524-
kv.get_nks(),
525-
kv.kvec_d,
526-
GlobalC::ld.phialpha,
527-
isstress);
521+
522+
GlobalC::ld
523+
.cal_gdmx(dm_gamma, ucell, orb, gd, kv.get_nks(), kv.kvec_d, GlobalC::ld.phialpha, gdmx);
528524
}
529525
else
530526
{
@@ -533,25 +529,25 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
533529
->get_DM()
534530
->get_DMK_vector();
535531

536-
GlobalC::ld
537-
.cal_gdmx(dm_k, ucell, orb, gd, kv.get_nks(), kv.kvec_d, GlobalC::ld.phialpha, isstress);
532+
GlobalC::ld.cal_gdmx(dm_k, ucell, orb, gd, kv.get_nks(), kv.kvec_d, GlobalC::ld.phialpha, gdmx);
538533
}
539534
if (PARAM.inp.deepks_out_unittest)
540535
{
541-
GlobalC::ld.check_gdmx(ucell.nat);
536+
GlobalC::ld.check_gdmx(ucell.nat, gdmx);
542537
}
543538
std::vector<torch::Tensor> gevdm;
544539
GlobalC::ld.cal_gevdm(ucell.nat, gevdm);
545-
GlobalC::ld.cal_gvx(ucell.nat, gevdm);
540+
torch::Tensor gvx;
541+
GlobalC::ld.cal_gvx(ucell.nat, gevdm, gdmx, gvx);
546542

547543
if (PARAM.inp.deepks_out_unittest)
548544
{
549-
GlobalC::ld.check_gvx(ucell.nat);
545+
GlobalC::ld.check_gvx(ucell.nat, gvx);
550546
}
551547

552548
LCAO_deepks_io::save_npy_gvx(ucell.nat,
553549
GlobalC::ld.des_per_atom,
554-
GlobalC::ld.gvx_tensor,
550+
gvx,
555551
PARAM.globalv.global_out_dir,
556552
GlobalV::MY_RANK);
557553
}
@@ -715,6 +711,12 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
715711
{
716712
scs(i, j) += stress_exx(i, j);
717713
}
714+
#endif
715+
#ifdef __DEEPKS
716+
if (PARAM.inp.deepks_scf)
717+
{
718+
scs(i, j) += svnl_dalpha(i, j);
719+
}
718720
#endif
719721
}
720722
}
@@ -726,47 +728,61 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
726728
#ifdef __DEEPKS
727729
if (PARAM.inp.deepks_out_labels) // not parallelized yet
728730
{
729-
const std::string file_s = PARAM.globalv.global_out_dir + "deepks_sbase.npy";
731+
const std::string file_stot = PARAM.globalv.global_out_dir + "deepks_stot.npy";
730732
LCAO_deepks_io::save_npy_s(scs,
731-
file_s,
732-
ucell.omega,
733-
GlobalV::MY_RANK); // change to energy unit Ry when printing, S_base;
734-
}
735-
if (PARAM.inp.deepks_scf)
736-
{
737-
if (ModuleSymmetry::Symmetry::symm_flag == 1)
738-
{
739-
symm->symmetrize_mat3(svnl_dalpha, ucell.lat);
740-
} // end symmetry
741-
for (int i = 0; i < 3; i++)
742-
{
743-
for (int j = 0; j < 3; j++)
744-
{
745-
scs(i, j) += svnl_dalpha(i, j);
746-
}
747-
}
748-
}
749-
if (PARAM.inp.deepks_out_labels) // not parallelized yet
750-
{
751-
const std::string file_s = PARAM.globalv.global_out_dir + "deepks_stot.npy";
752-
LCAO_deepks_io::save_npy_s(scs,
753-
file_s,
733+
file_stot,
754734
ucell.omega,
755735
GlobalV::MY_RANK); // change to energy unit Ry when printing, S_tot, w/ model
756736

757737
// wenfei add 2021/11/2
758738
if (PARAM.inp.deepks_scf)
759739
{
740+
const std::string file_sbase = PARAM.globalv.global_out_dir + "deepks_sbase.npy";
741+
LCAO_deepks_io::save_npy_s(scs - svnl_dalpha,
742+
file_sbase,
743+
ucell.omega,
744+
GlobalV::MY_RANK); // change to energy unit Ry when printing, S_base;
760745

761746
if (!PARAM.inp.deepks_equiv) // training with stress label not supported by equivariant version now
762747
{
748+
torch::Tensor gdmepsl;
749+
if (PARAM.globalv.gamma_only_local)
750+
{
751+
const std::vector<std::vector<double>>& dm_gamma
752+
= dynamic_cast<const elecstate::ElecStateLCAO<double>*>(pelec)->get_DM()->get_DMK_vector();
753+
754+
GlobalC::ld.cal_gdmepsl(dm_gamma,
755+
ucell,
756+
orb,
757+
gd,
758+
kv.get_nks(),
759+
kv.kvec_d,
760+
GlobalC::ld.phialpha,
761+
gdmepsl);
762+
}
763+
else
764+
{
765+
const std::vector<std::vector<std::complex<double>>>& dm_k
766+
= dynamic_cast<const elecstate::ElecStateLCAO<std::complex<double>>*>(pelec)
767+
->get_DM()
768+
->get_DMK_vector();
769+
770+
GlobalC::ld
771+
.cal_gdmepsl(dm_k, ucell, orb, gd, kv.get_nks(), kv.kvec_d, GlobalC::ld.phialpha, gdmepsl);
772+
}
773+
if (PARAM.inp.deepks_out_unittest)
774+
{
775+
GlobalC::ld.check_gdmepsl(gdmepsl);
776+
}
777+
763778
std::vector<torch::Tensor> gevdm;
764779
GlobalC::ld.cal_gevdm(ucell.nat, gevdm);
765-
GlobalC::ld.cal_gvepsl(ucell.nat, gevdm);
780+
torch::Tensor gvepsl;
781+
GlobalC::ld.cal_gvepsl(ucell.nat, gevdm, gdmepsl, gvepsl);
766782

767783
LCAO_deepks_io::save_npy_gvepsl(ucell.nat,
768784
GlobalC::ld.des_per_atom,
769-
GlobalC::ld.gvepsl_tensor,
785+
gvepsl,
770786
PARAM.globalv.global_out_dir,
771787
GlobalV::MY_RANK); // unitless, grad_vepsl
772788
}

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_gamma.cpp

+14-3
Original file line numberDiff line numberDiff line change
@@ -248,13 +248,19 @@ void Force_LCAO<double>::ftable(const bool isforce,
248248

249249
#ifdef __DEEPKS
250250
const std::vector<std::vector<double>>& dm_gamma = dm->get_DMK_vector();
251+
std::vector<torch::Tensor> descriptor;
251252
if (PARAM.inp.deepks_scf)
252253
{
253254
// when deepks_scf is on, the init pdm should be same as the out pdm, so we should not recalculate the pdm
254255
// GlobalC::ld.cal_projected_DM(dm, ucell, orb, gd);
255256

256-
GlobalC::ld.cal_descriptor(ucell.nat);
257-
GlobalC::ld.cal_gedm(ucell.nat);
257+
DeePKS_domain::cal_descriptor(ucell.nat,
258+
GlobalC::ld.inlmax,
259+
GlobalC::ld.inl_l,
260+
GlobalC::ld.pdm,
261+
descriptor,
262+
GlobalC::ld.des_per_atom);
263+
GlobalC::ld.cal_gedm(ucell.nat, descriptor);
258264

259265
const int nks = 1;
260266
DeePKS_domain::cal_f_delta<double>(dm_gamma,
@@ -305,7 +311,12 @@ void Force_LCAO<double>::ftable(const bool isforce,
305311

306312
GlobalC::ld.check_projected_dm();
307313

308-
GlobalC::ld.check_descriptor(ucell, PARAM.globalv.global_out_dir);
314+
DeePKS_domain::check_descriptor(GlobalC::ld.inlmax,
315+
GlobalC::ld.des_per_atom,
316+
GlobalC::ld.inl_l,
317+
ucell,
318+
PARAM.globalv.global_out_dir,
319+
descriptor);
309320

310321
GlobalC::ld.check_gedm();
311322

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_k.cpp

+8-3
Original file line numberDiff line numberDiff line change
@@ -349,9 +349,14 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
349349
// when deepks_scf is on, the init pdm should be same as the out pdm, so we should not recalculate the pdm
350350
// GlobalC::ld.cal_projected_DM(dm, ucell, orb, gd);
351351

352-
GlobalC::ld.cal_descriptor(ucell.nat);
353-
354-
GlobalC::ld.cal_gedm(ucell.nat);
352+
std::vector<torch::Tensor> descriptor;
353+
DeePKS_domain::cal_descriptor(ucell.nat,
354+
GlobalC::ld.inlmax,
355+
GlobalC::ld.inl_l,
356+
GlobalC::ld.pdm,
357+
descriptor,
358+
GlobalC::ld.des_per_atom);
359+
GlobalC::ld.cal_gedm(ucell.nat, descriptor);
355360

356361
DeePKS_domain::cal_f_delta<std::complex<double>>(dm_k,
357362
ucell,

source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/deepks_lcao.cpp

+12-5
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::initialize_HR(const Grid_Driv
5858
this->H_V_delta = new HContainer<TR>(paraV);
5959
if (std::is_same<TK, double>::value)
6060
{
61-
//this->H_V_delta = new HContainer<TR>(paraV);
61+
// this->H_V_delta = new HContainer<TR>(paraV);
6262
this->H_V_delta->fix_gamma();
6363
}
6464

@@ -138,8 +138,8 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::initialize_HR(const Grid_Driv
138138
// if (std::is_same<TK, double>::value)
139139
// {
140140
this->H_V_delta->allocate(nullptr, true);
141-
// expand hR with H_V_delta
142-
// update : for computational rigor, gamma-only and multi-k cases both have full size of Hamiltonian of DeePKS now
141+
// expand hR with H_V_delta
142+
// update : for computational rigor, gamma-only and multi-k cases both have full size of Hamiltonian of DeePKS now
143143
this->hR->add(*this->H_V_delta);
144144
this->hR->allocate(nullptr, false);
145145
// }
@@ -161,8 +161,15 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::contributeHR()
161161
ModuleBase::timer::tick("DeePKS", "contributeHR");
162162

163163
GlobalC::ld.cal_projected_DM<TK>(this->DM, *this->ucell, *ptr_orb_, *(this->gd));
164-
GlobalC::ld.cal_descriptor(this->ucell->nat);
165-
GlobalC::ld.cal_gedm(this->ucell->nat);
164+
165+
std::vector<torch::Tensor> descriptor;
166+
DeePKS_domain::cal_descriptor(this->ucell->nat,
167+
GlobalC::ld.inlmax,
168+
GlobalC::ld.inl_l,
169+
GlobalC::ld.pdm,
170+
descriptor,
171+
GlobalC::ld.des_per_atom);
172+
GlobalC::ld.cal_gedm(this->ucell->nat, descriptor);
166173

167174
// // recalculate the H_V_delta
168175
// if (this->H_V_delta == nullptr)

source/module_hamilt_lcao/module_deepks/CMakeLists.txt

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
if(ENABLE_DEEPKS)
22
list(APPEND objects
33
LCAO_deepks.cpp
4+
deepks_descriptor.cpp
45
deepks_force.cpp
56
deepks_orbital.cpp
7+
deepks_orbpre.cpp
8+
deepks_vdpre.cpp
9+
deepks_hmat.cpp
610
LCAO_deepks_io.cpp
7-
LCAO_deepks_mpi.cpp
811
LCAO_deepks_pdm.cpp
912
LCAO_deepks_phialpha.cpp
1013
LCAO_deepks_torch.cpp
1114
LCAO_deepks_vdelta.cpp
12-
deepks_hmat.cpp
1315
LCAO_deepks_interface.cpp
14-
deepks_orbpre.cpp
1516
cal_gdmx.cpp
17+
cal_gdmepsl.cpp
1618
cal_gedm.cpp
1719
cal_gvx.cpp
18-
cal_descriptor.cpp
19-
v_delta_precalc.cpp
2020
)
2121

2222
add_library(

0 commit comments

Comments
 (0)