Skip to content

Commit

Permalink
added support for generation of multiple completion options
Browse files Browse the repository at this point in the history
  • Loading branch information
jjviana committed Dec 13, 2021
1 parent 3cbe5c6 commit a0746e9
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 37 deletions.
12 changes: 6 additions & 6 deletions pkg/witty/suggest.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ const (
davinci = "davinci-codex"
)

func (w *Witty) suggest(prompt string) (string, error) {
func (w *Witty) suggest(prompt string) (*codex.Choice, error) {
// Try cushman first, as it is faster and cheaper
suggestion, err := w.suggestWithEngine(cushman, prompt)
if err != nil || len(suggestion) == 0 {
if err != nil || suggestion == nil || len(suggestion.Text) == 0 {
// Try davinci as a fallback
suggestion, err = w.suggestWithEngine(davinci, prompt)
}
return suggestion, err
}

func (w *Witty) suggestWithEngine(engine, prompt string) (string, error) {
func (w *Witty) suggestWithEngine(engine, prompt string) (*codex.Choice, error) {
log.Debug().Msgf("requesting suggestion to %s with prompt: %s", engine, prompt)

request := w.completionParameters
Expand All @@ -29,7 +29,7 @@ func (w *Witty) suggestWithEngine(engine, prompt string) (string, error) {

completion, err := codex.GenerateCompletions(request)
if err != nil {
return "", err
return nil, err
}
log.Debug().Msgf("Got completions: %+v", completion)

Expand All @@ -39,7 +39,7 @@ func (w *Witty) suggestWithEngine(engine, prompt string) (string, error) {
for i, p := range probabilities {
log.Debug().Msgf("Token %s probability %.3f", choice.Logprobs.Tokens[i], p)
}
return completion.Choices[0].Text, nil
return &completion.Choices[0], nil
}
return "", nil
return nil, nil
}
80 changes: 65 additions & 15 deletions pkg/witty/suggestions_ui.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,79 @@
package witty

import (
"time"
"fmt"
"math"
"sort"
"strings"

"github.com/gdamore/tcell/v2"
"github.com/rivo/tview"
"github.com/rs/zerolog/log"
)

type topChoice struct {
text string
probability float64
}

func (w *Witty) showCompletionsUI() {
if w.currentSuggestion == nil || w.currentSuggestion.Text == "" {
return
}
app := tview.NewApplication()
list := tview.NewList().
AddItem("List item 1", "Some explanatory text", 'a', nil).
AddItem("List item 2", "Some explanatory text", 'b', nil).
AddItem("List item 3", "Some explanatory text", 'c', nil).
AddItem("List item 4", "Some explanatory text", 'd', nil).
AddItem("Quit", "Press to exit", 'q', func() {
list := tview.NewList()
topChoices := topChoices(w.currentSuggestion.Logprobs.TopLogProbs[0])
shortcut := 'a'
for _, choice := range topChoices {
list.AddItem(choice.text, fmt.Sprintf("%.0f%%", 100*choice.probability), shortcut, nil)
shortcut++
}
list.SetBorder(true).SetTitle("Suggestions")
list.SetBorderPadding(10, 10, 10, 10)
// escape key closes the list
list.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
if event.Key() == tcell.KeyEscape {
app.Stop()
})
go func() {
// sleep for 5 seconds
time.Sleep(5 * time.Second)
app.QueueUpdateDraw(func() {
list.AddItem("List item 5", "Some explanatory text", 'e', nil)
})
}()
}
return event
})

prompt := w.getPrompt()
for i, choice := range topChoices {
go func(itemIndex int, choice topChoice) {
s, err := w.suggest(prompt + choice.text)
if err != nil {
log.Debug().Msgf("error suggesting %s: %v", choice.text, err)
}
app.QueueUpdateDraw(func() {
_, secondary := list.GetItemText(itemIndex)
list.SetItemText(itemIndex, choice.text+s.Text, secondary)
})
}(i, choice)
}

list.SetSelectedFunc(func(index int, mainText, secondaryText string, shortcut rune) {
app.Stop()
w.currentSuggestion.Text = strings.TrimRight(mainText, " ")
})

if err := app.SetRoot(list, true).SetFocus(list).Run(); err != nil {
panic(err)
}
}

func topChoices(m map[string]float64) []topChoice {
var choices []topChoice
for k, v := range m {
choices = append(choices, topChoice{k, v})
}
// sort by probability
sort.Slice(choices, func(i, j int) bool {
return choices[i].probability > choices[j].probability
})
// transform logprob to probability
for i := range choices {
choices[i].probability = math.Exp(choices[i].probability)
}
return choices
}
33 changes: 17 additions & 16 deletions pkg/witty/witty.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type Witty struct {
shellCommand string
shellArgs []string
wittyState int
currentSuggestion string
currentSuggestion *codex.Choice
completionParameters codex.CompletionParameters
terminalState vt10x.State
vterm *vt10x.VT
Expand Down Expand Up @@ -102,7 +102,7 @@ func (w *Witty) Run() error {
if w.wittyState == StateSuggesting {
// Reset the state as output has change
w.wittyState = StateNormal
w.currentSuggestion = ""
w.currentSuggestion = nil
}
w.triggerScreenUpdate()
}
Expand Down Expand Up @@ -153,29 +153,30 @@ func (w *Witty) triggerScreenUpdate() {
}

func (w *Witty) fetchSuggestions() {
prompt := getPrompt(w.terminalState)
prompt := w.getPrompt()
if len(prompt) > 0 {
log.Debug().Msgf("prompt: %s", prompt)
suggestion, err := w.suggest(prompt)
if err != nil {
log.Error().Err(err).Msg("error fetching suggestion")
w.wittyState = StateNormal
w.currentSuggestion = ""
w.currentSuggestion = nil
return
}
if w.wittyState == StateFetchingSuggestions { // someone else might have already changed the state
w.wittyState = StateSuggesting
w.currentSuggestion = strings.TrimRight(suggestion, " ")
w.currentSuggestion = suggestion
w.currentSuggestion.Text = strings.TrimRight(suggestion.Text, " ")
w.triggerScreenUpdate()
}
} else {
w.wittyState = StateNormal
w.currentSuggestion = ""
w.currentSuggestion = nil
}
}

func getPrompt(state vt10x.State) string {
prompt := state.StringBeforeCursor()
func (w *Witty) getPrompt() string {
prompt := w.terminalState.StringBeforeCursor()
if len(prompt) > 0 {
prompt = prompt[:len(prompt)-1] // remove the trailing newline inserted wrongly by the vt10x parser
}
Expand Down Expand Up @@ -205,16 +206,16 @@ func (w *Witty) updateScreen(s tcell.Screen, state *vt10x.State, width, height i
if state.CursorVisible() {
curx, cury := state.Cursor()
s.ShowCursor(curx, cury)
if w.currentSuggestion != "" {
if w.currentSuggestion != nil && w.currentSuggestion.Text != "" {
style := tcell.StyleDefault.Foreground(w.suggestionColor)
x := curx
y := cury
for i := 0; i < len(w.currentSuggestion); i++ {
if w.currentSuggestion[i] == '\n' {
for i := 0; i < len(w.currentSuggestion.Text); i++ {
if w.currentSuggestion.Text[i] == '\n' {
y++
x = 0
}
s.SetContent(x, y, rune(w.currentSuggestion[i]), nil, style)
s.SetContent(x, y, rune(w.currentSuggestion.Text[i]), nil, style)
x++
}
}
Expand All @@ -229,8 +230,8 @@ func (w *Witty) stdinToShellLoop(stdin chan []byte) {
log.Debug().Msgf("stdin: %+v", data)
switch w.wittyState {
case StateSuggesting:
if data[0] == '\t' && len(w.currentSuggestion) > 0 {
_, err := w.shellPty.Write([]byte(w.currentSuggestion))
if data[0] == '\t' && w.currentSuggestion != nil && len(w.currentSuggestion.Text) > 0 {
_, err := w.shellPty.Write([]byte(w.currentSuggestion.Text))
if err != nil {
log.Error().Err(err).Msg("failed to write to shell")
os.Exit(1)
Expand All @@ -247,11 +248,11 @@ func (w *Witty) stdinToShellLoop(stdin chan []byte) {
}

w.wittyState = StateNormal
w.currentSuggestion = ""
w.currentSuggestion = nil
case StateFetchingSuggestions:
// invalidate the suggestion fetch request as it is based on a stale prompt at this point
w.wittyState = StateNormal
w.currentSuggestion = ""
w.currentSuggestion = nil
}
_, err := w.shellPty.Write(data)
if err != nil {
Expand Down

0 comments on commit a0746e9

Please sign in to comment.