Skip to content

Commit

Permalink
small refactorings and adding support for requesting top log probabil…
Browse files Browse the repository at this point in the history
…ities from Codex
  • Loading branch information
jjviana committed Dec 7, 2021
1 parent 9b26fbb commit 3cbe5c6
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 32 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module github.com/jjviana/codex

go 1.16


require (
github.com/ActiveState/vt10x v1.3.1
github.com/creack/pty v1.1.17
Expand Down
13 changes: 7 additions & 6 deletions pkg/codex/codex.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type CompletionParameters struct {
FrequencyPenalty float64
PresencePenalty float64
Stop []string
LogProbs int
}

// GenerateCompletions generates a list of possible completions for the given prompt.
Expand Down Expand Up @@ -64,9 +65,9 @@ func GenerateCompletions(params CompletionParameters) (Completion, error) {
"top_p": %f,
"frequency_penalty": %f,
"presence_penalty": %f,
"logprobs": 0,
"logprobs": %d,
"stop": %s
}`, promptJSON, params.Temperature, params.MaxTokens, params.TopP, params.FrequencyPenalty, params.PresencePenalty, string(stopJSON))
}`, promptJSON, params.Temperature, params.MaxTokens, params.TopP, params.FrequencyPenalty, params.PresencePenalty, params.LogProbs, string(stopJSON))

resp, err := httpPost(url, params.APIKey, body)
if err != nil {
Expand Down Expand Up @@ -105,10 +106,10 @@ type Choice struct {
}

type Logprobs struct {
TextOffset []float64 `json:"text_offset"`
TokenLogProbs []float64 `json:"token_logprobs"`
Tokens []string `json:"tokens"`
TopLogProbs []float64 `json:"top_logprobs"`
TextOffset []float64 `json:"text_offset"`
TokenLogProbs []float64 `json:"token_logprobs"`
Tokens []string `json:"tokens"`
TopLogProbs []map[string]float64 `json:"top_logprobs"`
}

func (l Logprobs) TokenProbabilities() []float64 {
Expand Down
28 changes: 28 additions & 0 deletions pkg/witty/suggestions_ui.go
Original file line number Diff line number Diff line change
@@ -1 +1,29 @@
package witty

import (
"time"

"github.com/rivo/tview"
)

func (w *Witty) showCompletionsUI() {
app := tview.NewApplication()
list := tview.NewList().
AddItem("List item 1", "Some explanatory text", 'a', nil).
AddItem("List item 2", "Some explanatory text", 'b', nil).
AddItem("List item 3", "Some explanatory text", 'c', nil).
AddItem("List item 4", "Some explanatory text", 'd', nil).
AddItem("Quit", "Press to exit", 'q', func() {
app.Stop()
})
go func() {
// sleep for 5 seconds
time.Sleep(5 * time.Second)
app.QueueUpdateDraw(func() {
list.AddItem("List item 5", "Some explanatory text", 'e', nil)
})
}()
if err := app.SetRoot(list, true).SetFocus(list).Run(); err != nil {
panic(err)
}
}
21 changes: 19 additions & 2 deletions pkg/witty/tty.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package witty
import (
"errors"
"fmt"
"io"
"os"
"os/signal"
"strconv"
Expand All @@ -25,6 +26,7 @@ import (
"time"

"github.com/gdamore/tcell/v2"
"github.com/rs/zerolog/log"
"golang.org/x/term"
)

Expand All @@ -45,13 +47,18 @@ type stdIoTty struct {
}

func (tty *stdIoTty) Read(b []byte) (int, error) {
// log.Debug().Msgf("tty.Read()")
n, err := tty.in.Read(b)
if err != nil {
return n, err
log.Debug().Msgf("tty.Read() error: %v", err)
return n, io.EOF
}
log.Debug().Msgf("tty.Read() - %d bytes", n)
if tty.mirror != nil {
log.Debug().Msgf("tty.Read() - mirroring %d bytes", n)
tty.mirror <- b[:n]
}
log.Debug().Msgf("tty.Read() - done")
return n, nil
}

Expand All @@ -64,6 +71,7 @@ func (tty *stdIoTty) Close() error {
}

func (tty *stdIoTty) Start() error {
log.Debug().Msgf("tty.Start()")
tty.l.Lock()
defer tty.l.Unlock()

Expand All @@ -85,7 +93,15 @@ func (tty *stdIoTty) Start() error {
return errors.New("device is not a terminal")
}

_ = tty.in.SetReadDeadline(time.Time{})
err = tty.in.SetReadDeadline(time.Time{})
if err != nil {
log.Debug().Msgf("tty.Start() - SetReadDeadline() failed: %v", err)
}
err = syscall.SetNonblock(tty.fd, false)
if err != nil {
log.Debug().Msgf("tty.Start() - SetNonblock() failed: %v", err)
}

saved, err := term.MakeRaw(tty.fd) // also sets vMin and vTime
if err != nil {
return err
Expand Down Expand Up @@ -124,6 +140,7 @@ func (tty *stdIoTty) Drain() error {
}

func (tty *stdIoTty) Stop() error {
log.Debug().Msgf("tty.Stop()")
tty.l.Lock()
if err := term.Restore(tty.fd, tty.saved); err != nil {
tty.l.Unlock()
Expand Down
63 changes: 39 additions & 24 deletions pkg/witty/witty.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,27 @@ const (
type Witty struct {
shellCommand string
shellArgs []string
shellState int
wittyState int
currentSuggestion string
completionParameters codex.CompletionParameters
terminalState vt10x.State
vterm *vt10x.VT
screen tcell.Screen
shellPty *os.File
suggestionColor tcell.Color
updateTrigger chan struct{}
}

func New(completionParameters codex.CompletionParameters, color tcell.Color, shell string, args []string) *Witty {
return &Witty{
shellState: StateNormal,
w := &Witty{
wittyState: StateNormal,
completionParameters: completionParameters,
suggestionColor: color,
shellCommand: shell,
shellArgs: args,
}
w.completionParameters.LogProbs = 10
return w
}

func (w *Witty) Run() error {
Expand Down Expand Up @@ -86,7 +89,7 @@ func (w *Witty) Run() error {
w.vterm.Resize(width, height)

endc := make(chan bool)
updatec := make(chan struct{}, 1)
w.updateTrigger = make(chan struct{}, 1)
go func() {
defer close(endc)
// Parses the shell output
Expand All @@ -96,15 +99,12 @@ func (w *Witty) Run() error {
fmt.Fprintln(os.Stderr, err)
break
}
if w.shellState == StateSuggesting {
if w.wittyState == StateSuggesting {
// Reset the state as output has change
w.shellState = StateNormal
w.wittyState = StateNormal
w.currentSuggestion = ""
}
select {
case updatec <- struct{}{}:
default:
}
w.triggerScreenUpdate()
}
}()

Expand All @@ -131,38 +131,45 @@ func (w *Witty) Run() error {
case <-endc:
return nil

case <-updatec:
case <-w.updateTrigger:
w.updateScreen(w.screen, &w.terminalState, width, height)

case <-time.After(1 * time.Second):
log.Debug().Msg("shell is idle, state is " + string(w.shellState))
if w.shellState == StateNormal {
w.shellState = StateFetchingSuggestions
go w.fetchSuggestions(updatec)
log.Debug().Msg("shell is idle, state is " + string(w.wittyState))
if w.wittyState == StateNormal {
w.wittyState = StateFetchingSuggestions
go w.fetchSuggestions()
}

}
}
}

func (w *Witty) fetchSuggestions(updatec chan struct{}) {
func (w *Witty) triggerScreenUpdate() {
select {
case w.updateTrigger <- struct{}{}:
default:
}
}

func (w *Witty) fetchSuggestions() {
prompt := getPrompt(w.terminalState)
if len(prompt) > 0 {
log.Debug().Msgf("prompt: %s", prompt)
suggestion, err := w.suggest(prompt)
if err != nil {
log.Error().Err(err).Msg("error fetching suggestion")
w.shellState = StateNormal
w.wittyState = StateNormal
w.currentSuggestion = ""
return
}
if w.shellState == StateFetchingSuggestions { // someone else might have already changed the state
w.shellState = StateSuggesting
if w.wittyState == StateFetchingSuggestions { // someone else might have already changed the state
w.wittyState = StateSuggesting
w.currentSuggestion = strings.TrimRight(suggestion, " ")
updatec <- struct{}{} // trigger a screen updateScreen
w.triggerScreenUpdate()
}
} else {
w.shellState = StateNormal
w.wittyState = StateNormal
w.currentSuggestion = ""
}
}
Expand Down Expand Up @@ -220,7 +227,7 @@ func (w *Witty) updateScreen(s tcell.Screen, state *vt10x.State, width, height i
func (w *Witty) stdinToShellLoop(stdin chan []byte) {
for data := range stdin {
log.Debug().Msgf("stdin: %+v", data)
switch w.shellState {
switch w.wittyState {
case StateSuggesting:
if data[0] == '\t' && len(w.currentSuggestion) > 0 {
_, err := w.shellPty.Write([]byte(w.currentSuggestion))
Expand All @@ -229,13 +236,21 @@ func (w *Witty) stdinToShellLoop(stdin chan []byte) {
os.Exit(1)
}
data = data[1:]
} else if data[0] == 15 { // ctrl-o
log.Debug().Msgf("Suspending normal UI...")
w.screen.Suspend()
w.showCompletionsUI()
w.screen.Resume()
log.Debug().Msgf("Resumed from suggestions UI")
w.triggerScreenUpdate()
continue
}

w.shellState = StateNormal
w.wittyState = StateNormal
w.currentSuggestion = ""
case StateFetchingSuggestions:
// invalidate the suggestion fetch request as it is based on a stale prompt at this point
w.shellState = StateNormal
w.wittyState = StateNormal
w.currentSuggestion = ""
}
_, err := w.shellPty.Write(data)
Expand Down

0 comments on commit 3cbe5c6

Please sign in to comment.