diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index 01423795..d745885c 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -31,10 +31,13 @@ def create_model_from_response(response: Dict) -> Model: parameters[param["name"]] = [w["value"] for w in param["values"]] function = Function(response["function"]["id"]) - inputs = [] + inputs, temperature = [], None ModelClass = Model if function == Function.TEXT_GENERATION: ModelClass = LLM + f = [p for p in response.get("params", []) if p["name"] == "temperature"] + if len(f) > 0 and len(f[0].get("defaultValues", [])) > 0: + temperature = float(f[0]["defaultValues"][0]["value"]) elif function == Function.UTILITIES: ModelClass = UtilityModel inputs = [ @@ -67,6 +70,7 @@ def create_model_from_response(response: Dict) -> Model: is_subscribed=True if "subscription" in response else False, version=response["version"]["id"], inputs=inputs, + temperature=temperature, ) diff --git a/aixplain/modules/model/llm_model.py b/aixplain/modules/model/llm_model.py index 1f64f246..cf60d0a2 100644 --- a/aixplain/modules/model/llm_model.py +++ b/aixplain/modules/model/llm_model.py @@ -61,6 +61,7 @@ def __init__( function: Optional[Function] = None, is_subscribed: bool = False, cost: Optional[Dict] = None, + temperature: float = 0.001, **additional_info, ) -> None: """LLM Init @@ -92,6 +93,7 @@ def __init__( ) self.url = config.MODELS_RUN_URL self.backend_url = config.BACKEND_URL + self.temperature = temperature def run( self, @@ -99,7 +101,7 @@ def run( context: Optional[Text] = None, prompt: Optional[Text] = None, history: Optional[List[Dict]] = None, - temperature: float = 0.001, + temperature: Optional[float] = None, max_tokens: int = 128, top_p: float = 1.0, name: Text = "model_process", @@ -114,7 +116,7 @@ def run( context (Optional[Text], optional): System message. Defaults to None. prompt (Optional[Text], optional): Prompt Message which comes on the left side of the last utterance. Defaults to None. history (Optional[List[Dict]], optional): Conversation history in OpenAI format ([{ "role": "assistant", "content": "Hello, world!"}]). Defaults to None. - temperature (float, optional): LLM temperature. Defaults to 0.001. + temperature (Optional[float], optional): LLM temperature. Defaults to None. max_tokens (int, optional): Maximum Generation Tokens. Defaults to 128. top_p (float, optional): Top P. Defaults to 1.0. name (Text, optional): ID given to a call. Defaults to "model_process". @@ -135,7 +137,7 @@ def run( parameters.setdefault("context", context) parameters.setdefault("prompt", prompt) parameters.setdefault("history", history) - parameters.setdefault("temperature", temperature) + parameters.setdefault("temperature", temperature if temperature is not None else self.temperature) parameters.setdefault("max_tokens", max_tokens) parameters.setdefault("top_p", top_p) @@ -173,7 +175,7 @@ def run_async( context: Optional[Text] = None, prompt: Optional[Text] = None, history: Optional[List[Dict]] = None, - temperature: float = 0.001, + temperature: Optional[float] = None, max_tokens: int = 128, top_p: float = 1.0, name: Text = "model_process", @@ -186,7 +188,7 @@ def run_async( context (Optional[Text], optional): System message. Defaults to None. prompt (Optional[Text], optional): Prompt Message which comes on the left side of the last utterance. Defaults to None. history (Optional[List[Dict]], optional): Conversation history in OpenAI format ([{ "role": "assistant", "content": "Hello, world!"}]). Defaults to None. - temperature (float, optional): LLM temperature. Defaults to 0.001. + temperature (Optional[float], optional): LLM temperature. Defaults to None. max_tokens (int, optional): Maximum Generation Tokens. Defaults to 128. top_p (float, optional): Top P. Defaults to 1.0. name (Text, optional): ID given to a call. Defaults to "model_process". @@ -206,7 +208,7 @@ def run_async( parameters.setdefault("context", context) parameters.setdefault("prompt", prompt) parameters.setdefault("history", history) - parameters.setdefault("temperature", temperature) + parameters.setdefault("temperature", temperature if temperature is not None else self.temperature) parameters.setdefault("max_tokens", max_tokens) parameters.setdefault("top_p", top_p) payload = build_payload(data=data, parameters=parameters) diff --git a/tests/functional/general_assets/asset_functional_test.py b/tests/functional/general_assets/asset_functional_test.py index 266b04ea..a826ad19 100644 --- a/tests/functional/general_assets/asset_functional_test.py +++ b/tests/functional/general_assets/asset_functional_test.py @@ -19,7 +19,7 @@ def inputs(): def __get_asset_factory(asset_name): - if asset_name == "model": + if "model" in asset_name: AssetFactory = ModelFactory elif asset_name == "dataset": AssetFactory = DatasetFactory @@ -40,7 +40,7 @@ def test_list(asset_name): assert asset_list["page_total"] == len(asset_list["results"]) -@pytest.mark.parametrize("asset_name", ["model", "pipeline", "metric"]) +@pytest.mark.parametrize("asset_name", ["model", "model2", "model3", "pipeline", "metric"]) def test_run(inputs, asset_name): asset_details = inputs[asset_name] AssetFactory = __get_asset_factory(asset_name) diff --git a/tests/functional/general_assets/data/asset_run_test_data.json b/tests/functional/general_assets/data/asset_run_test_data.json index e24df1ef..c9db273d 100644 --- a/tests/functional/general_assets/data/asset_run_test_data.json +++ b/tests/functional/general_assets/data/asset_run_test_data.json @@ -7,6 +7,10 @@ "id" : "60ddefab8d38c51c5885ee38", "data": "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/myname.mp3" }, + "model3" : { + "id" : "6736411cf127849667606689", + "data": "How to cook a shrimp risotto?" + }, "pipeline": { "name": "SingleNodePipeline", "data": "This is a test sentence." diff --git a/tests/functional/pipelines/run_test.py b/tests/functional/pipelines/run_test.py index 6ca9e6fe..985e4a91 100644 --- a/tests/functional/pipelines/run_test.py +++ b/tests/functional/pipelines/run_test.py @@ -251,3 +251,20 @@ def test_run_script(version: str): assert response["status"] == "SUCCESS" data = response["data"][0]["segments"][0]["response"] assert data.startswith("SCRIPT MODIFIED:") + + +@pytest.mark.parametrize("version", ["2.0", "3.0"]) +def test_run_text_reconstruction(version: str): + pipeline = PipelineFactory.list(query="Text Reconstruction - DO NOT DELETE")["results"][0] + response = pipeline.run("Segment A\nSegment B\nSegment C", **{"version": version}) + + assert response["status"] == "SUCCESS" + labels = [d["label"] for d in response["data"]] + assert "Audio (Direct)" in labels + assert "Audio (Text Reconstruction)" in labels + assert "Audio (Audio Reconstruction)" in labels + assert "Text Reconstruction" in labels + + for d in response["data"]: + assert len(d["segments"]) > 0 + assert d["segments"][0]["success"] is True