From 9c91b6a5b59a7027f8e69f98539f0fad59f259f4 Mon Sep 17 00:00:00 2001 From: Jason Weill Date: Tue, 2 Jan 2024 17:37:15 -0800 Subject: [PATCH] Fixes lookup for custom chains --- .../jupyter_ai_magics/magics.py | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index f34e28a22..922d3da25 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -421,7 +421,11 @@ def _append_exchange_openai(self, prompt: str, output: str): def _decompose_model_id(self, model_id: str): """Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate.""" - if model_id in self.custom_model_registry: + # custom_model_registry maps keys to either a model name (a string) or an LLMChain. + # If this is an alias to another model, expand the full name of the model. + if model_id in self.custom_model_registry and isinstance( + self.custom_model_registry[model_id], str + ): model_id = self.custom_model_registry[model_id] return decompose_model_id(model_id, self.providers) @@ -477,6 +481,17 @@ def handle_list(self, args: ListArgs): def run_ai_cell(self, args: CellArgs, prompt: str): provider_id, local_model_id = self._decompose_model_id(args.model_id) + + # If this is a custom chain, send the message to the custom chain. + if args.model_id in self.custom_model_registry and isinstance( + self.custom_model_registry[args.model_id], LLMChain + ): + return self.display_output( + self.custom_model_registry[args.model_id].run(prompt), + args.format, + {"jupyter_ai": {"custom_chain_id": args.model_id}}, + ) + Provider = self._get_provider(provider_id) if Provider is None: return TextOrMarkdown( @@ -493,17 +508,6 @@ def run_ai_cell(self, args: CellArgs, prompt: str): self.transcript_openai = [] return - # Determine provider and local model IDs - # If this is a custom chain, send the message to the custom chain. - if args.model_id in self.custom_model_registry and isinstance( - self.custom_model_registry[args.model_id], LLMChain - ): - return self.display_output( - self.custom_model_registry[args.model_id].run(prompt), - args.format, - {"jupyter_ai": {"custom_chain_id": args.model_id}}, - ) - # validate presence of authn credentials auth_strategy = self.providers[provider_id].auth_strategy if auth_strategy: