From 9aadec205a6e208c62e29f52873fb3d675965a51 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Date: Tue, 22 Oct 2024 18:59:00 -0400 Subject: [PATCH] Fix lint warning and import error in data_types_and_io tf example (#1762) * Fix lint warning and import error in data_types_and_io tf example Signed-off-by: Eduardo Apolinario * Remove use of is_container in tensorflow_type.py example Signed-off-by: Eduardo Apolinario * Fix lint warning Signed-off-by: Eduardo Apolinario --------- Signed-off-by: Eduardo Apolinario Co-authored-by: Eduardo Apolinario --- .../data_types_and_io/tensorflow_type.py | 98 ++++++++++--------- examples/data_types_and_io/requirements.in | 1 + examples/kfmpi_plugin/README.md | 1 - 3 files changed, 53 insertions(+), 47 deletions(-) diff --git a/examples/data_types_and_io/data_types_and_io/tensorflow_type.py b/examples/data_types_and_io/data_types_and_io/tensorflow_type.py index 349f34b67..3ec8aea71 100644 --- a/examples/data_types_and_io/data_types_and_io/tensorflow_type.py +++ b/examples/data_types_and_io/data_types_and_io/tensorflow_type.py @@ -1,6 +1,6 @@ # Import necessary libraries and modules -from flytekit import task, workflow +from flytekit import ImageSpec, task, workflow from flytekit.types.directory import TFRecordsDirectory from flytekit.types.file import TFRecordFile @@ -9,48 +9,54 @@ registry="ghcr.io/flyteorg", ) -if custom_image.is_container(): - import tensorflow as tf - - # TensorFlow Model - @task - def train_model() -> tf.keras.Model: - model = tf.keras.Sequential( - [tf.keras.layers.Dense(128, activation="relu"), tf.keras.layers.Dense(10, activation="softmax")] - ) - model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) - return model - - @task - def evaluate_model(model: tf.keras.Model, x: tf.Tensor, y: tf.Tensor) -> float: - loss, accuracy = model.evaluate(x, y) - return accuracy - - @workflow - def training_workflow(x: tf.Tensor, y: tf.Tensor) -> float: - model = train_model() - return evaluate_model(model=model, x=x, y=y) - - # TFRecord Files - @task - def process_tfrecord(file: TFRecordFile) -> int: - count = 0 - for record in tf.data.TFRecordDataset(file): - count += 1 - return count - - @workflow - def tfrecord_workflow(file: TFRecordFile) -> int: - return process_tfrecord(file=file) - - # TFRecord Directories - @task - def process_tfrecords_dir(dir: TFRecordsDirectory) -> int: - count = 0 - for record in tf.data.TFRecordDataset(dir.path): - count += 1 - return count - - @workflow - def tfrecords_dir_workflow(dir: TFRecordsDirectory) -> int: - return process_tfrecords_dir(dir=dir) +import tensorflow as tf + + +# TensorFlow Model +@task +def train_model() -> tf.keras.Model: + model = tf.keras.Sequential( + [tf.keras.layers.Dense(128, activation="relu"), tf.keras.layers.Dense(10, activation="softmax")] + ) + model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) + return model + + +@task +def evaluate_model(model: tf.keras.Model, x: tf.Tensor, y: tf.Tensor) -> float: + loss, accuracy = model.evaluate(x, y) + return accuracy + + +@workflow +def training_workflow(x: tf.Tensor, y: tf.Tensor) -> float: + model = train_model() + return evaluate_model(model=model, x=x, y=y) + + +# TFRecord Files +@task +def process_tfrecord(file: TFRecordFile) -> int: + count = 0 + for record in tf.data.TFRecordDataset(file): + count += 1 + return count + + +@workflow +def tfrecord_workflow(file: TFRecordFile) -> int: + return process_tfrecord(file=file) + + +# TFRecord Directories +@task +def process_tfrecords_dir(dir: TFRecordsDirectory) -> int: + count = 0 + for record in tf.data.TFRecordDataset(dir.path): + count += 1 + return count + + +@workflow +def tfrecords_dir_workflow(dir: TFRecordsDirectory) -> int: + return process_tfrecords_dir(dir=dir) diff --git a/examples/data_types_and_io/requirements.in b/examples/data_types_and_io/requirements.in index 79bd303e5..2bcce8b12 100644 --- a/examples/data_types_and_io/requirements.in +++ b/examples/data_types_and_io/requirements.in @@ -1,4 +1,5 @@ pandas torch tabulate +tensorflow pyarrow diff --git a/examples/kfmpi_plugin/README.md b/examples/kfmpi_plugin/README.md index 0a43ff0ba..1eeeac68b 100644 --- a/examples/kfmpi_plugin/README.md +++ b/examples/kfmpi_plugin/README.md @@ -88,4 +88,3 @@ If your MPI workflow hangs or times out, it may be caused by an incorrect workfl 1. Verify Registration Method: When using a custom image, refer to the Flyte documentation on [Registering workflows](https://docs.flyte.org/en/latest/user_guide/flyte_fundamentals/registering_workflows.html#registration-patterns) to ensure you're following the correct registration method. -