From a0746e9b0e8ba1f25c716200e50c1c16a5928ff5 Mon Sep 17 00:00:00 2001 From: Juliano Viana Date: Thu, 9 Dec 2021 09:52:30 -0500 Subject: [PATCH] added support for generation of multiple completion options --- pkg/witty/suggest.go | 12 +++--- pkg/witty/suggestions_ui.go | 80 ++++++++++++++++++++++++++++++------- pkg/witty/witty.go | 33 +++++++-------- 3 files changed, 88 insertions(+), 37 deletions(-) diff --git a/pkg/witty/suggest.go b/pkg/witty/suggest.go index 3b3790d..32f0f58 100644 --- a/pkg/witty/suggest.go +++ b/pkg/witty/suggest.go @@ -10,17 +10,17 @@ const ( davinci = "davinci-codex" ) -func (w *Witty) suggest(prompt string) (string, error) { +func (w *Witty) suggest(prompt string) (*codex.Choice, error) { // Try cushman first, as it is faster and cheaper suggestion, err := w.suggestWithEngine(cushman, prompt) - if err != nil || len(suggestion) == 0 { + if err != nil || suggestion == nil || len(suggestion.Text) == 0 { // Try davinci as a fallback suggestion, err = w.suggestWithEngine(davinci, prompt) } return suggestion, err } -func (w *Witty) suggestWithEngine(engine, prompt string) (string, error) { +func (w *Witty) suggestWithEngine(engine, prompt string) (*codex.Choice, error) { log.Debug().Msgf("requesting suggestion to %s with prompt: %s", engine, prompt) request := w.completionParameters @@ -29,7 +29,7 @@ func (w *Witty) suggestWithEngine(engine, prompt string) (string, error) { completion, err := codex.GenerateCompletions(request) if err != nil { - return "", err + return nil, err } log.Debug().Msgf("Got completions: %+v", completion) @@ -39,7 +39,7 @@ func (w *Witty) suggestWithEngine(engine, prompt string) (string, error) { for i, p := range probabilities { log.Debug().Msgf("Token %s probability %.3f", choice.Logprobs.Tokens[i], p) } - return completion.Choices[0].Text, nil + return &completion.Choices[0], nil } - return "", nil + return nil, nil } diff --git a/pkg/witty/suggestions_ui.go b/pkg/witty/suggestions_ui.go index aefbfce..009c9e5 100644 --- a/pkg/witty/suggestions_ui.go +++ b/pkg/witty/suggestions_ui.go @@ -1,29 +1,79 @@ package witty import ( - "time" + "fmt" + "math" + "sort" + "strings" + "github.com/gdamore/tcell/v2" "github.com/rivo/tview" + "github.com/rs/zerolog/log" ) +type topChoice struct { + text string + probability float64 +} + func (w *Witty) showCompletionsUI() { + if w.currentSuggestion == nil || w.currentSuggestion.Text == "" { + return + } app := tview.NewApplication() - list := tview.NewList(). - AddItem("List item 1", "Some explanatory text", 'a', nil). - AddItem("List item 2", "Some explanatory text", 'b', nil). - AddItem("List item 3", "Some explanatory text", 'c', nil). - AddItem("List item 4", "Some explanatory text", 'd', nil). - AddItem("Quit", "Press to exit", 'q', func() { + list := tview.NewList() + topChoices := topChoices(w.currentSuggestion.Logprobs.TopLogProbs[0]) + shortcut := 'a' + for _, choice := range topChoices { + list.AddItem(choice.text, fmt.Sprintf("%.0f%%", 100*choice.probability), shortcut, nil) + shortcut++ + } + list.SetBorder(true).SetTitle("Suggestions") + list.SetBorderPadding(10, 10, 10, 10) + // escape key closes the list + list.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + if event.Key() == tcell.KeyEscape { app.Stop() - }) - go func() { - // sleep for 5 seconds - time.Sleep(5 * time.Second) - app.QueueUpdateDraw(func() { - list.AddItem("List item 5", "Some explanatory text", 'e', nil) - }) - }() + } + return event + }) + + prompt := w.getPrompt() + for i, choice := range topChoices { + go func(itemIndex int, choice topChoice) { + s, err := w.suggest(prompt + choice.text) + if err != nil { + log.Debug().Msgf("error suggesting %s: %v", choice.text, err) + } + app.QueueUpdateDraw(func() { + _, secondary := list.GetItemText(itemIndex) + list.SetItemText(itemIndex, choice.text+s.Text, secondary) + }) + }(i, choice) + } + + list.SetSelectedFunc(func(index int, mainText, secondaryText string, shortcut rune) { + app.Stop() + w.currentSuggestion.Text = strings.TrimRight(mainText, " ") + }) + if err := app.SetRoot(list, true).SetFocus(list).Run(); err != nil { panic(err) } } + +func topChoices(m map[string]float64) []topChoice { + var choices []topChoice + for k, v := range m { + choices = append(choices, topChoice{k, v}) + } + // sort by probability + sort.Slice(choices, func(i, j int) bool { + return choices[i].probability > choices[j].probability + }) + // transform logprob to probability + for i := range choices { + choices[i].probability = math.Exp(choices[i].probability) + } + return choices +} diff --git a/pkg/witty/witty.go b/pkg/witty/witty.go index d914c1c..5857407 100644 --- a/pkg/witty/witty.go +++ b/pkg/witty/witty.go @@ -24,7 +24,7 @@ type Witty struct { shellCommand string shellArgs []string wittyState int - currentSuggestion string + currentSuggestion *codex.Choice completionParameters codex.CompletionParameters terminalState vt10x.State vterm *vt10x.VT @@ -102,7 +102,7 @@ func (w *Witty) Run() error { if w.wittyState == StateSuggesting { // Reset the state as output has change w.wittyState = StateNormal - w.currentSuggestion = "" + w.currentSuggestion = nil } w.triggerScreenUpdate() } @@ -153,29 +153,30 @@ func (w *Witty) triggerScreenUpdate() { } func (w *Witty) fetchSuggestions() { - prompt := getPrompt(w.terminalState) + prompt := w.getPrompt() if len(prompt) > 0 { log.Debug().Msgf("prompt: %s", prompt) suggestion, err := w.suggest(prompt) if err != nil { log.Error().Err(err).Msg("error fetching suggestion") w.wittyState = StateNormal - w.currentSuggestion = "" + w.currentSuggestion = nil return } if w.wittyState == StateFetchingSuggestions { // someone else might have already changed the state w.wittyState = StateSuggesting - w.currentSuggestion = strings.TrimRight(suggestion, " ") + w.currentSuggestion = suggestion + w.currentSuggestion.Text = strings.TrimRight(suggestion.Text, " ") w.triggerScreenUpdate() } } else { w.wittyState = StateNormal - w.currentSuggestion = "" + w.currentSuggestion = nil } } -func getPrompt(state vt10x.State) string { - prompt := state.StringBeforeCursor() +func (w *Witty) getPrompt() string { + prompt := w.terminalState.StringBeforeCursor() if len(prompt) > 0 { prompt = prompt[:len(prompt)-1] // remove the trailing newline inserted wrongly by the vt10x parser } @@ -205,16 +206,16 @@ func (w *Witty) updateScreen(s tcell.Screen, state *vt10x.State, width, height i if state.CursorVisible() { curx, cury := state.Cursor() s.ShowCursor(curx, cury) - if w.currentSuggestion != "" { + if w.currentSuggestion != nil && w.currentSuggestion.Text != "" { style := tcell.StyleDefault.Foreground(w.suggestionColor) x := curx y := cury - for i := 0; i < len(w.currentSuggestion); i++ { - if w.currentSuggestion[i] == '\n' { + for i := 0; i < len(w.currentSuggestion.Text); i++ { + if w.currentSuggestion.Text[i] == '\n' { y++ x = 0 } - s.SetContent(x, y, rune(w.currentSuggestion[i]), nil, style) + s.SetContent(x, y, rune(w.currentSuggestion.Text[i]), nil, style) x++ } } @@ -229,8 +230,8 @@ func (w *Witty) stdinToShellLoop(stdin chan []byte) { log.Debug().Msgf("stdin: %+v", data) switch w.wittyState { case StateSuggesting: - if data[0] == '\t' && len(w.currentSuggestion) > 0 { - _, err := w.shellPty.Write([]byte(w.currentSuggestion)) + if data[0] == '\t' && w.currentSuggestion != nil && len(w.currentSuggestion.Text) > 0 { + _, err := w.shellPty.Write([]byte(w.currentSuggestion.Text)) if err != nil { log.Error().Err(err).Msg("failed to write to shell") os.Exit(1) @@ -247,11 +248,11 @@ func (w *Witty) stdinToShellLoop(stdin chan []byte) { } w.wittyState = StateNormal - w.currentSuggestion = "" + w.currentSuggestion = nil case StateFetchingSuggestions: // invalidate the suggestion fetch request as it is based on a stale prompt at this point w.wittyState = StateNormal - w.currentSuggestion = "" + w.currentSuggestion = nil } _, err := w.shellPty.Write(data) if err != nil {