Skip to content

Commit ea42d11

Browse files
authored
KEP-2170: Add validation to Torch numProcPerNode field (#2409)
Signed-off-by: Antonin Stefanutti <[email protected]>
1 parent 9b3b1de commit ea42d11

20 files changed

+63
-49
lines changed

api/openapi-spec/swagger.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@
517517
},
518518
"numProcPerNode": {
519519
"description": "Number of processes per node. This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI. Supported values: `auto`, `cpu`, `gpu`, or int value. Defaults to `auto`.",
520-
"type": "string"
520+
"$ref": "#/definitions/k8s.io.apimachinery.pkg.util.intstr.IntOrString"
521521
}
522522
}
523523
},
@@ -716,7 +716,7 @@
716716
},
717717
"numProcPerNode": {
718718
"description": "Number of processes/workers/slots on every training node. For the Torch runtime: `auto`, `cpu`, `gpu`, or int value can be set. For the MPI runtime only int value can be set.",
719-
"type": "string"
719+
"$ref": "#/definitions/k8s.io.apimachinery.pkg.util.intstr.IntOrString"
720720
},
721721
"resourcesPerNode": {
722722
"description": "Compute resources for each training node.",

manifests/base/crds/trainer.kubeflow.org_clustertrainingruntimes.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,17 +587,20 @@ spec:
587587
type: integer
588588
type: object
589589
numProcPerNode:
590+
anyOf:
591+
- type: integer
592+
- type: string
590593
default: auto
591594
description: |-
592595
Number of processes per node.
593596
This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI.
594597
Supported values: `auto`, `cpu`, `gpu`, or int value.
595598
Defaults to `auto`.
596-
type: string
599+
x-kubernetes-int-or-string: true
597600
x-kubernetes-validations:
598601
- message: NumProcPerNode must be equal to auto, cpu, gpu,
599602
or int value
600-
rule: self in ['auto', 'cpu', 'gpu'] || type(self) == int
603+
rule: self > 0 || self in ['auto', 'cpu', 'gpu']
601604
type: object
602605
type: object
603606
x-kubernetes-validations:

manifests/base/crds/trainer.kubeflow.org_trainingruntimes.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,17 +587,20 @@ spec:
587587
type: integer
588588
type: object
589589
numProcPerNode:
590+
anyOf:
591+
- type: integer
592+
- type: string
590593
default: auto
591594
description: |-
592595
Number of processes per node.
593596
This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI.
594597
Supported values: `auto`, `cpu`, `gpu`, or int value.
595598
Defaults to `auto`.
596-
type: string
599+
x-kubernetes-int-or-string: true
597600
x-kubernetes-validations:
598601
- message: NumProcPerNode must be equal to auto, cpu, gpu,
599602
or int value
600-
rule: self in ['auto', 'cpu', 'gpu'] || type(self) == int
603+
rule: self > 0 || self in ['auto', 'cpu', 'gpu']
601604
type: object
602605
type: object
603606
x-kubernetes-validations:

manifests/base/crds/trainer.kubeflow.org_trainjobs.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3138,11 +3138,14 @@ spec:
31383138
format: int32
31393139
type: integer
31403140
numProcPerNode:
3141+
anyOf:
3142+
- type: integer
3143+
- type: string
31413144
description: |-
31423145
Number of processes/workers/slots on every training node.
31433146
For the Torch runtime: `auto`, `cpu`, `gpu`, or int value can be set.
31443147
For the MPI runtime only int value can be set.
3145-
type: string
3148+
x-kubernetes-int-or-string: true
31463149
resourcesPerNode:
31473150
description: Compute resources for each training node.
31483151
properties:

pkg/apis/trainer/v1alpha1/trainingruntime_types.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package v1alpha1
1919
import (
2020
autoscalingv2 "k8s.io/api/autoscaling/v2"
2121
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
22+
"k8s.io/apimachinery/pkg/util/intstr"
2223
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
2324
)
2425

@@ -174,11 +175,10 @@ type TorchMLPolicySource struct {
174175
// Number of processes per node.
175176
// This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI.
176177
// Supported values: `auto`, `cpu`, `gpu`, or int value.
177-
// TODO (andreyvelich): Add kubebuilder validation.
178178
// Defaults to `auto`.
179179
// +kubebuilder:default="auto"
180-
// +kubebuilder:validation:XValidation:rule="self in ['auto', 'cpu', 'gpu'] || type(self) == int", message="NumProcPerNode must be equal to auto, cpu, gpu, or int value"
181-
NumProcPerNode *string `json:"numProcPerNode,omitempty"`
180+
// +kubebuilder:validation:XValidation:rule="self > 0 || self in ['auto', 'cpu', 'gpu']", message="NumProcPerNode must be equal to auto, cpu, gpu, or int value"
181+
NumProcPerNode *intstr.IntOrString `json:"numProcPerNode,omitempty"`
182182

183183
// Elastic policy for the PyTorch training.
184184
ElasticPolicy *TorchElasticPolicy `json:"elasticPolicy,omitempty"`

pkg/apis/trainer/v1alpha1/trainjob_types.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package v1alpha1
1919
import (
2020
corev1 "k8s.io/api/core/v1"
2121
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
22+
"k8s.io/apimachinery/pkg/util/intstr"
2223
)
2324

2425
const (
@@ -194,7 +195,7 @@ type Trainer struct {
194195
// Number of processes/workers/slots on every training node.
195196
// For the Torch runtime: `auto`, `cpu`, `gpu`, or int value can be set.
196197
// For the MPI runtime only int value can be set.
197-
NumProcPerNode *string `json:"numProcPerNode,omitempty"`
198+
NumProcPerNode *intstr.IntOrString `json:"numProcPerNode,omitempty"`
198199
}
199200

200201
// DatasetConfig represents the desired dataset configuration.

pkg/apis/trainer/v1alpha1/zz_generated.deepcopy.go

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/apis/trainer/v1alpha1/zz_generated.openapi.go

Lines changed: 4 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/client/applyconfiguration/trainer/v1alpha1/torchmlpolicysource.go

Lines changed: 6 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/client/applyconfiguration/trainer/v1alpha1/trainer.go

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/runtime/core/trainingruntime_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ package core
1919
import (
2020
"context"
2121
"fmt"
22-
"k8s.io/utils/ptr"
2322
"testing"
2423

2524
"github.com/google/go-cmp/cmp"
2625
"github.com/google/go-cmp/cmp/cmpopts"
2726
corev1 "k8s.io/api/core/v1"
2827
"k8s.io/apimachinery/pkg/api/resource"
2928
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
29+
"k8s.io/apimachinery/pkg/util/intstr"
3030
"sigs.k8s.io/controller-runtime/pkg/client"
3131
schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"
3232

@@ -264,7 +264,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
264264
"succeeded to build JobSet with Torch values from the TrainJob": {
265265
trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").RuntimeSpec(
266266
testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").Spec).
267-
TorchPolicy(100, ptr.To("auto")).
267+
TorchPolicy(100, intstr.FromString("auto")).
268268
ContainerTrainer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
269269
Obj(),
270270
).Obj(),
@@ -274,7 +274,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
274274
Trainer(
275275
testingutil.MakeTrainJobTrainerWrapper().
276276
NumNodes(30).
277-
NumProcPerNode(ptr.To("3")).
277+
NumProcPerNode(intstr.FromInt32(3)).
278278
Obj(),
279279
).
280280
Obj(),
@@ -318,7 +318,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
318318
"succeeded to build JobSet with Torch values from the Runtime and envs.": {
319319
trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").RuntimeSpec(
320320
testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").Spec).
321-
TorchPolicy(100, ptr.To("auto")).
321+
TorchPolicy(100, intstr.FromString("auto")).
322322
ContainerTrainer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
323323
ContainerTrainerEnv(
324324
[]corev1.EnvVar{

pkg/runtime/framework/plugins/mpi/mpi.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) er
9494

9595
numProcPerNode := strconv.Itoa(int(*info.RuntimePolicy.MLPolicy.MPI.NumProcPerNode))
9696
if trainJob.Spec.Trainer != nil && trainJob.Spec.Trainer.NumProcPerNode != nil {
97-
numProcPerNode = *trainJob.Spec.Trainer.NumProcPerNode
97+
numProcPerNode = (*trainJob.Spec.Trainer.NumProcPerNode).String()
9898
}
9999
info.Trainer.NumProcPerNode = numProcPerNode
100100

pkg/runtime/framework/plugins/torch/torch.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"fmt"
2222

2323
corev1 "k8s.io/api/core/v1"
24+
"k8s.io/apimachinery/pkg/util/intstr"
2425
"k8s.io/apimachinery/pkg/util/sets"
2526
"k8s.io/apimachinery/pkg/util/validation/field"
2627
"k8s.io/utils/ptr"
@@ -66,9 +67,9 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob)
6667
}
6768
info.Trainer.NumNodes = numNodes
6869

69-
numProcPerNode := info.RuntimePolicy.MLPolicy.Torch.NumProcPerNode
70+
numProcPerNode := ptr.Deref(info.RuntimePolicy.MLPolicy.Torch.NumProcPerNode, intstr.FromString("auto"))
7071
if trainJob.Spec.Trainer != nil && trainJob.Spec.Trainer.NumProcPerNode != nil {
71-
numProcPerNode = trainJob.Spec.Trainer.NumProcPerNode
72+
numProcPerNode = ptr.Deref(trainJob.Spec.Trainer.NumProcPerNode, intstr.FromString("auto"))
7273
}
7374

7475
// Update envs for Info object.
@@ -84,7 +85,7 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob)
8485
},
8586
{
8687
Name: constants.TorchEnvNumProcPerNode,
87-
Value: ptr.Deref(numProcPerNode, "auto"),
88+
Value: numProcPerNode.String(),
8889
},
8990
{
9091
Name: constants.TorchEnvNodeRank,

pkg/util/testing/wrapper.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2323
"k8s.io/apimachinery/pkg/runtime/schema"
2424
"k8s.io/apimachinery/pkg/types"
25+
"k8s.io/apimachinery/pkg/util/intstr"
2526
"k8s.io/utils/ptr"
2627
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
2728
schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"
@@ -392,8 +393,8 @@ func (t *TrainJobTrainerWrapper) NumNodes(numNodes int32) *TrainJobTrainerWrappe
392393
return t
393394
}
394395

395-
func (t *TrainJobTrainerWrapper) NumProcPerNode(numProcPerNode *string) *TrainJobTrainerWrapper {
396-
t.Trainer.NumProcPerNode = numProcPerNode
396+
func (t *TrainJobTrainerWrapper) NumProcPerNode(numProcPerNode intstr.IntOrString) *TrainJobTrainerWrapper {
397+
t.Trainer.NumProcPerNode = &numProcPerNode
397398
return t
398399
}
399400

@@ -689,12 +690,12 @@ func (s *TrainingRuntimeSpecWrapper) NumNodes(numNodes int32) *TrainingRuntimeSp
689690
return s
690691
}
691692

692-
func (s *TrainingRuntimeSpecWrapper) TorchPolicy(numNodes int32, numProcPerNode *string) *TrainingRuntimeSpecWrapper {
693+
func (s *TrainingRuntimeSpecWrapper) TorchPolicy(numNodes int32, numProcPerNode intstr.IntOrString) *TrainingRuntimeSpecWrapper {
693694
s.MLPolicy = &trainer.MLPolicy{
694695
NumNodes: &numNodes,
695696
MLPolicySource: trainer.MLPolicySource{
696697
Torch: &trainer.TorchMLPolicySource{
697-
NumProcPerNode: numProcPerNode,
698+
NumProcPerNode: &numProcPerNode,
698699
},
699700
},
700701
}

sdk/docs/TrainerV1alpha1TorchMLPolicySource.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ TorchMLPolicySource represents a PyTorch runtime configuration.
55
Name | Type | Description | Notes
66
------------ | ------------- | ------------- | -------------
77
**elastic_policy** | [**TrainerV1alpha1TorchElasticPolicy**](TrainerV1alpha1TorchElasticPolicy.md) | | [optional]
8-
**num_proc_per_node** | **str** | Number of processes per node. This value is inserted into the &#x60;--nproc-per-node&#x60; argument of the &#x60;torchrun&#x60; CLI. Supported values: &#x60;auto&#x60;, &#x60;cpu&#x60;, &#x60;gpu&#x60;, or int value. Defaults to &#x60;auto&#x60;. | [optional]
8+
**num_proc_per_node** | [**K8sIoApimachineryPkgUtilIntstrIntOrString**](K8sIoApimachineryPkgUtilIntstrIntOrString.md) | | [optional]
99

1010
[[Back to Model list]](../README.md#documentation-for-models) [[Back to API list]](../README.md#documentation-for-api-endpoints) [[Back to README]](../README.md)
1111

sdk/docs/TrainerV1alpha1Trainer.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Name | Type | Description | Notes
99
**env** | [**list[V1EnvVar]**](V1EnvVar.md) | List of environment variables to set in the training container. These values will be merged with the TrainingRuntime&#39;s trainer environments. | [optional]
1010
**image** | **str** | Docker image for the training container. | [optional]
1111
**num_nodes** | **int** | Number of training nodes. | [optional]
12-
**num_proc_per_node** | **str** | Number of processes/workers/slots on every training node. For the Torch runtime: &#x60;auto&#x60;, &#x60;cpu&#x60;, &#x60;gpu&#x60;, or int value can be set. For the MPI runtime only int value can be set. | [optional]
12+
**num_proc_per_node** | [**K8sIoApimachineryPkgUtilIntstrIntOrString**](K8sIoApimachineryPkgUtilIntstrIntOrString.md) | | [optional]
1313
**resources_per_node** | [**V1ResourceRequirements**](V1ResourceRequirements.md) | | [optional]
1414

1515
[[Back to Model list]](../README.md#documentation-for-models) [[Back to API list]](../README.md#documentation-for-api-endpoints) [[Back to README]](../README.md)

sdk/kubeflow/trainer/models/trainer_v1alpha1_torch_ml_policy_source.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class TrainerV1alpha1TorchMLPolicySource(object):
3434
"""
3535
openapi_types = {
3636
'elastic_policy': 'TrainerV1alpha1TorchElasticPolicy',
37-
'num_proc_per_node': 'str'
37+
'num_proc_per_node': 'K8sIoApimachineryPkgUtilIntstrIntOrString'
3838
}
3939

4040
attribute_map = {
@@ -82,21 +82,19 @@ def elastic_policy(self, elastic_policy):
8282
def num_proc_per_node(self):
8383
"""Gets the num_proc_per_node of this TrainerV1alpha1TorchMLPolicySource. # noqa: E501
8484
85-
Number of processes per node. This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI. Supported values: `auto`, `cpu`, `gpu`, or int value. Defaults to `auto`. # noqa: E501
8685
8786
:return: The num_proc_per_node of this TrainerV1alpha1TorchMLPolicySource. # noqa: E501
88-
:rtype: str
87+
:rtype: K8sIoApimachineryPkgUtilIntstrIntOrString
8988
"""
9089
return self._num_proc_per_node
9190

9291
@num_proc_per_node.setter
9392
def num_proc_per_node(self, num_proc_per_node):
9493
"""Sets the num_proc_per_node of this TrainerV1alpha1TorchMLPolicySource.
9594
96-
Number of processes per node. This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI. Supported values: `auto`, `cpu`, `gpu`, or int value. Defaults to `auto`. # noqa: E501
9795
9896
:param num_proc_per_node: The num_proc_per_node of this TrainerV1alpha1TorchMLPolicySource. # noqa: E501
99-
:type: str
97+
:type: K8sIoApimachineryPkgUtilIntstrIntOrString
10098
"""
10199

102100
self._num_proc_per_node = num_proc_per_node

0 commit comments

Comments
 (0)