Skip to content

Commit

Permalink
feat: add ability to list models from providers
Browse files Browse the repository at this point in the history
Signed-off-by: Donnie Adams <[email protected]>
  • Loading branch information
thedadams committed Aug 16, 2024
1 parent 48c714d commit 4f05179
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 14 deletions.
1 change: 0 additions & 1 deletion .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ jobs:
- name: Install gptscript
run: |
curl https://get.gptscript.ai/releases/default_windows_amd64_v1/gptscript.exe -o gptscript.exe
- name: Create config file
- name: Run Tests
env:
GPTSCRIPT_BIN: .\gptscript.exe
Expand Down
40 changes: 28 additions & 12 deletions gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ var (
const relativeToBinaryPath = "<me>"

type GPTScript struct {
url string
globalEnv []string
url string
globalOpts GlobalOptions
}

func NewGPTScript(opts ...GlobalOptions) (*GPTScript, error) {
Expand All @@ -40,7 +40,7 @@ func NewGPTScript(opts ...GlobalOptions) (*GPTScript, error) {

disableServer := os.Getenv("GPTSCRIPT_DISABLE_SERVER") == "true"

if serverURL == "" && disableServer {
if serverURL == "" {
serverURL = os.Getenv("GPTSCRIPT_URL")
}

Expand Down Expand Up @@ -96,11 +96,8 @@ func NewGPTScript(opts ...GlobalOptions) (*GPTScript, error) {
serverURL = strings.TrimSpace(serverURL)
}
g := &GPTScript{
url: "http://" + serverURL,
}

if disableServer {
g.globalEnv = opt.Env[:]
url: "http://" + serverURL,
globalOpts: opt,
}

return g, nil
Expand Down Expand Up @@ -132,7 +129,7 @@ func (g *GPTScript) Close() {
}

func (g *GPTScript) Evaluate(ctx context.Context, opts Options, tools ...ToolDef) (*Run, error) {
opts.Env = append(g.globalEnv, opts.Env...)
opts.GlobalOptions = completeGlobalOptions(g.globalOpts, opts.GlobalOptions)
return (&Run{
url: g.url,
requestPath: "evaluate",
Expand All @@ -143,7 +140,7 @@ func (g *GPTScript) Evaluate(ctx context.Context, opts Options, tools ...ToolDef
}

func (g *GPTScript) Run(ctx context.Context, toolPath string, opts Options) (*Run, error) {
opts.Env = append(g.globalEnv, opts.Env...)
opts.GlobalOptions = completeGlobalOptions(g.globalOpts, opts.GlobalOptions)
return (&Run{
url: g.url,
requestPath: "run",
Expand Down Expand Up @@ -281,9 +278,28 @@ func (g *GPTScript) ListTools(ctx context.Context) (string, error) {
return out, nil
}

type ListModelsOptions struct {
Providers []string
CredentialOverrides []string
}

// ListModels will list all the available models.
func (g *GPTScript) ListModels(ctx context.Context) ([]string, error) {
out, err := g.runBasicCommand(ctx, "list-models", nil)
func (g *GPTScript) ListModels(ctx context.Context, opts ...ListModelsOptions) ([]string, error) {
var o ListModelsOptions
for _, opt := range opts {
o.Providers = append(o.Providers, opt.Providers...)
o.CredentialOverrides = append(o.CredentialOverrides, opt.CredentialOverrides...)
}

if g.globalOpts.DefaultModelProvider != "" {
o.Providers = append(o.Providers, g.globalOpts.DefaultModelProvider)
}

out, err := g.runBasicCommand(ctx, "list-models", map[string]any{
"providers": o.Providers,
"env": g.globalOpts.Env,
"credentialOverrides": o.CredentialOverrides,
})
if err != nil {
return nil, err
}
Expand Down
63 changes: 62 additions & 1 deletion gptscript_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,22 @@ func TestMain(m *testing.M) {
panic("OPENAI_API_KEY or GPTSCRIPT_URL environment variable must be set")
}

var err error
// Start an initial GPTScript instance.
// This one doesn't have any options, but it's there to ensure that using another instance works as expected in all cases.
gFirst, err := NewGPTScript(GlobalOptions{})
if err != nil {
panic(fmt.Sprintf("error creating gptscript: %s", err))
}

g, err = NewGPTScript(GlobalOptions{OpenAIAPIKey: os.Getenv("OPENAI_API_KEY")})
if err != nil {
gFirst.Close()
panic(fmt.Sprintf("error creating gptscript: %s", err))
}

exitCode := m.Run()
g.Close()
gFirst.Close()
os.Exit(exitCode)
}

Expand Down Expand Up @@ -80,6 +88,59 @@ func TestListModels(t *testing.T) {
}
}

func TestListModelsWithProvider(t *testing.T) {
if os.Getenv("ANTHROPIC_API_KEY") == "" {
t.Skip("ANTHROPIC_API_KEY not set")
}
models, err := g.ListModels(context.Background(), ListModelsOptions{
Providers: []string{"github.com/gptscript-ai/claude3-anthropic-provider"},
CredentialOverrides: []string{"github.com/gptscript-ai/claude3-anthropic-provider/credential:ANTHROPIC_API_KEY"},
})
if err != nil {
t.Errorf("Error listing models: %v", err)
}

if len(models) == 0 {
t.Error("No models found")
}

for _, model := range models {
if !strings.HasPrefix(model, "claude-3-") || !strings.HasSuffix(model, "from github.com/gptscript-ai/claude3-anthropic-provider") {
t.Errorf("Unexpected model name: %s", model)
}
}
}

func TestListModelsWithDefaultProvider(t *testing.T) {
if os.Getenv("ANTHROPIC_API_KEY") == "" {
t.Skip("ANTHROPIC_API_KEY not set")
}
g, err := NewGPTScript(GlobalOptions{
DefaultModelProvider: "github.com/gptscript-ai/claude3-anthropic-provider",
})
if err != nil {
t.Fatalf("Error creating gptscript: %v", err)
}
defer g.Close()

models, err := g.ListModels(context.Background(), ListModelsOptions{
CredentialOverrides: []string{"github.com/gptscript-ai/claude3-anthropic-provider/credential:ANTHROPIC_API_KEY"},
})
if err != nil {
t.Errorf("Error listing models: %v", err)
}

if len(models) == 0 {
t.Error("No models found")
}

for _, model := range models {
if !strings.HasPrefix(model, "claude-3-") || !strings.HasSuffix(model, "from github.com/gptscript-ai/claude3-anthropic-provider") {
t.Errorf("Unexpected model name: %s", model)
}
}
}

func TestAbortRun(t *testing.T) {
tool := ToolDef{Instructions: "What is the capital of the united states?"}

Expand Down

0 comments on commit 4f05179

Please sign in to comment.