Skip to content

Commit

Permalink
Update Cohere model IDs (#584)
Browse files Browse the repository at this point in the history
* update Cohere model IDs

* get provider name from class attr instead of instance attr
  • Loading branch information
dlqqq authored Jan 18, 2024
1 parent 814eb44 commit effc609
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
3 changes: 2 additions & 1 deletion packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ def allows_concurrency(self):
class CohereProvider(BaseProvider, Cohere):
id = "cohere"
name = "Cohere"
models = ["medium", "xlarge"]
# Source: https://docs.cohere.com/reference/generate
models = ["command", "command-nightly", "command-light", "command-light-nightly"]
model_id_key = "model"
pypi_package_deps = ["cohere"]
auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY")
Expand Down
4 changes: 2 additions & 2 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def create_llm_chain(
prompt_template = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT).format(
provider_name=llm.name, local_model_id=llm.model_id
provider_name=provider.name, local_model_id=llm.model_id
),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}"),
Expand All @@ -64,7 +64,7 @@ def create_llm_chain(
prompt_template = PromptTemplate(
input_variables=["history", "input"],
template=SYSTEM_PROMPT.format(
provider_name=llm.name, local_model_id=llm.model_id
provider_name=provider.name, local_model_id=llm.model_id
)
+ "\n\n"
+ DEFAULT_TEMPLATE,
Expand Down

0 comments on commit effc609

Please sign in to comment.