Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for options and dynamic fetching of Gemini models #12

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions llm_gemini.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import httpx
import ijson
import llm
from typing import Optional, List
from pydantic import Field
import urllib.parse

# We disable all of these to avoid random unexpected errors
Expand All @@ -23,17 +25,49 @@
},
]

class GeminiOptions(llm.Options):
max_tokens: Optional[int] = Field(
description="The maximum number of tokens to generate before stopping",
default=8000,
)

temperature: Optional[float] = Field(
description="Amount of randomness injected into the response. Defaults to 1.0. Ranges from 0.0 to 1.0. Use temperature closer to 0.0 for analytical / multiple choice, and closer to 1.0 for creative and generative tasks. Note that even with temperature of 0.0, the results will not be fully deterministic.",
default=1.0,
)

top_p: Optional[float] = Field(
description="For Nucleus sampling. Nucleus sampling considers the smallest set of tokens whose probability sum is at least topP.",
default=None,
)

top_k: Optional[int] = Field(
description="For Top-k sampling. Top-k sampling considers the set of topK most probable tokens. If empty, indicates the model doesn't use top-k sampling, and topK isn't allowed as a generation parameter.",
default=None,
)


def fetch_available_models(api_key: str) -> List[str]:
url = f"https://generativelanguage.googleapis.com/v1beta/models?key={api_key}"
with httpx.Client() as client:
response = client.get(url)
response.raise_for_status()
models = response.json().get("models", [])
return [model["name"].split("/")[-1] for model in models if "generateContent" in model.get("supportedGenerationMethods", [])]

@llm.hookimpl
def register_models(register):
register(GeminiPro("gemini-pro"))
register(GeminiPro("gemini-1.5-pro-latest"))
register(GeminiPro("gemini-1.5-flash-latest"))
api_key = llm.get_key("", "gemini", "LLM_GEMINI_KEY")
available_models = fetch_available_models(api_key)
for model_id in available_models:
register(GeminiPro(model_id))


class GeminiPro(llm.Model):
can_stream = True


class Options(GeminiOptions): ...

def __init__(self, model_id):
self.model_id = model_id

Expand Down Expand Up @@ -88,9 +122,11 @@ def execute(self, prompt, stream, response, conversation):

@llm.hookimpl
def register_embedding_models(register):
register(
GeminiEmbeddingModel("text-embedding-004", "text-embedding-004"),
)
api_key = llm.get_key("", "gemini", "LLM_GEMINI_KEY")
available_models = fetch_available_models(api_key)
embedding_models = [model for model in available_models if "embedContent" in model]
for model_id in embedding_models:
register(GeminiEmbeddingModel(model_id, model_id))


class GeminiEmbeddingModel(llm.EmbeddingModel):
Expand Down