Skip to content

Commit c01fae1

Browse files
authored
add tests test_apply_scissor_insulator and test_apply_scissor_spin_polarized (#3082)
1 parent 0baff29 commit c01fae1

File tree

2 files changed

+93
-70
lines changed

2 files changed

+93
-70
lines changed

pymatgen/electronic_structure/bandstructure.py

Lines changed: 62 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,10 @@ def is_metal(self, efermi_tol=1e-4) -> bool:
335335
True if a metal, False if not
336336
"""
337337
for values in self.bands.values():
338-
for i in range(self.nb_bands):
339-
if np.any(values[i, :] - self.efermi < -efermi_tol) and np.any(values[i, :] - self.efermi > efermi_tol):
338+
for idx in range(self.nb_bands):
339+
if np.any(values[idx, :] - self.efermi < -efermi_tol) and np.any(
340+
values[idx, :] - self.efermi > efermi_tol
341+
):
340342
return True
341343
return False
342344

@@ -690,7 +692,7 @@ def from_dict(cls, dct):
690692
def from_old_dict(cls, dct):
691693
"""
692694
Args:
693-
dct (dict): A dict with all data for a band structure symm line object.
695+
dct (dict): A dict with all data for a band structure symmetry line object.
694696
695697
Returns:
696698
A BandStructureSymmLine object
@@ -863,57 +865,56 @@ def get_branch(self, index):
863865
branches
864866
"""
865867
to_return = []
866-
for i in self.get_equivalent_kpoints(index):
868+
for idx in self.get_equivalent_kpoints(index):
867869
for b in self.branches:
868-
if b["start_index"] <= i <= b["end_index"]:
870+
if b["start_index"] <= idx <= b["end_index"]:
869871
to_return.append(
870872
{
871873
"name": b["name"],
872874
"start_index": b["start_index"],
873875
"end_index": b["end_index"],
874-
"index": i,
876+
"index": idx,
875877
}
876878
)
877879
return to_return
878880

879881
def apply_scissor(self, new_band_gap):
880882
"""
881883
Apply a scissor operator (shift of the CBM) to fit the given band gap.
882-
If it's a metal. We look for the band crossing the fermi level
884+
If it's a metal, we look for the band crossing the Fermi level
883885
and shift this one up. This will not work all the time for metals!
884886
885887
Args:
886888
new_band_gap: the band gap the scissor band structure need to have.
887889
888890
Returns:
889-
a BandStructureSymmLine object with the applied scissor shift
891+
BandStructureSymmLine: with the applied scissor shift
890892
"""
891893
if self.is_metal():
892-
# moves then the highest index band crossing the fermi level
893-
# find this band...
894+
# moves then the highest index band crossing the fermi level find this band...
894895
max_index = -1000
895896
# spin_index = None
896-
for i in range(self.nb_bands):
897+
for idx in range(self.nb_bands):
897898
below = False
898899
above = False
899900
for j in range(len(self.kpoints)):
900-
if self.bands[Spin.up][i][j] < self.efermi:
901+
if self.bands[Spin.up][idx][j] < self.efermi:
901902
below = True
902-
if self.bands[Spin.up][i][j] > self.efermi:
903+
if self.bands[Spin.up][idx][j] > self.efermi:
903904
above = True
904-
if above and below and i > max_index:
905-
max_index = i
905+
if above and below and idx > max_index:
906+
max_index = idx
906907
# spin_index = Spin.up
907908
if self.is_spin_polarized:
908909
below = False
909910
above = False
910911
for j in range(len(self.kpoints)):
911-
if self.bands[Spin.down][i][j] < self.efermi:
912+
if self.bands[Spin.down][idx][j] < self.efermi:
912913
below = True
913-
if self.bands[Spin.down][i][j] > self.efermi:
914+
if self.bands[Spin.down][idx][j] > self.efermi:
914915
above = True
915-
if above and below and i > max_index:
916-
max_index = i
916+
if above and below and idx > max_index:
917+
max_index = idx
917918
# spin_index = Spin.down
918919
old_dict = self.as_dict()
919920
shift = new_band_gap
@@ -937,9 +938,9 @@ def as_dict(self):
937938
"""
938939
JSON-serializable dict representation of BandStructureSymmLine.
939940
"""
940-
d = super().as_dict()
941-
d["branches"] = self.branches
942-
return d
941+
dct = super().as_dict()
942+
dct["branches"] = self.branches
943+
return dct
943944

944945

945946
class LobsterBandStructureSymmLine(BandStructureSymmLine):
@@ -951,7 +952,7 @@ def as_dict(self):
951952
"""
952953
JSON-serializable dict representation of BandStructureSymmLine.
953954
"""
954-
d = {
955+
dct = {
955956
"@module": type(self).__module__,
956957
"@class": type(self).__name__,
957958
"lattice_rec": self.lattice_rec.as_dict(),
@@ -961,62 +962,62 @@ def as_dict(self):
961962
# kpoints are not kpoint objects dicts but are frac coords (this makes
962963
# the dict smaller and avoids the repetition of the lattice
963964
for k in self.kpoints:
964-
d["kpoints"].append(k.as_dict()["fcoords"])
965-
d["branches"] = self.branches
966-
d["bands"] = {str(int(spin)): self.bands[spin].tolist() for spin in self.bands}
967-
d["is_metal"] = self.is_metal()
965+
dct["kpoints"].append(k.as_dict()["fcoords"])
966+
dct["branches"] = self.branches
967+
dct["bands"] = {str(int(spin)): self.bands[spin].tolist() for spin in self.bands}
968+
dct["is_metal"] = self.is_metal()
968969
vbm = self.get_vbm()
969-
d["vbm"] = {
970+
dct["vbm"] = {
970971
"energy": vbm["energy"],
971972
"kpoint_index": [int(x) for x in vbm["kpoint_index"]],
972973
"band_index": {str(int(spin)): vbm["band_index"][spin] for spin in vbm["band_index"]},
973974
"projections": {str(spin): v for spin, v in vbm["projections"].items()},
974975
}
975976
cbm = self.get_cbm()
976-
d["cbm"] = {
977+
dct["cbm"] = {
977978
"energy": cbm["energy"],
978979
"kpoint_index": [int(x) for x in cbm["kpoint_index"]],
979980
"band_index": {str(int(spin)): cbm["band_index"][spin] for spin in cbm["band_index"]},
980981
"projections": {str(spin): v for spin, v in cbm["projections"].items()},
981982
}
982-
d["band_gap"] = self.get_band_gap()
983-
d["labels_dict"] = {}
984-
d["is_spin_polarized"] = self.is_spin_polarized
983+
dct["band_gap"] = self.get_band_gap()
984+
dct["labels_dict"] = {}
985+
dct["is_spin_polarized"] = self.is_spin_polarized
985986
# MongoDB does not accept keys starting with $. Add a blank space to fix the problem
986987
for c, label in self.labels_dict.items():
987988
mongo_key = c if not c.startswith("$") else " " + c
988-
d["labels_dict"][mongo_key] = label.as_dict()["fcoords"]
989+
dct["labels_dict"][mongo_key] = label.as_dict()["fcoords"]
989990
if len(self.projections) != 0:
990-
d["structure"] = self.structure.as_dict()
991-
d["projections"] = {str(int(spin)): np.array(v).tolist() for spin, v in self.projections.items()}
992-
return d
991+
dct["structure"] = self.structure.as_dict()
992+
dct["projections"] = {str(int(spin)): np.array(v).tolist() for spin, v in self.projections.items()}
993+
return dct
993994

994995
@classmethod
995-
def from_dict(cls, d):
996+
def from_dict(cls, dct):
996997
"""
997998
Args:
998-
d (dict): A dict with all data for a band structure symm line
999+
dct (dict): A dict with all data for a band structure symmetry line
9991000
object.
10001001
10011002
Returns:
10021003
A BandStructureSymmLine object
10031004
"""
10041005
try:
10051006
# Strip the label to recover initial string (see trick used in as_dict to handle $ chars)
1006-
labels_dict = {k.strip(): v for k, v in d["labels_dict"].items()}
1007+
labels_dict = {k.strip(): v for k, v in dct["labels_dict"].items()}
10071008
projections = {}
10081009
structure = None
1009-
if d.get("projections"):
1010-
if isinstance(d["projections"]["1"][0][0], dict):
1010+
if dct.get("projections"):
1011+
if isinstance(dct["projections"]["1"][0][0], dict):
10111012
raise ValueError("Old band structure dict format detected!")
1012-
structure = Structure.from_dict(d["structure"])
1013-
projections = {Spin(int(spin)): np.array(v) for spin, v in d["projections"].items()}
1013+
structure = Structure.from_dict(dct["structure"])
1014+
projections = {Spin(int(spin)): np.array(v) for spin, v in dct["projections"].items()}
10141015

10151016
return LobsterBandStructureSymmLine(
1016-
d["kpoints"],
1017-
{Spin(int(k)): d["bands"][k] for k in d["bands"]},
1018-
Lattice(d["lattice_rec"]["matrix"]),
1019-
d["efermi"],
1017+
dct["kpoints"],
1018+
{Spin(int(k)): dct["bands"][k] for k in dct["bands"]},
1019+
Lattice(dct["lattice_rec"]["matrix"]),
1020+
dct["efermi"],
10201021
labels_dict,
10211022
structure=structure,
10221023
projections=projections,
@@ -1028,39 +1029,39 @@ def from_dict(cls, d):
10281029
"format. The old format will be retired in pymatgen "
10291030
"5.0."
10301031
)
1031-
return LobsterBandStructureSymmLine.from_old_dict(d)
1032+
return LobsterBandStructureSymmLine.from_old_dict(dct)
10321033

10331034
@classmethod
1034-
def from_old_dict(cls, d):
1035+
def from_old_dict(cls, dct):
10351036
"""
10361037
Args:
1037-
d (dict): A dict with all data for a band structure symm line
1038+
dct (dict): A dict with all data for a band structure symmetry line
10381039
object.
10391040
10401041
Returns:
10411042
A BandStructureSymmLine object
10421043
"""
10431044
# Strip the label to recover initial string (see trick used in as_dict to handle $ chars)
1044-
labels_dict = {k.strip(): v for k, v in d["labels_dict"].items()}
1045+
labels_dict = {k.strip(): v for k, v in dct["labels_dict"].items()}
10451046
projections = {}
10461047
structure = None
1047-
if "projections" in d and len(d["projections"]) != 0:
1048-
structure = Structure.from_dict(d["structure"])
1048+
if "projections" in dct and len(dct["projections"]) != 0:
1049+
structure = Structure.from_dict(dct["structure"])
10491050
projections = {}
1050-
for spin in d["projections"]:
1051+
for spin in dct["projections"]:
10511052
dd = []
1052-
for i in range(len(d["projections"][spin])):
1053+
for i in range(len(dct["projections"][spin])):
10531054
ddd = []
1054-
for j in range(len(d["projections"][spin][i])):
1055-
ddd.append(d["projections"][spin][i][j])
1055+
for j in range(len(dct["projections"][spin][i])):
1056+
ddd.append(dct["projections"][spin][i][j])
10561057
dd.append(np.array(ddd))
10571058
projections[Spin(int(spin))] = np.array(dd)
10581059

10591060
return LobsterBandStructureSymmLine(
1060-
d["kpoints"],
1061-
{Spin(int(k)): d["bands"][k] for k in d["bands"]},
1062-
Lattice(d["lattice_rec"]["matrix"]),
1063-
d["efermi"],
1061+
dct["kpoints"],
1062+
{Spin(int(k)): dct["bands"][k] for k in dct["bands"]},
1063+
Lattice(dct["lattice_rec"]["matrix"]),
1064+
dct["efermi"],
10641065
labels_dict,
10651066
structure=structure,
10661067
projections=projections,

pymatgen/electronic_structure/tests/test_bandstructure.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,14 @@ def test_from_dict(self):
6161

6262
class BandStructureSymmLineTest(PymatgenTest):
6363
def setUp(self):
64-
self.bs = loadfn(os.path.join(PymatgenTest.TEST_FILES_DIR, "Cu2O_361_bandstructure.json"))
65-
self.bs2 = loadfn(os.path.join(PymatgenTest.TEST_FILES_DIR, "CaO_2605_bandstructure.json"))
66-
self.bs_spin = loadfn(os.path.join(PymatgenTest.TEST_FILES_DIR, "NiO_19009_bandstructure.json"))
67-
self.bs_cbm0 = loadfn(os.path.join(PymatgenTest.TEST_FILES_DIR, "InN_22205_bandstructure.json"))
68-
self.bs_cu = loadfn(os.path.join(PymatgenTest.TEST_FILES_DIR, "Cu_30_bandstructure.json"))
69-
self.bs_diff_spins = loadfn(os.path.join(PymatgenTest.TEST_FILES_DIR, "VBr2_971787_bandstructure.json"))
64+
self.bs: BandStructureSymmLine = loadfn(f"{PymatgenTest.TEST_FILES_DIR}/Cu2O_361_bandstructure.json")
65+
self.bs2: BandStructureSymmLine = loadfn(f"{PymatgenTest.TEST_FILES_DIR}/CaO_2605_bandstructure.json")
66+
self.bs_spin: BandStructureSymmLine = loadfn(f"{PymatgenTest.TEST_FILES_DIR}/NiO_19009_bandstructure.json")
67+
self.bs_cbm0: BandStructureSymmLine = loadfn(f"{PymatgenTest.TEST_FILES_DIR}/InN_22205_bandstructure.json")
68+
self.bs_cu: BandStructureSymmLine = loadfn(f"{PymatgenTest.TEST_FILES_DIR}/Cu_30_bandstructure.json")
69+
self.bs_diff_spins: BandStructureSymmLine = loadfn(
70+
f"{PymatgenTest.TEST_FILES_DIR}/VBr2_971787_bandstructure.json"
71+
)
7072
warnings.simplefilter("ignore")
7173

7274
def tearDown(self):
@@ -188,7 +190,7 @@ def test_get_sym_eq_kpoints_and_degeneracy(self):
188190
cbm_k = bs.get_cbm()["kpoint"].frac_coords
189191
vbm_k = bs.get_vbm()["kpoint"].frac_coords
190192
assert bs.get_kpoint_degeneracy(cbm_k) is None
191-
bs.structure = loadfn(os.path.join(PymatgenTest.TEST_FILES_DIR, "CaO_2605_structure.json"))
193+
bs.structure: BandStructureSymmLine = loadfn(f"{PymatgenTest.TEST_FILES_DIR}/CaO_2605_structure.json")
192194
assert bs.get_kpoint_degeneracy(cbm_k) == 3
193195
assert bs.get_kpoint_degeneracy(vbm_k) == 1
194196
cbm_eqs = bs.get_sym_eq_kpoints(cbm_k)
@@ -229,11 +231,31 @@ def test_old_format_load(self):
229231
bs_old = BandStructureSymmLine.from_dict(d)
230232
assert bs_old.get_projection_on_elements()[Spin.up][0][0]["Zn"] == 0.0971
231233

234+
def test_apply_scissor_insulator(self):
235+
# test applying a scissor operator to a metal
236+
for scissor in (1, 3):
237+
bs_scissored = self.bs.apply_scissor(scissor)
238+
assert not bs_scissored.is_metal()
239+
assert bs_scissored.nb_bands == 48
240+
assert bs_scissored.efermi == approx(3.75640309 + scissor)
241+
orig_efermi = self.bs_spin.efermi
242+
assert bs_scissored.efermi != approx(orig_efermi)
243+
244+
def test_apply_scissor_spin_polarized(self):
245+
# test applying a scissor operator to a spin-polarized system
246+
bs_scissored = self.bs_spin.apply_scissor(1.0)
247+
assert bs_scissored.is_metal()
248+
assert bs_scissored.nb_bands == 27
249+
assert {*bs_scissored.bands} == {Spin.up, Spin.down}
250+
assert bs_scissored.efermi == approx(4.64005999)
251+
orig_efermi = self.bs_spin.efermi
252+
assert bs_scissored.efermi != approx(orig_efermi)
253+
232254

233255
class ReconstructBandStructureTest(PymatgenTest):
234256
def setUp(self):
235-
self.bs_cu = loadfn(os.path.join(PymatgenTest.TEST_FILES_DIR, "Cu_30_bandstructure.json"))
236-
self.bs_cu2 = loadfn(os.path.join(PymatgenTest.TEST_FILES_DIR, "Cu_30_bandstructure.json"))
257+
self.bs_cu: BandStructureSymmLine = loadfn(f"{PymatgenTest.TEST_FILES_DIR}/Cu_30_bandstructure.json")
258+
self.bs_cu2: BandStructureSymmLine = loadfn(f"{PymatgenTest.TEST_FILES_DIR}/Cu_30_bandstructure.json")
237259
warnings.simplefilter("ignore")
238260

239261
def tearDown(self):

0 commit comments

Comments
 (0)