diff --git a/embedder/ollama/api.go b/embedder/ollama/api.go index b0233112..5e6be62d 100644 --- a/embedder/ollama/api.go +++ b/embedder/ollama/api.go @@ -34,6 +34,7 @@ func (r *request) ContentType() string { type response struct { HTTPStatusCode int `json:"-"` acceptContentType string `json:"-"` + RawBody []byte `json:"-"` Embedding []float64 `json:"embedding"` CreatedAt string `json:"created_at"` } @@ -46,7 +47,13 @@ func (r *response) Decode(body io.Reader) error { return json.NewDecoder(body).Decode(r) } -func (r *response) SetBody(_ io.Reader) error { +func (r *response) SetBody(body io.Reader) error { + rawBody, err := io.ReadAll(body) + if err != nil { + return err + } + + r.RawBody = rawBody return nil } diff --git a/embedder/ollama/ollama.go b/embedder/ollama/ollama.go index 54379303..b15766e4 100644 --- a/embedder/ollama/ollama.go +++ b/embedder/ollama/ollama.go @@ -2,6 +2,9 @@ package ollamaembedder import ( "context" + "errors" + "fmt" + "net/http" "github.com/henomis/restclientgo" @@ -14,6 +17,14 @@ const ( defaultEndpoint = "http://localhost:11434/api" ) +type OllamaEmbedError struct { + Err error +} + +func (e *OllamaEmbedError) Error() string { + return fmt.Sprintf("Error embedding text: %v", e.Err) +} + type Embedder struct { model string restClient *restclientgo.RestClient @@ -88,5 +99,11 @@ func (e *Embedder) embed(ctx context.Context, text string) (embedder.Embedding, return nil, err } + if resp.HTTPStatusCode >= http.StatusBadRequest { + return nil, &OllamaEmbedError{ + Err: errors.New(string(resp.RawBody)), + } + } + return resp.Embedding, nil }