Skip to content

Commit

Permalink
updated checkpoint saving extension, and data_augmentation implementa…
Browse files Browse the repository at this point in the history
…tion
  • Loading branch information
divyashreepathihalli committed Sep 22, 2023
1 parent 382992b commit c73a176
Showing 1 changed file with 8 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,15 @@
image_size = 32
auto = tf_data.AUTOTUNE

data_augmentation = keras.Sequential(
[
data_augmentation = [
layers.RandomCrop(image_size, image_size),
layers.RandomFlip("horizontal"),
],
name="data_augmentation",
)
]

def apply_augmentation(x):
for aug in data_augmentation:
x = aug(x)
return x

def make_datasets(images, labels, is_train=False):
dataset = tf_data.Dataset.from_tensor_slices((images, labels))
Expand All @@ -94,7 +95,7 @@ def make_datasets(images, labels, is_train=False):
dataset = dataset.batch(batch_size)
if is_train:
dataset = dataset.map(
lambda x, y: (data_augmentation(x), y), num_parallel_calls=auto
lambda x, y: (apply_augmentation(x), y), num_parallel_calls=auto
)
return dataset.prefetch(auto)

Expand Down Expand Up @@ -189,7 +190,7 @@ def run_experiment(model):
metrics=["accuracy"],
)

checkpoint_filepath = "/tmp/checkpoint"
checkpoint_filepath = "/tmp/checkpoint.weights.h5"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
checkpoint_filepath,
monitor="val_accuracy",
Expand Down

0 comments on commit c73a176

Please sign in to comment.