@@ -513,18 +513,14 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
513
513
514
514
if (!PARAM.inp .deepks_equiv ) // training with force label not supported by equivariant version now
515
515
{
516
+ torch::Tensor gdmx;
516
517
if (PARAM.globalv .gamma_only_local )
517
518
{
518
519
const std::vector<std::vector<double >>& dm_gamma
519
520
= dynamic_cast <const elecstate::ElecStateLCAO<double >*>(pelec)->get_DM ()->get_DMK_vector ();
520
- GlobalC::ld.cal_gdmx (dm_gamma,
521
- ucell,
522
- orb,
523
- gd,
524
- kv.get_nks (),
525
- kv.kvec_d ,
526
- GlobalC::ld.phialpha ,
527
- isstress);
521
+
522
+ GlobalC::ld
523
+ .cal_gdmx (dm_gamma, ucell, orb, gd, kv.get_nks (), kv.kvec_d , GlobalC::ld.phialpha , gdmx);
528
524
}
529
525
else
530
526
{
@@ -533,25 +529,25 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
533
529
->get_DM ()
534
530
->get_DMK_vector ();
535
531
536
- GlobalC::ld
537
- .cal_gdmx (dm_k, ucell, orb, gd, kv.get_nks (), kv.kvec_d , GlobalC::ld.phialpha , isstress);
532
+ GlobalC::ld.cal_gdmx (dm_k, ucell, orb, gd, kv.get_nks (), kv.kvec_d , GlobalC::ld.phialpha , gdmx);
538
533
}
539
534
if (PARAM.inp .deepks_out_unittest )
540
535
{
541
- GlobalC::ld.check_gdmx (ucell.nat );
536
+ GlobalC::ld.check_gdmx (ucell.nat , gdmx );
542
537
}
543
538
std::vector<torch::Tensor> gevdm;
544
539
GlobalC::ld.cal_gevdm (ucell.nat , gevdm);
545
- GlobalC::ld.cal_gvx (ucell.nat , gevdm);
540
+ torch::Tensor gvx;
541
+ GlobalC::ld.cal_gvx (ucell.nat , gevdm, gdmx, gvx);
546
542
547
543
if (PARAM.inp .deepks_out_unittest )
548
544
{
549
- GlobalC::ld.check_gvx (ucell.nat );
545
+ GlobalC::ld.check_gvx (ucell.nat , gvx );
550
546
}
551
547
552
548
LCAO_deepks_io::save_npy_gvx (ucell.nat ,
553
549
GlobalC::ld.des_per_atom ,
554
- GlobalC::ld. gvx_tensor ,
550
+ gvx ,
555
551
PARAM.globalv .global_out_dir ,
556
552
GlobalV::MY_RANK);
557
553
}
@@ -715,6 +711,12 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
715
711
{
716
712
scs (i, j) += stress_exx (i, j);
717
713
}
714
+ #endif
715
+ #ifdef __DEEPKS
716
+ if (PARAM.inp .deepks_scf )
717
+ {
718
+ scs (i, j) += svnl_dalpha (i, j);
719
+ }
718
720
#endif
719
721
}
720
722
}
@@ -726,47 +728,61 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
726
728
#ifdef __DEEPKS
727
729
if (PARAM.inp .deepks_out_labels ) // not parallelized yet
728
730
{
729
- const std::string file_s = PARAM.globalv .global_out_dir + " deepks_sbase .npy" ;
731
+ const std::string file_stot = PARAM.globalv .global_out_dir + " deepks_stot .npy" ;
730
732
LCAO_deepks_io::save_npy_s (scs,
731
- file_s,
732
- ucell.omega ,
733
- GlobalV::MY_RANK); // change to energy unit Ry when printing, S_base;
734
- }
735
- if (PARAM.inp .deepks_scf )
736
- {
737
- if (ModuleSymmetry::Symmetry::symm_flag == 1 )
738
- {
739
- symm->symmetrize_mat3 (svnl_dalpha, ucell.lat );
740
- } // end symmetry
741
- for (int i = 0 ; i < 3 ; i++)
742
- {
743
- for (int j = 0 ; j < 3 ; j++)
744
- {
745
- scs (i, j) += svnl_dalpha (i, j);
746
- }
747
- }
748
- }
749
- if (PARAM.inp .deepks_out_labels ) // not parallelized yet
750
- {
751
- const std::string file_s = PARAM.globalv .global_out_dir + " deepks_stot.npy" ;
752
- LCAO_deepks_io::save_npy_s (scs,
753
- file_s,
733
+ file_stot,
754
734
ucell.omega ,
755
735
GlobalV::MY_RANK); // change to energy unit Ry when printing, S_tot, w/ model
756
736
757
737
// wenfei add 2021/11/2
758
738
if (PARAM.inp .deepks_scf )
759
739
{
740
+ const std::string file_sbase = PARAM.globalv .global_out_dir + " deepks_sbase.npy" ;
741
+ LCAO_deepks_io::save_npy_s (scs - svnl_dalpha,
742
+ file_sbase,
743
+ ucell.omega ,
744
+ GlobalV::MY_RANK); // change to energy unit Ry when printing, S_base;
760
745
761
746
if (!PARAM.inp .deepks_equiv ) // training with stress label not supported by equivariant version now
762
747
{
748
+ torch::Tensor gdmepsl;
749
+ if (PARAM.globalv .gamma_only_local )
750
+ {
751
+ const std::vector<std::vector<double >>& dm_gamma
752
+ = dynamic_cast <const elecstate::ElecStateLCAO<double >*>(pelec)->get_DM ()->get_DMK_vector ();
753
+
754
+ GlobalC::ld.cal_gdmepsl (dm_gamma,
755
+ ucell,
756
+ orb,
757
+ gd,
758
+ kv.get_nks (),
759
+ kv.kvec_d ,
760
+ GlobalC::ld.phialpha ,
761
+ gdmepsl);
762
+ }
763
+ else
764
+ {
765
+ const std::vector<std::vector<std::complex<double >>>& dm_k
766
+ = dynamic_cast <const elecstate::ElecStateLCAO<std::complex<double >>*>(pelec)
767
+ ->get_DM ()
768
+ ->get_DMK_vector ();
769
+
770
+ GlobalC::ld
771
+ .cal_gdmepsl (dm_k, ucell, orb, gd, kv.get_nks (), kv.kvec_d , GlobalC::ld.phialpha , gdmepsl);
772
+ }
773
+ if (PARAM.inp .deepks_out_unittest )
774
+ {
775
+ GlobalC::ld.check_gdmepsl (gdmepsl);
776
+ }
777
+
763
778
std::vector<torch::Tensor> gevdm;
764
779
GlobalC::ld.cal_gevdm (ucell.nat , gevdm);
765
- GlobalC::ld.cal_gvepsl (ucell.nat , gevdm);
780
+ torch::Tensor gvepsl;
781
+ GlobalC::ld.cal_gvepsl (ucell.nat , gevdm, gdmepsl, gvepsl);
766
782
767
783
LCAO_deepks_io::save_npy_gvepsl (ucell.nat ,
768
784
GlobalC::ld.des_per_atom ,
769
- GlobalC::ld. gvepsl_tensor ,
785
+ gvepsl ,
770
786
PARAM.globalv .global_out_dir ,
771
787
GlobalV::MY_RANK); // unitless, grad_vepsl
772
788
}
0 commit comments