From ab2dc86383d701b96ef36129dcc7bf8abac3b485 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Tue, 30 Apr 2024 15:34:49 +0200 Subject: [PATCH] support XGB and LGBM frameworks --- osiris/app.py | 21 +++++++++++++++++---- osiris/cairo/serde/deserialize.py | 11 +++++++++-- pyproject.toml | 2 +- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/osiris/app.py b/osiris/app.py index 886f39d..6972953 100644 --- a/osiris/app.py +++ b/osiris/app.py @@ -13,6 +13,7 @@ app = typer.Typer() + def check_file_format(file_path): _, file_extension = os.path.splitext(file_path) @@ -67,13 +68,13 @@ def convert_to_numpy(data): @app.command() -def serialize(input_file: str, fp_impl: str = 'FP16x16'): +def serialize(input_file: str, framework='ONNX_ORION'): """ Serialize data from a file to a tensor representation. Args: input_file (str): The path to the input file. - fp_impl (str): Fixed-point implementation detail. + framework (str): Context of the framework used. Returns: Serialized tensor. @@ -86,8 +87,20 @@ def serialize(input_file: str, fp_impl: str = 'FP16x16'): numpy_array = convert_to_numpy(data) typer.echo("✅ Conversion to numpy completed!") - tensor = create_tensor_from_array(numpy_array, fp_impl) - typer.echo("✅ Conversion to tensor completed!") + match framework: + case 'ONNX_ORION': + tensor = create_tensor_from_array(numpy_array, 'FP16x16') + typer.echo("✅ Conversion to tensor completed!") + case 'XGB': + numpy_array *= 100000 + tensor = numpy_array.astype(np.int64) + case 'LGBM': + numpy_array *= 100000 + tensor = numpy_array.astype(np.int64) + case _: + tensor = create_tensor_from_array( + numpy_array, 'FP16x16') + typer.echo("✅ Conversion to tensor completed!") serialized = serializer(tensor) typer.echo("✅ Serialized tensor successfully! 🎉") diff --git a/osiris/cairo/serde/deserialize.py b/osiris/cairo/serde/deserialize.py index ec7cd41..6c46515 100644 --- a/osiris/cairo/serde/deserialize.py +++ b/osiris/cairo/serde/deserialize.py @@ -5,10 +5,17 @@ from osiris.cairo.serde.utils import felt_to_int, from_fp -def deserializer(serialized, dtype): +def deserializer(serialized, dtype, framework='ONNX_ORION'): if dtype in ["u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"]: - return felt_to_int(int(serialized)) + + match framework: + case 'XGB': + return felt_to_int(int(serialized)) / 100000 + case 'LGBM': + return felt_to_int(int(serialized)) / 100000 + case _: + return felt_to_int(int(serialized)) elif dtype.startswith("FP"): return deserialize_fp(serialized) diff --git a/pyproject.toml b/pyproject.toml index 7f0c5f4..6aed60c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "giza-osiris" -version = "0.2.6" +version = "0.2.7" description = "Osiris is a Python library designed for efficient data conversion and management, primarily transforming data into Cairo programs" authors = ["Fran Algaba "] readme = "README.md"