-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrun_single.py
66 lines (52 loc) · 1.99 KB
/
run_single.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import time
import click
import json
from dimreducers_crusher import datasets, metrics, reducers, plotters
from dimreducers_crusher.utils.py_utils import get_registry
DATASET_REGISTRY = get_registry(datasets)
METRIC_REGISTRY = get_registry(metrics, exclude_substr=[])
REDUCER_REGISTRY = get_registry(reducers)
@click.command()
@click.option("--dataset-name", type=click.Choice(DATASET_REGISTRY))
@click.option("--metric-name", type=click.Choice(METRIC_REGISTRY), multiple=True)
@click.option("--reducer-name", type=click.Choice(REDUCER_REGISTRY))
def main(dataset_name, metric_name, reducer_name):
print('---{}'.format(metric_name))
now = time.strftime("%y%m%d_%H%M%S")
report = dict()
picsdir = "./pics"
repdir = "./reports"
print("====================")
print("=====DIMCRUSHER=====")
print("====================")
datagen = DATASET_REGISTRY[dataset_name]()
data = datagen.get(n=10000, d=10)
print(data.shape, data.min(), data.max())
reducer = REDUCER_REGISTRY[reducer_name]()
data_reduced = reducer.fit_transform(data)
print(data_reduced.shape, data_reduced.min(), data_reduced.max())
# TODO: Allow multiple metrics
for mn in metric_name:
metric = METRIC_REGISTRY[mn]()
metric_value = metric.score(data, data_reduced)
print(metric_value)
metric_report = {mn: metric_value}
# Plotter
# TODO: Allow multiple plotters
fname = "{}_{}_{}".format(dataset_name, reducer_name, now)
p = plotters.DefaultPlotter()
os.makedirs(picsdir, exist_ok=True)
picpath = os.path.join(picsdir, "{}.png".format(fname))
p.plot(data_reduced, picpath)
# Report
os.makedirs(repdir, exist_ok=True)
report["dataset"] = dataset_name
report["reducer"] = reducer_name
report["metrics"] = metric_report
report["plots"] = [picpath]
reppath = os.path.join(repdir, "{}.json".format(fname))
with open(reppath, "w") as fp:
json.dump(report, fp)
if __name__ == "__main__":
main()