forked from lukeconibear/distributed_deep_learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tensorflow_tune_mnist_example.py
62 lines (53 loc) · 1.73 KB
/
tensorflow_tune_mnist_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Example taken from Ray
# https://docs.ray.io/en/latest/train/examples/tune_tensorflow_mnist_example.html
import argparse
import ray
from ray import tune
from ray.train import Trainer
from tensorflow_mnist_example import train_func
def tune_tensorflow_mnist(num_workers, num_samples):
trainer = Trainer(backend="tensorflow", num_workers=num_workers)
Trainable = trainer.to_tune_trainable(train_func)
analysis = tune.run(
Trainable,
num_samples=num_samples,
config={
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([32, 64, 128]),
"epochs": 3
})
best_loss = analysis.get_best_config(metric="loss", mode="min")
best_accuracy = analysis.get_best_config(metric="accuracy", mode="max")
print(f"Best loss config: {best_loss}")
print(f"Best accuracy config: {best_accuracy}")
return analysis
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test",
action="store_true",
default=False,
help="Finish quickly for testing.")
parser.add_argument(
"--address",
required=False,
type=str,
help="the address to use for Ray")
parser.add_argument(
"--num-workers",
"-n",
type=int,
default=2,
help="Sets number of workers for training.")
parser.add_argument(
"--num-samples",
type=int,
default=2,
help="Sets number of samples for training.")
args = parser.parse_args()
if args.smoke_test:
ray.init(num_cpus=4)
else:
ray.init(address=args.address)
tune_tensorflow_mnist(
num_workers=args.num_workers, num_samples=args.num_samples)