Skip to content

Commit ec491c6

Browse files
KEP-2401: Complement torch plugin to support torchtune config mutation (kubeflow/trainer#2587)
* chore(plugin): Add torchtune-related constants & update current torch plugin. Signed-off-by: Electronic-Waste <[email protected]> * chore(plugin): Add EnforceMLPolicy for torchtune. Signed-off-by: Electronic-Waste <[email protected]> * chore(plugin): Add UTs in torch plugin. Signed-off-by: Electronic-Waste <[email protected]> * fix(test): fix error in torch plugin UTs. Signed-off-by: Electronic-Waste <[email protected]> * chore(plugin): Choose recipe according to numNodes & numProcPerNode & Args. Signed-off-by: Electronic-Waste <[email protected]> * chore(sdk): Add PretrainedModel enum type. Signed-off-by: Electronic-Waste <[email protected]> * chore(plugin): Add torchtune config arg. Signed-off-by: Electronic-Waste <[email protected]> * chore(test): add UT for single-device full fine-tuning with torchtune. Signed-off-by: Electronic-Waste <[email protected]> * chore(test): Add test for multi-nodes full fine-tuning with torchtune. Signed-off-by: Electronic-Waste <[email protected]> * chore(test): Update torch validate UTs. Signed-off-by: Electronic-Waste <[email protected]> * fix(lint): fix lint error. Signed-off-by: Electronic-Waste <[email protected]> * fix(sdk): remove pretrained model enum type in sdk. Signed-off-by: Electronic-Waste <[email protected]> * fix(plugin): retrieve model name from runtimeRef. Signed-off-by: Electronic-Waste <[email protected]> * fix(lint): fix typo. Signed-off-by: Electronic-Waste <[email protected]> * fix(plugin): make some adjustments according to the review. Signed-off-by: Electronic-Waste <[email protected]> * fix(sdk): remove runtime in get_trainer_crd_from_builtin_trainer. Signed-off-by: Electronic-Waste <[email protected]> * fix(plugin): pass PET_ env variables in torch plugin for torchtune. Signed-off-by: Electronic-Waste <[email protected]> * fix(plugin): add env validation for torchtune. Signed-off-by: Electronic-Waste <[email protected]> * fix(plugin): update comments. Signed-off-by: Electronic-Waste <[email protected]> * fix(plugins): fix the implementation according to the review. Signed-off-by: Electronic-Waste <[email protected]> * test(plugins): fix UT error in torch plugin. Signed-off-by: Electronic-Waste <[email protected]> * fix: fix UT and e2e tests error. Signed-off-by: Electronic-Waste <[email protected]> * fix: remove debug info. Signed-off-by: Electronic-Waste <[email protected]> * fix(test): add args in UTs related to torchtune. Signed-off-by: Electronic-Waste <[email protected]> * fix(test): update torchtune related args. Signed-off-by: Electronic-Waste <[email protected]> * fix(test): Add a UT for multi-node mode check in torch plugin. Signed-off-by: Electronic-Waste <[email protected]> --------- Signed-off-by: Electronic-Waste <[email protected]>
1 parent 761eea3 commit ec491c6

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

kubeflow/trainer/types/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ class Initializer:
220220
"ghcr.io/kubeflow/trainer/torchtune-trainer": Trainer(
221221
trainer_type=TrainerType.BUILTIN_TRAINER,
222222
framework=Framework.TORCHTUNE,
223+
entrypoint=constants.DEFAULT_TORCHTUNE_COMMAND,
223224
),
224225
}
225226

0 commit comments

Comments
 (0)