Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes lookup for custom chains #560

Merged
merged 1 commit into from
Jan 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,11 @@ def _append_exchange_openai(self, prompt: str, output: str):

def _decompose_model_id(self, model_id: str):
"""Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate."""
if model_id in self.custom_model_registry:
# custom_model_registry maps keys to either a model name (a string) or an LLMChain.
# If this is an alias to another model, expand the full name of the model.
if model_id in self.custom_model_registry and isinstance(
self.custom_model_registry[model_id], str
):
model_id = self.custom_model_registry[model_id]
Comment on lines +424 to 429
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick comment: it seems like it would be easier to re-define self.custom_model_registry to have values also be LLMChains in the case of an alias key. No need to do that in this PR if it's too big of a change though.


return decompose_model_id(model_id, self.providers)
Expand Down Expand Up @@ -477,6 +481,17 @@ def handle_list(self, args: ListArgs):

def run_ai_cell(self, args: CellArgs, prompt: str):
provider_id, local_model_id = self._decompose_model_id(args.model_id)

# If this is a custom chain, send the message to the custom chain.
if args.model_id in self.custom_model_registry and isinstance(
self.custom_model_registry[args.model_id], LLMChain
):
return self.display_output(
self.custom_model_registry[args.model_id].run(prompt),
args.format,
{"jupyter_ai": {"custom_chain_id": args.model_id}},
)

Provider = self._get_provider(provider_id)
if Provider is None:
return TextOrMarkdown(
Expand All @@ -493,17 +508,6 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
self.transcript_openai = []
return

# Determine provider and local model IDs
# If this is a custom chain, send the message to the custom chain.
if args.model_id in self.custom_model_registry and isinstance(
self.custom_model_registry[args.model_id], LLMChain
):
return self.display_output(
self.custom_model_registry[args.model_id].run(prompt),
args.format,
{"jupyter_ai": {"custom_chain_id": args.model_id}},
)

# validate presence of authn credentials
auth_strategy = self.providers[provider_id].auth_strategy
if auth_strategy:
Expand Down
Loading