From 7e7c28fe67faffc1ac248a28602b74c1ba2ad9f3 Mon Sep 17 00:00:00 2001 From: Weirui Kuang <39145382+rayrayraykk@users.noreply.github.com> Date: Fri, 19 Aug 2022 12:16:30 +0800 Subject: [PATCH] [Feature] Add utils for draw landscape (#338) --- benchmark/FedHPOB/fedhpob/utils/draw.py | 71 ++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/benchmark/FedHPOB/fedhpob/utils/draw.py b/benchmark/FedHPOB/fedhpob/utils/draw.py index cf4a569f2..776e7ce15 100644 --- a/benchmark/FedHPOB/fedhpob/utils/draw.py +++ b/benchmark/FedHPOB/fedhpob/utils/draw.py @@ -1,5 +1,6 @@ import os import json +import datetime import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm @@ -23,7 +24,6 @@ def logloader(file): def ecdf(model, data_list, algo, sample_client=None, key='test_acc'): - import datetime from fedhpob.benchmarks import TabularBenchmark # Draw ECDF from target data_list @@ -288,5 +288,74 @@ def rank_over_time(root, suffix, Y_label) +def landscape(model='cnn', + dname='femnist', + algo='avg', + sample_client=None, + key='test_acc'): + import plotly.graph_objects as go + from fedhpob.config import fhb_cfg + from fedhpob.benchmarks import TabularBenchmark + + z = [] + benchmark = TabularBenchmark(model, dname, algo, device=-1) + + def get_best_config(benchmark): + results, config = [], [] + for idx in tqdm(range(len(benchmark.table))): + row = benchmark.table.iloc[idx] + if sample_client is not None and row[ + 'sample_client'] != sample_client: + continue + result = eval(row['result']) + val_loss = result['val_avg_loss'] + try: + best_round = np.argmin(val_loss) + except: + continue + results.append(result[key][best_round]) + config.append(row) + best_index = np.argmax(results) + return config[best_index], results[best_index] + + # config, _ = get_best_config(benchmark) + config = {'wd': 0.0, 'dropout': 0.5, 'step': 1.0} + config_space = benchmark.get_configuration_space() + X, Y = sorted(list(config_space['batch'])), sorted(list( + config_space['lr'])) + print(X, Y) + for lr in Y: + y = [] + for batch in X: + xy = {'lr': lr, 'batch': batch} + print({**config, **xy}) + res = benchmark({ + **config, + **xy + }, { + 'sample_client': 1.0, + 'round': 249 + }, + fhb_cfg=fhb_cfg, + seed=12345) + y.append(res['function_value']) + z.append(y) + Z = np.array(z) + fig = go.Figure(data=[go.Surface(z=Z, x=X, y=Y)]) + fig.update_layout(title='FEMNIST (FedAvg)', + autosize=False, + width=900, + height=900, + margin=dict(l=65, r=50, b=65, t=90), + scene=dict( + xaxis_title='BS', + yaxis_title='LR', + zaxis_title='ACC', + )) + fig.write_image(os.path.join('figures', 'femnist_fedavg_landscape.pdf')) + + return + + if __name__ == '__main__': ecdf('gcn', ['cora', 'citeseer', 'pubmed'], sample_client=1.0)