Skip to content

Enable NanoPET for atomic-basis spherical targets #527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

jwa7
Copy link
Member

@jwa7 jwa7 commented Mar 21, 2025

Extends NanoPET to enable predictions of spherical targets expressed on an atomic basis. Examples are the electron density decomposed on an auxiliary basis, and the Hamiltonian on the coupled atomic orbital basis.

Overview

  • Introduces a new target type "atomic_basis_spherical" to allow learning of spherical targets on an atomic basis set.
  • Only NanoPET is supported at present
  • Targets with 1 component axis are supported - this covers for example the electron density on a basis and the Hamiltonian/density matrix on a coupled basis
  • Per-pair targets can be either permutationally symmetrized or unsymmetrized. In both cases, only unique atom pairs are predicted

NanoPET architecture details

  • Re-uses the existing infrastructure for mapping last layer features to the spherical component heads
  • Infrastructural changes were required as follows:
    • For per-atom targets:
      • Last-layer PET features are sliced just before being passed through the output layer. Slicing is needed to extract the samples of the correct type based on the "center_type" the output block corresponds to.
    • For per-pair targets:
      • Again, the last-layer features are sliced according to the "first_atom_type" and "second_atom_type" the output block corresponds to. In the case of permutationally-symmetrized targets, the "s2_pi" key is also used to do this slicing.
      • Pre-last layer transformations are a little more complicated than for per-atom targets. Both the PET node and edge features are used. These are passed through separate heads before being combined in a single tensor that represents the last-layer features for the whole per-pair quantity.
      • Targets can be either permutationally symmetrized or not. In the former case, the PET edge features are symmetrized and thus the last layer features (if outputted) would carry the relevant metadata ("s2_pi" etc), but in the samples labels. as slicing into blocks only occurs at the output layer.
      • Only unique samples are predicted for per-pair targets. Practically this means that triangular off-site blocks in atom type are predicted where "first_atom_type" < "second_atom_type". For blocks where "first_atom_type" == "second_atom_type", samples are triangular in atom index such that "first_atom" <= "second_atom"
  • A new file "nanopet/modules/samples.py" has been created to handle the construction of samples for the features and outputs (which are in general different for these targets, per-pair ones in particular). These can stay here for now, and perhaps later be moved to metatomic.

Other metatrain infrastructure details

  • The dataloader has been modified to use join_kwargs={"different_keys": "union"} as this is required for targets on an atomic basis where different systems have different atom types (and therefore keys)
  • "per_atom.py" has been modified to not perform a sum over samples, also for per-pair targets
  • "augmentation.py" has been cleaned up a bit, specifically with regards to how the blocks are split along samples

Contributor (creator of pull-request) checklist

  • Tests updated (for new features and bugfixes)?
  • Documentation updated (for new features)?
  • Issue referenced (for PRs that solve an issue)?

Reviewer checklist

  • CHANGELOG updated with public API or any other important changes?

📚 Documentation preview 📚: https://metatrain--527.org.readthedocs.build/en/527/

Copy link
Collaborator

@frostedoyster frostedoyster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checked the model part for now, I will check the data augmentation and infrastructure parts later

jwa7 and others added 2 commits April 3, 2025 16:48
Co-authored-by: Filippo Bigi <[email protected]>
Co-authored-by: Paolo Pegolo <paolo.pegolo.epfl.ch>
Co-authored-by: Paolo Pegolo <[email protected]>
@jwa7 jwa7 force-pushed the pet_atomic_basis branch from 04d0e2f to 130b110 Compare April 4, 2025 14:07
@jwa7
Copy link
Member Author

jwa7 commented Apr 16, 2025

@Luthaf @frostedoyster @ppegolo here's an update - ready for review when you're ready!

On my side, still to do (but shouldn't affect review in the meantime):

  • Write some tests
  • Update the changelog
  • Fix the model export torchscript error - help appreciated, I can't seem to figure it out!

@jwa7 jwa7 marked this pull request as ready for review April 16, 2025 14:35
@jwa7 jwa7 self-assigned this Apr 17, 2025
Copy link
Member

@Luthaf Luthaf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how much I understand the changes to PET, so I'll let someone else check this part.

One question is how much work would it be to port this to NativePET?

Comment on lines +412 to +416
if self.atomic_basis_target_info[output_name][
"type"
] == "atomic_basis_spherical" and self.atomic_basis_target_info[
output_name
]["sample_kind"].startswith("per_pair"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not great to read. Maybe you could extract two variable and change this to something like if atomic_basis_is_spherical and sample_is_per_pair

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

# symmetrize the PET edge features and pass through its head
if (
self.atomic_basis_target_info[output_name]["sample_kind"]
== "per_pair_sym"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this a different kind of sample?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because in this case the PET features are still one whole block (i.e. we haven't sliced by s2_pi/first_atom_type/second_atom_type yet), and because it is symmetrized the normal samples aren't complete in info: we have "duplicated" samples that carry the index s2_pi=+/-1

from metatensor.torch.atomistic import System


def get_samples(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

documentation please!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was following suit on what is in concatenate_structures in structure.py in the same directory! :D But sure I can add

Comment on lines +180 to +182
If ``include_atom_type=True``, the atom types are prepended dimensions, either
corresponding to "center_type" if ``n_center=1`` or ["first_atom_type",
"second_atom_type"] if ``n_center=2``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no longer an n_center parameter

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

Comment on lines +256 to +258
torch.zeros(len(first_atom_type), dtype=torch.int32).reshape(
-1, 1
), # s2_pi = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch.zeros(len(first_atom_type), dtype=torch.int32).reshape(
-1, 1
), # s2_pi = 0
# s2_pi = 0
torch.zeros((len(first_atom_type), 1), dtype=torch.int32),

# ===== Slicing PET node/edge features for an atomic basis =====


def samples_for_atomic_basis_per_atom(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docs please

# target keys.
for target_name in targets.keys():
if predictions[target_name].keys != targets[target_name].keys:
# TODO: use `metatensor.filter_blocks` once PR #XXX is available
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this metatensor/metatensor#885? I can make a release =)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Thanks :)

# First, build the indices that split the block samples by system
split_indices: List[int] = []

if target_type == "spherical":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference between target_type == "spherical" and target_type == "atomic_basis_spherical"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

atomic basis spherical doesn't have all the atomic samples in a given block - only those with the corresponding to the atom types. The splitting along the samples axis needs more care than for pure spherical targets

s2_pi: int,
) -> List[int]:
"""
Finds the indices that splits a TensorBlock along the samples axis by system index.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this an implementation of metatensor/metatensor#627?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes :)

@@ -30,6 +30,8 @@ def __init__(
self.is_scalar = False
self.is_cartesian = False
self.is_spherical = False
self.is_atomic_basis_spherical_per_atom = False
self.is_atomic_basis_spherical_per_pair = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have a definition of what each of them mean

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure can do

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants