Skip to content

Commit

Permalink
tf.keras Serialization fix (#432)
Browse files Browse the repository at this point in the history
* Fixes #422 tf.keras model serialization

Signed-off-by: Shah, Karan <[email protected]>

* Update Tensorflow_MNIST example with simple CNN

* Remove `layers.py`, declare model/opt/loss in notebook
* Update shape blobs to MNIST 2D input shape

Signed-off-by: Shah, Karan <[email protected]>

* Clear notebook cell outputs

Signed-off-by: Shah, Karan <[email protected]>
  • Loading branch information
MasterSkepticista authored May 24, 2022
1 parent b930fbb commit dc63dcb
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 83 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
settings:
listen_host: localhost
listen_port: 50051
sample_shape: ['784']
sample_shape: ['28', '28', '1']
target_shape: ['1']
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_dataset(self, dataset_type='train'):
@property
def sample_shape(self):
"""Return the sample shape info."""
return ['784']
return ['28', '28', '1']

@property
def target_shape(self):
Expand All @@ -94,8 +94,8 @@ def download_data(self):
with np.load(local_file_path) as f:
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))
x_train = np.reshape(x_train, (-1, 28, 28, 1))
x_test = np.reshape(x_test, (-1, 28, 28, 1))

os.remove(local_file_path) # remove mnist.npz
print('Mnist data was loaded!')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,49 @@
"id": "26fdd9ed",
"metadata": {},
"source": [
"# Federated Tensorflow Mnist Tutorial"
"# Federated Tensorflow MNIST Tutorial"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d0570122",
"id": "1329f2e6",
"metadata": {},
"outputs": [],
"source": [
"# Install dependencies if not already installed\n",
"!pip install tensorflow==2.3.1"
"# Install TF if not already. We recommend TF2.7 or greater.\n",
"# !pip install tensorflow==2.8"
]
},
{
"cell_type": "markdown",
"id": "e0d30942",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0833dfc9",
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"print('TensorFlow', tf.__version__)"
]
},
{
"cell_type": "markdown",
"id": "246f9c98",
"metadata": {},
"source": [
"## Connect to the Federation"
"## Connect to the Federation\n",
"\n",
"Start `Director` and `Envoy` before proceeding with this cell. \n",
"\n",
"This cell connects this notebook to the Federation."
]
},
{
Expand All @@ -34,14 +57,13 @@
"metadata": {},
"outputs": [],
"source": [
"# Create a federation\n",
"from openfl.interface.interactive_api.federation import Federation\n",
"\n",
"# please use the same identificator that was used in signed certificate\n",
"client_id = 'api'\n",
"cert_dir = 'cert'\n",
"director_node_fqdn = 'localhost'\n",
"director_port=50051\n",
"director_port = 50051\n",
"# 1) Run with API layer - Director mTLS \n",
"# If the user wants to enable mTLS their must provide CA root chain, and signed key pair to the federation interface\n",
"# cert_chain = f'{cert_dir}/root_ca.crt'\n",
Expand All @@ -60,13 +82,22 @@
"# --------------------------------------------------------------------------------------------------------------------\n",
"\n",
"# 2) Run with TLS disabled (trusted environment)\n",
"# Federation can also determine local fqdn automatically\n",
"\n",
"# Create a Federation\n",
"federation = Federation(\n",
" client_id=client_id,\n",
" director_node_fqdn=director_node_fqdn,\n",
" director_port=director_port, \n",
" tls=False\n",
")\n"
")"
]
},
{
"cell_type": "markdown",
"id": "6efe22a8",
"metadata": {},
"source": [
"## Query Datasets from Shard Registry"
]
},
{
Expand Down Expand Up @@ -99,7 +130,7 @@
"id": "cc0dbdbd",
"metadata": {},
"source": [
"## Describing FL experimen"
"## Describing FL experiment"
]
},
{
Expand All @@ -109,7 +140,9 @@
"metadata": {},
"outputs": [],
"source": [
"from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment"
"from openfl.interface.interactive_api.experiment import TaskInterface\n",
"from openfl.interface.interactive_api.experiment import ModelInterface\n",
"from openfl.interface.interactive_api.experiment import FLExperiment"
]
},
{
Expand All @@ -127,9 +160,29 @@
"metadata": {},
"outputs": [],
"source": [
"from layers import create_model, optimizer\n",
"# Define model\n",
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),\n",
" tf.keras.layers.MaxPooling2D((2, 2)),\n",
" tf.keras.layers.BatchNormalization(),\n",
" tf.keras.layers.Conv2D(64, (3, 3), activation='relu', input_shape=(28, 28, 1)),\n",
" tf.keras.layers.MaxPooling2D((2, 2)),\n",
" tf.keras.layers.BatchNormalization(),\n",
" tf.keras.layers.Flatten(),\n",
" tf.keras.layers.Dense(10, activation=None),\n",
"], name='simplecnn')\n",
"model.summary()\n",
"\n",
"# Define optimizer\n",
"optimizer = tf.optimizers.Adam(learning_rate=1e-3)\n",
"\n",
"# Loss and metrics. These will be used later.\n",
"loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
"train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()\n",
"val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()\n",
"\n",
"# Create ModelInterface\n",
"framework_adapter = 'openfl.plugins.frameworks_adapters.keras_adapter.FrameworkAdapterPlugin'\n",
"model = create_model()\n",
"MI = ModelInterface(model=model, optimizer=optimizer, framework_plugin=framework_adapter)"
]
},
Expand All @@ -151,6 +204,9 @@
"import numpy as np\n",
"from tensorflow.keras.utils import Sequence\n",
"\n",
"from openfl.interface.interactive_api.experiment import DataInterface\n",
"\n",
"\n",
"class DataGenerator(Sequence):\n",
"\n",
" def __init__(self, shard_descriptor, batch_size):\n",
Expand Down Expand Up @@ -269,20 +325,19 @@
"metadata": {},
"outputs": [],
"source": [
"TI = TaskInterface()\n",
"\n",
"import time\n",
"import tensorflow as tf\n",
"from layers import train_acc_metric, val_acc_metric, loss_fn\n",
"\n",
"\n",
"\n",
"TI = TaskInterface()\n",
"\n",
"# from openfl.component.aggregation_functions import AdagradAdaptiveAggregation # Uncomment this lines to use \n",
"# agg_fn = AdagradAdaptiveAggregation(model_interface=MI, learning_rate=0.4) # Adaptive Federated Optimization\n",
"# @TI.set_aggregation_function(agg_fn) # alghorithm!\n",
"# # See details in the:\n",
"# # https://arxiv.org/abs/2003.00295\n",
"\n",
"@TI.register_fl_task(model='model', data_loader='train_dataset', \\\n",
" device='device', optimizer='optimizer') \n",
"@TI.register_fl_task(model='model', data_loader='train_dataset', device='device', optimizer='optimizer') \n",
"def train(model, train_dataset, optimizer, device, loss_fn=loss_fn, warmup=False):\n",
" start_time = time.time()\n",
"\n",
Expand Down Expand Up @@ -379,8 +434,25 @@
}
],
"metadata": {
"interpreter": {
"hash": "f82a63373a71051274245dbf52f7a790e1979bab025fdff4da684b10eb9978bd"
},
"kernelspec": {
"display_name": "Python 3.8.10 ('venv': venv)",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
Expand Down

This file was deleted.

35 changes: 3 additions & 32 deletions openfl/plugins/frameworks_adapters/keras_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,9 @@ def __init__(self) -> None:
@staticmethod
def serialization_setup():
"""Prepare model for serialization (optional)."""
# Source: https://github.com/tensorflow/tensorflow/issues/34697
from tensorflow.keras.models import Model
from tensorflow.python.keras.layers import deserialize
from tensorflow.python.keras.layers import serialize
from tensorflow.python.keras.saving import saving_utils

def unpack(model, training_config, weights):
restored_model = deserialize(model)
if training_config is not None:
restored_model.compile(
**saving_utils.compile_args_from_training_config(
training_config
)
)
restored_model.set_weights(weights)
return restored_model

# Hotfix function
def make_keras_picklable():

def __reduce__(self): # NOQA:N807
model_metadata = saving_utils.model_metadata(self)
training_config = model_metadata.get('training_config', None)
model = serialize(self)
weights = self.get_weights()
return (unpack, (model, training_config, weights))

cls = Model
cls.__reduce__ = __reduce__

# Run the function
make_keras_picklable()
# Keras supports serialization natively.
# https://github.com/keras-team/keras/pull/14748.
pass

@staticmethod
def get_tensor_dict(model, optimizer=None, suffix=''):
Expand Down

0 comments on commit dc63dcb

Please sign in to comment.