Skip to content

Commit

Permalink
enhance: get more information about models (#68)
Browse files Browse the repository at this point in the history
Signed-off-by: Grant Linville <[email protected]>
  • Loading branch information
g-linville authored Dec 19, 2024
1 parent 64f1b28 commit d9098db
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
8 changes: 5 additions & 3 deletions gptscript/gptscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from gptscript.datasets import DatasetElementMeta, DatasetElement, DatasetMeta
from gptscript.fileinfo import FileInfo
from gptscript.frame import RunFrame, CallFrame, PromptFrame, Program
from gptscript.openai import Model
from gptscript.opts import GlobalOptions
from gptscript.prompt import PromptResponse
from gptscript.run import Run, RunBasicCommand, Options
Expand Down Expand Up @@ -164,16 +165,17 @@ async def _run_basic_command(self, sub_command: str, request_body: Any = None):
async def version(self) -> str:
return await self._run_basic_command("version")

async def list_models(self, providers: list[str] = None, credential_overrides: list[str] = None) -> list[str]:
async def list_models(self, providers: list[str] = None, credential_overrides: list[str] = None) -> list[Model]:
if self.opts.DefaultModelProvider != "":
if providers is None:
providers = []
providers.append(self.opts.DefaultModelProvider)

return (await self._run_basic_command(
res = await self._run_basic_command(
"list-models",
{"providers": providers, "credentialOverrides": credential_overrides}
)).split("\n")
)
return [Model(**model) for model in json.loads(res)]

async def list_credentials(self, contexts: List[str] = None, all_contexts: bool = False) -> list[Credential] | str:
if contexts is None:
Expand Down
28 changes: 28 additions & 0 deletions gptscript/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from pydantic import BaseModel, conlist
from typing import Any, Dict, Optional


class Permission(BaseModel):
created: int
id: str
object: str
allow_create_engine: bool
allow_sampling: bool
allow_logprobs: bool
allow_search_indices: bool
allow_view: bool
allow_fine_tuning: bool
organization: str
group: Any
is_blocking: bool


class Model(BaseModel):
created: Optional[int]
id: str
object: str
owned_by: str
permission: Optional[conlist(Permission)]
root: Optional[str]
parent: Optional[str]
metadata: Optional[Dict[str, str]]
8 changes: 4 additions & 4 deletions tests/test_gptscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ async def test_list_models_from_provider(gptscript):
)
assert isinstance(models, list) and len(models) > 1, "Expected list_models to return a list"
for model in models:
assert model.startswith("claude-3-"), "Unexpected model name"
assert model.endswith("from github.com/gptscript-ai/claude3-anthropic-provider"), "Unexpected model name"
assert model.id.startswith("claude-3-"), "Unexpected model name"
assert model.id.endswith("from github.com/gptscript-ai/claude3-anthropic-provider"), "Unexpected model name"


@pytest.mark.asyncio
Expand All @@ -140,8 +140,8 @@ async def test_list_models_from_default_provider():
)
assert isinstance(models, list) and len(models) > 1, "Expected list_models to return a list"
for model in models:
assert model.startswith("claude-3-"), "Unexpected model name"
assert model.endswith("from github.com/gptscript-ai/claude3-anthropic-provider"), "Unexpected model name"
assert model.id.startswith("claude-3-"), "Unexpected model name"
assert model.id.endswith("from github.com/gptscript-ai/claude3-anthropic-provider"), "Unexpected model name"
finally:
g.close()

Expand Down

0 comments on commit d9098db

Please sign in to comment.