Skip to content

Commit

Permalink
[HF][fix] Allow kwargs in ModelParser.run()
Browse files Browse the repository at this point in the history
See #877 for context, and #880 for why that original fix needed to be reverted

Just a quick hack to get unblocked. I tried originally to make the ASR and Image2Text ParameterizedModel but that caused other errors.

## Test Plan

Before

https://github.com/lastmile-ai/aiconfig/assets/151060367/a23a5d5c-d9a2-415b-8a6e-9826da56e985

After

https://github.com/lastmile-ai/aiconfig/assets/151060367/f29580e9-5cf6-43c5-b848-bb525eb368f2
  • Loading branch information
Rossdan Craig [email protected] committed Jan 11, 2024
1 parent 823241b commit acf78ec
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def deserialize(
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data}))
return completion_data

async def run(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]:
async def run(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any], **kwargs) -> list[Output]:
await aiconfig.callback_manager.run_callbacks(
CallbackEvent(
"on_run_start",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ async def deserialize(
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_params}))
return completion_params

async def run(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]:
async def run(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any], **kwargs) -> list[Output]:
await aiconfig.callback_manager.run_callbacks(
CallbackEvent(
"on_run_start",
Expand Down
2 changes: 1 addition & 1 deletion python/src/aiconfig/Config.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ async def run(
options,
params,
callback_manager=self.callback_manager,
**kwargs,
**kwargs, # TODO: We should remove and make argument explicit
)

event = CallbackEvent("on_run_complete", __name__, {"result": response})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def run(
aiconfig: AIConfig,
options: Optional[InferenceOptions] = None,
parameters: Dict = {},
**kwargs,
**kwargs, #TODO: We should remove and make arguments explicit
) -> List[Output]:
# maybe use prompt metadata instead of kwargs?
if kwargs.get("run_with_dependencies", False):
Expand Down
1 change: 1 addition & 0 deletions python/src/aiconfig/model_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ async def run(
aiconfig: AIConfig,
options: Optional["InferenceOptions"] = None,
parameters: Dict = {},
**kwargs, # TODO: Remove this, just a hack for now to ensure that it doesn't break
) -> ExecuteResult:
"""
Execute model inference based on completion data to be constructed in deserialize(), which includes the input prompt and
Expand Down

0 comments on commit acf78ec

Please sign in to comment.