Skip to content

Commit

Permalink
[WIP] Flatbuffers
Browse files Browse the repository at this point in the history
  • Loading branch information
hlinander committed May 2, 2024
1 parent ee42ec8 commit bd7586a
Show file tree
Hide file tree
Showing 11 changed files with 1,035 additions and 134 deletions.
108 changes: 108 additions & 0 deletions experiments/artifact_test/flatbuffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#!/usr/bin/env python
import torch
import numpy as np
import io

from lib.train_dataclasses import TrainConfig
from lib.train_dataclasses import TrainRun
from lib.train_dataclasses import OptimizerConfig
from lib.train_dataclasses import ComputeConfig

from lib.classification_metrics import create_classification_metrics
from lib.data_registry import DataSpiralsConfig
from lib.datasets.spiral_visualization import visualize_spiral
from lib.models.mlp import MLPClassConfig
from lib.generic_ablation import generic_ablation

from lib.distributed_trainer import distributed_train
from lib.ddp import ddp_setup
from lib.files import prepare_results
from lib.render_psql import setup_psql, add_artifact

from lib.flatbuffers import NPY, Component, Dimension
import flatbuffers


def create_config(mlp_dim, ensemble_id):
loss = torch.nn.CrossEntropyLoss()

def ce_loss(output, batch):
return loss(output["logits"], batch["target"])

train_config = TrainConfig(
model_config=MLPClassConfig(widths=[mlp_dim, mlp_dim]),
train_data_config=DataSpiralsConfig(seed=0, N=1000),
val_data_config=DataSpiralsConfig(seed=1, N=500),
loss=ce_loss,
optimizer=OptimizerConfig(
optimizer=torch.optim.Adam, kwargs=dict(weight_decay=0.0001)
),
batch_size=500,
ensemble_id=ensemble_id,
)
train_eval = create_classification_metrics(visualize_spiral, 2)
train_run = TrainRun(
compute_config=ComputeConfig(distributed=False, num_workers=1),
train_config=train_config,
train_eval=train_eval,
epochs=1,
save_nth_epoch=20,
validate_nth_epoch=20,
)
return train_run


if __name__ == "__main__":
configs = generic_ablation(
create_config,
dict(mlp_dim=[10], ensemble_id=list(range(1))),
)
distributed_train(configs)

device = ddp_setup()
path = prepare_results("test", configs[0])
setup_psql()

builder = flatbuffers.Builder()
xstr = builder.CreateString("x")
ystr = builder.CreateString("y")
mstr = builder.CreateString("m")
array = np.array([[1, 2, 3], [4, 5, 6]])
buffer = io.BytesIO()
np.save(buffer, array)
data = builder.CreateByteVector(buffer.getvalue())

Component.Start(builder)
Component.AddName(builder, xstr)
Component.AddUnit(builder, mstr)
x = Component.End(builder)

Component.Start(builder)
Component.AddName(builder, ystr)
Component.AddUnit(builder, mstr)
y = Component.End(builder)

Dimension.StartComponentsVector(builder, 2)
builder.PrependUOffsetTRelative(x)
builder.PrependUOffsetTRelative(y)
components = builder.EndVector()

Dimension.Start(builder)
Dimension.AddComponents(builder, components)
d1 = Dimension.End(builder)

NPY.StartDimsVector(builder, 1)
builder.PrependUOffsetTRelative(d1)
dimensions = builder.EndVector()
NPY.Start(builder)
NPY.AddDims(builder, dimensions)

NPY.AddData(builder, data)
npy = NPY.End(builder)
builder.Finish(npy)
bytes = builder.Output()

with open(path / "npyspec.npyspec", "wb") as f:
f.write(bytes)

add_artifact(configs[0], "test.npyspec", path / "npyspec.npyspec")
19 changes: 19 additions & 0 deletions flatbuffers/npy.fbs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// namespace EP;

table Component {
name: string;
unit: string;
}

table Dimension {
components: [Component];
}

table NPY {
dims: [Dimension];
data: [byte];
}

root_type NPY;


13 changes: 13 additions & 0 deletions generate_flatbuffers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/usr/bin/env python
import subprocess

subprocess.run(
["flatc", "--python", "-o", "lib/flatbuffers", "flatbuffers/npy.fbs"],
# stdout=subprocess.PIPE,
# stderr=subprocess.PIPE,
)
subprocess.run(
["flatc", "--rust", "-o", "rust/vis/src/flatbuffers", "flatbuffers/npy.fbs"],
# stdout=subprocess.PIPE,
# stderr=subprocess.PIPE,
)
74 changes: 74 additions & 0 deletions lib/flatbuffers/ArraySpec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# automatically generated by the FlatBuffers compiler, do not modify

# namespace:

import flatbuffers
from flatbuffers.compat import import_numpy
np = import_numpy()

class ArraySpec(object):
__slots__ = ['_tab']

@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = ArraySpec()
x.Init(buf, n + offset)
return x

@classmethod
def GetRootAsArraySpec(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
# ArraySpec
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)

# ArraySpec
def Dims(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
x = self._tab.Vector(o)
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
x = self._tab.Indirect(x)
from Dimension import Dimension
obj = Dimension()
obj.Init(self._tab.Bytes, x)
return obj
return None

# ArraySpec
def DimsLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.VectorLen(o)
return 0

# ArraySpec
def DimsIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
return o == 0

def ArraySpecStart(builder):
builder.StartObject(1)

def Start(builder):
ArraySpecStart(builder)

def ArraySpecAddDims(builder, dims):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(dims), 0)

def AddDims(builder, dims):
ArraySpecAddDims(builder, dims)

def ArraySpecStartDimsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def StartDimsVector(builder, numElems: int) -> int:
return ArraySpecStartDimsVector(builder, numElems)

def ArraySpecEnd(builder):
return builder.EndObject()

def End(builder):
return ArraySpecEnd(builder)
63 changes: 63 additions & 0 deletions lib/flatbuffers/Component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# automatically generated by the FlatBuffers compiler, do not modify

# namespace:

import flatbuffers
from flatbuffers.compat import import_numpy
np = import_numpy()

class Component(object):
__slots__ = ['_tab']

@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = Component()
x.Init(buf, n + offset)
return x

@classmethod
def GetRootAsComponent(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
# Component
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)

# Component
def Name(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.String(o + self._tab.Pos)
return None

# Component
def Unit(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.String(o + self._tab.Pos)
return None

def ComponentStart(builder):
builder.StartObject(2)

def Start(builder):
ComponentStart(builder)

def ComponentAddName(builder, name):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)

def AddName(builder, name):
ComponentAddName(builder, name)

def ComponentAddUnit(builder, unit):
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(unit), 0)

def AddUnit(builder, unit):
ComponentAddUnit(builder, unit)

def ComponentEnd(builder):
return builder.EndObject()

def End(builder):
return ComponentEnd(builder)
74 changes: 74 additions & 0 deletions lib/flatbuffers/Dimension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# automatically generated by the FlatBuffers compiler, do not modify

# namespace:

import flatbuffers
from flatbuffers.compat import import_numpy
np = import_numpy()

class Dimension(object):
__slots__ = ['_tab']

@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = Dimension()
x.Init(buf, n + offset)
return x

@classmethod
def GetRootAsDimension(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
# Dimension
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)

# Dimension
def Components(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
x = self._tab.Vector(o)
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
x = self._tab.Indirect(x)
from Component import Component
obj = Component()
obj.Init(self._tab.Bytes, x)
return obj
return None

# Dimension
def ComponentsLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.VectorLen(o)
return 0

# Dimension
def ComponentsIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
return o == 0

def DimensionStart(builder):
builder.StartObject(1)

def Start(builder):
DimensionStart(builder)

def DimensionAddComponents(builder, components):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(components), 0)

def AddComponents(builder, components):
DimensionAddComponents(builder, components)

def DimensionStartComponentsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def StartComponentsVector(builder, numElems: int) -> int:
return DimensionStartComponentsVector(builder, numElems)

def DimensionEnd(builder):
return builder.EndObject()

def End(builder):
return DimensionEnd(builder)
Loading

0 comments on commit bd7586a

Please sign in to comment.