Description
Right now, the test framework (#1841) only checks if the metadata
is a dict
, but we should add the metadata
testing to to test if the metadata
of datamodule is calculated correctly and passed correctly to dataloaders and model for initialisation.
The blocker is different data_module
s might have different metadata
configurations so we can't directly add these tests to the integration test.
One idea proposed by @fkiraly:
Create an intermediate model class (between _BasePtForecasterV2
and ModelMetadata
class) where we add a method _check_metadata
that can take the specific test for metadata of datamodule based on which data module the model is using , i.e, if the model is using EncoderDecoderDataModule
, _check_metadata
will test the metadata
of the EncoderDecoderDataModule
. As this will be a parent of ModelMetadata
, there will no duplication of code
So the idea is:
class _BaseObject(_SkbaseBaseObject):
pass
class _BasePtForecaster(_BaseObject):
pass
class _BasePtForecasterV2(_BasePtForecaster):
_tags = {
"object_type": "forecaster_pytorch_v2",
}
class _EncoderDecoderConfigBase(_BasePtForecasterV2):
def _check_metadata():
# metadata tests
class TFTMetadata(_EncoderDecoderConfigBase):
# testing logic
if this is a TSLibDataModule
then, the ConfigBase
class will change and will have tests specific for TSLibDataModule