Skip to content

add additional documentation for the with_overrides feature #1181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
128 changes: 128 additions & 0 deletions examples/productionizing/productionizing/customizing_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def my_workflow(x: typing.List[int]) -> int:
#
# ## Using `with_overrides`
#
# ### override Resources
# You can use the `with_overrides` method to override the resources allocated to the tasks dynamically.
# Let's understand how the resources can be initialized with an example.

Expand Down Expand Up @@ -142,3 +143,130 @@ def my_pipeline(x: typing.List[int]) -> int:
# Resource allocated using "with_overrides" method
# :::
#
# ### override task_config
# Another example for using `with_overrides` method to override the `task_config`.
# In the following we take TF Trainning for example.
# Let’s understand how the TfJob can be initialized and override with an example.
#
# For task_config, refer to the {py:func}`flytekit:flytekit.task` documentation.
#
# Define some necessary functions and dependency.
# For more detail please check [here](https://docs.flyte.org/projects/cookbook/en/latest/auto_examples/kftensorflow_plugin/tf_mnist.html#run-distributed-tensorflow-training).
# In this content we focus on how to override the `task_conf`.
# %%
import os
from dataclasses import dataclass
from typing import NamedTuple, Tuple

from dataclasses_json import dataclass_json
from flytekit import ImageSpec, Resources, dynamic, task, workflow
from flytekit.types.directory import FlyteDirectory

custom_image = ImageSpec(
name="kftensorflow-flyte-plugin",
packages=["tensorflow", "tensorflow-datasets", "flytekitplugins-kftensorflow"],
registry="ghcr.io/flyteorg",
)

if custom_image.is_container():
import tensorflow as tf
from flytekitplugins.kftensorflow import PS, Chief, TfJob, Worker

MODEL_FILE_PATH = "saved_model/"


@dataclass_json
@dataclass
class Hyperparameters(object):
# initialize a data class to store the hyperparameters.
batch_size_per_replica: int = 64
buffer_size: int = 10000
epochs: int = 10


def load_data(
hyperparameters: Hyperparameters,
) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.distribute.Strategy]:
# Fetch train and evaluation datasets
...


def get_compiled_model(strategy: tf.distribute.Strategy) -> tf.keras.Model:
# compile a model
...


def decay(epoch: int):
# define a function for decaying the learning rate
...


def train_model(
model: tf.keras.Model,
train_dataset: tf.data.Dataset,
hyperparameters: Hyperparameters,
) -> Tuple[tf.keras.Model, str]:
# define the train_model function
...


def test_model(model: tf.keras.Model, checkpoint_dir: str, eval_dataset: tf.data.Dataset) -> Tuple[float, float]:
# define the test_model function to evaluate loss and accuracy on the test dataset
...


# %% [markdown]
# To create a TensorFlow task, add {py:class}`flytekitplugins:flytekitplugins.kftensorflow.TfJob` config to the Flyte task, that is a plugin can run distributed TensorFlow training on Kubernetes.
# %%
training_outputs = NamedTuple("TrainingOutputs", accuracy=float, loss=float, model_state=FlyteDirectory)

if os.getenv("SANDBOX") != "":
resources = Resources(gpu="0", mem="1000Mi", storage="500Mi", ephemeral_storage="500Mi")
else:
resources = Resources(gpu="2", mem="10Gi", storage="10Gi", ephemeral_storage="500Mi")


@task(
task_config=TfJob(worker=Worker(replicas=1), ps=PS(replicas=1), chief=Chief(replicas=1)),
retries=2,
cache=True,
cache_version="2.2",
requests=resources,
limits=resources,
container_image=custom_image,
)
def mnist_tensorflow_job(hyperparameters: Hyperparameters) -> training_outputs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a simpler task. Let's not make this complicated, and every task has to have a definition.

train_dataset, eval_dataset, strategy = load_data(hyperparameters=hyperparameters)
model = get_compiled_model(strategy=strategy)
model, checkpoint_dir = train_model(model=model, train_dataset=train_dataset, hyperparameters=hyperparameters)
eval_loss, eval_accuracy = test_model(model=model, checkpoint_dir=checkpoint_dir, eval_dataset=eval_dataset)
return training_outputs(accuracy=eval_accuracy, loss=eval_loss, model_state=MODEL_FILE_PATH)


# %% [markdown]
# You can use `@dynamic` to generate tasks at runtime with any custom configurations you want, and `with_overrides` method overrides the old configuration allocations.
# For here we override the worker replica count.
# %%
@workflow
def mnist_tensorflow_workflow(
hyperparameters: Hyperparameters = Hyperparameters(batch_size_per_replica=64),
) -> training_outputs:
return mnist_tensorflow_job(hyperparameters=hyperparameters)


@dynamic
def dynamic_run(
new_worker: int,
hyperparameters: Hyperparameters = Hyperparameters(batch_size_per_replica=64),
) -> training_outputs:
return mnist_tensorflow_job(hyperparameters=hyperparameters).with_overrides(
task_config=TfJob(worker=Worker(replicas=new_worker), ps=PS(replicas=1), chief=Chief(replicas=1))
)


# %% [markdown]
# You can execute the workflow locally.
# %%
if __name__ == "__main__":
print(mnist_tensorflow_workflow())
print(dynamic_run(new_worker=4))