@@ -401,6 +401,8 @@ def _train_worker(
401
401
include_package = include_package or []
402
402
403
403
if distributed :
404
+ assert distributed_device_ids is not None
405
+
404
406
# Since the worker is spawned and not forked, the extra imports need to be done again.
405
407
# Both the ones from the plugins and the ones from `include_package`.
406
408
import_plugins ()
@@ -556,7 +558,7 @@ def from_partial_objects(
556
558
model : Lazy [Model ],
557
559
data_loader : Lazy [DataLoader ],
558
560
trainer : Lazy [Trainer ],
559
- vocabulary : Lazy [Vocabulary ] = None ,
561
+ vocabulary : Lazy [Vocabulary ] = Lazy ( Vocabulary ) ,
560
562
datasets_for_vocab_creation : List [str ] = None ,
561
563
validation_dataset_reader : DatasetReader = None ,
562
564
validation_data_path : str = None ,
@@ -610,7 +612,7 @@ def from_partial_objects(
610
612
trainer: `Lazy[Trainer]`
611
613
The `Trainer` that actually implements the training loop. This is a lazy object because
612
614
it depends on the model that's going to be trained.
613
- vocabulary: `Lazy[Vocabulary]`, optional (default=`None `)
615
+ vocabulary: `Lazy[Vocabulary]`, optional (default=`Lazy(Vocabulary) `)
614
616
The `Vocabulary` that we will use to convert strings in the data to integer ids (and
615
617
possibly set sizes of embedding matrices in the `Model`). By default we construct the
616
618
vocabulary from the instances that we read.
@@ -664,8 +666,7 @@ def from_partial_objects(
664
666
)
665
667
666
668
vocabulary_ = vocabulary .construct (instances = instance_generator )
667
- if not vocabulary_ :
668
- vocabulary_ = Vocabulary .from_instances (instance_generator )
669
+
669
670
model_ = model .construct (vocab = vocabulary_ , serialization_dir = serialization_dir )
670
671
671
672
# Initializing the model can have side effect of expanding the vocabulary.
@@ -682,13 +683,9 @@ def from_partial_objects(
682
683
683
684
data_loader_ = data_loader .construct (dataset = datasets ["train" ])
684
685
validation_data = datasets .get ("validation" )
686
+ validation_data_loader_ : Optional [DataLoader ] = None
685
687
if validation_data is not None :
686
- # Because of the way Lazy[T] works, we can't check it's existence
687
- # _before_ we've tried to construct it. It returns None if it is not
688
- # present, so we try to construct it first, and then afterward back off
689
- # to the data_loader configuration used for training if it returns None.
690
- validation_data_loader_ = validation_data_loader .construct (dataset = validation_data )
691
- if validation_data_loader_ is None :
688
+ if validation_data_loader is None :
692
689
validation_data_loader_ = data_loader .construct (dataset = validation_data )
693
690
if getattr (validation_data_loader_ , "_batches_per_epoch" , None ) is not None :
694
691
warnings .warn (
@@ -698,16 +695,16 @@ def from_partial_objects(
698
695
"validation datasets for each epoch." ,
699
696
UserWarning ,
700
697
)
701
- else :
702
- validation_data_loader_ = None
698
+ else :
699
+ validation_data_loader_ = validation_data_loader . construct ( dataset = validation_data )
703
700
704
701
test_data = datasets .get ("test" )
702
+ test_data_loader : Optional [DataLoader ] = None
705
703
if test_data is not None :
706
- test_data_loader = validation_data_loader .construct (dataset = test_data )
707
- if test_data_loader is None :
704
+ if validation_data_loader is None :
708
705
test_data_loader = data_loader .construct (dataset = test_data )
709
- else :
710
- test_data_loader = None
706
+ else :
707
+ test_data_loader = validation_data_loader . construct ( dataset = test_data )
711
708
712
709
# We don't need to pass serialization_dir and local_rank here, because they will have been
713
710
# passed through the trainer by from_params already, because they were keyword arguments to
@@ -717,6 +714,7 @@ def from_partial_objects(
717
714
data_loader = data_loader_ ,
718
715
validation_data_loader = validation_data_loader_ ,
719
716
)
717
+ assert trainer_ is not None
720
718
721
719
return cls (
722
720
serialization_dir = serialization_dir ,
0 commit comments