Skip to content

Commit

Permalink
support XGB and LGBM frameworks
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Apr 30, 2024
1 parent 21ba9ea commit ab2dc86
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
21 changes: 17 additions & 4 deletions osiris/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

app = typer.Typer()


def check_file_format(file_path):
_, file_extension = os.path.splitext(file_path)

Expand Down Expand Up @@ -67,13 +68,13 @@ def convert_to_numpy(data):


@app.command()
def serialize(input_file: str, fp_impl: str = 'FP16x16'):
def serialize(input_file: str, framework='ONNX_ORION'):
"""
Serialize data from a file to a tensor representation.
Args:
input_file (str): The path to the input file.
fp_impl (str): Fixed-point implementation detail.
framework (str): Context of the framework used.
Returns:
Serialized tensor.
Expand All @@ -86,8 +87,20 @@ def serialize(input_file: str, fp_impl: str = 'FP16x16'):
numpy_array = convert_to_numpy(data)
typer.echo("✅ Conversion to numpy completed!")

tensor = create_tensor_from_array(numpy_array, fp_impl)
typer.echo("✅ Conversion to tensor completed!")
match framework:
case 'ONNX_ORION':
tensor = create_tensor_from_array(numpy_array, 'FP16x16')
typer.echo("✅ Conversion to tensor completed!")
case 'XGB':
numpy_array *= 100000
tensor = numpy_array.astype(np.int64)
case 'LGBM':
numpy_array *= 100000
tensor = numpy_array.astype(np.int64)
case _:
tensor = create_tensor_from_array(
numpy_array, 'FP16x16')
typer.echo("✅ Conversion to tensor completed!")

serialized = serializer(tensor)
typer.echo("✅ Serialized tensor successfully! 🎉")
Expand Down
11 changes: 9 additions & 2 deletions osiris/cairo/serde/deserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@
from osiris.cairo.serde.utils import felt_to_int, from_fp


def deserializer(serialized, dtype):
def deserializer(serialized, dtype, framework='ONNX_ORION'):

if dtype in ["u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"]:
return felt_to_int(int(serialized))

match framework:
case 'XGB':
return felt_to_int(int(serialized)) / 100000
case 'LGBM':
return felt_to_int(int(serialized)) / 100000
case _:
return felt_to_int(int(serialized))

elif dtype.startswith("FP"):
return deserialize_fp(serialized)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "giza-osiris"
version = "0.2.6"
version = "0.2.7"
description = "Osiris is a Python library designed for efficient data conversion and management, primarily transforming data into Cairo programs"
authors = ["Fran Algaba <[email protected]>"]
readme = "README.md"
Expand Down

0 comments on commit ab2dc86

Please sign in to comment.