Skip to content

Ewm6834 fix default record #452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/snapred/backend/dao/calibration/CalibrationDefaultRecord.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Dict, List

from snapred.backend.dao.calibration.Calibration import Calibration
from snapred.backend.dao.indexing.Record import Record
from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName, WorkspaceType


class CalibrationDefaultRecord(Record, extra="ignore"):
"""

The refer to the CalibrationRecord class for a more in-depth explanation of Calibration Records.
This class contains the default, most basic information contained within the default Calibration.

"""

# inherits from Record
# - runNumber
# - useLiteMode
# - version
# override this to point at the correct daughter class
# NOTE the version on the calculationParameters MUST match the version on the record
# this should be enforced by a validator
calculationParameters: Calibration

# specific to calibration records
workspaces: Dict[WorkspaceType, List[WorkspaceName]]
20 changes: 4 additions & 16 deletions src/snapred/backend/dao/calibration/CalibrationRecord.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from typing import Dict, List, Optional
from typing import List, Optional

from snapred.backend.dao.calibration.Calibration import Calibration
from snapred.backend.dao.calibration.CalibrationDefaultRecord import CalibrationDefaultRecord
from snapred.backend.dao.calibration.FocusGroupMetric import FocusGroupMetric
from snapred.backend.dao.CrystallographicInfo import CrystallographicInfo
from snapred.backend.dao.indexing.Record import Record
from snapred.backend.dao.state.PixelGroup import PixelGroup
from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName, WorkspaceType


class CalibrationRecord(Record, extra="ignore"):
class CalibrationRecord(CalibrationDefaultRecord, extra="ignore"):
"""

The CalibrationRecord class, serves as a comprehensive log of the inputs and parameters employed
Expand All @@ -22,17 +20,7 @@ class CalibrationRecord(Record, extra="ignore"):

"""

# inherits from Record
# - runNumber
# - useLiteMode
# - version
# override this to point at the correct daughter class
# NOTE the version on the calculationParameters MUST match the version on the record
# this should be enforced by a validator
calculationParameters: Calibration

# specific to calibration records
# specific to full calibration records
crystalInfo: CrystallographicInfo
pixelGroups: Optional[List[PixelGroup]] = None # TODO: really shouldn't be optional, will be when sns data fixed
focusGroupCalibrationMetrics: FocusGroupMetric
workspaces: Dict[WorkspaceType, List[WorkspaceName]]
2 changes: 2 additions & 0 deletions src/snapred/backend/dao/ingredients/ReductionIngredients.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class ReductionIngredients(BaseModel):

# these should come from calibration / normalization records
# But will not exist if we proceed without calibration / normalization
# NOTE: These are peaks for normalization, and thus should use the
# Calibrant Sample for the Normalization
detectorPeaksMany: Optional[List[List[GroupPeakList]]] = None
smoothingParameter: Optional[float]
calibrantSamplePath: Optional[str]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class NormalizationRecord(Record, extra="ignore"):
workspaceNames: List[WorkspaceName] = []
calibrationVersionUsed: int = VERSION_DEFAULT
crystalDBounds: Limit[float]
normalizationCalibrantSamplePath: str
Copy link
Contributor

@ekapadi ekapadi Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm tempted to suggest, that because this is a NormalizationRecord, the "normalization" part of normalizationCalibrantSamplePath is redundant. And that to clarify the usage, the "normalization" only be added at the ReductionRecord.normalizationCalibrantSamplePath. (However, I can go either way on this.)


# must also parse integers as background run numbers
@field_validator("backgroundRunNumber", mode="before")
Expand Down
3 changes: 2 additions & 1 deletion src/snapred/backend/dao/reduction/ReductionRecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pydantic import BaseModel, ConfigDict, Field

from snapred.backend.dao.calibration import CalibrationDefaultRecord
from snapred.backend.dao.calibration.CalibrationRecord import CalibrationRecord
from snapred.backend.dao.normalization.NormalizationRecord import NormalizationRecord
from snapred.backend.dao.state.PixelGroupingParameters import PixelGroupingParameters
Expand All @@ -24,7 +25,7 @@ class ReductionRecord(BaseModel):
timestamp: float = Field(frozen=True, default=None)

# specific to reduction records
calibration: Optional[CalibrationRecord] = None
calibration: Optional[CalibrationRecord | CalibrationDefaultRecord] = None
normalization: Optional[NormalizationRecord] = None
pixelGroupingParameters: Dict[str, List[PixelGroupingParameters]]

Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,18 @@
from typing import List, Optional
from pydantic import ConfigDict

from pydantic import BaseModel, ConfigDict
from snapred.backend.dao.normalization import NormalizationRecord

from snapred.backend.dao.indexing.Versioning import VERSION_DEFAULT
from snapred.backend.dao.Limit import Limit
from snapred.backend.dao.normalization.Normalization import Normalization
from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName


class CreateNormalizationRecordRequest(BaseModel, extra="forbid"):
class CreateNormalizationRecordRequest(NormalizationRecord, extra="forbid"):
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do see the purpose of this, but the necessity for this extra layer struck me as a bit weird. That is, if this is basically just a NormalizationRecord, then why do we need the additional request class. (For example, instead we could have Indexer.updateRecordVersions, instead of Indexer.createRecord(<from this record??>) which might be more clear, in terms of what is actually happening.)

Copy link
Collaborator Author

@walshmm walshmm Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did think this data object was kinda odd, I didnt look too closely at it. Looking further into how its used there's a bit of a smell.
I think this had been just passing a record around up until the indexer stuff?
I'll see if I can rectify this.

The needed data to create a NormalizationRecord.
"""

runNumber: str
useLiteMode: bool
version: Optional[int] = None
calculationParameters: Normalization
backgroundRunNumber: str
smoothingParameter: float
workspaceNames: List[WorkspaceName] = []
calibrationVersionUsed: Optional[int] = VERSION_DEFAULT
crystalDBounds: Limit[float]

model_config = ConfigDict(
# required in order to use 'WorkspaceName'
arbitrary_types_allowed=True,
)

@classmethod
def parseVersion(cls, version, *, exclude_none: bool = False, exclude_default: bool = False) -> int | None: # noqa: ARG003
return version
71 changes: 60 additions & 11 deletions src/snapred/backend/data/GroceryService.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from pydantic import validate_call

from snapred.backend.dao.indexing.Versioning import VERSION_DEFAULT
from snapred.backend.dao.ingredients import GroceryListItem
from snapred.backend.dao.state import DetectorState
from snapred.backend.data.LocalDataService import LocalDataService
Expand All @@ -26,6 +27,7 @@
from snapred.meta.mantid.WorkspaceNameGenerator import (
NameBuilder,
WorkspaceName,
WorkspaceType,
)
from snapred.meta.mantid.WorkspaceNameGenerator import (
WorkspaceNameGenerator as wng,
Expand All @@ -44,6 +46,8 @@ class GroceryService:
Just send me a list.
"""

diffcalTableFileExtension: str = ".h5"

def __init__(self, dataService: LocalDataService = None):
# 'LocalDataService' is a singleton:
# declare it here as an instance attribute, rather than a class attribute,
Expand Down Expand Up @@ -212,23 +216,36 @@ def _createGroupingFilename(self, runNumber: str, groupingScheme: str, useLiteMo
def _createDiffcalOutputWorkspaceFilename(self, item: GroceryListItem) -> str:
ext = Config["calibration.diffraction.output.extension"]
return str(
Path(self._getCalibrationDataPath(item.runNumber, item.useLiteMode, item.version))
self._getCalibrationDataPath(item.runNumber, item.useLiteMode, item.version)
/ (self._createDiffcalOutputWorkspaceName(item) + ext)
)

@validate_call
def _createDiffcalDiagnosticWorkspaceFilename(self, item: GroceryListItem) -> str:
ext = Config["calibration.diffraction.diagnostic.extension"]
return str(
Path(self._getCalibrationDataPath(item.runNumber, item.useLiteMode, item.version))
self._getCalibrationDataPath(item.runNumber, item.useLiteMode, item.version)
/ (self._createDiffcalOutputWorkspaceName(item) + ext)
)

def _createDiffcalTableFilepathFromWsName(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this. I could do
runNumber, useLiteMode, version = wsName.tokens('runNumber', 'useLiteMode', 'version')
to get the required field values. What is the purpose of the redundant arguments?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, and possibly more importantly, what happens when these argument values are in contradiction with each other?

self, runNumber: str, useLiteMode: bool, version: Optional[int], wsName: WorkspaceName
) -> str:
calibrationDataPath = self._getCalibrationDataPath(runNumber, useLiteMode, version)
expectedWsName = self.createDiffcalTableWorkspaceName(runNumber, useLiteMode, version)
if wsName != expectedWsName:
raise ValueError(
f"Workspace name {wsName} does not match the expected diffcal table workspace name for run {runNumber}",
f"(i.e. {expectedWsName})",
)

return str(calibrationDataPath / (wsName + self.diffcalTableFileExtension))

@validate_call
def _createDiffcalTableFilename(self, runNumber: str, useLiteMode: bool, version: Optional[int]) -> str:
def _createDiffcalTableFilepath(self, runNumber: str, useLiteMode: bool, version: Optional[int]) -> str:
return str(
Path(self._getCalibrationDataPath(runNumber, useLiteMode, version))
/ (self._createDiffcalTableWorkspaceName(runNumber, useLiteMode, version) + ".h5")
/ (self.createDiffcalTableWorkspaceName(runNumber, useLiteMode, version) + self.diffcalTableFileExtension)
)

@validate_call
Expand All @@ -245,7 +262,10 @@ def _createNormalizationWorkspaceFilename(self, runNumber: str, useLiteMode: boo
def _createReductionPixelMaskWorkspaceFilename(self, runNumber: str, useLiteMode: bool, timestamp: float) -> str:
return str(
Path(self._getReductionDataPath(runNumber, useLiteMode, timestamp))
/ (self._createReductionPixelMaskWorkspaceName(runNumber, useLiteMode, timestamp) + ".h5")
/ (
self._createReductionPixelMaskWorkspaceName(runNumber, useLiteMode, timestamp)
+ self.diffcalTableFileExtension
)
)

## WORKSPACE NAME METHODS
Expand Down Expand Up @@ -285,14 +305,41 @@ def _createDiffcalOutputWorkspaceName(self, item: GroceryListItem) -> WorkspaceN
.build()
)

def lookupDiffcalTableWorkspaceName(
self, runNumber: str, useLiteMode: bool, version: Optional[int]
) -> WorkspaceName:
indexer = self.dataService.calibrationIndexer(runNumber, useLiteMode)
if not isinstance(version, int):
version = indexer.latestApplicableVersion(runNumber)

record = indexer.readRecord(version)
if record is None:
raise RuntimeError(f"Could not find calibration record for run {runNumber} and version {version}")

# find first difcal table in record
wsTableNameTuple = next(filter(lambda t: t[0] == WorkspaceType.DIFFCAL_TABLE, record.workspaces.items()), None)
if wsTableNameTuple is None:
raise RuntimeError(
f"Could not find diffcal table in record for run {runNumber} in workspaces: {record.workspaces}"
)
# grab first value in list value of tuple
tableWorkspaceName = wsTableNameTuple[1][0]
return tableWorkspaceName

@validate_call
def _createDiffcalTableWorkspaceName(
def createDiffcalTableWorkspaceName(
self,
runNumber: str,
useLiteMode: bool, # noqa: ARG002
version: Optional[int],
) -> WorkspaceName:
return wng.diffCalTable().runNumber(runNumber).version(version).build()
"""
NOTE: This method will IGNORE runNumber if the provided version is VERSION_DEFAULT
"""
wsName = wng.diffCalTable().runNumber(runNumber).version(version).build()
if version == VERSION_DEFAULT:
wsName = wsName = wng.diffCalTable().runNumber("default").version(VERSION_DEFAULT).build()
return wsName

@validate_call
def _createDiffcalMaskWorkspaceName(
Expand Down Expand Up @@ -854,7 +901,7 @@ def fetchCalibrationWorkspaces(self, item: GroceryListItem) -> Dict[str, Any]:
:rtype: Dict[str, Any]
"""
runNumber, version, useLiteMode = item.runNumber, item.version, item.useLiteMode
tableWorkspaceName = self._createDiffcalTableWorkspaceName(runNumber, useLiteMode, version)
tableWorkspaceName = self.lookupDiffcalTableWorkspaceName(runNumber, useLiteMode, version)
maskWorkspaceName = self._createDiffcalMaskWorkspaceName(runNumber, useLiteMode, version)

if self.workspaceDoesExist(tableWorkspaceName):
Expand All @@ -865,7 +912,7 @@ def fetchCalibrationWorkspaces(self, item: GroceryListItem) -> Dict[str, Any]:
}
else:
# table + mask are in the same hdf5 file:
filename = self._createDiffcalTableFilename(runNumber, useLiteMode, version)
filename = self._createDiffcalTableFilepathFromWsName(runNumber, useLiteMode, version, tableWorkspaceName)

# Unless overridden: use a cached workspace as the instrument donor.
instrumentPropertySource, instrumentSource = (
Expand All @@ -888,9 +935,10 @@ def fetchCalibrationWorkspaces(self, item: GroceryListItem) -> Dict[str, Any]:

return data

# this isnt really a fetch method, this generates data
@validate_call
def fetchDefaultDiffCalTable(self, runNumber: str, useLiteMode: bool, version: int) -> WorkspaceName:
tableWorkspaceName = self._createDiffcalTableWorkspaceName("default", useLiteMode, version)
tableWorkspaceName = self.createDiffcalTableWorkspaceName("default", useLiteMode, version)
self.mantidSnapper.CalculateDiffCalTable(
"Generate the default diffcal table",
InputWorkspace=self._fetchInstrumentDonor(runNumber, useLiteMode),
Expand Down Expand Up @@ -1074,7 +1122,8 @@ def fetchGroceryList(self, groceryList: Iterable[GroceryListItem]) -> List[Works
# NOTE: fetchCalibrationWorkspaces will set the workspace name
# to that of the table workspace. Because of possible confusion with
# the behavior of the mask workspace, the workspace name is overridden here.
tableWorkspaceName = self._createDiffcalTableWorkspaceName(

tableWorkspaceName = self.lookupDiffcalTableWorkspaceName(
item.runNumber, item.useLiteMode, item.version
)
res = self.fetchCalibrationWorkspaces(item)
Expand Down
17 changes: 15 additions & 2 deletions src/snapred/backend/data/Indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import validate_call

from snapred.backend.dao.calibration.Calibration import Calibration
from snapred.backend.dao.calibration.CalibrationRecord import CalibrationRecord
from snapred.backend.dao.calibration.CalibrationRecord import CalibrationDefaultRecord, CalibrationRecord
from snapred.backend.dao.indexing.CalculationParameters import CalculationParameters
from snapred.backend.dao.indexing.IndexEntry import IndexEntry
from snapred.backend.dao.indexing.Record import Record
Expand Down Expand Up @@ -59,6 +59,10 @@ class IndexerType(StrEnum):
IndexerType.DEFAULT: Record,
}

DEFAULT_RECORD_TYPE = {
IndexerType.CALIBRATION: CalibrationDefaultRecord,
}

# the params type for each indexer type
PARAMS_TYPE = {
IndexerType.CALIBRATION: Calibration,
Expand Down Expand Up @@ -354,6 +358,15 @@ def createRecord(self, *, version, **other_arguments):
record.calculationParameters.version = record.version
return record

def _determineRecordType(self, version: Optional[int] = None):
version = self.thisOrCurrentVersion(version)
recordType = None
if version == VERSION_DEFAULT:
recordType = DEFAULT_RECORD_TYPE.get(self.indexerType, None)
if recordType is None:
recordType = RECORD_TYPE[self.indexerType]
return recordType

def readRecord(self, version: Optional[int] = None) -> Record:
"""
If no version given, defaults to current version
Expand All @@ -362,7 +375,7 @@ def readRecord(self, version: Optional[int] = None) -> Record:
filePath = self.recordPath(version)
record = None
if filePath.exists():
record = parse_file_as(RECORD_TYPE[self.indexerType], filePath)
record = parse_file_as(self._determineRecordType(version), filePath)
return record

def writeRecord(self, record: Record):
Expand Down
17 changes: 13 additions & 4 deletions src/snapred/backend/data/LocalDataService.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@
StateConfig,
StateId,
)
from snapred.backend.dao.calibration import Calibration, CalibrationRecord
from snapred.backend.dao.calibration import Calibration, CalibrationDefaultRecord, CalibrationRecord
from snapred.backend.dao.indexing.IndexEntry import IndexEntry
from snapred.backend.dao.indexing.Record import Record
from snapred.backend.dao.indexing.Versioning import VERSION_DEFAULT
from snapred.backend.dao.Limit import Limit, Pair
from snapred.backend.dao.normalization import Normalization, NormalizationRecord
Expand Down Expand Up @@ -57,6 +56,7 @@
)
from snapred.meta.mantid.WorkspaceNameGenerator import (
WorkspaceName,
WorkspaceType,
)
from snapred.meta.mantid.WorkspaceNameGenerator import (
WorkspaceNameGenerator as wng,
Expand Down Expand Up @@ -845,7 +845,7 @@ def _writeDefaultDiffCalTable(self, runNumber: str, useLiteMode: bool):
indexer = self.calibrationIndexer(runNumber, useLiteMode)
version = indexer.defaultVersion()
grocer = GroceryService()
filename = Path(grocer._createDiffcalTableWorkspaceName("default", useLiteMode, version) + ".h5")
filename = Path(grocer.createDiffcalTableWorkspaceName("default", useLiteMode, version) + ".h5")
Copy link
Contributor

@ekapadi ekapadi Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get this: why does "default" subsume the <run number> position (especially as "default" as a string is not a legitimate <run number>), but not just use the VERSION_DEFAULT which is provided for that reason particularly? (And then use the real <run number> which might be useful information. What am I missing here?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized later that possibly there's no meaningful <run number> where this is actually required. I would like to see these "default" names centralized somewhere (we have the same requirement for < version > ), but that doesn't have much to do with this PR.

outWS = grocer.fetchDefaultDiffCalTable(runNumber, useLiteMode, version)

calibrationDataPath = indexer.versionPath(version)
Expand Down Expand Up @@ -897,6 +897,9 @@ def generateInstrumentStateFromRoot(self, runId: str):
@ExceptionHandler(StateValidationException)
# NOTE if you are debugging and got here, coment out the ExceptionHandler and try again
def initializeState(self, runId: str, useLiteMode: bool, name: str = None):
from snapred.backend.data.GroceryService import GroceryService

grocer = GroceryService()
stateId, _ = self.generateStateId(runId)

instrumentState = self.generateInstrumentStateFromRoot(runId)
Expand All @@ -922,12 +925,18 @@ def initializeState(self, runId: str, useLiteMode: bool, name: str = None):
creationDate=datetime.datetime.now(),
version=version,
)

# NOTE: this creates a bare record without any other CalibrationRecord data
record = Record(
defaultDiffCalTableName = grocer.createDiffcalTableWorkspaceName("default", liteMode, version)
workspaces: Dict[WorkspaceType, List[WorkspaceName]] = {
wngt.DIFFCAL_TABLE: [defaultDiffCalTableName],
}
record = CalibrationDefaultRecord(
runNumber=runId,
useLiteMode=liteMode,
version=version,
calculationParameters=calibration,
workspaces=workspaces,
)
entry = indexer.createIndexEntry(
runNumber=runId,
Expand Down
Loading