Skip to content

Commit

Permalink
Fixes lookup for custom chains
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonWeill committed Jan 3, 2024
1 parent f9c8033 commit 4ea716a
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,11 @@ def _append_exchange_openai(self, prompt: str, output: str):

def _decompose_model_id(self, model_id: str):
"""Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate."""
if model_id in self.custom_model_registry:
# custom_model_registry maps keys to either a model name (a string) or an LLMChain.
# If this is an alias to another model, expand the full name of the model.
if model_id in self.custom_model_registry and isinstance(
self.custom_model_registry[model_id], str
):
model_id = self.custom_model_registry[model_id]

return decompose_model_id(model_id, self.providers)
Expand Down Expand Up @@ -477,6 +481,17 @@ def handle_list(self, args: ListArgs):

def run_ai_cell(self, args: CellArgs, prompt: str):
provider_id, local_model_id = self._decompose_model_id(args.model_id)

# If this is a custom chain, send the message to the custom chain.
if args.model_id in self.custom_model_registry and isinstance(
self.custom_model_registry[args.model_id], LLMChain
):
return self.display_output(
self.custom_model_registry[args.model_id].run(prompt),
args.format,
{"jupyter_ai": {"custom_chain_id": args.model_id}},
)

Provider = self._get_provider(provider_id)
if Provider is None:
return TextOrMarkdown(
Expand All @@ -493,17 +508,6 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
self.transcript_openai = []
return

# Determine provider and local model IDs
# If this is a custom chain, send the message to the custom chain.
if args.model_id in self.custom_model_registry and isinstance(
self.custom_model_registry[args.model_id], LLMChain
):
return self.display_output(
self.custom_model_registry[args.model_id].run(prompt),
args.format,
{"jupyter_ai": {"custom_chain_id": args.model_id}},
)

# validate presence of authn credentials
auth_strategy = self.providers[provider_id].auth_strategy
if auth_strategy:
Expand Down

0 comments on commit 4ea716a

Please sign in to comment.