diff --git a/.gitignore b/.gitignore index f679f4a..376045d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ .vscode **/.venv output.json -log.txt \ No newline at end of file +log.txt +.idea diff --git a/README.md b/README.md index 8ab0deb..397da10 100644 --- a/README.md +++ b/README.md @@ -171,3 +171,60 @@ Agents are the main component of the library. Agents can perform complex tasks t ### Prebuilt (WIP) A collection of ready-made agents that can be easily integrated with your application. + +### Evaluation (WIP) +A collection of evaluation tools for agents and engines. +## Example +```go +package main + +import ( + "fmt" + "os" + + "github.com/natexcvi/go-llm/engines" + "github.com/natexcvi/go-llm/evaluation" +) + +func goodness(_ *engines.ChatPrompt, _ *engines.ChatMessage, err error) float64 { + if err != nil { + return 0 + } + + return 1 +} + +func main() { + engine := engines.NewGPTEngine(os.Getenv("OPENAI_TOKEN"), "gpt-3.5-turbo-0613") + engineRunner := evaluation.NewLLMRunner(engine) + + evaluator := evaluation.NewEvaluator(engineRunner, &evaluation.Options[*engines.ChatPrompt, *engines.ChatMessage]{ + GoodnessFunction: goodness, + Repetitions: 5, + }) + + testPack := []*engines.ChatPrompt{ + { + History: []*engines.ChatMessage{ + { + Text: "Hello, how are you?", + }, + { + Text: "I'm trying to understand how this works.", + }, + }, + }, + { + History: []*engines.ChatMessage{ + { + Text: "Could you please explain it to me?", + }, + }, + }, + } + + results := evaluator.Evaluate(testPack) + fmt.Println("Goodness level of the first prompt:", results[0]) + fmt.Println("Goodness level of the second prompt:", results[1]) +} +``` \ No newline at end of file diff --git a/evaluation/agent_evaluator.go b/evaluation/agent_evaluator.go new file mode 100644 index 0000000..0a78610 --- /dev/null +++ b/evaluation/agent_evaluator.go @@ -0,0 +1,20 @@ +package evaluation + +import ( + "github.com/natexcvi/go-llm/agents" +) + +type agentRunner[Input, Output any] struct { + agent agents.Agent[Input, Output] +} + +// Returns a new agent runner that can be used to evaluate the output. +func NewAgentRunner[Input, Output any](agent agents.Agent[Input, Output]) Runner[Input, Output] { + return &agentRunner[Input, Output]{ + agent: agent, + } +} + +func (t *agentRunner[Input, Output]) Run(input Input) (Output, error) { + return t.agent.Run(input) +} diff --git a/evaluation/evaluator.go b/evaluation/evaluator.go new file mode 100644 index 0000000..b16e28c --- /dev/null +++ b/evaluation/evaluator.go @@ -0,0 +1,104 @@ +package evaluation + +import ( + "fmt" + "github.com/samber/mo" +) + +// GoodnessFunction is a function that takes an input, an output and an error (if one occurred) and returns a float64 +// which represents the goodness score of the output. +type GoodnessFunction[Input, Output any] func(input Input, output Output, err error) float64 + +// Options is a struct that contains the options for the evaluator. +type Options[Input, Output any] struct { + // The goodness function that will be used to evaluate the output. + GoodnessFunction GoodnessFunction[Input, Output] + // The number of times the test will be repeated. The goodness level of each output will be + // averaged. + Repetitions int +} + +// Runner is an interface that represents a test runner that will be used to evaluate the output. +// It takes an input and returns an output and an error. +type Runner[Input, Output any] interface { + Run(input Input) (Output, error) +} + +// Evaluator is a struct that runs the tests and evaluates the outputs. +type Evaluator[Input, Output any] struct { + options *Options[Input, Output] + runner Runner[Input, Output] +} + +// Creates a new `Evaluator` with the provided configuration. +func NewEvaluator[Input, Output any](runner Runner[Input, Output], options *Options[Input, Output]) *Evaluator[Input, Output] { + return &Evaluator[Input, Output]{ + options: options, + runner: runner, + } +} + +// Runs the tests and evaluates the outputs. The function receives a test pack +// which is a slice of inputs and returns a slice of float64 which represents the goodness level +// of each respective output. +func (e *Evaluator[Input, Output]) Evaluate(testPack []Input) []float64 { + repetitionChannels := make([]chan []float64, e.options.Repetitions) + + for i := 0; i < e.options.Repetitions; i++ { + repetitionChannels[i] = make(chan []float64) + go func(i int) { + report, err := e.evaluate(testPack) + if err != nil { + repetitionChannels[i] <- nil + return + } + repetitionChannels[i] <- report + }(i) + } + + responses := make([][]float64, e.options.Repetitions) + for i := 0; i < e.options.Repetitions; i++ { + responses[i] = <-repetitionChannels[i] + } + + report := make([]float64, len(testPack)) + for i := 0; i < len(testPack); i++ { + sum := 0.0 + for j := 0; j < e.options.Repetitions; j++ { + sum += responses[j][i] + } + report[i] = sum / float64(e.options.Repetitions) + } + + return report +} + +func (e *Evaluator[Input, Output]) evaluate(testPack []Input) ([]float64, error) { + responses, err := e.test(testPack) + if err != nil { + return nil, fmt.Errorf("failed to test: %w", err) + } + + report := make([]float64, len(testPack)) + for i, response := range responses { + res, resErr := response.Get() + report[i] = e.options.GoodnessFunction(testPack[i], res, resErr) + } + + return report, nil +} + +func (e *Evaluator[Input, Output]) test(testPack []Input) ([]mo.Result[Output], error) { + responses := make([]mo.Result[Output], len(testPack)) + + for i, test := range testPack { + response, err := e.runner.Run(test) + if err != nil { + responses[i] = mo.Err[Output](err) + } else { + responses[i] = mo.Ok(response) + } + } + + return responses, nil +} diff --git a/evaluation/evaluator_test.go b/evaluation/evaluator_test.go new file mode 100644 index 0000000..8695495 --- /dev/null +++ b/evaluation/evaluator_test.go @@ -0,0 +1,283 @@ +package evaluation + +import ( + "errors" + "github.com/golang/mock/gomock" + "github.com/natexcvi/go-llm/engines" + "github.com/natexcvi/go-llm/engines/mocks" + "github.com/samber/lo" + "github.com/stretchr/testify/assert" + "math" + "strings" + "testing" +) + +func createMockEchoLLM(t *testing.T) engines.LLM { + t.Helper() + ctrl := gomock.NewController(t) + mock := mocks.NewMockLLM(ctrl) + mock.EXPECT().Chat(gomock.Any()).DoAndReturn(func(prompt *engines.ChatPrompt) (*engines.ChatMessage, error) { + return &engines.ChatMessage{ + Text: prompt.History[0].Text, + }, nil + }).AnyTimes() + return mock +} + +func createMockIncrementalLLM(t *testing.T) engines.LLM { + t.Helper() + ctrl := gomock.NewController(t) + mock := mocks.NewMockLLM(ctrl) + counters := make(map[string]int) + mock.EXPECT().Chat(gomock.Any()).DoAndReturn(func(prompt *engines.ChatPrompt) (*engines.ChatMessage, error) { + counters[prompt.History[0].Text]++ + return &engines.ChatMessage{ + Text: strings.Repeat(prompt.History[0].Text, counters[prompt.History[0].Text]), + }, nil + }).AnyTimes() + return mock +} + +func createMockExponentialLLM(t *testing.T) engines.LLM { + t.Helper() + ctrl := gomock.NewController(t) + mock := mocks.NewMockLLM(ctrl) + counters := make(map[string]int) + mock.EXPECT().Chat(gomock.Any()).DoAndReturn(func(prompt *engines.ChatPrompt) (*engines.ChatMessage, error) { + counters[prompt.History[0].Text]++ + return &engines.ChatMessage{ + Text: strings.Repeat(prompt.History[0].Text, int(math.Pow(float64(len(prompt.History[0].Text)), float64(counters[prompt.History[0].Text]+1)))), + }, nil + }).AnyTimes() + return mock +} + +func createMockOddErrorLLM(t *testing.T) engines.LLM { + t.Helper() + ctrl := gomock.NewController(t) + mock := mocks.NewMockLLM(ctrl) + counters := make(map[string]int) + mock.EXPECT().Chat(gomock.Any()).DoAndReturn(func(prompt *engines.ChatPrompt) (*engines.ChatMessage, error) { + counters[prompt.History[0].Text]++ + if counters[prompt.History[0].Text]%2 == 1 { + return nil, errors.New("error") + } + return &engines.ChatMessage{ + Text: "OK!", + }, nil + }).AnyTimes() + return mock +} + +func TestLLMEvaluator(t *testing.T) { + tests := []struct { + name string + options *Options[*engines.ChatPrompt, *engines.ChatMessage] + engine engines.LLM + testPack []*engines.ChatPrompt + want []float64 + }{ + { + name: "Test echo engine with response length goodness and 1 repetition", + options: &Options[*engines.ChatPrompt, *engines.ChatMessage]{ + GoodnessFunction: func(_ *engines.ChatPrompt, response *engines.ChatMessage, _ error) float64 { + return float64(len(response.Text)) + }, + Repetitions: 1, + }, + engine: createMockEchoLLM(t), + testPack: []*engines.ChatPrompt{ + { + History: []*engines.ChatMessage{ + { + Text: "Hello", + }, + }, + }, + { + History: []*engines.ChatMessage{ + { + Text: "Hello Hello", + }, + }, + }, + { + History: []*engines.ChatMessage{ + { + Text: "Hello Hello Hello Hello", + }, + }, + }, + { + History: []*engines.ChatMessage{ + { + Text: "Hello Hello Hello Hello Hello Hello", + }, + }, + }, + }, + want: []float64{5, 11, 23, 35}, + }, + { + name: "Test echo engine with response length goodness and 5 repetitions", + options: &Options[*engines.ChatPrompt, *engines.ChatMessage]{ + GoodnessFunction: func(_ *engines.ChatPrompt, response *engines.ChatMessage, _ error) float64 { + return float64(len(response.Text)) + }, + Repetitions: 5, + }, + engine: createMockEchoLLM(t), + testPack: []*engines.ChatPrompt{ + { + History: []*engines.ChatMessage{ + { + Text: "Hello", + }, + }, + }, + { + History: []*engines.ChatMessage{ + { + Text: "Hello Hello", + }, + }, + }, + { + History: []*engines.ChatMessage{ + { + Text: "Hello Hello Hello Hello", + }, + }, + }, + { + History: []*engines.ChatMessage{ + { + Text: "Hello Hello Hello Hello Hello Hello", + }, + }, + }, + }, + want: []float64{5, 11, 23, 35}, + }, + { + name: "Test incremental engine with response length goodness and 5 repetitions", + options: &Options[*engines.ChatPrompt, *engines.ChatMessage]{ + GoodnessFunction: func(_ *engines.ChatPrompt, response *engines.ChatMessage, _ error) float64 { + return float64(len(response.Text)) + }, + Repetitions: 5, + }, + engine: createMockIncrementalLLM(t), + testPack: []*engines.ChatPrompt{ + { + History: []*engines.ChatMessage{ + { + Text: "a", + }, + }, + }, + { + History: []*engines.ChatMessage{ + { + Text: "aa", + }, + }, + }, + { + History: []*engines.ChatMessage{ + { + Text: "aaa", + }, + }, + }, + { + History: []*engines.ChatMessage{ + { + Text: "aaaa", + }, + }, + }, + }, + want: []float64{3, 6, 9, 12}, + }, + { + name: "Test exponential engine with response length goodness and 4 repetitions", + options: &Options[*engines.ChatPrompt, *engines.ChatMessage]{ + GoodnessFunction: func(_ *engines.ChatPrompt, response *engines.ChatMessage, _ error) float64 { + return float64(len(response.Text)) + }, + Repetitions: 4, + }, + engine: createMockExponentialLLM(t), + testPack: []*engines.ChatPrompt{ + { + History: []*engines.ChatMessage{ + { + Text: "a", + }, + }, + }, + { + History: []*engines.ChatMessage{ + { + Text: "aa", + }, + }, + }, + { + History: []*engines.ChatMessage{ + { + Text: "aaa", + }, + }, + }, + { + History: []*engines.ChatMessage{ + { + Text: "aaaa", + }, + }, + }, + }, + want: []float64{1, 30, 270, 1360}, + }, + { + name: "Test error engine with dummy error goodness and 4 repetitions", + options: &Options[*engines.ChatPrompt, *engines.ChatMessage]{ + GoodnessFunction: func(_ *engines.ChatPrompt, _ *engines.ChatMessage, err error) float64 { + return lo.If(err == nil, 100.0).Else(0.0) + }, + Repetitions: 4, + }, + engine: createMockOddErrorLLM(t), + testPack: []*engines.ChatPrompt{ + { + History: []*engines.ChatMessage{ + { + Text: "a", + }, + }, + }, + { + History: []*engines.ChatMessage{ + { + Text: "aa", + }, + }, + }, + }, + want: []float64{50, 50}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + runner := NewLLMRunner(tt.engine) + evaluator := NewEvaluator(runner, tt.options) + + got := evaluator.Evaluate(tt.testPack) + + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/evaluation/llm_evaluator.go b/evaluation/llm_evaluator.go new file mode 100644 index 0000000..c12e77e --- /dev/null +++ b/evaluation/llm_evaluator.go @@ -0,0 +1,18 @@ +package evaluation + +import "github.com/natexcvi/go-llm/engines" + +type llmRunner struct { + llm engines.LLM +} + +// Returns a new llm runner that can be used to evaluate the output. +func NewLLMRunner(llm engines.LLM) Runner[*engines.ChatPrompt, *engines.ChatMessage] { + return &llmRunner{ + llm: llm, + } +} + +func (t *llmRunner) Run(input *engines.ChatPrompt) (*engines.ChatMessage, error) { + return t.llm.Chat(input) +} diff --git a/go.mod b/go.mod index 5c6ed8b..befe2fe 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/mattn/go-isatty v0.0.8 // indirect github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/samber/mo v1.8.0 // indirect github.com/spf13/pflag v1.0.5 // indirect golang.org/x/net v0.7.0 // indirect golang.org/x/sys v0.5.0 // indirect diff --git a/go.sum b/go.sum index 64481de..bf165e5 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,6 @@ github.com/AlecAivazis/survey/v2 v2.3.7 h1:6I/u8FvytdGsgonrYsVn2t8t4QiRnh6QSTqkkhIiSjQ= github.com/AlecAivazis/survey/v2 v2.3.7/go.mod h1:xUTIdE4KCOIjsBAE1JYsUPoCqYdZ1reCfTwbto0Fduo= +github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2 h1:+vx7roKuyA63nhn5WAunQHLTznkw5W8b1Xc0dNjp83s= github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2/go.mod h1:HBCaDeC1lPdgDeDbhX8XFpy1jqjK0IBG8W5K+xYqA0w= github.com/PuerkitoBio/goquery v1.8.1 h1:uQxhNlArOIdbrH1tr0UXwdVFgDcZDrZVdcpygAcwmWM= github.com/PuerkitoBio/goquery v1.8.1/go.mod h1:Q8ICL1kNUJ2sXGoAhPGUdYDJvgQgHzJsnnd3H7Ho5jQ= @@ -8,6 +9,7 @@ github.com/andybalholm/cascadia v1.3.1/go.mod h1:R4bJ1UQfqADjvDa4P6HZHLh/3OxWWEq github.com/briandowns/spinner v1.23.0 h1:alDF2guRWqa/FOZZYWjlMIx2L6H0wyewPxo/CH4Pt2A= github.com/briandowns/spinner v1.23.0/go.mod h1:rPG4gmXeN3wQV/TsAY4w8lPdIM6RX3yqeBQJSrbXjuE= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/creack/pty v1.1.17 h1:QeVUsEDNrLBW4tMgZHvxy18sKtr6VI492kBhUfhDJNI= github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 h1:y5HC9v93H5EPKqaS1UYVg1uYah5Xf51mBfIoWehClUQ= github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964/go.mod h1:Xd9hchkHSWYkEqJwUGisez3G1QY8Ryz0sdWrLPMGjLk= @@ -26,6 +28,7 @@ github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/U github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= @@ -44,6 +47,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= +github.com/samber/mo v1.8.0 h1:vYjHTfg14JF9tD2NLhpoUsRi9bjyRoYwa4+do0nvbVw= +github.com/samber/mo v1.8.0/go.mod h1:BfkrCPuYzVG3ZljnZB783WIJIGk1mcZr9c9CPf8tAxs= github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=