Skip to content

pt: explicitly set device #3307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,13 +343,17 @@ def _eval_model(
natoms = len(atom_types[0])

coord_input = torch.tensor(
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION
).to(DEVICE)
type_input = torch.tensor(atom_types, dtype=torch.long).to(DEVICE)
coords.reshape([-1, natoms, 3]),
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)
if cells is not None:
box_input = torch.tensor(
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION
).to(DEVICE)
cells.reshape([-1, 3, 3]),
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
else:
box_input = None

Expand Down Expand Up @@ -420,7 +424,7 @@ def eval_model(
if cells is not None:
assert isinstance(cells, torch.Tensor), err_msg
assert isinstance(atom_types, torch.Tensor) or isinstance(atom_types, list)
atom_types = torch.tensor(atom_types, dtype=torch.long).to(DEVICE)
atom_types = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)
elif isinstance(coords, np.ndarray):
if cells is not None:
assert isinstance(cells, np.ndarray), err_msg
Expand All @@ -441,17 +445,17 @@ def eval_model(
natoms = len(atom_types[0])

coord_input = torch.tensor(
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION
).to(DEVICE)
type_input = torch.tensor(atom_types, dtype=torch.long).to(DEVICE)
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)
box_input = None
if cells is None:
pbc = False
else:
pbc = True
box_input = torch.tensor(
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION
).to(DEVICE)
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size)

for ii in range(num_iter):
Expand Down Expand Up @@ -527,35 +531,37 @@ def eval_model(
energy_out = (
torch.cat(energy_out)
if energy_out
else torch.zeros([nframes, 1], dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE)
else torch.zeros(
[nframes, 1], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
)
atomic_energy_out = (
torch.cat(atomic_energy_out)
if atomic_energy_out
else torch.zeros([nframes, natoms, 1], dtype=GLOBAL_PT_FLOAT_PRECISION).to(
DEVICE
else torch.zeros(
[nframes, natoms, 1], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
)
force_out = (
torch.cat(force_out)
if force_out
else torch.zeros([nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION).to(
DEVICE
else torch.zeros(
[nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
)
virial_out = (
torch.cat(virial_out)
if virial_out
else torch.zeros([nframes, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION).to(
DEVICE
else torch.zeros(
[nframes, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
)
atomic_virial_out = (
torch.cat(atomic_virial_out)
if atomic_virial_out
else torch.zeros(
[nframes, natoms, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION
).to(DEVICE)
[nframes, natoms, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
)
updated_coord_out = torch.cat(updated_coord_out) if updated_coord_out else None
logits_out = torch.cat(logits_out) if logits_out else None
Expand Down
6 changes: 4 additions & 2 deletions deepmd/pt/infer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def __init__(

@staticmethod
def get_data(data):
batch_data = next(iter(data))
with torch.device("cpu"):
batch_data = next(iter(data))
for key in batch_data.keys():
if key == "sid" or key == "fid":
continue
Expand Down Expand Up @@ -235,7 +236,8 @@ def run(self):
), # setting to 0 diverges the behavior of its iterator; should be >=1
drop_last=False,
)
data = iter(dataloader)
with torch.device("cpu"):
data = iter(dataloader)

single_results = {}
sum_natoms = 0
Expand Down
19 changes: 15 additions & 4 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
FittingOutputDef,
OutputVariableDef,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.nlist import (
build_multiple_neighbor_list,
get_multiple_nlist_key,
Expand Down Expand Up @@ -91,9 +94,17 @@ def get_model_sels(self) -> List[List[int]]:

def _sort_rcuts_sels(self) -> Tuple[List[float], List[int]]:
# sort the pair of rcut and sels in ascending order, first based on sel, then on rcut.
rcuts = torch.tensor(self.get_model_rcuts(), dtype=torch.float64)
nsels = torch.tensor(self.get_model_nsels())
zipped = torch.stack([torch.tensor(rcuts), torch.tensor(nsels)], dim=0).T
rcuts = torch.tensor(
self.get_model_rcuts(), dtype=torch.float64, device=env.DEVICE
)
nsels = torch.tensor(self.get_model_nsels(), device=env.DEVICE)
zipped = torch.stack(
[
torch.tensor(rcuts, device=env.DEVICE),
torch.tensor(nsels, device=env.DEVICE),
],
dim=0,
).T
inner_sorting = torch.argsort(zipped[:, 1], dim=0)
inner_sorted = zipped[inner_sorting]
outer_sorting = torch.argsort(inner_sorted[:, 0], stable=True)
Expand Down Expand Up @@ -285,7 +296,7 @@ def __init__(
self.smin_alpha = smin_alpha

# this is a placeholder being updated in _compute_weight, to handle Jit attribute init error.
self.zbl_weight = torch.empty(0, dtype=torch.float64)
self.zbl_weight = torch.empty(0, dtype=torch.float64, device=env.DEVICE)

def serialize(self) -> dict:
return {
Expand Down
8 changes: 6 additions & 2 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
FittingOutputDef,
OutputVariableDef,
)
from deepmd.pt.utils import (
env,
)
from deepmd.utils.pair_tab import (
PairTab,
)
Expand Down Expand Up @@ -156,15 +159,16 @@ def forward_atomic(
pairwise_rr = self._get_pairwise_dist(
extended_coord, masked_nlist
) # (nframes, nloc, nnei)
self.tab_data = self.tab_data.view(
self.tab_data = self.tab_data.to(device=env.DEVICE).view(
int(self.tab_info[-1]), int(self.tab_info[-1]), int(self.tab_info[2]), 4
)

# to calculate the atomic_energy, we need 3 tensors, i_type, j_type, pairwise_rr
# i_type : (nframes, nloc), this is atype.
# j_type : (nframes, nloc, nnei)
j_type = extended_atype[
torch.arange(extended_atype.size(0))[:, None, None], masked_nlist
torch.arange(extended_atype.size(0), device=env.DEVICE)[:, None, None],
masked_nlist,
]

raw_atomic_energy = self._pair_tabulated_inter(
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def __init__(
sel = [sel] if isinstance(sel, int) else sel
self.nnei = sum(sel)
assert len(sel) == 1
self.sel = torch.tensor(sel)
self.sel = torch.tensor(sel, device=env.DEVICE)
self.sec = self.sel
self.axis_dim = axis_dim
self.set_davg_zero = set_davg_zero
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ def forward(
) # shape is [nframes*nall, self.ndescrpt]
xyz_scatter = torch.empty(
1,
device=env.DEVICE,
)
ret = self.filter_layers_old[0](dmatrix)
xyz_scatter = ret
Expand Down
6 changes: 4 additions & 2 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,10 @@ def _format_nlist(
nlist,
-1
* torch.ones(
[n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype
).to(nlist.device),
[n_nf, n_nloc, nnei - n_nnei],
dtype=nlist.dtype,
device=nlist.device,
),
],
dim=-1,
)
Expand Down
13 changes: 10 additions & 3 deletions deepmd/pt/model/network/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


def Tensor(*shape):
return torch.empty(shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION)
return torch.empty(shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)


class Dropout(nn.Module):
Expand Down Expand Up @@ -332,7 +332,13 @@ def __init__(
bias: bool = True,
init: str = "default",
):
super().__init__(d_in, d_out, bias=bias, dtype=env.GLOBAL_PT_FLOAT_PRECISION)
super().__init__(
d_in,
d_out,
bias=bias,
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE,
)

self.use_bias = bias

Expand Down Expand Up @@ -552,6 +558,7 @@ def __init__(self, type_nums, embed_dim, bavg=0.0, stddev=1.0):
embed_dim,
padding_idx=type_nums,
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE,
)
# nn.init.normal_(self.embedding.weight[:-1], mean=bavg, std=stddev)

Expand Down Expand Up @@ -799,7 +806,7 @@ def __init__(
temperature=temperature,
)
self.attn_layer_norm = nn.LayerNorm(
self.embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION
self.embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
if self.ffn:
self.ffn_embed_dim = ffn_embed_dim
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def __init__(
bias_atom_e = np.zeros([self.ntypes])
if not use_tebd:
assert self.ntypes == len(bias_atom_e), "Element count mismatches!"
bias_atom_e = torch.tensor(bias_atom_e)
bias_atom_e = torch.tensor(bias_atom_e, device=env.DEVICE)
self.register_buffer("bias_atom_e", bias_atom_e)

filter_layers_dipole = []
Expand Down
9 changes: 6 additions & 3 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def get_data_loader(_training_data, _validation_data, _training_params):
drop_last=False,
pin_memory=True,
)
training_data_buffered = BufferedIterator(iter(training_dataloader))
with torch.device("cpu"):
training_data_buffered = BufferedIterator(iter(training_dataloader))
validation_dataloader = DataLoader(
_validation_data,
sampler=valid_sampler,
Expand All @@ -166,7 +167,8 @@ def get_data_loader(_training_data, _validation_data, _training_params):
pin_memory=True,
)

validation_data_buffered = BufferedIterator(iter(validation_dataloader))
with torch.device("cpu"):
validation_data_buffered = BufferedIterator(iter(validation_dataloader))
if _training_params.get("validation_data", None) is not None:
valid_numb_batch = _training_params["validation_data"].get(
"numb_btch", 1
Expand Down Expand Up @@ -519,7 +521,8 @@ def step(_step_id, task_key="Default"):
if not torch.isfinite(grad_norm).all():
# check local gradnorm single GPU case, trigger NanDetector
raise FloatingPointError("gradients are Nan/Inf")
self.optimizer.step()
with torch.device("cpu"):
self.optimizer.step()
self.scheduler.step()
elif self.opt_type == "LKF":
if isinstance(self.loss, EnergyStdLoss):
Expand Down
8 changes: 5 additions & 3 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ def construct_dataset(system):
self.total_batch += len(system_dataloader)
# Initialize iterator instances for DataLoader
self.iters = []
for item in self.dataloaders:
self.iters.append(iter(item))
with torch.device("cpu"):
for item in self.dataloaders:
self.iters.append(iter(item))

def set_noise(self, noise_settings):
# noise_settings['noise_type'] # "trunc_normal", "normal", "uniform"
Expand Down Expand Up @@ -250,5 +251,6 @@ def get_weighted_sampler(training_data, prob_style, sys_prob=False):
log.info("Generated weighted sampler with prob array: " + str(probs))
# training_data.total_batch is the size of one epoch, you can increase it to avoid too many rebuilding of iteraters
len_sampler = training_data.total_batch * max(env.NUM_WORKERS, 1)
sampler = WeightedRandomSampler(probs, len_sampler, replacement=True)
with torch.device("cpu"):
sampler = WeightedRandomSampler(probs, len_sampler, replacement=True)
return sampler
8 changes: 4 additions & 4 deletions deepmd/pt/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,14 @@ def build_neighbor_list(
nlist = nlist[:, :, :nsel]
else:
rr = torch.cat(
[rr, torch.ones([batch_size, nloc, nsel - nnei]).to(rr.device) + rcut],
[rr, torch.ones([batch_size, nloc, nsel - nnei], device=rr.device) + rcut],
dim=-1,
)
nlist = torch.cat(
[
nlist,
torch.ones([batch_size, nloc, nsel - nnei], dtype=nlist.dtype).to(
rr.device
torch.ones(
[batch_size, nloc, nsel - nnei], dtype=nlist.dtype, device=rr.device
),
],
dim=-1,
Expand Down Expand Up @@ -289,7 +289,7 @@ def extend_coord_with_ghosts(

"""
nf, nloc = atype.shape
aidx = torch.tile(torch.arange(nloc).unsqueeze(0), [nf, 1])
aidx = torch.tile(torch.arange(nloc, device=env.DEVICE).unsqueeze(0), [nf, 1])
if cell is None:
nall = nloc
extend_coord = coord.clone()
Expand Down
21 changes: 11 additions & 10 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,17 @@ def make_stat_input(datasets, dataloaders, nbatches):
log.info(f"Packing data for statistics from {len(datasets)} systems")
for i in range(len(datasets)):
sys_stat = {key: [] for key in keys}
iterator = iter(dataloaders[i])
for _ in range(nbatches):
try:
stat_data = next(iterator)
except StopIteration:
iterator = iter(dataloaders[i])
stat_data = next(iterator)
for dd in stat_data:
if dd in keys:
sys_stat[dd].append(stat_data[dd])
with torch.device("cpu"):
iterator = iter(dataloaders[i])
for _ in range(nbatches):
try:
stat_data = next(iterator)
except StopIteration:
iterator = iter(dataloaders[i])
stat_data = next(iterator)
for dd in stat_data:
if dd in keys:
sys_stat[dd].append(stat_data[dd])
for key in keys:
if not isinstance(sys_stat[key][0], list):
if sys_stat[key][0] is None:
Expand Down
2 changes: 2 additions & 0 deletions source/tests/pt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@

torch.set_num_threads(1)
torch.set_num_interop_threads(1)
# testing purposes; device should always be set explicitly
torch.set_default_device("cuda:9999999")
Loading