From f12e916f81a4f878e96c94f03bd208f41b73c3d4 Mon Sep 17 00:00:00 2001 From: Sumana Sree Angajala <110307215+sumana-2705@users.noreply.github.com> Date: Thu, 17 Oct 2024 16:23:19 +0530 Subject: [PATCH] Added examples for tensorflow types in Datatypes and IO section (#1739) * Added examples for tensorflow types in Datatypes and IO section Signed-off-by: sumana sree * Fixed linting errors Signed-off-by: sumana sree * updated tensorflow_type.py file to avoid linting errors Signed-off-by: sumana sree * Apply lint corrections using pre-commit hooks Signed-off-by: sumana sree * added required comments Signed-off-by: sumana sree * fixed error on importing tensorflow Signed-off-by: sumana sree --------- Signed-off-by: sumana sree --- .../data_types_and_io/tensorflow_type.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 examples/data_types_and_io/data_types_and_io/tensorflow_type.py diff --git a/examples/data_types_and_io/data_types_and_io/tensorflow_type.py b/examples/data_types_and_io/data_types_and_io/tensorflow_type.py new file mode 100644 index 000000000..349f34b67 --- /dev/null +++ b/examples/data_types_and_io/data_types_and_io/tensorflow_type.py @@ -0,0 +1,56 @@ +# Import necessary libraries and modules + +from flytekit import task, workflow +from flytekit.types.directory import TFRecordsDirectory +from flytekit.types.file import TFRecordFile + +custom_image = ImageSpec( + packages=["tensorflow", "tensorflow-datasets", "flytekitplugins-kftensorflow"], + registry="ghcr.io/flyteorg", +) + +if custom_image.is_container(): + import tensorflow as tf + + # TensorFlow Model + @task + def train_model() -> tf.keras.Model: + model = tf.keras.Sequential( + [tf.keras.layers.Dense(128, activation="relu"), tf.keras.layers.Dense(10, activation="softmax")] + ) + model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) + return model + + @task + def evaluate_model(model: tf.keras.Model, x: tf.Tensor, y: tf.Tensor) -> float: + loss, accuracy = model.evaluate(x, y) + return accuracy + + @workflow + def training_workflow(x: tf.Tensor, y: tf.Tensor) -> float: + model = train_model() + return evaluate_model(model=model, x=x, y=y) + + # TFRecord Files + @task + def process_tfrecord(file: TFRecordFile) -> int: + count = 0 + for record in tf.data.TFRecordDataset(file): + count += 1 + return count + + @workflow + def tfrecord_workflow(file: TFRecordFile) -> int: + return process_tfrecord(file=file) + + # TFRecord Directories + @task + def process_tfrecords_dir(dir: TFRecordsDirectory) -> int: + count = 0 + for record in tf.data.TFRecordDataset(dir.path): + count += 1 + return count + + @workflow + def tfrecords_dir_workflow(dir: TFRecordsDirectory) -> int: + return process_tfrecords_dir(dir=dir)