Skip to content

Overrides with !new that also require arguments are not working #28

Open
@egaznep

Description

@egaznep

Context: I want to instantiate a SpeechBrain pretrained model derivative. This should be possible according to

Overrides
In order to run experiments with various values for a hyperparameter, we have a system for overriding the values that are listed in the yaml file.

overrides = {"foo": 7}
fake_file = """
foo: 2
bar: 5
"""
load_hyperpyyaml(fake_file, overrides)
As shown in this example, overrides can take an ordinary python dictionary. However, this form does not support python objects. To override a python object, overrides can also take a yaml-formatted string with the HyperPyYAML syntax.

load_hyperpyyaml(fake_file, "foo: !new:collections.Counter")


Minimal example:

device = "cuda" if torch.cuda.is_available() else "cpu"
classifier: EncoderClassifier = EncoderClassifier.from_hparams( # type: ignore
    source="speechbrain/spkrec-ecapa-voxceleb",
    run_opts={"device":device},
    overrides=
'''
embedding_model: !new:ecapa.ECAPA_TDNN
    input_size: !ref <n_mels>
    channels: [1024, 1024, 1024, 1024, 3072]
    kernel_sizes: [5, 3, 3, 3, 1]
    dilations: [1, 2, 3, 4, 1]
    attention_channels: 128
    lin_neurons: 192
'''
)

This codeblock should instantiate an ECAPA_TDNN instance that's defined inside some local ecapa.py, or it should fail if this is not the case. However, regardless of whether a local ECAPA_TDNN definition exists or not, this silently fails and returns an identical outcome to the following:

device = "cuda" if torch.cuda.is_available() else "cpu"
classifier: EncoderClassifier = EncoderClassifier.from_hparams( # type: ignore
    source="speechbrain/spkrec-ecapa-voxceleb",
    run_opts={"device":device},
)

Tracking the issue, I think this is the problem: the method recursive_update will recursively update the entries of hparams['embedding_model'] but it doesn't copy the new tag, if there's one.

for k, v in u.items():
if isinstance(v, collections.abc.Mapping) and k in d:
recursive_update(d.get(k, {}), v)
elif must_match and k not in d:
raise KeyError(f"Override '{k}' not found in: {[key for key in d.keys()]}")
else:
d[k] = v

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions