@@ -71,9 +71,8 @@ func (t *Torch) Validate(runtimeInfo *runtime.Info, _, newObj *trainer.TrainJob)
71
71
}
72
72
}
73
73
74
- // Check reserved envs for torchrun.
75
- // TODO(Electronic-Waste): Add validation for torchtune args.
76
74
if ! slices .Equal (newObj .Spec .Trainer .Command , constants .TorchTuneEntrypoint ) {
75
+ // Check reserved envs for torchrun.
77
76
torchEnvs := sets .New [string ]()
78
77
for _ , env := range newObj .Spec .Trainer .Env {
79
78
if constants .TorchRunReservedEnvNames .Has (env .Name ) {
@@ -85,6 +84,17 @@ func (t *Torch) Validate(runtimeInfo *runtime.Info, _, newObj *trainer.TrainJob)
85
84
trainerEnvsPath := specPath .Child ("trainer" ).Child ("env" )
86
85
allErrs = append (allErrs , field .Invalid (trainerEnvsPath , newObj .Spec .Trainer .Env , fmt .Sprintf ("must not have reserved envs, invalid envs configured: %v" , sets .List (torchEnvs ))))
87
86
}
87
+ } else {
88
+ // Check supported pretrained models for torchtune.
89
+ // TODO(Electronic-Waste): Add more validation for torchtune when we support more arguments.
90
+ argPath := specPath .Child ("trainer" ).Child ("args" )
91
+ model := getModelFromArgs (newObj .Spec .Trainer .Args )
92
+
93
+ if model == nil {
94
+ allErrs = append (allErrs , field .Invalid (argPath , newObj .Spec .Trainer .Args , "must specify a pretrained model" ))
95
+ } else if ! constants .TorchTuneSupportedPretrainedModels .Has (* model ) {
96
+ allErrs = append (allErrs , field .Invalid (argPath , newObj .Spec .Trainer .Args , fmt .Sprintf ("must have a supported pretrained model, invalid model configured: %v" , * model )))
97
+ }
88
98
}
89
99
}
90
100
@@ -246,15 +256,6 @@ func getRecipeFromArgs(numNodes int32, numProcPerNode intstr.IntOrString, _ []st
246
256
247
257
// getConfigFromArgs extracts the config from distributed parameters, recipe and command line arguments.
248
258
func getConfigFileFromArgs (numNodes int32 , recipe string , args []string ) string {
249
- // Extract model from command line args.
250
- model := constants .MODEL_LLAMA3_2_1B
251
- for _ , arg := range args {
252
- if strings .HasPrefix (arg , "model" ) {
253
- model = strings .Split (arg , "=" )[1 ]
254
- break
255
- }
256
- }
257
-
258
259
// Determine the config file name based on the recipe and number of nodes.
259
260
var suffix string
260
261
switch recipe {
@@ -268,5 +269,16 @@ func getConfigFileFromArgs(numNodes int32, recipe string, args []string) string
268
269
suffix = constants .TorchTuneFullFinetuneSingleDeviceConfigSuffix
269
270
}
270
271
271
- return fmt .Sprintf ("%s%s.yaml" , model , suffix )
272
+ return fmt .Sprintf ("%s%s.yaml" , * getModelFromArgs (args ), suffix )
273
+ }
274
+
275
+ func getModelFromArgs (args []string ) * string {
276
+ var model * string
277
+ for _ , arg := range args {
278
+ if strings .HasPrefix (arg , "model" ) {
279
+ model = & strings .Split (arg , "=" )[1 ]
280
+ break
281
+ }
282
+ }
283
+ return model
272
284
}
0 commit comments