From 4f05179de3611fb4de9f514ff66e45ab88e51c56 Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Fri, 16 Aug 2024 18:05:40 -0400 Subject: [PATCH] feat: add ability to list models from providers Signed-off-by: Donnie Adams --- .github/workflows/run_tests.yaml | 1 - gptscript.go | 40 ++++++++++++++------ gptscript_test.go | 63 +++++++++++++++++++++++++++++++- 3 files changed, 90 insertions(+), 14 deletions(-) diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 8e80a83..c3b3123 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -49,7 +49,6 @@ jobs: - name: Install gptscript run: | curl https://get.gptscript.ai/releases/default_windows_amd64_v1/gptscript.exe -o gptscript.exe - - name: Create config file - name: Run Tests env: GPTSCRIPT_BIN: .\gptscript.exe diff --git a/gptscript.go b/gptscript.go index cc0d620..f3a9e44 100644 --- a/gptscript.go +++ b/gptscript.go @@ -28,8 +28,8 @@ var ( const relativeToBinaryPath = "" type GPTScript struct { - url string - globalEnv []string + url string + globalOpts GlobalOptions } func NewGPTScript(opts ...GlobalOptions) (*GPTScript, error) { @@ -40,7 +40,7 @@ func NewGPTScript(opts ...GlobalOptions) (*GPTScript, error) { disableServer := os.Getenv("GPTSCRIPT_DISABLE_SERVER") == "true" - if serverURL == "" && disableServer { + if serverURL == "" { serverURL = os.Getenv("GPTSCRIPT_URL") } @@ -96,11 +96,8 @@ func NewGPTScript(opts ...GlobalOptions) (*GPTScript, error) { serverURL = strings.TrimSpace(serverURL) } g := &GPTScript{ - url: "http://" + serverURL, - } - - if disableServer { - g.globalEnv = opt.Env[:] + url: "http://" + serverURL, + globalOpts: opt, } return g, nil @@ -132,7 +129,7 @@ func (g *GPTScript) Close() { } func (g *GPTScript) Evaluate(ctx context.Context, opts Options, tools ...ToolDef) (*Run, error) { - opts.Env = append(g.globalEnv, opts.Env...) + opts.GlobalOptions = completeGlobalOptions(g.globalOpts, opts.GlobalOptions) return (&Run{ url: g.url, requestPath: "evaluate", @@ -143,7 +140,7 @@ func (g *GPTScript) Evaluate(ctx context.Context, opts Options, tools ...ToolDef } func (g *GPTScript) Run(ctx context.Context, toolPath string, opts Options) (*Run, error) { - opts.Env = append(g.globalEnv, opts.Env...) + opts.GlobalOptions = completeGlobalOptions(g.globalOpts, opts.GlobalOptions) return (&Run{ url: g.url, requestPath: "run", @@ -281,9 +278,28 @@ func (g *GPTScript) ListTools(ctx context.Context) (string, error) { return out, nil } +type ListModelsOptions struct { + Providers []string + CredentialOverrides []string +} + // ListModels will list all the available models. -func (g *GPTScript) ListModels(ctx context.Context) ([]string, error) { - out, err := g.runBasicCommand(ctx, "list-models", nil) +func (g *GPTScript) ListModels(ctx context.Context, opts ...ListModelsOptions) ([]string, error) { + var o ListModelsOptions + for _, opt := range opts { + o.Providers = append(o.Providers, opt.Providers...) + o.CredentialOverrides = append(o.CredentialOverrides, opt.CredentialOverrides...) + } + + if g.globalOpts.DefaultModelProvider != "" { + o.Providers = append(o.Providers, g.globalOpts.DefaultModelProvider) + } + + out, err := g.runBasicCommand(ctx, "list-models", map[string]any{ + "providers": o.Providers, + "env": g.globalOpts.Env, + "credentialOverrides": o.CredentialOverrides, + }) if err != nil { return nil, err } diff --git a/gptscript_test.go b/gptscript_test.go index f298cab..700f0ea 100644 --- a/gptscript_test.go +++ b/gptscript_test.go @@ -19,14 +19,22 @@ func TestMain(m *testing.M) { panic("OPENAI_API_KEY or GPTSCRIPT_URL environment variable must be set") } - var err error + // Start an initial GPTScript instance. + // This one doesn't have any options, but it's there to ensure that using another instance works as expected in all cases. + gFirst, err := NewGPTScript(GlobalOptions{}) + if err != nil { + panic(fmt.Sprintf("error creating gptscript: %s", err)) + } + g, err = NewGPTScript(GlobalOptions{OpenAIAPIKey: os.Getenv("OPENAI_API_KEY")}) if err != nil { + gFirst.Close() panic(fmt.Sprintf("error creating gptscript: %s", err)) } exitCode := m.Run() g.Close() + gFirst.Close() os.Exit(exitCode) } @@ -80,6 +88,59 @@ func TestListModels(t *testing.T) { } } +func TestListModelsWithProvider(t *testing.T) { + if os.Getenv("ANTHROPIC_API_KEY") == "" { + t.Skip("ANTHROPIC_API_KEY not set") + } + models, err := g.ListModels(context.Background(), ListModelsOptions{ + Providers: []string{"github.com/gptscript-ai/claude3-anthropic-provider"}, + CredentialOverrides: []string{"github.com/gptscript-ai/claude3-anthropic-provider/credential:ANTHROPIC_API_KEY"}, + }) + if err != nil { + t.Errorf("Error listing models: %v", err) + } + + if len(models) == 0 { + t.Error("No models found") + } + + for _, model := range models { + if !strings.HasPrefix(model, "claude-3-") || !strings.HasSuffix(model, "from github.com/gptscript-ai/claude3-anthropic-provider") { + t.Errorf("Unexpected model name: %s", model) + } + } +} + +func TestListModelsWithDefaultProvider(t *testing.T) { + if os.Getenv("ANTHROPIC_API_KEY") == "" { + t.Skip("ANTHROPIC_API_KEY not set") + } + g, err := NewGPTScript(GlobalOptions{ + DefaultModelProvider: "github.com/gptscript-ai/claude3-anthropic-provider", + }) + if err != nil { + t.Fatalf("Error creating gptscript: %v", err) + } + defer g.Close() + + models, err := g.ListModels(context.Background(), ListModelsOptions{ + CredentialOverrides: []string{"github.com/gptscript-ai/claude3-anthropic-provider/credential:ANTHROPIC_API_KEY"}, + }) + if err != nil { + t.Errorf("Error listing models: %v", err) + } + + if len(models) == 0 { + t.Error("No models found") + } + + for _, model := range models { + if !strings.HasPrefix(model, "claude-3-") || !strings.HasSuffix(model, "from github.com/gptscript-ai/claude3-anthropic-provider") { + t.Errorf("Unexpected model name: %s", model) + } + } +} + func TestAbortRun(t *testing.T) { tool := ToolDef{Instructions: "What is the capital of the united states?"}