Skip to content

Commit a29dddc

Browse files
chore(test): Update torch validate UTs.
Signed-off-by: Electronic-Waste <[email protected]>
1 parent 206822e commit a29dddc

File tree

2 files changed

+85
-13
lines changed

2 files changed

+85
-13
lines changed

pkg/runtime/framework/plugins/torch/torch.go

+24-12
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,8 @@ func (t *Torch) Validate(runtimeInfo *runtime.Info, _, newObj *trainer.TrainJob)
7171
}
7272
}
7373

74-
// Check reserved envs for torchrun.
75-
// TODO(Electronic-Waste): Add validation for torchtune args.
7674
if !slices.Equal(newObj.Spec.Trainer.Command, constants.TorchTuneEntrypoint) {
75+
// Check reserved envs for torchrun.
7776
torchEnvs := sets.New[string]()
7877
for _, env := range newObj.Spec.Trainer.Env {
7978
if constants.TorchRunReservedEnvNames.Has(env.Name) {
@@ -85,6 +84,17 @@ func (t *Torch) Validate(runtimeInfo *runtime.Info, _, newObj *trainer.TrainJob)
8584
trainerEnvsPath := specPath.Child("trainer").Child("env")
8685
allErrs = append(allErrs, field.Invalid(trainerEnvsPath, newObj.Spec.Trainer.Env, fmt.Sprintf("must not have reserved envs, invalid envs configured: %v", sets.List(torchEnvs))))
8786
}
87+
} else {
88+
// Check supported pretrained models for torchtune.
89+
// TODO(Electronic-Waste): Add more validation for torchtune when we support more arguments.
90+
argPath := specPath.Child("trainer").Child("args")
91+
model := getModelFromArgs(newObj.Spec.Trainer.Args)
92+
93+
if model == nil {
94+
allErrs = append(allErrs, field.Invalid(argPath, newObj.Spec.Trainer.Args, "must specify a pretrained model"))
95+
} else if !constants.TorchTuneSupportedPretrainedModels.Has(*model) {
96+
allErrs = append(allErrs, field.Invalid(argPath, newObj.Spec.Trainer.Args, fmt.Sprintf("must have a supported pretrained model, invalid model configured: %v", *model)))
97+
}
8898
}
8999
}
90100

@@ -246,15 +256,6 @@ func getRecipeFromArgs(numNodes int32, numProcPerNode intstr.IntOrString, _ []st
246256

247257
// getConfigFromArgs extracts the config from distributed parameters, recipe and command line arguments.
248258
func getConfigFileFromArgs(numNodes int32, recipe string, args []string) string {
249-
// Extract model from command line args.
250-
model := constants.MODEL_LLAMA3_2_1B
251-
for _, arg := range args {
252-
if strings.HasPrefix(arg, "model") {
253-
model = strings.Split(arg, "=")[1]
254-
break
255-
}
256-
}
257-
258259
// Determine the config file name based on the recipe and number of nodes.
259260
var suffix string
260261
switch recipe {
@@ -268,5 +269,16 @@ func getConfigFileFromArgs(numNodes int32, recipe string, args []string) string
268269
suffix = constants.TorchTuneFullFinetuneSingleDeviceConfigSuffix
269270
}
270271

271-
return fmt.Sprintf("%s%s.yaml", model, suffix)
272+
return fmt.Sprintf("%s%s.yaml", *getModelFromArgs(args), suffix)
273+
}
274+
275+
func getModelFromArgs(args []string) *string {
276+
var model *string
277+
for _, arg := range args {
278+
if strings.HasPrefix(arg, "model") {
279+
model = &strings.Split(arg, "=")[1]
280+
break
281+
}
282+
}
283+
return model
272284
}

pkg/runtime/framework/plugins/torch/torch_test.go

+61-1
Original file line numberDiff line numberDiff line change
@@ -1568,7 +1568,8 @@ func TestValidate(t *testing.T) {
15681568
Container(
15691569
"ghcr.io/kubeflow/trainer/torchtune-trainer",
15701570
[]string{"tune", "run"},
1571-
nil, corev1.ResourceList{},
1571+
[]string{"model=llama3_2/1B"},
1572+
corev1.ResourceList{},
15721573
).
15731574
Env(
15741575
[]corev1.EnvVar{
@@ -1586,6 +1587,65 @@ func TestValidate(t *testing.T) {
15861587
).
15871588
Obj(),
15881589
},
1590+
"missing pretrained model": {
1591+
info: runtime.NewInfo(
1592+
runtime.WithMLPolicySource(utiltesting.MakeMLPolicyWrapper().
1593+
WithMLPolicySource(*utiltesting.MakeMLPolicySourceWrapper().
1594+
TorchPolicy(ptr.To(intstr.FromString("auto")), nil).
1595+
Obj(),
1596+
).
1597+
Obj(),
1598+
),
1599+
),
1600+
newObj: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test").
1601+
Trainer(utiltesting.MakeTrainJobTrainerWrapper().
1602+
NumProcPerNode(intstr.FromString("auto")).
1603+
Container(
1604+
"ghcr.io/kubeflow/trainer/torchtune-trainer",
1605+
[]string{"tune", "run"},
1606+
nil, corev1.ResourceList{},
1607+
).
1608+
Obj(),
1609+
).
1610+
Obj(),
1611+
wantError: field.ErrorList{
1612+
field.Invalid(
1613+
field.NewPath("spec").Child("trainer").Child("args"),
1614+
[]string(nil),
1615+
"must specify a pretrained model",
1616+
),
1617+
},
1618+
},
1619+
"unsupported pretrained model": {
1620+
info: runtime.NewInfo(
1621+
runtime.WithMLPolicySource(utiltesting.MakeMLPolicyWrapper().
1622+
WithMLPolicySource(*utiltesting.MakeMLPolicySourceWrapper().
1623+
TorchPolicy(ptr.To(intstr.FromString("auto")), nil).
1624+
Obj(),
1625+
).
1626+
Obj(),
1627+
),
1628+
),
1629+
newObj: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test").
1630+
Trainer(utiltesting.MakeTrainJobTrainerWrapper().
1631+
NumProcPerNode(intstr.FromString("auto")).
1632+
Container(
1633+
"ghcr.io/kubeflow/trainer/torchtune-trainer",
1634+
[]string{"tune", "run"},
1635+
[]string{"model=llama3_1/70B"},
1636+
corev1.ResourceList{},
1637+
).
1638+
Obj(),
1639+
).
1640+
Obj(),
1641+
wantError: field.ErrorList{
1642+
field.Invalid(
1643+
field.NewPath("spec").Child("trainer").Child("args"),
1644+
[]string{"model=llama3_1/70B"},
1645+
fmt.Sprintf("must have a supported pretrained model, invalid model configured: %s", "llama3_1/70B"),
1646+
),
1647+
},
1648+
},
15891649
}
15901650
for name, tc := range cases {
15911651
t.Run(name, func(t *testing.T) {

0 commit comments

Comments
 (0)