From 157f280c967f6ad12dd4ca4312225e180ec68131 Mon Sep 17 00:00:00 2001 From: Colton Date: Fri, 9 Jan 2026 00:24:40 -0700 Subject: [PATCH] improve repl to handle history and expected arrow key backspace etc behavior --- go.mod | 9 ++- go.sum | 4 ++ repl/repl.go | 164 +++++++++++++++++++++++++++++++++++++++++++-------- 3 files changed, 152 insertions(+), 25 deletions(-) create mode 100644 go.sum diff --git a/go.mod b/go.mod index e768802..b9a3456 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,10 @@ module github.com/chirst/cdb -go 1.24 +go 1.24.0 + +toolchain go1.24.3 + +require ( + golang.org/x/sys v0.39.0 // indirect + golang.org/x/term v0.38.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..08d079a --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= +golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= diff --git a/repl/repl.go b/repl/repl.go index b8aa634..3b67b8d 100644 --- a/repl/repl.go +++ b/repl/repl.go @@ -2,12 +2,17 @@ package repl import ( - "bufio" + "errors" "fmt" + "io" "os" + "os/signal" + "slices" "strings" + "syscall" "github.com/chirst/cdb/db" + "golang.org/x/term" ) const ( @@ -15,39 +20,60 @@ const ( emptyRowValue = "NULL" // emptyHeaderValue is printed when the cell in a header is the empty string emptyHeaderValue = "" -) - -// ansi are codes that will color terminal output -const ( - ansiWarn = "\033[33m" - ansiReset = "\033[0m" + // prompt is the prompt. + prompt = "cdb> " + // promptContinued is the prompt when it is pending termination for example + // by a semi colon. + promptContinued = "...> " ) type repl struct { - db *db.DB + db *db.DB + terminal *term.Terminal } func New(db *db.DB) *repl { - return &repl{db: db} + r := &repl{ + db: db, + terminal: term.NewTerminal(os.Stdin, prompt), + } + r.loadHistory() + return r } func (r *repl) Run() { - fmt.Println("Welcome to cdb. Type .exit to exit") + r.writeLn("Welcome to cdb. Type .exit to exit") if r.db.UseMemory { - fmt.Println(ansiWarn + "WARN database is running in memory and will not persist changes" + ansiReset) + r.writeWarning("WARN database is running in memory and will not persist changes") } - reader := bufio.NewScanner(os.Stdin) + + // Handling kill signals works under two methods for the REPL. When the + // terminal is in raw mode the signals are caught by readline as bytes. When + // the terminal is not in raw mode the signals are caught by the following + // channel. + // + // The handling keeping in mind two major considerations in that the + // terminal history is written to and the database always allows a long + // running query to be shut down. + c := make(chan os.Signal, 2) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + go func() { + <-c + r.exitGracefully() + }() + previousInput := "" - for r.getInput(reader, previousInput) { - input := previousInput + reader.Text() + for { + line := r.readLine(previousInput) + input := previousInput + line if len(input) == 0 { continue } if input[0] == '.' { if input == ".exit" { - os.Exit(0) + r.exitGracefully() } - fmt.Println("Command not supported") + r.writeLn("Command not supported") continue } @@ -61,27 +87,50 @@ func (r *repl) Run() { for _, statement := range statements { result := r.db.Execute(statement, []any{}) if result.Err != nil { - fmt.Printf("Err: %s\n", result.Err) + r.writeLn("Err: " + result.Err.Error()) continue } if result.Text != "" { - fmt.Println(result.Text) + r.writeLn(result.Text) } if len(result.ResultRows) != 0 { - fmt.Println(r.printRows(result.ResultHeader, result.ResultRows)) + r.writeLn(r.printRows(result.ResultHeader, result.ResultRows)) } - fmt.Printf("Time: %s\n", result.Duration) + r.writeLn("Time: " + result.Duration.String()) } } } -func (*repl) getInput(reader *bufio.Scanner, previousInput string) bool { +func (r *repl) readLine(previousInput string) string { + oldState, err := term.MakeRaw(int(os.Stdin.Fd())) + if err != nil { + panic(err) + } + defer term.Restore(int(os.Stdin.Fd()), oldState) if previousInput == "" { - fmt.Printf("cdb> ") + r.terminal.SetPrompt(prompt) } else { - fmt.Printf("...> ") + r.terminal.SetPrompt(promptContinued) + } + line, err := r.terminal.ReadLine() + if err != nil { + if err == io.EOF { + term.Restore(int(os.Stdin.Fd()), oldState) + r.exitGracefully() + } + panic("err reading line: " + err.Error()) } - return reader.Scan() + return line +} + +func (r *repl) writeLn(text string) { + r.terminal.Write(([]byte)(text + "\n")) +} + +func (r *repl) writeWarning(text string) { + r.terminal.Write(r.terminal.Escape.Yellow) + r.writeLn(text) + r.terminal.Write(r.terminal.Escape.Reset) } func (r *repl) printRows(resultHeader []string, resultRows [][]*string) string { @@ -163,3 +212,70 @@ func (*repl) printRow(row []*string, widths []int) string { } return ret } + +func (r *repl) exitGracefully() { + r.saveHistory() + os.Exit(0) +} + +func (r *repl) loadHistory() { + p, err := r.getHistoryPath() + if err != nil { + r.writeWarning("failed to get history path " + err.Error()) + return + } + contents, err := os.ReadFile(p) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return + } + r.writeWarning("failed to load history " + err.Error()) + return + } + lines := strings.Split((string)(contents), "\n") + slices.Reverse(lines) + for _, line := range lines { + if line == "" { + continue + } + r.terminal.History.Add(line) + } +} + +func (r *repl) saveHistory() { + history := []byte{} + for i := range r.terminal.History.Len() { + str_entry := r.terminal.History.At(i) + byte_entry := ([]byte)(str_entry + "\n") + history = append(history, byte_entry...) + } + p, err := r.getHistoryPath() + if err != nil { + r.writeWarning("failed to get history path for saving " + err.Error()) + return + } + f, err := os.OpenFile(p, os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + r.writeWarning("failed to open history file for saving " + err.Error()) + return + } + defer f.Close() + err = f.Truncate(0) + if err != nil { + r.writeWarning("failed to overwrite history " + err.Error()) + return + } + _, err = f.Write(history) + if err != nil { + r.writeWarning("failed to write history " + err.Error()) + return + } +} + +func (r *repl) getHistoryPath() (string, error) { + dir, err := os.UserHomeDir() + if err != nil { + return "", err + } + return dir + "/.cdb_history", nil +}