diff --git a/pymatgen/io/lobster/inputs.py b/pymatgen/io/lobster/inputs.py index 38b2201df87..90c64ae3613 100644 --- a/pymatgen/io/lobster/inputs.py +++ b/pymatgen/io/lobster/inputs.py @@ -12,7 +12,9 @@ import itertools import os +import re import warnings +from collections import UserDict from typing import TYPE_CHECKING, Any import numpy as np @@ -49,7 +51,7 @@ ) -class Lobsterin(dict, MSONable): +class Lobsterin(UserDict, MSONable): """ This class can handle and generate lobsterin files Furthermore, it can also modify INCAR files for lobster, generate KPOINT files for fatband calculations in Lobster, @@ -159,7 +161,10 @@ def __getitem__(self, item): if not found: new_key = item - return dict.__getitem__(self, new_key) + return super().__getitem__(new_key) + + def __delitem__(self, key): + del self.data[key.lower()] def diff(self, other): """ @@ -580,31 +585,31 @@ def from_file(cls, lobsterin: str): Lobsterindict: dict[str, Any] = {} for datum in data: - # will remove all comments to avoid complications - raw_datum = datum.split("!")[0] - raw_datum = raw_datum.split("//")[0] - raw_datum = raw_datum.split("#")[0] - raw_datum = raw_datum.split(" ") - while "" in raw_datum: - raw_datum.remove("") - if len(raw_datum) > 1: - # check which type of keyword this is, handle accordingly - if raw_datum[0].lower() not in [datum2.lower() for datum2 in Lobsterin.LISTKEYWORDS]: - if raw_datum[0].lower() not in [datum2.lower() for datum2 in Lobsterin.FLOAT_KEYWORDS]: - if raw_datum[0].lower() not in Lobsterindict: - Lobsterindict[raw_datum[0].lower()] = " ".join(raw_datum[1:]) + # Remove all comments + if not datum.startswith(("!", "#", "//")): + pattern = r"\b[^!#//]+" # exclude comments after commands + matched_pattern = re.findall(pattern, datum) + if matched_pattern: + raw_datum = matched_pattern[0].replace("\t", " ") # handle tab in between and end of command + key_word = raw_datum.strip().split(" ") # extract keyword + if len(key_word) > 1: + # check which type of keyword this is, handle accordingly + if key_word[0].lower() not in [datum2.lower() for datum2 in Lobsterin.LISTKEYWORDS]: + if key_word[0].lower() not in [datum2.lower() for datum2 in Lobsterin.FLOAT_KEYWORDS]: + if key_word[0].lower() not in Lobsterindict: + Lobsterindict[key_word[0].lower()] = " ".join(key_word[1:]) + else: + raise ValueError(f"Same keyword {key_word[0].lower()} twice!") + elif key_word[0].lower() not in Lobsterindict: + Lobsterindict[key_word[0].lower()] = float(key_word[1]) + else: + raise ValueError(f"Same keyword {key_word[0].lower()} twice!") + elif key_word[0].lower() not in Lobsterindict: + Lobsterindict[key_word[0].lower()] = [" ".join(key_word[1:])] else: - raise ValueError(f"Same keyword {raw_datum[0].lower()} twice!") - elif raw_datum[0].lower() not in Lobsterindict: - Lobsterindict[raw_datum[0].lower()] = float(raw_datum[1]) - else: - raise ValueError(f"Same keyword {raw_datum[0].lower()} twice!") - elif raw_datum[0].lower() not in Lobsterindict: - Lobsterindict[raw_datum[0].lower()] = [" ".join(raw_datum[1:])] - else: - Lobsterindict[raw_datum[0].lower()].append(" ".join(raw_datum[1:])) - elif len(raw_datum) > 0: - Lobsterindict[raw_datum[0].lower()] = True + Lobsterindict[key_word[0].lower()].append(" ".join(key_word[1:])) + elif len(key_word) > 0: + Lobsterindict[key_word[0].lower()] = True return cls(Lobsterindict) diff --git a/tests/files/cohp/lobsterin.2 b/tests/files/cohp/lobsterin.2 index 726a016834d..53da1377490 100644 --- a/tests/files/cohp/lobsterin.2 +++ b/tests/files/cohp/lobsterin.2 @@ -1,11 +1,12 @@ COHPstartEnergy -15.0 COHPendEnergy 5.0 -basisSet pbeVaspFit2015 +basisSet pbeVaspFit2015 gaussianSmearingWidth 0.1 -basisfunctions Fe 3d 4p 4s ! This is a comment +basisfunctions Fe 3d 4p 4s ! This is a comment basisfunctions Co 3d 4p 4s # This is another comment skipdos // Here, we comment again skipcohp skipcoop skipPopulationAnalysis skipGrossPopulation +! cohpsteps diff --git a/tests/io/lobster/test_inputs.py b/tests/io/lobster/test_inputs.py index 3316207f916..20c51cbbd99 100644 --- a/tests/io/lobster/test_inputs.py +++ b/tests/io/lobster/test_inputs.py @@ -1651,6 +1651,32 @@ def test_diff(self): == self.Lobsterinfromfile3.diff(self.Lobsterinfromfile)["Different"]["SKIPCOHP"]["lobsterin2"] ) + def test_dict_functionality(self): + assert self.Lobsterinfromfile.get("COHPstartEnergy") == -15.0 + assert self.Lobsterinfromfile.get("COHPstartEnergy") == -15.0 + assert self.Lobsterinfromfile.get("COhPstartenergy") == -15.0 + lobsterincopy = self.Lobsterinfromfile.copy() + lobsterincopy.update({"cohpstarteNergy": -10.00}) + assert lobsterincopy["cohpstartenergy"] == -10.0 + lobsterincopy.pop("cohpstarteNergy") + assert "cohpstartenergy" not in lobsterincopy + lobsterincopy.pop("cohpendenergY") + lobsterincopy["cohpsteps"] = 100 + assert lobsterincopy["cohpsteps"] == 100 + before = len(lobsterincopy.items()) + lobsterincopy.popitem() + after = len(lobsterincopy.items()) + assert before != after + + def test_read_write_lobsterin(self): + outfile_path = tempfile.mkstemp()[1] + lobsterin1 = Lobsterin.from_file(f"{TEST_FILES_DIR}/cohp/lobsterin.1") + lobsterin1.write_lobsterin(outfile_path) + lobsterin2 = Lobsterin.from_file(outfile_path) + assert lobsterin1.diff(lobsterin2)["Different"] == {} + + # TODO: will integer vs float break cohpsteps? + def test_get_basis(self): # get basis functions lobsterin1 = Lobsterin({})