-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
1,035 additions
and
134 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
#!/usr/bin/env python | ||
import torch | ||
import numpy as np | ||
import io | ||
|
||
from lib.train_dataclasses import TrainConfig | ||
from lib.train_dataclasses import TrainRun | ||
from lib.train_dataclasses import OptimizerConfig | ||
from lib.train_dataclasses import ComputeConfig | ||
|
||
from lib.classification_metrics import create_classification_metrics | ||
from lib.data_registry import DataSpiralsConfig | ||
from lib.datasets.spiral_visualization import visualize_spiral | ||
from lib.models.mlp import MLPClassConfig | ||
from lib.generic_ablation import generic_ablation | ||
|
||
from lib.distributed_trainer import distributed_train | ||
from lib.ddp import ddp_setup | ||
from lib.files import prepare_results | ||
from lib.render_psql import setup_psql, add_artifact | ||
|
||
from lib.flatbuffers import NPY, Component, Dimension | ||
import flatbuffers | ||
|
||
|
||
def create_config(mlp_dim, ensemble_id): | ||
loss = torch.nn.CrossEntropyLoss() | ||
|
||
def ce_loss(output, batch): | ||
return loss(output["logits"], batch["target"]) | ||
|
||
train_config = TrainConfig( | ||
model_config=MLPClassConfig(widths=[mlp_dim, mlp_dim]), | ||
train_data_config=DataSpiralsConfig(seed=0, N=1000), | ||
val_data_config=DataSpiralsConfig(seed=1, N=500), | ||
loss=ce_loss, | ||
optimizer=OptimizerConfig( | ||
optimizer=torch.optim.Adam, kwargs=dict(weight_decay=0.0001) | ||
), | ||
batch_size=500, | ||
ensemble_id=ensemble_id, | ||
) | ||
train_eval = create_classification_metrics(visualize_spiral, 2) | ||
train_run = TrainRun( | ||
compute_config=ComputeConfig(distributed=False, num_workers=1), | ||
train_config=train_config, | ||
train_eval=train_eval, | ||
epochs=1, | ||
save_nth_epoch=20, | ||
validate_nth_epoch=20, | ||
) | ||
return train_run | ||
|
||
|
||
if __name__ == "__main__": | ||
configs = generic_ablation( | ||
create_config, | ||
dict(mlp_dim=[10], ensemble_id=list(range(1))), | ||
) | ||
distributed_train(configs) | ||
|
||
device = ddp_setup() | ||
path = prepare_results("test", configs[0]) | ||
setup_psql() | ||
|
||
builder = flatbuffers.Builder() | ||
xstr = builder.CreateString("x") | ||
ystr = builder.CreateString("y") | ||
mstr = builder.CreateString("m") | ||
array = np.array([[1, 2, 3], [4, 5, 6]]) | ||
buffer = io.BytesIO() | ||
np.save(buffer, array) | ||
data = builder.CreateByteVector(buffer.getvalue()) | ||
|
||
Component.Start(builder) | ||
Component.AddName(builder, xstr) | ||
Component.AddUnit(builder, mstr) | ||
x = Component.End(builder) | ||
|
||
Component.Start(builder) | ||
Component.AddName(builder, ystr) | ||
Component.AddUnit(builder, mstr) | ||
y = Component.End(builder) | ||
|
||
Dimension.StartComponentsVector(builder, 2) | ||
builder.PrependUOffsetTRelative(x) | ||
builder.PrependUOffsetTRelative(y) | ||
components = builder.EndVector() | ||
|
||
Dimension.Start(builder) | ||
Dimension.AddComponents(builder, components) | ||
d1 = Dimension.End(builder) | ||
|
||
NPY.StartDimsVector(builder, 1) | ||
builder.PrependUOffsetTRelative(d1) | ||
dimensions = builder.EndVector() | ||
NPY.Start(builder) | ||
NPY.AddDims(builder, dimensions) | ||
|
||
NPY.AddData(builder, data) | ||
npy = NPY.End(builder) | ||
builder.Finish(npy) | ||
bytes = builder.Output() | ||
|
||
with open(path / "npyspec.npyspec", "wb") as f: | ||
f.write(bytes) | ||
|
||
add_artifact(configs[0], "test.npyspec", path / "npyspec.npyspec") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
// namespace EP; | ||
|
||
table Component { | ||
name: string; | ||
unit: string; | ||
} | ||
|
||
table Dimension { | ||
components: [Component]; | ||
} | ||
|
||
table NPY { | ||
dims: [Dimension]; | ||
data: [byte]; | ||
} | ||
|
||
root_type NPY; | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#!/usr/bin/env python | ||
import subprocess | ||
|
||
subprocess.run( | ||
["flatc", "--python", "-o", "lib/flatbuffers", "flatbuffers/npy.fbs"], | ||
# stdout=subprocess.PIPE, | ||
# stderr=subprocess.PIPE, | ||
) | ||
subprocess.run( | ||
["flatc", "--rust", "-o", "rust/vis/src/flatbuffers", "flatbuffers/npy.fbs"], | ||
# stdout=subprocess.PIPE, | ||
# stderr=subprocess.PIPE, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# automatically generated by the FlatBuffers compiler, do not modify | ||
|
||
# namespace: | ||
|
||
import flatbuffers | ||
from flatbuffers.compat import import_numpy | ||
np = import_numpy() | ||
|
||
class ArraySpec(object): | ||
__slots__ = ['_tab'] | ||
|
||
@classmethod | ||
def GetRootAs(cls, buf, offset=0): | ||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) | ||
x = ArraySpec() | ||
x.Init(buf, n + offset) | ||
return x | ||
|
||
@classmethod | ||
def GetRootAsArraySpec(cls, buf, offset=0): | ||
"""This method is deprecated. Please switch to GetRootAs.""" | ||
return cls.GetRootAs(buf, offset) | ||
# ArraySpec | ||
def Init(self, buf, pos): | ||
self._tab = flatbuffers.table.Table(buf, pos) | ||
|
||
# ArraySpec | ||
def Dims(self, j): | ||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) | ||
if o != 0: | ||
x = self._tab.Vector(o) | ||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 | ||
x = self._tab.Indirect(x) | ||
from Dimension import Dimension | ||
obj = Dimension() | ||
obj.Init(self._tab.Bytes, x) | ||
return obj | ||
return None | ||
|
||
# ArraySpec | ||
def DimsLength(self): | ||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) | ||
if o != 0: | ||
return self._tab.VectorLen(o) | ||
return 0 | ||
|
||
# ArraySpec | ||
def DimsIsNone(self): | ||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) | ||
return o == 0 | ||
|
||
def ArraySpecStart(builder): | ||
builder.StartObject(1) | ||
|
||
def Start(builder): | ||
ArraySpecStart(builder) | ||
|
||
def ArraySpecAddDims(builder, dims): | ||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(dims), 0) | ||
|
||
def AddDims(builder, dims): | ||
ArraySpecAddDims(builder, dims) | ||
|
||
def ArraySpecStartDimsVector(builder, numElems): | ||
return builder.StartVector(4, numElems, 4) | ||
|
||
def StartDimsVector(builder, numElems: int) -> int: | ||
return ArraySpecStartDimsVector(builder, numElems) | ||
|
||
def ArraySpecEnd(builder): | ||
return builder.EndObject() | ||
|
||
def End(builder): | ||
return ArraySpecEnd(builder) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# automatically generated by the FlatBuffers compiler, do not modify | ||
|
||
# namespace: | ||
|
||
import flatbuffers | ||
from flatbuffers.compat import import_numpy | ||
np = import_numpy() | ||
|
||
class Component(object): | ||
__slots__ = ['_tab'] | ||
|
||
@classmethod | ||
def GetRootAs(cls, buf, offset=0): | ||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) | ||
x = Component() | ||
x.Init(buf, n + offset) | ||
return x | ||
|
||
@classmethod | ||
def GetRootAsComponent(cls, buf, offset=0): | ||
"""This method is deprecated. Please switch to GetRootAs.""" | ||
return cls.GetRootAs(buf, offset) | ||
# Component | ||
def Init(self, buf, pos): | ||
self._tab = flatbuffers.table.Table(buf, pos) | ||
|
||
# Component | ||
def Name(self): | ||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) | ||
if o != 0: | ||
return self._tab.String(o + self._tab.Pos) | ||
return None | ||
|
||
# Component | ||
def Unit(self): | ||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) | ||
if o != 0: | ||
return self._tab.String(o + self._tab.Pos) | ||
return None | ||
|
||
def ComponentStart(builder): | ||
builder.StartObject(2) | ||
|
||
def Start(builder): | ||
ComponentStart(builder) | ||
|
||
def ComponentAddName(builder, name): | ||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) | ||
|
||
def AddName(builder, name): | ||
ComponentAddName(builder, name) | ||
|
||
def ComponentAddUnit(builder, unit): | ||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(unit), 0) | ||
|
||
def AddUnit(builder, unit): | ||
ComponentAddUnit(builder, unit) | ||
|
||
def ComponentEnd(builder): | ||
return builder.EndObject() | ||
|
||
def End(builder): | ||
return ComponentEnd(builder) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# automatically generated by the FlatBuffers compiler, do not modify | ||
|
||
# namespace: | ||
|
||
import flatbuffers | ||
from flatbuffers.compat import import_numpy | ||
np = import_numpy() | ||
|
||
class Dimension(object): | ||
__slots__ = ['_tab'] | ||
|
||
@classmethod | ||
def GetRootAs(cls, buf, offset=0): | ||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) | ||
x = Dimension() | ||
x.Init(buf, n + offset) | ||
return x | ||
|
||
@classmethod | ||
def GetRootAsDimension(cls, buf, offset=0): | ||
"""This method is deprecated. Please switch to GetRootAs.""" | ||
return cls.GetRootAs(buf, offset) | ||
# Dimension | ||
def Init(self, buf, pos): | ||
self._tab = flatbuffers.table.Table(buf, pos) | ||
|
||
# Dimension | ||
def Components(self, j): | ||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) | ||
if o != 0: | ||
x = self._tab.Vector(o) | ||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 | ||
x = self._tab.Indirect(x) | ||
from Component import Component | ||
obj = Component() | ||
obj.Init(self._tab.Bytes, x) | ||
return obj | ||
return None | ||
|
||
# Dimension | ||
def ComponentsLength(self): | ||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) | ||
if o != 0: | ||
return self._tab.VectorLen(o) | ||
return 0 | ||
|
||
# Dimension | ||
def ComponentsIsNone(self): | ||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) | ||
return o == 0 | ||
|
||
def DimensionStart(builder): | ||
builder.StartObject(1) | ||
|
||
def Start(builder): | ||
DimensionStart(builder) | ||
|
||
def DimensionAddComponents(builder, components): | ||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(components), 0) | ||
|
||
def AddComponents(builder, components): | ||
DimensionAddComponents(builder, components) | ||
|
||
def DimensionStartComponentsVector(builder, numElems): | ||
return builder.StartVector(4, numElems, 4) | ||
|
||
def StartComponentsVector(builder, numElems: int) -> int: | ||
return DimensionStartComponentsVector(builder, numElems) | ||
|
||
def DimensionEnd(builder): | ||
return builder.EndObject() | ||
|
||
def End(builder): | ||
return DimensionEnd(builder) |
Oops, something went wrong.