diff --git a/go/plugins/ollama/ollama.go b/go/plugins/ollama/ollama.go index aa9d1f0ca..2ac9a9874 100644 --- a/go/plugins/ollama/ollama.go +++ b/go/plugins/ollama/ollama.go @@ -45,6 +45,7 @@ var state struct { mu sync.Mutex initted bool serverAddress string + timeout int } func DefineModel(model ModelDefinition, caps *ai.ModelCapabilities) ai.Model { @@ -67,7 +68,7 @@ func DefineModel(model ModelDefinition, caps *ai.ModelCapabilities) ai.Model { Label: "Ollama - " + model.Name, Supports: mc, } - g := &generator{model: model, serverAddress: state.serverAddress} + g := &generator{model: model, serverAddress: state.serverAddress, timeout: state.timeout} return ai.DefineModel(provider, model.Name, meta, g.generate) } @@ -92,6 +93,7 @@ type ModelDefinition struct { type generator struct { model ModelDefinition serverAddress string + timeout int } type ollamaMessage struct { @@ -148,6 +150,7 @@ type ollamaGenerateResponse struct { type Config struct { // Server Address of oLLama. ServerAddress string + Timeout int } // Init initializes the plugin. @@ -162,6 +165,10 @@ func Init(ctx context.Context, cfg *Config) (err error) { if cfg == nil || cfg.ServerAddress == "" { return errors.New("ollama: need ServerAddress") } + if cfg.Timeout == 0 { + cfg.Timeout = 30 + } + state.timeout = cfg.Timeout state.serverAddress = cfg.ServerAddress state.initted = true return nil @@ -201,7 +208,7 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb Stream: stream, } } - client := &http.Client{Timeout: 30 * time.Second} + client := &http.Client{Timeout: time.Duration(g.timeout) * time.Second} payloadBytes, err := json.Marshal(payload) if err != nil { return nil, err