@@ -1870,10 +1870,10 @@ def normalize_multi_task(data):
1870
1870
data ["model" ]["fitting_net_dict" ].keys (), data ["learning_rate_dict" ]
1871
1871
)
1872
1872
elif single_learning_rate :
1873
- data [
1874
- "learning_rate_dict"
1875
- ] = normalize_learning_rate_dict_with_single_learning_rate (
1876
- data [ "model" ][ "fitting_net_dict" ]. keys (), data [ "learning_rate" ]
1873
+ data ["learning_rate_dict" ] = (
1874
+ normalize_learning_rate_dict_with_single_learning_rate (
1875
+ data [ "model" ][ "fitting_net_dict" ]. keys (), data [ "learning_rate" ]
1876
+ )
1877
1877
)
1878
1878
fitting_weight = (
1879
1879
data ["training" ]["fitting_weight" ] if multi_fitting_weight else None
@@ -1916,11 +1916,7 @@ def normalize_data_dict(data_dict):
1916
1916
def normalize_loss_dict (fitting_keys , loss_dict ):
1917
1917
# check the loss dict
1918
1918
failed_loss_keys = [item for item in loss_dict if item not in fitting_keys ]
1919
- assert (
1920
- not failed_loss_keys
1921
- ), "Loss dict key(s) {} not have corresponding fitting keys in {}! " .format (
1922
- str (failed_loss_keys ), str (list (fitting_keys ))
1923
- )
1919
+ assert not failed_loss_keys , f"Loss dict key(s) { failed_loss_keys !s} not have corresponding fitting keys in { list (fitting_keys )!s} ! "
1924
1920
new_dict = {}
1925
1921
base = Argument ("base" , dict , [], [loss_variant_type_args ()], doc = "" )
1926
1922
for item in loss_dict :
@@ -1935,9 +1931,7 @@ def normalize_learning_rate_dict(fitting_keys, learning_rate_dict):
1935
1931
failed_learning_rate_keys = [
1936
1932
item for item in learning_rate_dict if item not in fitting_keys
1937
1933
]
1938
- assert not failed_learning_rate_keys , "Learning rate dict key(s) {} not have corresponding fitting keys in {}! " .format (
1939
- str (failed_learning_rate_keys ), str (list (fitting_keys ))
1940
- )
1934
+ assert not failed_learning_rate_keys , f"Learning rate dict key(s) { failed_learning_rate_keys !s} not have corresponding fitting keys in { list (fitting_keys )!s} ! "
1941
1935
new_dict = {}
1942
1936
base = Argument ("base" , dict , [], [learning_rate_variant_type_args ()], doc = "" )
1943
1937
for item in learning_rate_dict :
@@ -1960,11 +1954,7 @@ def normalize_learning_rate_dict_with_single_learning_rate(fitting_keys, learnin
1960
1954
def normalize_fitting_weight (fitting_keys , data_keys , fitting_weight = None ):
1961
1955
# check the mapping
1962
1956
failed_data_keys = [item for item in data_keys if item not in fitting_keys ]
1963
- assert (
1964
- not failed_data_keys
1965
- ), "Data dict key(s) {} not have corresponding fitting keys in {}! " .format (
1966
- str (failed_data_keys ), str (list (fitting_keys ))
1967
- )
1957
+ assert not failed_data_keys , f"Data dict key(s) { failed_data_keys !s} not have corresponding fitting keys in { list (fitting_keys )!s} ! "
1968
1958
empty_fitting_keys = []
1969
1959
valid_fitting_keys = []
1970
1960
for item in fitting_keys :
@@ -1974,9 +1964,7 @@ def normalize_fitting_weight(fitting_keys, data_keys, fitting_weight=None):
1974
1964
valid_fitting_keys .append (item )
1975
1965
if empty_fitting_keys :
1976
1966
log .warning (
1977
- "Fitting net(s) {} have no data and will not be used in training." .format (
1978
- str (empty_fitting_keys )
1979
- )
1967
+ f"Fitting net(s) { empty_fitting_keys !s} have no data and will not be used in training."
1980
1968
)
1981
1969
num_pair = len (valid_fitting_keys )
1982
1970
assert num_pair > 0 , "No valid training data systems for fitting nets!"
@@ -1991,9 +1979,7 @@ def normalize_fitting_weight(fitting_keys, data_keys, fitting_weight=None):
1991
1979
failed_weight_keys = [
1992
1980
item for item in fitting_weight if item not in fitting_keys
1993
1981
]
1994
- assert not failed_weight_keys , "Fitting weight key(s) {} not have corresponding fitting keys in {}! " .format (
1995
- str (failed_weight_keys ), str (list (fitting_keys ))
1996
- )
1982
+ assert not failed_weight_keys , f"Fitting weight key(s) { failed_weight_keys !s} not have corresponding fitting keys in { list (fitting_keys )!s} ! "
1997
1983
sum_prob = 0.0
1998
1984
for item in fitting_keys :
1999
1985
if item in valid_fitting_keys :
0 commit comments