diff --git a/docs/docs.go b/docs/docs.go index f0f47417..8a51f1ac 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -125,6 +125,57 @@ const docTemplate = `{ } }, "definitions": { + "anthropic.Config": { + "type": "object", + "required": [ + "baseUrl", + "chatEndpoint", + "model" + ], + "properties": { + "baseUrl": { + "type": "string" + }, + "chatEndpoint": { + "type": "string" + }, + "defaultParams": { + "$ref": "#/definitions/anthropic.Params" + }, + "model": { + "type": "string" + } + } + }, + "anthropic.Params": { + "type": "object", + "properties": { + "max_tokens": { + "type": "integer" + }, + "metadata": { + "type": "string" + }, + "stop": { + "type": "array", + "items": { + "type": "string" + } + }, + "system": { + "type": "string" + }, + "temperature": { + "type": "number" + }, + "top_k": { + "type": "integer" + }, + "top_p": { + "type": "number" + } + } + }, "azureopenai.Config": { "type": "object", "required": [ @@ -458,6 +509,9 @@ const docTemplate = `{ "id" ], "properties": { + "anthropic": { + "$ref": "#/definitions/anthropic.Config" + }, "azureopenai": { "$ref": "#/definitions/azureopenai.Config" }, diff --git a/docs/swagger.json b/docs/swagger.json index 90177540..4722323c 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -122,6 +122,57 @@ } }, "definitions": { + "anthropic.Config": { + "type": "object", + "required": [ + "baseUrl", + "chatEndpoint", + "model" + ], + "properties": { + "baseUrl": { + "type": "string" + }, + "chatEndpoint": { + "type": "string" + }, + "defaultParams": { + "$ref": "#/definitions/anthropic.Params" + }, + "model": { + "type": "string" + } + } + }, + "anthropic.Params": { + "type": "object", + "properties": { + "max_tokens": { + "type": "integer" + }, + "metadata": { + "type": "string" + }, + "stop": { + "type": "array", + "items": { + "type": "string" + } + }, + "system": { + "type": "string" + }, + "temperature": { + "type": "number" + }, + "top_k": { + "type": "integer" + }, + "top_p": { + "type": "number" + } + } + }, "azureopenai.Config": { "type": "object", "required": [ @@ -455,6 +506,9 @@ "id" ], "properties": { + "anthropic": { + "$ref": "#/definitions/anthropic.Config" + }, "azureopenai": { "$ref": "#/definitions/azureopenai.Config" }, diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 8efd2c5e..74616bb9 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -1,5 +1,39 @@ basePath: / definitions: + anthropic.Config: + properties: + baseUrl: + type: string + chatEndpoint: + type: string + defaultParams: + $ref: '#/definitions/anthropic.Params' + model: + type: string + required: + - baseUrl + - chatEndpoint + - model + type: object + anthropic.Params: + properties: + max_tokens: + type: integer + metadata: + type: string + stop: + items: + type: string + type: array + system: + type: string + temperature: + type: number + top_k: + type: integer + top_p: + type: number + type: object azureopenai.Config: properties: apiVersion: @@ -224,6 +258,8 @@ definitions: type: object providers.LangModelConfig: properties: + anthropic: + $ref: '#/definitions/anthropic.Config' azureopenai: $ref: '#/definitions/azureopenai.Config' client: diff --git a/pkg/api/schemas/language.go b/pkg/api/schemas/language.go index f0ff6b2a..068e0588 100644 --- a/pkg/api/schemas/language.go +++ b/pkg/api/schemas/language.go @@ -31,7 +31,7 @@ type UnifiedChatResponse struct { // ProviderResponse is the unified response from the provider. type ProviderResponse struct { - SystemID map[string]string `json:"responseId,omitempty"` + SystemID map[string]string `json:"responseId,omitempty"` Message ChatMessage `json:"message"` TokenCount TokenCount `json:"tokenCount"` } @@ -146,16 +146,16 @@ type ConnectorsResponse struct { // Anthropic Chat Response type AnthropicChatCompletion struct { - ID string `json:"id"` - Type string `json:"type"` - Model string `json:"model"` - Role string `json:"role"` - Content []Content `json:"content"` - StopReason string `json:"stop_reason"` - StopSequence string `json:"stop_sequence"` + ID string `json:"id"` + Type string `json:"type"` + Model string `json:"model"` + Role string `json:"role"` + Content []Content `json:"content"` + StopReason string `json:"stop_reason"` + StopSequence string `json:"stop_sequence"` } type Content struct { - Type string `json:"type"` - Text string `json:"text"` + Type string `json:"type"` + Text string `json:"text"` } diff --git a/pkg/providers/anthropic/chat.go b/pkg/providers/anthropic/chat.go index 28cb0eb2..11c742f0 100644 --- a/pkg/providers/anthropic/chat.go +++ b/pkg/providers/anthropic/chat.go @@ -37,15 +37,15 @@ type ChatRequest struct { // NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives func NewChatRequestFromConfig(cfg *Config) *ChatRequest { return &ChatRequest{ - Model: cfg.Model, - System: cfg.DefaultParams.System, - Temperature: cfg.DefaultParams.Temperature, - TopP: cfg.DefaultParams.TopP, - TopK: cfg.DefaultParams.TopK, - MaxTokens: cfg.DefaultParams.MaxTokens, - Metadata: cfg.DefaultParams.Metadata, + Model: cfg.Model, + System: cfg.DefaultParams.System, + Temperature: cfg.DefaultParams.Temperature, + TopP: cfg.DefaultParams.TopP, + TopK: cfg.DefaultParams.TopK, + MaxTokens: cfg.DefaultParams.MaxTokens, + Metadata: cfg.DefaultParams.Metadata, StopSequences: cfg.DefaultParams.StopSequences, - Stream: false, // unsupported right now + Stream: false, // unsupported right now } } diff --git a/pkg/providers/anthropic/client_test.go b/pkg/providers/anthropic/client_test.go new file mode 100644 index 00000000..467119c7 --- /dev/null +++ b/pkg/providers/anthropic/client_test.go @@ -0,0 +1,68 @@ +package anthropic + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "glide/pkg/providers/clients" + + "glide/pkg/api/schemas" + + "glide/pkg/telemetry" + + "github.com/stretchr/testify/require" +) + +func TestOpenAIClient_ChatRequest(t *testing.T) { + // Anthropic Messages API: https://docs.anthropic.com/claude/reference/messages_post + AnthropicMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rawPayload, _ := io.ReadAll(r.Body) + + var data interface{} + // Parse the JSON body + err := json.Unmarshal(rawPayload, &data) + if err != nil { + t.Errorf("error decoding payload (%q): %v", string(rawPayload), err) + } + + chatResponse, err := os.ReadFile(filepath.Clean("./testdata/chat.success.json")) + if err != nil { + t.Errorf("error reading openai chat mock response: %v", err) + } + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(chatResponse) + + if err != nil { + t.Errorf("error on sending chat response: %v", err) + } + }) + + AnthropicServer := httptest.NewServer(AnthropicMock) + defer AnthropicServer.Close() + + ctx := context.Background() + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + + providerCfg.BaseURL = AnthropicServer.URL + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + Role: "human", + Content: "What's the biggest animal?", + }} + + response, err := client.Chat(ctx, &request) + require.NoError(t, err) + + require.Equal(t, "msg_013Zva2CMHLNnXjNJJKqJ2EF", response.ID) +} diff --git a/pkg/providers/anthropic/testdata/chat.req.json b/pkg/providers/anthropic/testdata/chat.req.json new file mode 100644 index 00000000..e4aac07b --- /dev/null +++ b/pkg/providers/anthropic/testdata/chat.req.json @@ -0,0 +1,12 @@ +{ + "model": "claude-instant-1.2", + "messages": [ + { + "role": "human", + "content": "What's the biggest animal?" + } + ], + "temperature": 1, + "top_p": 0, + "max_tokens": 100 +} diff --git a/pkg/providers/anthropic/testdata/chat.success.json b/pkg/providers/anthropic/testdata/chat.success.json new file mode 100644 index 00000000..eaf0f6c9 --- /dev/null +++ b/pkg/providers/anthropic/testdata/chat.success.json @@ -0,0 +1,14 @@ +{ + "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF", + "type": "message", + "model": "claude-2.1", + "role": "assistant", + "content": [ + { + "type": "text", + "text": "Blue is often seen as a calming and soothing color." + } + ], + "stop_reason": "end_turn", + "stop_sequence": null +} \ No newline at end of file diff --git a/pkg/providers/config.go b/pkg/providers/config.go index 406d84ac..957cfd66 100644 --- a/pkg/providers/config.go +++ b/pkg/providers/config.go @@ -48,56 +48,31 @@ func DefaultLangModelConfig() *LangModelConfig { } func (c *LangModelConfig) ToModel(tel *telemetry.Telemetry) (*LangModel, error) { - var client LangModelProvider - - var err error - - if c.OpenAI != nil { - client, err = openai.NewClient(c.OpenAI, c.Client, tel) - - if err != nil { - return nil, fmt.Errorf("error initing openai client: %v", err) - } - } - - if c.AzureOpenAI != nil { - client, err = azureopenai.NewClient(c.AzureOpenAI, c.Client, tel) - - if err != nil { - return nil, fmt.Errorf("error initing azureopenai client: %v", err) - } - } - - if c.Cohere != nil { - client, err = cohere.NewClient(c.Cohere, c.Client, tel) - - if err != nil { - return nil, fmt.Errorf("error initing cohere client: %v", err) - } - } - - if c.OctoML != nil { - client, err = octoml.NewClient(c.OctoML, c.Client, tel) - - if err != nil { - return nil, fmt.Errorf("error initing OctoML client: %v", err) - } - } - - if c.Anthropic != nil { - client, err = octoml.NewClient(c.OctoML, c.Client, tel) - - if err != nil { - return nil, fmt.Errorf("error initing Anthropic client: %v", err) - } + client, err := c.initClient(tel) + if err != nil { + return nil, fmt.Errorf("error initializing client: %v", err) } + return NewLangModel(c.ID, client, *c.ErrorBudget, *c.Latency, c.Weight), nil +} - if client != nil { - return NewLangModel(c.ID, client, *c.ErrorBudget, *c.Latency, c.Weight), nil +// initClient initializes the language model client based on the provided configuration. +// It takes a telemetry object as input and returns a LangModelProvider and an error. +func (c *LangModelConfig) initClient(tel *telemetry.Telemetry) (LangModelProvider, error) { + switch { + case c.OpenAI != nil: + return openai.NewClient(c.OpenAI, c.Client, tel) + case c.AzureOpenAI != nil: + return azureopenai.NewClient(c.AzureOpenAI, c.Client, tel) + case c.Cohere != nil: + return cohere.NewClient(c.Cohere, c.Client, tel) + case c.OctoML != nil: + return octoml.NewClient(c.OctoML, c.Client, tel) + case c.Anthropic != nil: + return anthropic.NewClient(c.Anthropic, c.Client, tel) + default: + return nil, ErrProviderNotFound } - - return nil, ErrProviderNotFound } func (c *LangModelConfig) validateOneProvider() error {