Skip to content

Read site labels from cif file #3136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pymatgen/alchemy/materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def from_cif_string(
raw_string = re.sub(r"'", '"', cif_string)
cif_dict = parser.as_dict()
cif_keys = list(cif_dict)
s = parser.get_structures(primitive)[0]
struct = parser.get_structures(primitive)[0]
partial_cif = cif_dict[cif_keys[0]]
if "_database_code_ICSD" in partial_cif:
source = partial_cif["_database_code_ICSD"] + "-ICSD"
Expand All @@ -302,7 +302,7 @@ def from_cif_string(
"original_file": raw_string,
"cif_data": cif_dict[cif_keys[0]],
}
return TransformedStructure(s, transformations, history=[source_info])
return TransformedStructure(struct, transformations, history=[source_info])

@staticmethod
def from_poscar_string(
Expand Down
42 changes: 21 additions & 21 deletions pymatgen/core/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -3998,8 +3998,8 @@ def rotate_sites(
theta %= 2 * np.pi

rm = expm(cross(eye(3), axis / norm(axis)) * theta)
for i in indices:
site = self._sites[i]
for idx in indices:
site = self._sites[idx]
coords = ((np.dot(rm, np.array(site.coords - anchor).T)).T + anchor).ravel()
new_site = PeriodicSite(
site.species,
Expand All @@ -4010,7 +4010,7 @@ def rotate_sites(
properties=site.properties,
skip_checks=True,
)
self._sites[i] = new_site
self._sites[idx] = new_site

def perturb(self, distance: float, min_distance: float | None = None) -> None:
"""
Expand All @@ -4035,8 +4035,8 @@ def get_rand_vec():
dist = np.random.uniform(min_distance, dist)
return vector / vnorm * dist if vnorm != 0 else get_rand_vec()

for i in range(len(self._sites)):
self.translate_sites([i], get_rand_vec(), frac_coords=False)
for idx in range(len(self._sites)):
self.translate_sites([idx], get_rand_vec(), frac_coords=False)

def make_supercell(self, scaling_matrix: ArrayLike, to_unit_cell: bool = True) -> None:
"""
Expand Down Expand Up @@ -4090,9 +4090,9 @@ def merge_sites(self, tol: float = 0.01, mode: Literal["sum", "delete", "average
from scipy.cluster.hierarchy import fcluster, linkage
from scipy.spatial.distance import squareform

d = self.distance_matrix
np.fill_diagonal(d, 0)
clusters = fcluster(linkage(squareform((d + d.T) / 2)), tol, "distance")
dist_mat = self.distance_matrix
np.fill_diagonal(dist_mat, 0)
clusters = fcluster(linkage(squareform((dist_mat + dist_mat.T) / 2)), tol, "distance")
sites = []
for c in np.unique(clusters):
inds = np.where(clusters == c)[0]
Expand Down Expand Up @@ -4318,7 +4318,7 @@ def __setitem__( # type: ignore
return
elif isinstance(idx, slice):
to_mod = self[idx]
indices = [ii for ii, s in enumerate(self._sites) if s in to_mod]
indices = [idx for idx, site in enumerate(self._sites) if site in to_mod]
else:
indices = list(idx)

Expand Down Expand Up @@ -4378,22 +4378,22 @@ def set_charge_and_spin(self, charge: float, spin_multiplicity: int | None = Non
if there are unpaired electrons.
"""
self._charge = charge
nelectrons = 0.0
n_electrons = 0.0
for site in self._sites:
for sp, amt in site.species.items():
if not isinstance(sp, DummySpecies):
nelectrons += sp.Z * amt
nelectrons -= charge
self._nelectrons = nelectrons
n_electrons += sp.Z * amt
n_electrons -= charge
self._nelectrons = n_electrons
if spin_multiplicity:
if self._charge_spin_check and (nelectrons + spin_multiplicity) % 2 != 1:
if self._charge_spin_check and (n_electrons + spin_multiplicity) % 2 != 1:
raise ValueError(
f"Charge of {self._charge} and spin multiplicity of {spin_multiplicity} is"
" not possible for this molecule"
)
self._spin_multiplicity = spin_multiplicity
else:
self._spin_multiplicity = 1 if nelectrons % 2 == 0 else 2
self._spin_multiplicity = 1 if n_electrons % 2 == 0 else 2

def insert( # type: ignore
self,
Expand Down Expand Up @@ -4506,11 +4506,11 @@ def rotate_sites(

rm = expm(cross(eye(3), axis / norm(axis)) * theta)

for i in indices:
site = self._sites[i]
for idx in indices:
site = self._sites[idx]
s = ((np.dot(rm, (site.coords - anchor).T)).T + anchor).ravel()
new_site = Site(site.species, s, properties=site.properties)
self._sites[i] = new_site
self._sites[idx] = new_site

def perturb(self, distance: float):
"""
Expand All @@ -4528,8 +4528,8 @@ def get_rand_vec():
vnorm = np.linalg.norm(vector)
return vector / vnorm * distance if vnorm != 0 else get_rand_vec()

for i in range(len(self._sites)):
self.translate_sites([i], get_rand_vec())
for idx in range(len(self._sites)):
self.translate_sites([idx], get_rand_vec())

def apply_operation(self, symmop: SymmOp):
"""
Expand All @@ -4543,7 +4543,7 @@ def operate_site(site):
new_cart = symmop.operate(site.coords)
return Site(site.species, new_cart, properties=site.properties)

self._sites = [operate_site(s) for s in self._sites]
self._sites = [operate_site(site) for site in self._sites]

def copy(self):
"""
Expand Down
56 changes: 33 additions & 23 deletions pymatgen/io/cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@

sub_spgrp = partial(re.sub, r"[\s_]", "")

space_groups = {sub_spgrp(k): k for k in SYMM_DATA["space_group_encoding"]} # type: ignore
space_groups = {sub_spgrp(key): key for key in SYMM_DATA["space_group_encoding"]} # type: ignore

space_groups.update({sub_spgrp(k): k for k in SYMM_DATA["space_group_encoding"]}) # type: ignore
space_groups.update({sub_spgrp(key): key for key in SYMM_DATA["space_group_encoding"]}) # type: ignore


class CifBlock:
Expand Down Expand Up @@ -895,6 +895,7 @@ def get_num_implicit_hydrogens(sym):

coord_to_species = {}
coord_to_magmoms = {}
labels = {}

def get_matching_coord(coord):
keys = list(coord_to_species)
Expand All @@ -908,15 +909,15 @@ def get_matching_coord(coord):
return keys[inds[0]]
return False

for i in range(len(data["_atom_site_label"])):
for i, label in enumerate(data["_atom_site_label"]):
try:
# If site type symbol exists, use it. Otherwise, we use the
# label.
symbol = self._parse_symbol(data["_atom_site_type_symbol"][i])
num_h = get_num_implicit_hydrogens(data["_atom_site_type_symbol"][i])
except KeyError:
symbol = self._parse_symbol(data["_atom_site_label"][i])
num_h = get_num_implicit_hydrogens(data["_atom_site_label"][i])
symbol = self._parse_symbol(label)
num_h = get_num_implicit_hydrogens(label)
if not symbol:
continue

Expand All @@ -936,7 +937,7 @@ def get_matching_coord(coord):
x = str2float(data["_atom_site_fract_x"][i])
y = str2float(data["_atom_site_fract_y"][i])
z = str2float(data["_atom_site_fract_z"][i])
magmom = magmoms.get(data["_atom_site_label"][i], np.array([0, 0, 0]))
magmom = magmoms.get(label, np.array([0, 0, 0]))

try:
occu = str2float(data["_atom_site_occupancy"][i])
Expand All @@ -955,18 +956,21 @@ def get_matching_coord(coord):
"in calculations unless hydrogens added."
)
comp = Composition(comp_d)

if not match:
coord_to_species[coord] = comp
coord_to_magmoms[coord] = magmom
labels[coord] = label
else:
coord_to_species[match] += comp
# disordered magnetic not currently supported
coord_to_magmoms[match] = None
labels[match] = label

sum_occu = [
sum(c.values()) for c in coord_to_species.values() if set(c.elements) != {Element("O"), Element("H")}
]
if any(o > 1 for o in sum_occu):
if any(occu > 1 for occu in sum_occu):
msg = (
f"Some occupancies ({sum_occu}) sum to > 1! If they are within "
"the occupancy_tolerance, they will be rescaled. "
Expand All @@ -975,11 +979,12 @@ def get_matching_coord(coord):
warnings.warn(msg)
self.warnings.append(msg)

allspecies = []
allcoords = []
allmagmoms = []
allhydrogens = []
all_species = []
all_coords = []
all_magmoms = []
all_hydrogens = []
equivalent_indices = []
all_labels = []

# check to see if magCIF file is disordered
if self.feature_flags["magcif"]:
Expand Down Expand Up @@ -1026,30 +1031,35 @@ def get_matching_coord(coord):
# it is equivalent.
equivalent_indices += len(coords) * [idx]

allhydrogens.extend(len(coords) * [im_h])
allcoords.extend(coords)
allspecies.extend(len(coords) * [species])
allmagmoms.extend(magmoms)
all_hydrogens.extend(len(coords) * [im_h])
all_coords.extend(coords)
all_species.extend(len(coords) * [species])
all_magmoms.extend(magmoms)
all_labels.extend(len(coords) * [labels[tmp_coords[0]]])

# rescale occupancies if necessary
for i, species in enumerate(allspecies):
for i, species in enumerate(all_species):
total_occu = sum(species.values())
if 1 < total_occu <= self._occupancy_tolerance:
allspecies[i] = species / total_occu
all_species[i] = species / total_occu

if allspecies and len(allspecies) == len(allcoords) and len(allspecies) == len(allmagmoms):
if all_species and len(all_species) == len(all_coords) and len(all_species) == len(all_magmoms):
site_properties = {}
if any(allhydrogens):
assert len(allhydrogens) == len(allcoords)
site_properties["implicit_hydrogens"] = allhydrogens
if any(all_hydrogens):
assert len(all_hydrogens) == len(all_coords)
site_properties["implicit_hydrogens"] = all_hydrogens

if self.feature_flags["magcif"]:
site_properties["magmom"] = allmagmoms
site_properties["magmom"] = all_magmoms

if any(all_labels):
assert len(all_labels) == len(all_species)
site_properties["labels"] = all_labels

if len(site_properties) == 0:
site_properties = None

struct = Structure(lattice, allspecies, allcoords, site_properties=site_properties)
struct = Structure(lattice, all_species, all_coords, site_properties=site_properties)

if symmetrized:
# Wyckoff labels not currently parsed, note that not all CIFs will contain Wyckoff labels
Expand Down
37 changes: 25 additions & 12 deletions pymatgen/io/tests/test_cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,22 +300,35 @@ def test_site_symbol_preference(self):

def test_implicit_hydrogen(self):
parser = CifParser(f"{self.TEST_FILES_DIR}/Senegalite_implicit_hydrogen.cif")
for s in parser.get_structures():
assert s.formula == "Al8 P4 O32"
assert sum(s.site_properties["implicit_hydrogens"]) == 20
for struct in parser.get_structures():
assert struct.formula == "Al8 P4 O32"
assert sum(struct.site_properties["implicit_hydrogens"]) == 20
assert (
"Structure has implicit hydrogens defined, "
"parsed structure unlikely to be suitable for use "
"in calculations unless hydrogens added." in parser.warnings
)
parser = CifParser(f"{self.TEST_FILES_DIR}/cif_implicit_hydrogens_cod_1011130.cif")
s = parser.get_structures()[0]
struct = parser.get_structures()[0]
assert (
"Structure has implicit hydrogens defined, "
"parsed structure unlikely to be suitable for use "
"in calculations unless hydrogens added." in parser.warnings
)

def test_site_labels(self):
parser = CifParser(f"{self.TEST_FILES_DIR}/garnet.cif")
struct = parser.get_structures()[0]

assert "labels" in struct.site_properties
assert (
len(struct.site_properties["labels"]) == len(struct) == 80
), "Mismatch between number of labels and sites."
assert len(set(struct.site_properties["labels"])) == 4, "Expecting only 4 unique labels"

for label, specie in zip(struct.site_properties["labels"], struct.species):
assert label.startswith(specie.name)

def test_CifParserSpringerPauling(self):
# Below are 10 tests for CIFs from the Springer Materials/Pauling file DBs.

Expand Down Expand Up @@ -780,26 +793,26 @@ def test_bad_cif(self):
with pytest.raises(ValueError, match="Invalid cif file with no structures"):
parser.get_structures()
parser = CifParser(f, occupancy_tolerance=2)
s = parser.get_structures()[0]
assert s[0].species["Al3+"] == approx(0.5)
struct = parser.get_structures()[0]
assert struct[0].species["Al3+"] == approx(0.5)

def test_one_line_symm(self):
f = f"{self.TEST_FILES_DIR}/OneLineSymmP1.cif"
parser = CifParser(f)
s = parser.get_structures()[0]
assert s.formula == "Ga4 Pb2 O8"
struct = parser.get_structures()[0]
assert struct.formula == "Ga4 Pb2 O8"

def test_no_symmops(self):
f = f"{self.TEST_FILES_DIR}/nosymm.cif"
parser = CifParser(f)
s = parser.get_structures()[0]
assert s.formula == "H96 C60 O8"
struct = parser.get_structures()[0]
assert struct.formula == "H96 C60 O8"

def test_dot_positions(self):
f = f"{self.TEST_FILES_DIR}/ICSD59959.cif"
parser = CifParser(f)
s = parser.get_structures()[0]
assert s.formula == "K1 Mn1 F3"
struct = parser.get_structures()[0]
assert struct.formula == "K1 Mn1 F3"

def test_replacing_finite_precision_frac_coords(self):
cif = f"{self.TEST_FILES_DIR}/cif_finite_precision_frac_coord_error.cif"
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
"tabulate",
"tqdm",
"uncertainties>=3.1.4",
"joblib"
"joblib",
],
extras_require={
"ase": ["ase>=3.3"],
Expand Down