Skip to content

SevenNetModel does not work with torch.float64 #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
orionarcher opened this issue Apr 2, 2025 · 8 comments
Open

SevenNetModel does not work with torch.float64 #92

orionarcher opened this issue Apr 2, 2025 · 8 comments
Labels
bug Something isn't working

Comments

@orionarcher
Copy link
Collaborator

The SevenNetModel currently fails when float64 is set as the type.

---------------------------------------------------------------------------
OperationFailure                          Traceback (most recent call last)
File /workspaces/propfoliotorchsim/propfolio/propfolio/scripts/analysis/analyze_torchsim_vs_ase_benchmark.py:11
      [8](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/workspaces/propfoliotorchsim/propfolio/propfolio/scripts/analysis/analyze_torchsim_vs_ase_benchmark.py:8) db = client["ray_md_testing"]
      [9](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/workspaces/propfoliotorchsim/propfolio/propfolio/scripts/analysis/analyze_torchsim_vs_ase_benchmark.py:9) collection = db["torch_sim_speedtest"]
---> [11](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/workspaces/propfoliotorchsim/propfolio/propfolio/scripts/analysis/analyze_torchsim_vs_ase_benchmark.py:11) docs = list(collection.find())
     [13](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/workspaces/propfoliotorchsim/propfolio/propfolio/scripts/analysis/analyze_torchsim_vs_ase_benchmark.py:13) tags = [
     [14](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/workspaces/propfoliotorchsim/propfolio/propfolio/scripts/analysis/analyze_torchsim_vs_ase_benchmark.py:14)     "production_v0.1",  # 1000 steps, 8000:10000 max atoms
     [15](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/workspaces/propfoliotorchsim/propfolio/propfolio/scripts/analysis/analyze_torchsim_vs_ase_benchmark.py:15)     "production_v0.2",  # 100 steps, 8000 max atoms
   (...)
     [26](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/workspaces/propfoliotorchsim/propfolio/propfolio/scripts/analysis/analyze_torchsim_vs_ase_benchmark.py:26)     "production_v0.13",  # 300 steps, timesteps fixed, mace only, small systems
     [27](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/workspaces/propfoliotorchsim/propfolio/propfolio/scripts/analysis/analyze_torchsim_vs_ase_benchmark.py:27) ]
     [28](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/workspaces/propfoliotorchsim/propfolio/propfolio/scripts/analysis/analyze_torchsim_vs_ase_benchmark.py:28) active_tag = tags[10]

File /usr/local/lib/python3.12/dist-packages/pymongo/synchronous/cursor.py:1281, in Cursor.__next__(self)
   [1280](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/usr/local/lib/python3.12/dist-packages/pymongo/synchronous/cursor.py:1280) def __next__(self) -> _DocumentType:
-> [1281](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/usr/local/lib/python3.12/dist-packages/pymongo/synchronous/cursor.py:1281)     return self.next()

File /usr/local/lib/python3.12/dist-packages/pymongo/synchronous/cursor.py:1257, in Cursor.next(self)
   [1255](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/usr/local/lib/python3.12/dist-packages/pymongo/synchronous/cursor.py:1255) if self._empty:
   [1256](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/usr/local/lib/python3.12/dist-packages/pymongo/synchronous/cursor.py:1256)     raise StopIteration
-> [1257](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/usr/local/lib/python3.12/dist-packages/pymongo/synchronous/cursor.py:1257) if len(self._data) or self._refresh():
   [1258](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/usr/local/lib/python3.12/dist-packages/pymongo/synchronous/cursor.py:1258)     return self._data.popleft()
   [1259](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/usr/local/lib/python3.12/dist-packages/pymongo/synchronous/cursor.py:1259) else:
...
    [244](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/usr/local/lib/python3.12/dist-packages/pymongo/helpers_shared.py:244) elif code == 43:
    [245](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/usr/local/lib/python3.12/dist-packages/pymongo/helpers_shared.py:245)     raise CursorNotFound(errmsg, code, response, max_wire_version)
--> [247](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/usr/local/lib/python3.12/dist-packages/pymongo/helpers_shared.py:247) raise OperationFailure(errmsg, code, response, max_wire_version)

OperationFailure: bad auth : Authentication failed., full error: {'ok': 0, 'errmsg': 'bad auth : Authentication failed.', 'code': 8000, 'codeName': 'AtlasError'}
Output is truncated. View as a [scrollable element](command:cellOutput.enableScrolling?01c24e86-747e-4fed-8893-4c4044a48a66) or open in a [text editor](command:workbench.action.openLargeOutput?01c24e86-747e-4fed-8893-4c4044a48a66). Adjust cell output [settings](command:workbench.action.openSettings?%5B%22%40tag%3AnotebookOutputLayout%22%5D)...
read-write
read-write
hi
Restarted .venv (Python 3.12.9)

Restarted .venv (Python 3.12.9)

Restarted .venv (Python 3.12.9)

Restarted .venv (Python 3.12.9)

Restarted .venv (Python 3.12.9)

Connected to .venv (Python 3.12.9)

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[8], line 20
     12 pretrained_sevenn_model = model_loaded.to(device, dtype=torch.float64)
     14 model = SevenNetModel(
     15     model=pretrained_sevenn_model,
     16     modal="omat24",
     17     device=device,
     18     dtype=torch.float64,
     19 )
---> 20 model(state)
     22 # max_scaler = ts.autobatching.estimate_max_memory_scaler(
     23 #     model=model,
     24 #     state_list=[ts.initialize_state(atoms, dtype=torch.float64, device=device)],
     25 #     metric_values=[1],
     26 #     max_atoms=1000000,
     27 # )

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
...
               ~~~~~~~~~~~~~ <--- HERE
    else:
        return _VF.tensordot(a, b, dims_a, dims_b, out=out)  # type: ignore[attr-defined]
RuntimeError: both inputs should have same dtype
Output is truncated. View as a [scrollable element](command:cellOutput.enableScrolling?9ece058a-4398-4f43-b4f9-ebb62f9d7810) or open in a [text editor](command:workbench.action.openLargeOutput?9ece058a-4398-4f43-b4f9-ebb62f9d7810). Adjust cell output [settings](command:workbench.action.openSettings?%5B%22%40tag%3AnotebookOutputLayout%22%5D)...
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[7], line 19
     12 pretrained_sevenn_model = model_loaded.to(device, dtype=torch.float64)
     14 model = SevenNetModel(
     15     model=pretrained_sevenn_model,
     16     modal="omat24",
     17     device=device,
     18 )
---> 19 model(state)
     21 # max_scaler = ts.autobatching.estimate_max_memory_scaler(
     22 #     model=model,
     23 #     state_list=[ts.initialize_state(atoms, dtype=torch.float64, device=device)],
     24 #     metric_values=[1],
     25 #     max_atoms=1000000,
     26 # )

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
...
---> 42     x = x @ w
     43     x = self.act(x)
     44     x = x * self.var_out**0.5

RuntimeError: expected mat1 and mat2 to have the same dtype, but got: double != float
Output is truncated. View as a [scrollable element](command:cellOutput.enableScrolling?4b7d2cbe-f7ef-4f2e-9a51-03e8da37e3ec) or open in a [text editor](command:workbench.action.openLargeOutput?4b7d2cbe-f7ef-4f2e-9a51-03e8da37e3ec). Adjust cell output [settings](command:workbench.action.openSettings?%5B%22%40tag%3AnotebookOutputLayout%22%5D)...
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 19
     12 pretrained_sevenn_model = model_loaded.to(device)
     14 model = SevenNetModel(
     15     model=pretrained_sevenn_model,
     16     modal="omat24",
     17     device=device,
     18 )
---> 19 model(state)
     21 # max_scaler = ts.autobatching.estimate_max_memory_scaler(
     22 #     model=model,
     23 #     state_list=[ts.initialize_state(atoms, dtype=torch.float64, device=device)],
     24 #     metric_values=[1],
     25 #     max_atoms=1000000,
     26 # )

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File /workspaces/propfoliotorchsim/torch-sim/build/__editable__.torch_sim-0.0.0rc0-py3-none-any/torch_sim/models/sevennet.py:226, in SevenNetModel.forward(self, state)
    223     batched_data = batched_data.to_dict()
    224     del batched_data["data_info"]
--> 226 output = self.model(batched_data)
    228 results = {}
    229 # Process energy

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/sevenn/nn/sequential.py:182, in AtomGraphSequential.forward(self, input)
    180 data = self._preprocess(input)
    181 for module in self:
--> 182     data = module(data)
    183 return data

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/sevenn/nn/convolution.py:125, in IrrepsConvolution.forward(self, data)
    123 assert self.convolution is not None, 'Convolution is not instantiated'
    124 assert self.weight_nn is not None, 'Weight_nn is not instantiated'
--> 125 weight = self.weight_nn(data[self.key_weight_input])
    126 x = data[self.key_x]
    127 if self.is_parallel:

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/container.py:250, in Sequential.forward(self, input)
    248 def forward(self, input):
    249     for module in self:
--> 250         input = module(input)
    251     return input

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/e3nn/nn/_fc.py:42, in _Layer.forward(self, x)
     40 if self.act is not None:
     41     w = self.weight / (self.h_in * self.var_in) ** 0.5
---> 42     x = x @ w
     43     x = self.act(x)
     44     x = x * self.var_out**0.5

RuntimeError: expected mat1 and mat2 to have the same dtype, but got: double != float
@orionarcher orionarcher added the bug Something isn't working label Apr 3, 2025
@YutackPark
Copy link

Hi 👋 I am the main developer 7net. We don't currently support double precision. Even if it works somehow, I can't say that the values are reliable.

By the way, this project is really great and exactly what I wanted recently. I want to help, especially the sevennet part, if any help is needed. Thanks for releasing this great work as an open source.

@abhijeetgangan
Copy link
Collaborator

@YutackPark Thank you for your comments. What’s the best way to contact you?

@YutackPark
Copy link

Here's my email: [email protected]

As I'm a PhD. candidate, I may have to CC my supervisor. For simple code maintenance or features, I can help freely on weekends. Just contact me with github in this case.

@orionarcher
Copy link
Collaborator Author

orionarcher commented Apr 4, 2025

Thanks @YutackPark! I am working on a rewrite of our model testing logic and would love your feedback. I'll tag you when I open the PR.

I'll also include a modification to the SevenNetModel to disallow float64 use.

@CompRhys
Copy link
Contributor

CompRhys commented Apr 4, 2025

@YutackPark One thing that I did here in the implementation/wrapper of SevenNet was that I used the neighbor list implementation from torch_sim.neighbors.vesin_nl_ts in order to not move tensors from the GPU. The potential for this implementation to be subtly different is one of the reasons that we might want to extend the default model testing configurations to be more extensive.

@YutackPark
Copy link

@CompRhys , @orionarcher
Thanks for the clarification—I understand the point. While constructing atomistic graphs is standard practice for MLIPs, there’s still a lack of consistency across implementations, and our current approach could also be improved. (I mean, I'm open to change 7net itself. Our code is not mature enough). Have you tried applying 7net to molecular systems? That’s typically where things start to get messy.

Other than that, the 7net part looks good to me.

By the way, do you know why 7net or GemNet shows less speed-up compared to MACE?

@CompRhys
Copy link
Contributor

CompRhys commented Apr 5, 2025

As far as I understand speed up is mainly a function of how much more efficiently we can use GPU memory due to batching at small batch sizes, at larger batch sizes this is more an open question that we are still exploring/profiling. Hopefully sharing the efficient batching implementations here will make it easier for developers like yourself to help identify the bottlenecks and optimizations that can then be applied to all models.

In #112 I think I will revert the neighbor list to the default you used just to be on the safe side, we don't want to accidentally impact the model performance if there are subtle differences.

@AdeeshKolluru
Copy link
Contributor

By the way, do you know why 7net or GemNet shows less speed-up compared to MACE?

It's likely due to the higher order terms used in these architectures. The GemNet used here calculates triplets and 7net has lmax of 3 whereas the MACE work has lmax of 1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants