Skip to content

Commit

Permalink
ruff linting
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankJonasmoelle committed Aug 20, 2024
1 parent befd546 commit d04a3cf
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 4 deletions.
4 changes: 3 additions & 1 deletion examples/mnist-keras/client/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

NUM_CLASSES = 10


def get_data(out_dir="data"):
# Make dir if necessary
if not os.path.exists(out_dir):
Expand All @@ -20,7 +21,7 @@ def get_data(out_dir="data"):


def load_data(data_path, is_train=True):
""" Load data from disk.
"""Load data from disk.
:param data_path: Path to data file.
:type data_path: str
Expand Down Expand Up @@ -49,6 +50,7 @@ def load_data(data_path, is_train=True):

return X, y


def splitset(dataset, parts):
n = dataset.shape[0]
local_n = floor(n / parts)
Expand Down
5 changes: 4 additions & 1 deletion examples/mnist-keras/client/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def compile_model(img_rows=28, img_cols=28):
model.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer=tf.keras.optimizers.Adam(), metrics=["accuracy"])
return model


def save_parameters(model, out_path):
"""Save model parameters to file.
Expand All @@ -41,6 +42,7 @@ def save_parameters(model, out_path):
weights = model.get_weights()
helper.save(weights, out_path)


def load_parameters(model_path):
"""Load model parameters from file and populate model.
Expand All @@ -54,6 +56,7 @@ def load_parameters(model_path):
model.set_weights(weights)
return model


def init_seed(out_path="../seed.npz"):
"""Initialize seed model and save it to file.
Expand All @@ -65,4 +68,4 @@ def init_seed(out_path="../seed.npz"):


if __name__ == "__main__":
init_seed("../seed.npz")
init_seed("../seed.npz")
2 changes: 1 addition & 1 deletion examples/mnist-keras/client/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ def predict(in_model_path, out_json_path, data_path=None):


if __name__ == "__main__":
predict(sys.argv[1], sys.argv[2])
predict(sys.argv[1], sys.argv[2])
2 changes: 1 addition & 1 deletion examples/mnist-keras/client/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@ def validate(in_model_path, out_json_path, data_path=None):


if __name__ == "__main__":
validate(sys.argv[1], sys.argv[2])
validate(sys.argv[1], sys.argv[2])

0 comments on commit d04a3cf

Please sign in to comment.