Skip to content

fix c++ interface bug #3613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 28, 2024
43 changes: 41 additions & 2 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ void DeepPotPT::init(const std::string& model,
std::cout << "load model from: " << model << " to gpu " << gpu_rank
<< std::endl;
}
int gpu_num = -1;
DPGetDeviceCount(gpu_num);
if (gpu_id > gpu_num) {
throw deepmd::deepmd_exception(
"current rank" + gpu_id + "is larger than the number of gpu" + gpu_num);
}
module = torch::jit::load(model, device);

torch::jit::FusionStrategy strategy;
Expand Down Expand Up @@ -107,7 +113,6 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
}
auto int_options = torch::TensorOptions().dtype(torch::kInt64);
auto int32_options = torch::TensorOptions().dtype(torch::kInt32);

// select real atoms
std::vector<VALUETYPE> dcoord, dforce, aparam_, datom_energy, datom_virial;
std::vector<int> datype, fwd_map, bkw_map;
Expand All @@ -116,6 +121,24 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
select_real_atoms_coord(dcoord, datype, aparam_, nghost_real, fwd_map,
bkw_map, nall_real, nloc_real, coord, atype, aparam,
nghost, ntypes, 1, daparam, nall, aparam_nall);
int nloc = nall_real - nghost_real;
int nframes = 1;
if (nloc == 0) {
// no backward map needed
// dforce of size nall * 3
force.resize(static_cast<size_t>(nframes) * fwd_map.size() * 3);
fill(force.begin(), force.end(), (VALUETYPE)0.0);
// dvirial of size 9
virial.resize(static_cast<size_t>(nframes) * 9);
fill(virial.begin(), virial.end(), (VALUETYPE)0.0);
// datom_energy_ of size nall
atom_energy.resize(static_cast<size_t>(nframes) * fwd_map.size());
fill(atom_energy.begin(), atom_energy.end(), (VALUETYPE)0.0);
// datom_virial_ of size nall * 9
atom_virial.resize(static_cast<size_t>(nframes) * fwd_map.size() * 9);
fill(atom_virial.begin(), atom_virial.end(), (VALUETYPE)0.0);
return;
}
std::vector<VALUETYPE> coord_wrapped = dcoord;
at::Tensor coord_wrapped_Tensor =
torch::from_blob(coord_wrapped.data(), {1, nall_real, 3}, options)
Expand Down Expand Up @@ -185,7 +208,6 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
datom_virial.assign(
cpu_atom_virial_.data_ptr<VALUETYPE>(),
cpu_atom_virial_.data_ptr<VALUETYPE>() + cpu_atom_virial_.numel());
int nframes = 1;
// bkw map
force.resize(static_cast<size_t>(nframes) * fwd_map.size() * 3);
atom_energy.resize(static_cast<size_t>(nframes) * fwd_map.size());
Expand Down Expand Up @@ -249,6 +271,23 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
floatType = torch::kFloat32;
}
auto int_options = torch::TensorOptions().dtype(torch::kInt64);
int nframes = 1;
if (natoms == 0) {
// no backward map needed
// dforce of size nall * 3
force_.resize(static_cast<size_t>(nframes) * natoms * 3);
fill(force_.begin(), force_.end(), (VALUETYPE)0.0);
// dvirial of size 9
virial.resize(static_cast<size_t>(nframes) * 9);
fill(virial.begin(), virial.end(), (VALUETYPE)0.0);
// datom_energy_ of size nall
atom_energy.resize(static_cast<size_t>(nframes) * natoms);
fill(atom_energy.begin(), atom_energy.end(), (VALUETYPE)0.0);
// datom_virial_ of size nall * 9
atom_virial.resize(static_cast<size_t>(nframes) * natoms * 9);
fill(atom_virial.begin(), atom_virial.end(), (VALUETYPE)0.0);
return;
}
std::vector<torch::jit::IValue> inputs;
at::Tensor coord_wrapped_Tensor =
torch::from_blob(coord_wrapped.data(), {1, natoms, 3}, options)
Expand Down
Loading