From 365055b71970245a94f842cbbd49153ae0e614be Mon Sep 17 00:00:00 2001 From: John McBride Date: Mon, 2 Feb 2026 22:30:41 -0500 Subject: [PATCH 1/2] feat: GolangCI linting Signed-off-by: John McBride --- .dagger/build.go | 15 ++- .dagger/linting.go | 38 ++++++++ .dagger/main.go | 26 +++--- .github/workflows/ci.yaml | 17 ++++ .golangci.yml | 117 ++++++++++++++++++++++++ api/api_test.go | 2 +- api/mcp/mcp.go | 12 +-- api/mcp/mcp_test.go | 4 +- api/mcp/search.go | 10 +- api/search/search.go | 34 +++---- api/search/search_test.go | 12 +-- api/search_handler_test.go | 26 +++--- cmd/tapes/chat/chat.go | 7 +- cmd/tapes/checkout/checkout.go | 20 ++-- cmd/tapes/checkout/checkout_test.go | 12 +-- cmd/tapes/search/search.go | 15 ++- cmd/tapes/serve/api/api.go | 12 +-- cmd/tapes/serve/proxy/proxy.go | 8 +- cmd/tapes/serve/serve.go | 12 +-- cmd/version/version.go | 7 +- dagger.json | 7 ++ makefile | 10 ++ pkg/dotdir/checkout.go | 4 +- pkg/embeddings/ollama/ollama.go | 10 +- pkg/llm/message.go | 8 +- pkg/llm/provider/anthropic/anthropic.go | 16 ++-- pkg/llm/provider/ollama/ollama.go | 16 ++-- pkg/llm/provider/openai/openai.go | 46 +++++----- pkg/llm/provider/openai/types.go | 10 -- pkg/llm/response.go | 2 +- pkg/llm/stream.go | 2 +- pkg/merkle/dag.go | 14 +-- pkg/merkle/dag_test.go | 18 ++-- pkg/storage/ent/driver/driver.go | 8 +- pkg/storage/error.go | 6 +- pkg/storage/inmemory/inmemory.go | 37 ++++---- pkg/storage/sqlite/sqlite.go | 14 +-- pkg/storage/sqlite/sqlite_test.go | 14 +-- pkg/utils/test/vector.go | 4 +- pkg/vector/chroma/chroma.go | 39 ++++---- pkg/vector/chroma/chroma_test.go | 14 +-- pkg/vector/sqlitevec/sqlitevec.go | 93 +++++++------------ pkg/vector/sqlitevec/sqlitevec_test.go | 30 +++--- pkg/vector/utils/new.go | 7 +- proxy/header/header_test.go | 28 +++--- proxy/proxy.go | 7 +- proxy/proxy_test.go | 58 ++++++------ proxy/worker/pool.go | 9 +- proxy/worker/pool_test.go | 9 +- 49 files changed, 565 insertions(+), 381 deletions(-) create mode 100644 .dagger/linting.go create mode 100644 .golangci.yml diff --git a/.dagger/build.go b/.dagger/build.go index 01f6861..27def91 100644 --- a/.dagger/build.go +++ b/.dagger/build.go @@ -72,19 +72,13 @@ func (t *Tapes) buildLinux(outputs *dagger.Directory, ldflags string) *dagger.Di zigDownloadURL := fmt.Sprintf("https://ziglang.org/download/%s/zig-%s-linux-%s.tar.xz", zigVersion, zigArch, zigVersion) zigDir := fmt.Sprintf("zig-%s-linux-%s", zigArch, zigVersion) - golang := dag.Container(). - From("golang:1.25-bookworm"). - WithExec([]string{"apt-get", "update"}). - WithExec([]string{"apt-get", "install", "-y", "libsqlite3-dev", "xz-utils"}). + golang := t.goContainer(). + WithExec([]string{"apt-get", "install", "-y", "xz-utils"}). WithExec([]string{"mkdir", "-p", "/opt/sqlite"}). WithExec([]string{"cp", "/usr/include/sqlite3.h", "/opt/sqlite/"}). WithExec([]string{"cp", "/usr/include/sqlite3ext.h", "/opt/sqlite/"}). WithExec([]string{"sh", "-c", fmt.Sprintf("curl -L %s | tar -xJ -C /usr/local", zigDownloadURL)}). - WithEnvVariable("PATH", fmt.Sprintf("/usr/local/%s:$PATH", zigDir), dagger.ContainerWithEnvVariableOpts{Expand: true}). - WithMountedCache("/go/pkg/mod", dag.CacheVolume("go-mod")). - WithMountedCache("/root/.cache/go-build", dag.CacheVolume("go-build")). - WithDirectory("/src", t.Source). - WithWorkdir("/src") + WithEnvVariable("PATH", fmt.Sprintf("/usr/local/%s:$PATH", zigDir), dagger.ContainerWithEnvVariableOpts{Expand: true}) for _, target := range targets { path := fmt.Sprintf("%s/%s/", target.goos, target.goarch) @@ -129,6 +123,7 @@ func (t *Tapes) buildDarwin(outputs *dagger.Directory, ldflags string) *dagger.D // Use Debian Trixie as the base for darwin builds because the osxcross // toolchain binaries require GLIBC 2.38+ (Bookworm only has 2.36). + // NOTE: this cannot reuse goContainer() since it needs Trixie, not Bookworm. golang := dag.Container(). From("golang:1.25-trixie"). WithExec([]string{"apt-get", "update"}). @@ -139,6 +134,8 @@ func (t *Tapes) buildDarwin(outputs *dagger.Directory, ldflags string) *dagger.D WithDirectory("/osxcross", osxcross). WithEnvVariable("PATH", "/osxcross/bin:$PATH", dagger.ContainerWithEnvVariableOpts{Expand: true}). WithEnvVariable("LD_LIBRARY_PATH", "/osxcross/lib:$LD_LIBRARY_PATH", dagger.ContainerWithEnvVariableOpts{Expand: true}). + WithEnvVariable("CGO_ENABLED", "1"). + WithEnvVariable("GOEXPERIMENT", "jsonv2"). WithMountedCache("/go/pkg/mod", dag.CacheVolume("go-mod")). WithMountedCache("/root/.cache/go-build", dag.CacheVolume("go-build")). WithDirectory("/src", t.Source). diff --git a/.dagger/linting.go b/.dagger/linting.go new file mode 100644 index 0000000..84e324c --- /dev/null +++ b/.dagger/linting.go @@ -0,0 +1,38 @@ +package main + +import ( + "context" + "fmt" + + "dagger/tapes/internal/dagger" +) + +const golangciLintVersion = "v2.8.0" + +// lintOpts returns the common GolangcilintOpts used by both CheckLint and FixLint. +// It layers golangci-lint on top of goContainer() so the sqlite dev headers, +// CGO, and Go caches are already in place. +func (t *Tapes) lintOpts() dagger.GolangcilintOpts { + base := t.goContainer(). + WithExec([]string{ + "go", + "install", + fmt.Sprintf("github.com/golangci/golangci-lint/v2/cmd/golangci-lint@%s", golangciLintVersion), + }) + + return dagger.GolangcilintOpts{ + BaseCtr: base, + Config: t.Source.File(".golangci.yml"), + } +} + +// CheckLint runs golangci-lint against the tapes source code without applying fixes. +func (t *Tapes) CheckLint(ctx context.Context) (string, error) { + return dag.Golangcilint(t.Source, t.lintOpts()).Check(ctx) +} + +// FixLint runs golangci-lint against the tapes source code with --fix, applying +// automatic fixes where possible, and returns the modified source directory. +func (t *Tapes) FixLint(ctx context.Context) *dagger.Directory { + return dag.Golangcilint(t.Source, t.lintOpts()).Lint() +} diff --git a/.dagger/main.go b/.dagger/main.go index ce28e1d..402209d 100644 --- a/.dagger/main.go +++ b/.dagger/main.go @@ -31,25 +31,27 @@ func New( } } -// Test runs the tapes unit tests via "go test" -func (t *Tapes) Test(ctx context.Context) (string, error) { - return t.testContainer(). - WithExec([]string{"go", "test", "-v", "./..."}). - Stdout(ctx) -} - -// testContainer returns a container configured for running tests -// with a local gcc toolchain for CGO and sqlite dependencies. -func (t *Tapes) testContainer() *dagger.Container { +// goContainer returns a Debian Bookworm-based Go container with gcc, +// libsqlite3-dev, CGO enabled, and the project source mounted. +// +// It is the shared foundation for tests, builds, and linting. +func (t *Tapes) goContainer() *dagger.Container { return dag.Container(). From("golang:1.25-bookworm"). WithExec([]string{"apt-get", "update"}). - WithExec([]string{"apt-get", "install", "-y", "gcc"}). - WithExec([]string{"apt-get", "install", "-y", "libsqlite3-dev"}). + WithExec([]string{"apt-get", "install", "-y", "gcc", "libsqlite3-dev"}). WithEnvVariable("CGO_ENABLED", "1"). WithEnvVariable("GOEXPERIMENT", "jsonv2"). + WithEnvVariable("PATH", "/go/bin:$PATH", dagger.ContainerWithEnvVariableOpts{Expand: true}). WithMountedCache("/go/pkg/mod", dag.CacheVolume("go-mod")). WithMountedCache("/root/.cache/go-build", dag.CacheVolume("go-build")). WithWorkdir("/src"). WithDirectory("/src", t.Source) } + +// Test runs the tapes unit tests via "go test" +func (t *Tapes) Test(ctx context.Context) (string, error) { + return t.goContainer(). + WithExec([]string{"go", "test", "-v", "./..."}). + Stdout(ctx) +} diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index c2a04d2..68fc34a 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -33,6 +33,23 @@ jobs: args: test version: ${{ env.DAGGER_VERSION }} + lint-check: + name: GolangCI Lint Check + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Daggerverse golangcilint + uses: dagger/dagger-for-github@v8.2.0 + env: + DAGGER_CLOUD_TOKEN: ${{ secrets.DAGGER_CLOUD_TOKEN }} + with: + verb: call + args: "-m github.com/papercomputeco/daggerverse/golangcilint check" + version: ${{ env.DAGGER_VERSION }} + build: name: Build runs-on: ubuntu-latest diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..3130ac2 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,117 @@ +# yaml-language-server: $schema=https://golangci-lint.run/jsonschema/golangci.jsonschema.json +# +# Paper Compute Co. — opinionated golangci-lint defaults +# https://golangci-lint.run/docs/configuration/file/ + +version: "2" + +linters: + enable: + # --- Bugs --- + - bodyclose # Report unclosed HTTP response bodies. + - contextcheck # Non-inherited context usage. + - durationcheck # Accidental time.duration * time.duration. + - errcheck # Raises unchecked errors. + - errchkjson # Unchecked JSON marshal/unmarshal errors. + - errorlint # Go 1.13+ error wrapping best practices. + - exhaustive # Check exhaustiveness of enum switch statements. + - fatcontext # Nested contexts in loops / closures. + - govet # Examines and reports suspicious constructs. + - ineffassign # Reports when assignments to existing vars are not used. + - makezero # Slices with non-zero len that are later appended to. + - musttag # Missing struct tags for (un)marshaling. + - nilerr # Returning nil when err != nil. + - nilnesserr # Reports returning a different nil error after err check. + - noctx # Reports HTTP requests without context. + - unused # Checks Go code for unused consts, vars, funcs, types. + + # --- Code quality --- + - copyloopvar # loop variable copy issues + - dupword # duplicate words in comments / strings + - goconst # repeated strings that could be constants + - gocritic # broad set of style & performance diagnostics + - gosec # security-oriented checks + - misspell # common English typos + - nakedret # naked returns in long functions + - prealloc # slice pre-allocation hints + - predeclared # shadowing predeclared identifiers + - reassign # package-variable reassignment + - revive # golint successor — style & correctness + - unconvert # unnecessary type conversions + - unparam # unused function parameters + - wastedassign # wasted assignments + - whitespace # unnecessary blank lines + + # --- Modernization --- + - exptostd # x/exp -> stdlib replacements + - intrange # use integer ranges in for loops + - modernize # modern Go idioms + - usestdlibvars # use stdlib constants/variables + - usetesting # prefer testing package helpers + - perfsprint # faster fmt.Sprintf alternatives + + # --- Style --- + - errname # error/sentinel naming conventions + - nolintlint # well-formed nolint directives + - recvcheck # consistent receiver types + - nosprintfhostport # host:port via net.JoinHostPort + + disable: + - gochecknoglobals # forbid all globals + - gochecknoinits # forbid all init() + + exclusions: + presets: + - comments + - std-error-handling + - common-false-positives + - legacy + generated: lax + rules: + - path: _test\.go + linters: + - errcheck + - gosec + - dupl + - goconst + - gocritic + - revive + - unparam + # @jpmcb - TODO: rename these packages to avoid "meaningless" names + - path: api/ + text: "var-naming: avoid meaningless package names" + linters: + - revive + - path: pkg/utils/ + text: "var-naming: avoid meaningless package names" + linters: + - revive + +formatters: + enable: + - gci + - gofmt + - goimports + - gofumpt + settings: + gci: + sections: + - standard + - default + - prefix(github.com/papercomputeco/tapes) + goimports: + local-prefixes: + - github.com/papercomputeco/tapes + exclusions: + generated: lax + paths: + - .dagger/internal + - .dagger/dagger.gen.go + +issues: + # Show all issues + max-issues-per-linter: 0 + max-same-issues: 0 + +run: + timeout: 5m diff --git a/api/api_test.go b/api/api_test.go index 890b1f2..f0e52bc 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -35,7 +35,7 @@ var _ = Describe("buildHistory", func() { BeforeEach(func() { var err error logger, _ := zap.NewDevelopment() - inMem := inmemory.NewInMemoryDriver() + inMem := inmemory.NewDriver() driver = inMem dagLoader = inMem server, err = NewServer(Config{ListenAddr: ":0"}, driver, dagLoader, logger) diff --git a/api/mcp/mcp.go b/api/mcp/mcp.go index 64111f9..0affc7a 100644 --- a/api/mcp/mcp.go +++ b/api/mcp/mcp.go @@ -2,7 +2,7 @@ package mcp import ( - "fmt" + "errors" "net/http" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -61,16 +61,16 @@ func NewServer(c Config) (*Server, error) { } if c.DagLoader == nil { - return nil, fmt.Errorf("storage driver is required") + return nil, errors.New("storage driver is required") } if c.VectorDriver == nil { - return nil, fmt.Errorf("vector driver is required") + return nil, errors.New("vector driver is required") } if c.Embedder == nil { - return nil, fmt.Errorf("embedder is required") + return nil, errors.New("embedder is required") } if c.Logger == nil { - return nil, fmt.Errorf("logger is required") + return nil, errors.New("logger is required") } // Add tools @@ -83,7 +83,7 @@ func NewServer(c Config) (*Server, error) { // Create a streamable HTTP net/http handler for stateless operations s.handler = mcp.NewStreamableHTTPHandler( - func(r *http.Request) *mcp.Server { + func(_ *http.Request) *mcp.Server { return mcpServer }, &mcp.StreamableHTTPOptions{ diff --git a/api/mcp/mcp_test.go b/api/mcp/mcp_test.go index 6474784..4d1c8d5 100644 --- a/api/mcp/mcp_test.go +++ b/api/mcp/mcp_test.go @@ -13,14 +13,14 @@ import ( var _ = Describe("MCP Server", func() { var ( server *mcp.Server - driver *inmemory.InMemoryDriver + driver *inmemory.Driver vectorDriver *testutils.MockVectorDriver embedder *testutils.MockEmbedder ) BeforeEach(func() { logger, _ := zap.NewDevelopment() - driver = inmemory.NewInMemoryDriver() + driver = inmemory.NewDriver() vectorDriver = testutils.NewMockVectorDriver() embedder = testutils.NewMockEmbedder() diff --git a/api/mcp/search.go b/api/mcp/search.go index 18bdb85..0138eeb 100644 --- a/api/mcp/search.go +++ b/api/mcp/search.go @@ -15,16 +15,16 @@ var ( searchDescription = "Search over stored LLM sessions using semantic search. Returns the most relevant sessions based on the query text, including the full conversation branch (ancestors and descendants)." ) -// MCPSearchInput represents the input arguments for the MCP search tool. +// SearchInput represents the input arguments for the MCP search tool. // It uses jsonschema tags specific to the MCP protocol. -type MCPSearchInput struct { +type SearchInput struct { Query string `json:"query" jsonschema:"the search query text to find relevant sessions"` TopK int `json:"top_k,omitempty" jsonschema:"number of results to return (default: 5)"` } // handleSearch processes a search request via MCP. // It delegates to the shared search package for the core search logic. -func (s *Server) handleSearch(ctx context.Context, req *mcp.CallToolRequest, input MCPSearchInput) (*mcp.CallToolResult, apisearch.SearchOutput, error) { +func (s *Server) handleSearch(ctx context.Context, _ *mcp.CallToolRequest, input SearchInput) (*mcp.CallToolResult, apisearch.Output, error) { output, err := apisearch.Search( ctx, input.Query, @@ -40,7 +40,7 @@ func (s *Server) handleSearch(ctx context.Context, req *mcp.CallToolRequest, inp Content: []mcp.Content{ &mcp.TextContent{Text: fmt.Sprintf("Search failed: %v", err)}, }, - }, apisearch.SearchOutput{}, nil + }, apisearch.Output{}, nil } // Serialize the structured output as JSON for the text field @@ -53,7 +53,7 @@ func (s *Server) handleSearch(ctx context.Context, req *mcp.CallToolRequest, inp Content: []mcp.Content{ &mcp.TextContent{Text: fmt.Sprintf("Failed to serialize results: %v", err)}, }, - }, apisearch.SearchOutput{}, nil + }, apisearch.Output{}, nil } return &mcp.CallToolResult{ diff --git a/api/search/search.go b/api/search/search.go index 97d68ca..9d97be6 100644 --- a/api/search/search.go +++ b/api/search/search.go @@ -14,14 +14,14 @@ import ( "github.com/papercomputeco/tapes/pkg/vector" ) -// SearchInput represents the input arguments for a search request. -type SearchInput struct { +// Input represents the input arguments for a search request. +type Input struct { Query string `json:"query"` TopK int `json:"top_k,omitempty"` } -// SearchResult represents a single search result. -type SearchResult struct { +// Result represents a single search result. +type Result struct { Hash string `json:"hash"` Score float32 `json:"score"` Role string `json:"role"` @@ -38,11 +38,11 @@ type Turn struct { Matched bool `json:"matched,omitempty"` } -// SearchOutput represents the output of a search operation. -type SearchOutput struct { - Query string `json:"query"` - Results []SearchResult `json:"results"` - Count int `json:"count"` +// Output represents the output of a search operation. +type Output struct { + Query string `json:"query"` + Results []Result `json:"results"` + Count int `json:"count"` } // Search performs a semantic search over stored LLM sessions. @@ -56,7 +56,7 @@ func Search( vectorDriver vector.Driver, dagLoader merkle.DagLoader, logger *zap.Logger, -) (*SearchOutput, error) { +) (*Output, error) { if topK <= 0 { topK = 5 } @@ -79,7 +79,7 @@ func Search( } // Build search results with full branch using merkle.LoadDag - searchResults := make([]SearchResult, 0, len(results)) + searchResults := make([]Result, 0, len(results)) for _, result := range results { dag, err := merkle.LoadDag(ctx, dagLoader, result.Hash) if err != nil { @@ -90,25 +90,25 @@ func Search( continue } - searchResult := BuildSearchResult(result, dag) + searchResult := BuildResult(result, dag) searchResults = append(searchResults, searchResult) } - return &SearchOutput{ + return &Output{ Query: query, Results: searchResults, Count: len(searchResults), }, nil } -// BuildSearchResult converts a vector query result and DAG into a SearchResult. -func BuildSearchResult(result vector.QueryResult, dag *merkle.Dag) SearchResult { +// BuildResult converts a vector query result and DAG into a Result. +func BuildResult(result vector.QueryResult, dag *merkle.Dag) Result { turns := []Turn{} preview := "" role := "" // Build turns from the DAG using Walk (depth-first from root to leaves) - dag.Walk(func(node *merkle.DagNode) (bool, error) { + _ = dag.Walk(func(node *merkle.DagNode) (bool, error) { isMatched := node.Hash == result.Hash turns = append(turns, Turn{ Hash: node.Hash, @@ -125,7 +125,7 @@ func BuildSearchResult(result vector.QueryResult, dag *merkle.Dag) SearchResult return true, nil }) - return SearchResult{ + return Result{ Hash: result.Hash, Score: result.Score, Role: role, diff --git a/api/search/search_test.go b/api/search/search_test.go index 4244ec9..e1140f1 100644 --- a/api/search/search_test.go +++ b/api/search/search_test.go @@ -22,7 +22,7 @@ func TestSearch(t *testing.T) { var _ = Describe("Search", func() { var ( - driver *inmemory.InMemoryDriver + driver *inmemory.Driver vectorDriver *testutils.MockVectorDriver embedder *testutils.MockEmbedder logger *zap.Logger @@ -31,7 +31,7 @@ var _ = Describe("Search", func() { BeforeEach(func() { logger, _ = zap.NewDevelopment() - driver = inmemory.NewInMemoryDriver() + driver = inmemory.NewDriver() vectorDriver = testutils.NewMockVectorDriver() embedder = testutils.NewMockEmbedder() ctx = context.Background() @@ -113,7 +113,7 @@ var _ = Describe("Search", func() { }) }) - Describe("BuildSearchResult", func() { + Describe("BuildResult", func() { It("builds a result from a single node", func() { node := merkle.NewNode(testutils.NewTestBucket("user", "Hello world"), nil) _, err := driver.Put(ctx, node) @@ -130,7 +130,7 @@ var _ = Describe("Search", func() { dag, err := merkle.LoadDag(ctx, driver, node.Hash) Expect(err).NotTo(HaveOccurred()) - searchResult := search.BuildSearchResult(result, dag) + searchResult := search.BuildResult(result, dag) Expect(searchResult.Hash).To(Equal(node.Hash)) Expect(searchResult.Score).To(Equal(float32(0.95))) @@ -163,7 +163,7 @@ var _ = Describe("Search", func() { dag, err := merkle.LoadDag(ctx, driver, node3.Hash) Expect(err).NotTo(HaveOccurred()) - searchResult := search.BuildSearchResult(result, dag) + searchResult := search.BuildResult(result, dag) Expect(searchResult.Hash).To(Equal(node3.Hash)) Expect(searchResult.Turns).To(Equal(3)) @@ -190,7 +190,7 @@ var _ = Describe("Search", func() { } dag := merkle.NewDag() - searchResult := search.BuildSearchResult(result, dag) + searchResult := search.BuildResult(result, dag) Expect(searchResult.Hash).To(Equal("empty-hash")) Expect(searchResult.Turns).To(Equal(0)) diff --git a/api/search_handler_test.go b/api/search_handler_test.go index b814839..0cd604a 100644 --- a/api/search_handler_test.go +++ b/api/search_handler_test.go @@ -21,7 +21,7 @@ import ( var _ = Describe("handleSearchEndpoint", func() { var ( server *Server - inMem *inmemory.InMemoryDriver + inMem *inmemory.Driver vectorDriver *testutils.MockVectorDriver embedder *testutils.MockEmbedder ctx context.Context @@ -29,7 +29,7 @@ var _ = Describe("handleSearchEndpoint", func() { BeforeEach(func() { logger, _ := zap.NewDevelopment() - inMem = inmemory.NewInMemoryDriver() + inMem = inmemory.NewDriver() vectorDriver = testutils.NewMockVectorDriver() embedder = testutils.NewMockEmbedder() ctx = context.Background() @@ -59,7 +59,7 @@ var _ = Describe("handleSearchEndpoint", func() { ) Expect(err).NotTo(HaveOccurred()) - req, err := http.NewRequest("GET", "/v1/search?query=test", nil) + req, err := http.NewRequest(http.MethodGet, "/v1/search?query=test", nil) Expect(err).NotTo(HaveOccurred()) resp, err := noSearchServer.app.Test(req) @@ -70,7 +70,7 @@ var _ = Describe("handleSearchEndpoint", func() { Context("when query parameter is missing", func() { It("returns 400", func() { - req, err := http.NewRequest("GET", "/v1/search", nil) + req, err := http.NewRequest(http.MethodGet, "/v1/search", nil) Expect(err).NotTo(HaveOccurred()) resp, err := server.app.Test(req) @@ -85,7 +85,7 @@ var _ = Describe("handleSearchEndpoint", func() { Context("when query parameter is empty", func() { It("returns 400", func() { - req, err := http.NewRequest("GET", "/v1/search?query=", nil) + req, err := http.NewRequest(http.MethodGet, "/v1/search?query=", nil) Expect(err).NotTo(HaveOccurred()) resp, err := server.app.Test(req) @@ -96,7 +96,7 @@ var _ = Describe("handleSearchEndpoint", func() { Context("when top_k is invalid", func() { It("returns 400 for non-integer top_k", func() { - req, err := http.NewRequest("GET", "/v1/search?query=test&top_k=abc", nil) + req, err := http.NewRequest(http.MethodGet, "/v1/search?query=test&top_k=abc", nil) Expect(err).NotTo(HaveOccurred()) resp, err := server.app.Test(req) @@ -109,7 +109,7 @@ var _ = Describe("handleSearchEndpoint", func() { }) It("returns 400 for zero top_k", func() { - req, err := http.NewRequest("GET", "/v1/search?query=test&top_k=0", nil) + req, err := http.NewRequest(http.MethodGet, "/v1/search?query=test&top_k=0", nil) Expect(err).NotTo(HaveOccurred()) resp, err := server.app.Test(req) @@ -118,7 +118,7 @@ var _ = Describe("handleSearchEndpoint", func() { }) It("returns 400 for negative top_k", func() { - req, err := http.NewRequest("GET", "/v1/search?query=test&top_k=-1", nil) + req, err := http.NewRequest(http.MethodGet, "/v1/search?query=test&top_k=-1", nil) Expect(err).NotTo(HaveOccurred()) resp, err := server.app.Test(req) @@ -129,14 +129,14 @@ var _ = Describe("handleSearchEndpoint", func() { Context("when search succeeds with no results", func() { It("returns 200 with empty results", func() { - req, err := http.NewRequest("GET", "/v1/search?query=hello", nil) + req, err := http.NewRequest(http.MethodGet, "/v1/search?query=hello", nil) Expect(err).NotTo(HaveOccurred()) resp, err := server.app.Test(req) Expect(err).NotTo(HaveOccurred()) Expect(resp.StatusCode).To(Equal(fiber.StatusOK)) - var output apisearch.SearchOutput + var output apisearch.Output body, err := io.ReadAll(resp.Body) Expect(err).NotTo(HaveOccurred()) Expect(json.Unmarshal(body, &output)).To(Succeed()) @@ -167,14 +167,14 @@ var _ = Describe("handleSearchEndpoint", func() { }, } - req, err := http.NewRequest("GET", "/v1/search?query=greeting&top_k=3", nil) + req, err := http.NewRequest(http.MethodGet, "/v1/search?query=greeting&top_k=3", nil) Expect(err).NotTo(HaveOccurred()) resp, err := server.app.Test(req) Expect(err).NotTo(HaveOccurred()) Expect(resp.StatusCode).To(Equal(fiber.StatusOK)) - var output apisearch.SearchOutput + var output apisearch.Output body, err := io.ReadAll(resp.Body) Expect(err).NotTo(HaveOccurred()) Expect(json.Unmarshal(body, &output)).To(Succeed()) @@ -197,7 +197,7 @@ var _ = Describe("handleSearchEndpoint", func() { It("returns 500", func() { vectorDriver.FailQuery = true - req, err := http.NewRequest("GET", "/v1/search?query=test", nil) + req, err := http.NewRequest(http.MethodGet, "/v1/search?query=test", nil) Expect(err).NotTo(HaveOccurred()) resp, err := server.app.Test(req) diff --git a/cmd/tapes/chat/chat.go b/cmd/tapes/chat/chat.go index 9779547..3100296 100644 --- a/cmd/tapes/chat/chat.go +++ b/cmd/tapes/chat/chat.go @@ -5,6 +5,7 @@ package chatcmder import ( "bufio" "bytes" + "context" "encoding/json" "fmt" "io" @@ -87,7 +88,7 @@ func NewChatCmd() *cobra.Command { var err error cmder.debug, err = cmd.Flags().GetBool("debug") if err != nil { - return fmt.Errorf("could not get debug flag: %v", err) + return fmt.Errorf("could not get debug flag: %w", err) } return cmder.run() @@ -104,7 +105,7 @@ func NewChatCmd() *cobra.Command { func (c *chatCommander) run() error { c.logger = logger.NewLogger(c.debug) - defer c.logger.Sync() + defer func() { _ = c.logger.Sync() }() // Load checkout state dotdirManager := dotdir.NewManager() @@ -205,7 +206,7 @@ func (c *chatCommander) sendAndStream(messages []ollamaMessage) (string, error) // POST to the proxy's Ollama-compatible chat endpoint url := c.proxy + "/api/chat" - httpReq, err := http.NewRequest("POST", url, bytes.NewReader(body)) + httpReq, err := http.NewRequestWithContext(context.Background(), http.MethodPost, url, bytes.NewReader(body)) if err != nil { return "", fmt.Errorf("creating request: %w", err) } diff --git a/cmd/tapes/checkout/checkout.go b/cmd/tapes/checkout/checkout.go index 0bf0a91..0cad95d 100644 --- a/cmd/tapes/checkout/checkout.go +++ b/cmd/tapes/checkout/checkout.go @@ -3,10 +3,12 @@ package checkoutcmder import ( + "context" "encoding/json" "fmt" "io" "net/http" + "strings" "time" "github.com/spf13/cobra" @@ -75,12 +77,12 @@ func NewCheckoutCmd() *cobra.Command { var err error cmder.debug, err = cmd.Flags().GetBool("debug") if err != nil { - return fmt.Errorf("could not get debug flag: %v", err) + return fmt.Errorf("could not get debug flag: %w", err) } cmder.api, err = cmd.Flags().GetString("api") if err != nil { - return fmt.Errorf("could not get api flag: %v", err) + return fmt.Errorf("could not get api flag: %w", err) } return cmder.run() @@ -95,7 +97,7 @@ func NewCheckoutCmd() *cobra.Command { func (c *checkoutCommander) run() error { dotdirManager := dotdir.NewManager() c.logger = logger.NewLogger(c.debug) - defer c.logger.Sync() + defer func() { _ = c.logger.Sync() }() // If no hash provided, clear checkout state if c.hash == "" { @@ -150,7 +152,11 @@ func (c *checkoutCommander) fetchHistory(hash string) (*historyResponse, error) url := fmt.Sprintf("%s/dag/history/%s", c.api, hash) client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Get(url) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("requesting history from API: %w", err) } @@ -176,11 +182,11 @@ func (c *checkoutCommander) fetchHistory(hash string) (*historyResponse, error) // extractText concatenates all text content blocks from a message. func extractText(content []llm.ContentBlock) string { - var text string + var b strings.Builder for _, block := range content { if block.Type == "text" { - text += block.Text + b.WriteString(block.Text) } } - return text + return b.String() } diff --git a/cmd/tapes/checkout/checkout_test.go b/cmd/tapes/checkout/checkout_test.go index 18ef607..bc9f016 100644 --- a/cmd/tapes/checkout/checkout_test.go +++ b/cmd/tapes/checkout/checkout_test.go @@ -2,9 +2,9 @@ package checkoutcmder_test import ( "encoding/json" - "fmt" "net/http" "net/http/httptest" + "strings" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -100,13 +100,13 @@ var _ = Describe("History API response parsing", func() { Expect(parsed.Messages[1].Role).To(Equal("assistant")) // Extract text from content blocks - var text string + var b strings.Builder for _, block := range parsed.Messages[1].Content { if block.Type == "text" { - text += block.Text + b.WriteString(block.Text) } } - Expect(text).To(Equal("Hi there!")) + Expect(b.String()).To(Equal("Hi there!")) }) It("correctly handles a mock API server returning history", func() { @@ -141,7 +141,7 @@ var _ = Describe("History API response parsing", func() { defer server.Close() // Fetch from mock server - url := fmt.Sprintf("%s/dag/history/abc123", server.URL) + url := server.URL + "/dag/history/abc123" resp, err := http.Get(url) Expect(err).NotTo(HaveOccurred()) defer resp.Body.Close() @@ -164,7 +164,7 @@ var _ = Describe("History API response parsing", func() { })) defer server.Close() - resp, err := http.Get(fmt.Sprintf("%s/dag/history/unknown", server.URL)) + resp, err := http.Get(server.URL + "/dag/history/unknown") Expect(err).NotTo(HaveOccurred()) defer resp.Body.Close() diff --git a/cmd/tapes/search/search.go b/cmd/tapes/search/search.go index 6ea0a60..227cc6f 100644 --- a/cmd/tapes/search/search.go +++ b/cmd/tapes/search/search.go @@ -2,6 +2,7 @@ package searchcmder import ( + "context" "encoding/json" "fmt" "io" @@ -57,7 +58,7 @@ func NewSearchCmd() *cobra.Command { var err error cmder.debug, err = cmd.Flags().GetBool("debug") if err != nil { - return fmt.Errorf("could not get debug flag: %v", err) + return fmt.Errorf("could not get debug flag: %w", err) } return cmder.run() @@ -72,7 +73,7 @@ func NewSearchCmd() *cobra.Command { func (c *searchCommander) run() error { c.logger = logger.NewLogger(c.debug) - defer c.logger.Sync() + defer func() { _ = c.logger.Sync() }() c.logger.Debug("searching via API", zap.String("api_target", c.apiTarget), @@ -93,7 +94,11 @@ func (c *searchCommander) run() error { c.logger.Debug("requesting search", zap.String("url", searchURL.String())) - resp, err := http.Get(searchURL.String()) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, searchURL.String(), nil) + if err != nil { + return fmt.Errorf("creating search request: %w", err) + } + resp, err := (&http.Client{}).Do(req) if err != nil { return fmt.Errorf("failed to connect to Tapes API at %s: %w", c.apiTarget, err) } @@ -108,7 +113,7 @@ func (c *searchCommander) run() error { return fmt.Errorf("search request failed (HTTP %d): %s", resp.StatusCode, string(body)) } - var output apisearch.SearchOutput + var output apisearch.Output if err := json.Unmarshal(body, &output); err != nil { return fmt.Errorf("failed to parse search response: %w", err) } @@ -129,7 +134,7 @@ func (c *searchCommander) run() error { return nil } -func (c *searchCommander) printResult(rank int, result apisearch.SearchResult) { +func (c *searchCommander) printResult(rank int, result apisearch.Result) { fmt.Printf("\n[%d] Score: %.4f\n", rank, result.Score) fmt.Printf(" Hash: %s\n", result.Hash) diff --git a/cmd/tapes/serve/api/api.go b/cmd/tapes/serve/api/api.go index ff5db0f..f699557 100644 --- a/cmd/tapes/serve/api/api.go +++ b/cmd/tapes/serve/api/api.go @@ -37,7 +37,7 @@ func NewAPICmd() *cobra.Command { var err error cmder.debug, err = cmd.Flags().GetBool("debug") if err != nil { - return fmt.Errorf("could not get debug flag: %v", err) + return fmt.Errorf("could not get debug flag: %w", err) } return cmder.run() @@ -52,7 +52,7 @@ func NewAPICmd() *cobra.Command { func (c *apiCommander) run() error { c.logger = logger.NewLogger(c.debug) - defer c.logger.Sync() + defer func() { _ = c.logger.Sync() }() driver, err := c.newStorageDriver() if err != nil { @@ -84,7 +84,7 @@ func (c *apiCommander) run() error { func (c *apiCommander) newStorageDriver() (storage.Driver, error) { if c.sqlitePath != "" { - driver, err := sqlite.NewSQLiteDriver(c.sqlitePath) + driver, err := sqlite.NewDriver(c.sqlitePath) if err != nil { return nil, fmt.Errorf("failed to create SQLite storer: %w", err) } @@ -93,12 +93,12 @@ func (c *apiCommander) newStorageDriver() (storage.Driver, error) { } c.logger.Info("using in-memory storage") - return inmemory.NewInMemoryDriver(), nil + return inmemory.NewDriver(), nil } func (c *apiCommander) newDagLoader() (merkle.DagLoader, error) { if c.sqlitePath != "" { - driver, err := sqlite.NewSQLiteDriver(c.sqlitePath) + driver, err := sqlite.NewDriver(c.sqlitePath) if err != nil { return nil, fmt.Errorf("failed to create SQLite storer: %w", err) } @@ -107,5 +107,5 @@ func (c *apiCommander) newDagLoader() (merkle.DagLoader, error) { } c.logger.Info("using in-memory storage") - return inmemory.NewInMemoryDriver(), nil + return inmemory.NewDriver(), nil } diff --git a/cmd/tapes/serve/proxy/proxy.go b/cmd/tapes/serve/proxy/proxy.go index d72ebb0..e327961 100644 --- a/cmd/tapes/serve/proxy/proxy.go +++ b/cmd/tapes/serve/proxy/proxy.go @@ -56,7 +56,7 @@ func NewProxyCmd() *cobra.Command { var err error cmder.debug, err = cmd.Flags().GetBool("debug") if err != nil { - return fmt.Errorf("could not get debug flag: %v", err) + return fmt.Errorf("could not get debug flag: %w", err) } return cmder.run() @@ -78,7 +78,7 @@ func NewProxyCmd() *cobra.Command { func (c *proxyCommander) run() error { c.logger = logger.NewLogger(c.debug) - defer c.logger.Sync() + defer func() { _ = c.logger.Sync() }() driver, err := c.newStorageDriver() if err != nil { @@ -142,7 +142,7 @@ func (c *proxyCommander) run() error { func (c *proxyCommander) newStorageDriver() (storage.Driver, error) { if c.sqlitePath != "" { - driver, err := sqlite.NewSQLiteDriver(c.sqlitePath) + driver, err := sqlite.NewDriver(c.sqlitePath) if err != nil { return nil, fmt.Errorf("failed to create SQLite storer: %w", err) } @@ -151,5 +151,5 @@ func (c *proxyCommander) newStorageDriver() (storage.Driver, error) { } c.logger.Info("using in-memory storage") - return inmemory.NewInMemoryDriver(), nil + return inmemory.NewDriver(), nil } diff --git a/cmd/tapes/serve/serve.go b/cmd/tapes/serve/serve.go index 88b88f0..3c540b5 100644 --- a/cmd/tapes/serve/serve.go +++ b/cmd/tapes/serve/serve.go @@ -68,7 +68,7 @@ func NewServeCmd() *cobra.Command { var err error cmder.debug, err = cmd.Flags().GetBool("debug") if err != nil { - return fmt.Errorf("could not get debug flag: %v", err) + return fmt.Errorf("could not get debug flag: %w", err) } return cmder.run() }, @@ -101,7 +101,7 @@ func NewServeCmd() *cobra.Command { func (c *ServeCommander) run() error { c.logger = logger.NewLogger(c.debug) - defer c.logger.Sync() + defer func() { _ = c.logger.Sync() }() // Create shared driver driver, err := c.newStorageDriver() @@ -211,7 +211,7 @@ func (c *ServeCommander) run() error { func (c *ServeCommander) newStorageDriver() (storage.Driver, error) { if c.sqlitePath != "" { - driver, err := sqlite.NewSQLiteDriver(c.sqlitePath) + driver, err := sqlite.NewDriver(c.sqlitePath) if err != nil { return nil, fmt.Errorf("failed to create SQLite storer: %w", err) } @@ -220,12 +220,12 @@ func (c *ServeCommander) newStorageDriver() (storage.Driver, error) { } c.logger.Info("using in-memory storage") - return inmemory.NewInMemoryDriver(), nil + return inmemory.NewDriver(), nil } func (c *ServeCommander) newDagLoader() (merkle.DagLoader, error) { if c.sqlitePath != "" { - driver, err := sqlite.NewSQLiteDriver(c.sqlitePath) + driver, err := sqlite.NewDriver(c.sqlitePath) if err != nil { return nil, fmt.Errorf("failed to create SQLite storer: %w", err) } @@ -234,5 +234,5 @@ func (c *ServeCommander) newDagLoader() (merkle.DagLoader, error) { } c.logger.Info("using in-memory storage") - return inmemory.NewInMemoryDriver(), nil + return inmemory.NewDriver(), nil } diff --git a/cmd/version/version.go b/cmd/version/version.go index 832dbfc..031fcdd 100644 --- a/cmd/version/version.go +++ b/cmd/version/version.go @@ -9,10 +9,7 @@ import ( "github.com/papercomputeco/tapes/pkg/utils" ) -type VersionCommander struct { - semVer string - commit string -} +type VersionCommander struct{} func NewVersionCmd() *cobra.Command { cmder := &VersionCommander{} @@ -21,7 +18,7 @@ func NewVersionCmd() *cobra.Command { Use: "version", Short: "displays version", Long: "displays the version of this CLI", - RunE: func(cmd *cobra.Command, _ []string) error { + RunE: func(_ *cobra.Command, _ []string) error { return cmder.run() }, } diff --git a/dagger.json b/dagger.json index a704236..09ee97e 100644 --- a/dagger.json +++ b/dagger.json @@ -4,5 +4,12 @@ "sdk": { "source": "go" }, + "dependencies": [ + { + "name": "golangcilint", + "source": "github.com/papercomputeco/daggerverse/golangcilint@main", + "pin": "8646f8127d0bb464437e12f0a5d83d72edc211d8" + } + ], "source": ".dagger" } diff --git a/makefile b/makefile index 7a865a4..b715537 100644 --- a/makefile +++ b/makefile @@ -14,6 +14,16 @@ LDFLAGS := -s -w \ format: find . -type f -name "*.go" -exec goimports -local github.com/papercomputeco/tapes -w {} \; +.PHONY: check-lint +check-lint: ## Runs golangci-lint check. Auto-fixes are not automatically applied. + $(call print-target) + dagger call check-lint + +.PHONY: fix-lint +fix-lint: ## Runs golangci-lint lint with auto-fixes applied. + $(call print-target) + dagger call fix-lint export --path . + .PHONY: generate generate: ## Regenerates ent code from schema go generate ./pkg/storage/ent/... diff --git a/pkg/dotdir/checkout.go b/pkg/dotdir/checkout.go index 9d1d251..eb475ed 100644 --- a/pkg/dotdir/checkout.go +++ b/pkg/dotdir/checkout.go @@ -59,7 +59,7 @@ func (m *Manager) LoadCheckoutState(overrideDir string) (*CheckoutState, error) // SaveCheckout persists the checkout state to a target .tapes/checkout.json. func (m *Manager) SaveCheckout(state *CheckoutState, overrideDir string) error { if state == nil { - return fmt.Errorf("cannot save nil checkout state") + return errors.New("cannot save nil checkout state") } dir, err := m.Target(overrideDir) @@ -73,7 +73,7 @@ func (m *Manager) SaveCheckout(state *CheckoutState, overrideDir string) error { } path := filepath.Join(dir, checkoutFile) - if err := os.WriteFile(path, data, 0o644); err != nil { + if err := os.WriteFile(path, data, 0o644); err != nil { //nolint:gosec // @jpmcb - TODO: refactor file permissions return fmt.Errorf("writing checkout state: %w", err) } diff --git a/pkg/embeddings/ollama/ollama.go b/pkg/embeddings/ollama/ollama.go index 9c4ad05..2321c17 100644 --- a/pkg/embeddings/ollama/ollama.go +++ b/pkg/embeddings/ollama/ollama.go @@ -81,18 +81,18 @@ func (e *Embedder) Embed(ctx context.Context, text string) ([]float32, error) { jsonBody, err := json.Marshal(reqBody) if err != nil { - return nil, fmt.Errorf("%w: marshaling request: %v", vector.ErrEmbedding, err) + return nil, fmt.Errorf("%w: marshaling request: %w", vector.ErrEmbedding, err) } - req, err := http.NewRequestWithContext(ctx, "POST", e.baseURL+"/api/embed", bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, e.baseURL+"/api/embed", bytes.NewReader(jsonBody)) if err != nil { - return nil, fmt.Errorf("%w: creating request: %v", vector.ErrEmbedding, err) + return nil, fmt.Errorf("%w: creating request: %w", vector.ErrEmbedding, err) } req.Header.Set("Content-Type", "application/json") resp, err := e.httpClient.Do(req) if err != nil { - return nil, fmt.Errorf("%w: sending request: %v", vector.ErrEmbedding, err) + return nil, fmt.Errorf("%w: sending request: %w", vector.ErrEmbedding, err) } defer resp.Body.Close() @@ -103,7 +103,7 @@ func (e *Embedder) Embed(ctx context.Context, text string) ([]float32, error) { var embedResp embedResponse if err := json.NewDecoder(resp.Body).Decode(&embedResp); err != nil { - return nil, fmt.Errorf("%w: decoding response: %v", vector.ErrEmbedding, err) + return nil, fmt.Errorf("%w: decoding response: %w", vector.ErrEmbedding, err) } if len(embedResp.Embeddings) == 0 { diff --git a/pkg/llm/message.go b/pkg/llm/message.go index 4422dc2..2f5d078 100644 --- a/pkg/llm/message.go +++ b/pkg/llm/message.go @@ -1,5 +1,7 @@ package llm +import "strings" + // Message represents a single message in a conversation. // Content is stored as an array of ContentBlocks to support multimodal content // (text, images, tool use, etc.) in a provider-agnostic way. @@ -45,11 +47,11 @@ func NewTextMessage(role, text string) Message { // GetText returns the concatenated text content from all text blocks in the message. // This is a convenience method for simple text-only messages. func (m *Message) GetText() string { - var result string + var b strings.Builder for _, block := range m.Content { if block.Type == "text" { - result += block.Text + b.WriteString(block.Text) } } - return result + return b.String() } diff --git a/pkg/llm/provider/anthropic/anthropic.go b/pkg/llm/provider/anthropic/anthropic.go index a721544..874a84d 100644 --- a/pkg/llm/provider/anthropic/anthropic.go +++ b/pkg/llm/provider/anthropic/anthropic.go @@ -9,23 +9,23 @@ import ( "github.com/papercomputeco/tapes/pkg/llm" ) -// provider implements the Provider interface for Anthropic's Claude API. -type provider struct{} +// Provider implements the Provider interface for Anthropic's Claude API. +type Provider struct{} // New -func New() *provider { return &provider{} } +func New() *Provider { return &Provider{} } // Name -func (p *provider) Name() string { +func (p *Provider) Name() string { return "anthropic" } // DefaultStreaming is false - Anthropic requires explicit "stream": true. -func (p *provider) DefaultStreaming() bool { +func (p *Provider) DefaultStreaming() bool { return false } -func (p *provider) ParseRequest(payload []byte) (*llm.ChatRequest, error) { +func (p *Provider) ParseRequest(payload []byte) (*llm.ChatRequest, error) { var req anthropicRequest if err := json.Unmarshal(payload, &req); err != nil { return nil, err @@ -123,7 +123,7 @@ func parseAnthropicSystem(system any) string { } } -func (p *provider) ParseResponse(payload []byte) (*llm.ChatResponse, error) { +func (p *Provider) ParseResponse(payload []byte) (*llm.ChatResponse, error) { var resp anthropicResponse if err := json.Unmarshal(payload, &resp); err != nil { return nil, err @@ -173,6 +173,6 @@ func (p *provider) ParseResponse(payload []byte) (*llm.ChatResponse, error) { return result, nil } -func (p *provider) ParseStreamChunk(payload []byte) (*llm.StreamChunk, error) { +func (p *Provider) ParseStreamChunk(_ []byte) (*llm.StreamChunk, error) { panic("not implemented") } diff --git a/pkg/llm/provider/ollama/ollama.go b/pkg/llm/provider/ollama/ollama.go index c26596d..14448d2 100644 --- a/pkg/llm/provider/ollama/ollama.go +++ b/pkg/llm/provider/ollama/ollama.go @@ -6,21 +6,21 @@ import ( "github.com/papercomputeco/tapes/pkg/llm" ) -// provider implements the Provider interface for Ollama's API. -type provider struct{} +// Provider implements the Provider interface for Ollama's API. +type Provider struct{} -func New() *provider { return &provider{} } +func New() *Provider { return &Provider{} } -func (o *provider) Name() string { +func (o *Provider) Name() string { return "ollama" } // DefaultStreaming is true - Ollama streams by default -func (o *provider) DefaultStreaming() bool { +func (o *Provider) DefaultStreaming() bool { return true } -func (o *provider) ParseRequest(payload []byte) (*llm.ChatRequest, error) { +func (o *Provider) ParseRequest(payload []byte) (*llm.ChatRequest, error) { var req ollamaRequest if err := json.Unmarshal(payload, &req); err != nil { return nil, err @@ -105,7 +105,7 @@ func (o *provider) ParseRequest(payload []byte) (*llm.ChatRequest, error) { return result, nil } -func (o *provider) ParseResponse(payload []byte) (*llm.ChatResponse, error) { +func (o *Provider) ParseResponse(payload []byte) (*llm.ChatResponse, error) { var resp ollamaResponse if err := json.Unmarshal(payload, &resp); err != nil { return nil, err @@ -180,6 +180,6 @@ func (o *provider) ParseResponse(payload []byte) (*llm.ChatResponse, error) { return result, nil } -func (o *provider) ParseStreamChunk(payload []byte) (*llm.StreamChunk, error) { +func (o *Provider) ParseStreamChunk(_ []byte) (*llm.StreamChunk, error) { panic("Not yet implemented") } diff --git a/pkg/llm/provider/openai/openai.go b/pkg/llm/provider/openai/openai.go index 2d4678c..bcc60b6 100644 --- a/pkg/llm/provider/openai/openai.go +++ b/pkg/llm/provider/openai/openai.go @@ -8,21 +8,21 @@ import ( "github.com/papercomputeco/tapes/pkg/llm" ) -// provider implements the Provider interface for OpenAI's Chat Completions API. -type provider struct{} +// Provider implements the Provider interface for OpenAI's Chat Completions API. +type Provider struct{} -func New() *provider { return &provider{} } +func New() *Provider { return &Provider{} } -func (o *provider) Name() string { +func (o *Provider) Name() string { return "openai" } // DefaultStreaming is false - OpenAI requires explicit "stream": true. -func (o *provider) DefaultStreaming() bool { +func (o *Provider) DefaultStreaming() bool { return false } -func (o *provider) ParseRequest(payload []byte) (*llm.ChatRequest, error) { +func (o *Provider) ParseRequest(payload []byte) (*llm.ChatRequest, error) { var req openaiRequest if err := json.Unmarshal(payload, &req); err != nil { return nil, err @@ -63,13 +63,14 @@ func (o *provider) ParseRequest(payload []byte) (*llm.ChatRequest, error) { // Handle tool calls in assistant messages for _, tc := range msg.ToolCalls { var input map[string]any - json.Unmarshal([]byte(tc.Function.Arguments), &input) - converted.Content = append(converted.Content, llm.ContentBlock{ - Type: "tool_use", - ToolUseID: tc.ID, - ToolName: tc.Function.Name, - ToolInput: input, - }) + if err := json.Unmarshal([]byte(tc.Function.Arguments), &input); err == nil { + converted.Content = append(converted.Content, llm.ContentBlock{ + Type: "tool_use", + ToolUseID: tc.ID, + ToolName: tc.Function.Name, + ToolInput: input, + }) + } } // Handle tool results @@ -130,7 +131,7 @@ func (o *provider) ParseRequest(payload []byte) (*llm.ChatRequest, error) { return result, nil } -func (o *provider) ParseResponse(payload []byte) (*llm.ChatResponse, error) { +func (o *Provider) ParseResponse(payload []byte) (*llm.ChatResponse, error) { var resp openaiResponse if err := json.Unmarshal(payload, &resp); err != nil { return nil, err @@ -173,13 +174,14 @@ func (o *provider) ParseResponse(payload []byte) (*llm.ChatResponse, error) { // Handle tool calls for _, tc := range msg.ToolCalls { var input map[string]any - json.Unmarshal([]byte(tc.Function.Arguments), &input) - content = append(content, llm.ContentBlock{ - Type: "tool_use", - ToolUseID: tc.ID, - ToolName: tc.Function.Name, - ToolInput: input, - }) + if err := json.Unmarshal([]byte(tc.Function.Arguments), &input); err == nil { + content = append(content, llm.ContentBlock{ + Type: "tool_use", + ToolUseID: tc.ID, + ToolName: tc.Function.Name, + ToolInput: input, + }) + } } var usage *llm.Usage @@ -211,6 +213,6 @@ func (o *provider) ParseResponse(payload []byte) (*llm.ChatResponse, error) { return result, nil } -func (o *provider) ParseStreamChunk(payload []byte) (*llm.StreamChunk, error) { +func (o *Provider) ParseStreamChunk(_ []byte) (*llm.StreamChunk, error) { panic("Not yet implemented") } diff --git a/pkg/llm/provider/openai/types.go b/pkg/llm/provider/openai/types.go index 485b062..d8d909a 100644 --- a/pkg/llm/provider/openai/types.go +++ b/pkg/llm/provider/openai/types.go @@ -32,16 +32,6 @@ type openaiMessage struct { } `json:"tool_calls,omitempty"` } -// openaiContentPart represents a content part for multimodal messages. -type openaiContentPart struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - ImageURL *struct { - URL string `json:"url"` - Detail string `json:"detail,omitempty"` - } `json:"image_url,omitempty"` -} - // openaiResponse represents OpenAI's response format. type openaiResponse struct { ID string `json:"id"` diff --git a/pkg/llm/response.go b/pkg/llm/response.go index a99ea9b..d2b6c99 100644 --- a/pkg/llm/response.go +++ b/pkg/llm/response.go @@ -13,7 +13,7 @@ type ChatResponse struct { Model string `json:"model"` // Response timestamp - CreatedAt time.Time `json:"created_at,omitempty"` + CreatedAt time.Time `json:"created_at,omitzero"` // The assistant's response message Message Message `json:"message"` diff --git a/pkg/llm/stream.go b/pkg/llm/stream.go index e7b9608..ffd6e83 100644 --- a/pkg/llm/stream.go +++ b/pkg/llm/stream.go @@ -10,7 +10,7 @@ type StreamChunk struct { Model string `json:"model"` // Chunk timestamp - CreatedAt time.Time `json:"created_at,omitempty"` + CreatedAt time.Time `json:"created_at,omitzero"` // The content of this chunk (typically a partial message) Message Message `json:"message"` diff --git a/pkg/merkle/dag.go b/pkg/merkle/dag.go index 6ba0358..483d48c 100644 --- a/pkg/merkle/dag.go +++ b/pkg/merkle/dag.go @@ -2,6 +2,7 @@ package merkle import ( "context" + "errors" "fmt" ) @@ -113,12 +114,13 @@ func (d *Dag) Leaves() []*DagNode { // Walk traverses the DAG depth-first from root, calling fn for each node. // If the provided function returns false, traversal stops. // If the provided function errors, traversal stops and the error is propagated. -func (d *Dag) Walk(f func(*DagNode) (bool, error)) { +func (d *Dag) Walk(f func(*DagNode) (bool, error)) error { if d.Root == nil { - return + return nil } - d.walkNode(d.Root, f) + _, err := d.walkNode(d.Root, f) + return err } // walkNode recursively, depth first, traverses the given node with the provided @@ -168,7 +170,7 @@ func (d *Dag) Descendants(hash string) []*DagNode { } descendants := []*DagNode{} - d.Walk(func(n *DagNode) (bool, error) { + _ = d.Walk(func(n *DagNode) (bool, error) { descendants = append(descendants, n.Children...) return true, nil }) @@ -211,7 +213,7 @@ func (d *Dag) BranchPoints() []*DagNode { // - The node already exists in the DAG func (d *Dag) addNode(node *Node) (*DagNode, error) { if node == nil { - return nil, fmt.Errorf("cannot add nil node to dag") + return nil, errors.New("cannot add nil node to dag") } dagNode, ok := d.index[node.Hash] @@ -227,7 +229,7 @@ func (d *Dag) addNode(node *Node) (*DagNode, error) { if node.ParentHash == nil { // This is a root node if d.Root != nil { - return nil, fmt.Errorf("DAG already has a root node") + return nil, errors.New("DAG already has a root node") } d.Root = dagNode diff --git a/pkg/merkle/dag_test.go b/pkg/merkle/dag_test.go index 510ac05..47ca442 100644 --- a/pkg/merkle/dag_test.go +++ b/pkg/merkle/dag_test.go @@ -27,7 +27,7 @@ func dagTestBucket(role, text string) merkle.Bucket { // the DAG from the specified node hash. If loadFromHash is empty, it loads from // the last node in the slice. func buildTestDag(ctx context.Context, nodes []*merkle.Node, loadFromHash string) (*merkle.Dag, error) { - driver := inmemory.NewInMemoryDriver() + driver := inmemory.NewDriver() for _, node := range nodes { if _, err := driver.Put(ctx, node); err != nil { return nil, err @@ -156,10 +156,11 @@ var _ = Describe("Dag", func() { Expect(err).NotTo(HaveOccurred()) var visited []string - dag.Walk(func(node *merkle.DagNode) (bool, error) { + err = dag.Walk(func(node *merkle.DagNode) (bool, error) { visited = append(visited, node.Bucket.ExtractText()) return true, nil }) + Expect(err).NotTo(HaveOccurred()) Expect(visited).To(Equal([]string{"1", "2", "3"})) }) @@ -173,12 +174,13 @@ var _ = Describe("Dag", func() { Expect(err).NotTo(HaveOccurred()) var visited []string - dag.Walk(func(node *merkle.DagNode) (bool, error) { + err = dag.Walk(func(node *merkle.DagNode) (bool, error) { visited = append(visited, node.Bucket.ExtractText()) // Stop after we hit "2" return node.Bucket.ExtractText() != "2", nil }) + Expect(err).NotTo(HaveOccurred()) Expect(visited).To(Equal([]string{"1", "2"})) }) @@ -193,13 +195,14 @@ var _ = Describe("Dag", func() { testErr := errors.New("test error") var visited []string - dag.Walk(func(node *merkle.DagNode) (bool, error) { + walkErr := dag.Walk(func(node *merkle.DagNode) (bool, error) { visited = append(visited, node.Bucket.ExtractText()) if node.Bucket.ExtractText() == "2" { return false, testErr } return true, nil }) + Expect(walkErr).To(MatchError(testErr)) // Should have stopped at node "2" Expect(visited).To(Equal([]string{"1", "2"})) @@ -209,10 +212,11 @@ var _ = Describe("Dag", func() { dag := merkle.NewDag() var visited []string - dag.Walk(func(node *merkle.DagNode) (bool, error) { + err := dag.Walk(func(node *merkle.DagNode) (bool, error) { visited = append(visited, node.Bucket.ExtractText()) return true, nil }) + Expect(err).NotTo(HaveOccurred()) Expect(visited).To(BeEmpty()) }) @@ -365,10 +369,10 @@ var _ = Describe("Dag", func() { }) Describe("LoadDag", func() { - var driver *inmemory.InMemoryDriver + var driver *inmemory.Driver BeforeEach(func() { - driver = inmemory.NewInMemoryDriver() + driver = inmemory.NewDriver() }) It("loads a single node", func() { diff --git a/pkg/storage/ent/driver/driver.go b/pkg/storage/ent/driver/driver.go index f8db1e8..3ba31a0 100644 --- a/pkg/storage/ent/driver/driver.go +++ b/pkg/storage/ent/driver/driver.go @@ -4,6 +4,7 @@ package entdriver import ( "context" "encoding/json" + "errors" "fmt" "github.com/papercomputeco/tapes/pkg/llm" @@ -23,7 +24,7 @@ type EntDriver struct { // false if it already existed. This is a no-op due to content-addressing. func (ed *EntDriver) Put(ctx context.Context, n *merkle.Node) (bool, error) { if n == nil { - return false, fmt.Errorf("cannot store nil node") + return false, errors.New("cannot store nil node") } // Check if node already exists (idempotent insert) @@ -100,7 +101,7 @@ func (ed *EntDriver) Get(ctx context.Context, hash string) (*merkle.Node, error) entNode, err := ed.Client.Node.Get(ctx, hash) if err != nil { if ent.IsNotFound(err) { - return nil, storage.ErrNotFound{Hash: hash} + return nil, storage.NotFoundError{Hash: hash} } return nil, fmt.Errorf("failed to get node: %w", err) } @@ -175,7 +176,7 @@ func (ed *EntDriver) Ancestry(ctx context.Context, hash string) ([]*merkle.Node, current, err := ed.Client.Node.Get(ctx, hash) if err != nil { if ent.IsNotFound(err) { - return nil, storage.ErrNotFound{Hash: hash} + return nil, storage.NotFoundError{Hash: hash} } return nil, fmt.Errorf("failed to get node: %w", err) } @@ -261,7 +262,6 @@ func (ed *EntDriver) entNodeToMerkleNode(entNode *ent.Node) (*merkle.Node, error if entNode.PromptDurationNs != nil { node.Usage.PromptDurationNs = *entNode.PromptDurationNs - } } diff --git a/pkg/storage/error.go b/pkg/storage/error.go index a14c19e..f4bc868 100644 --- a/pkg/storage/error.go +++ b/pkg/storage/error.go @@ -1,11 +1,11 @@ package storage -// ErrNotFound is returned when a node doesn't exist in the store. -type ErrNotFound struct { +// NotFoundError is returned when a node doesn't exist in the store. +type NotFoundError struct { Hash string } -func (e ErrNotFound) Error() string { +func (e NotFoundError) Error() string { if e.Hash == "" { return "node not found" } diff --git a/pkg/storage/inmemory/inmemory.go b/pkg/storage/inmemory/inmemory.go index ed606db..359129f 100644 --- a/pkg/storage/inmemory/inmemory.go +++ b/pkg/storage/inmemory/inmemory.go @@ -2,6 +2,7 @@ package inmemory import ( "context" + "errors" "fmt" "sync" @@ -9,8 +10,8 @@ import ( "github.com/papercomputeco/tapes/pkg/storage" ) -// InMemoryDriver implements Storer using an in-memory map. -type InMemoryDriver struct { +// Driver implements Storer using an in-memory map. +type Driver struct { // mu is a read write sync mutex for locking the mapping of nodes mu sync.RWMutex @@ -19,18 +20,18 @@ type InMemoryDriver struct { nodes map[string]*merkle.Node } -// NewInMemoryDriver creates a new in-memory storer. -func NewInMemoryDriver() *InMemoryDriver { - return &InMemoryDriver{ +// NewDriver creates a new in-memory storer. +func NewDriver() *Driver { + return &Driver{ nodes: make(map[string]*merkle.Node), } } // Put stores a node. Returns true if the node was newly inserted, // false if it already existed (no-op due to content-addressing). -func (s *InMemoryDriver) Put(ctx context.Context, node *merkle.Node) (bool, error) { +func (s *Driver) Put(_ context.Context, node *merkle.Node) (bool, error) { if node == nil { - return false, fmt.Errorf("cannot store nil node") + return false, errors.New("cannot store nil node") } s.mu.Lock() @@ -47,20 +48,20 @@ func (s *InMemoryDriver) Put(ctx context.Context, node *merkle.Node) (bool, erro } // Get retrieves a node by its hash. -func (s *InMemoryDriver) Get(ctx context.Context, hash string) (*merkle.Node, error) { +func (s *Driver) Get(_ context.Context, hash string) (*merkle.Node, error) { s.mu.RLock() defer s.mu.RUnlock() node, ok := s.nodes[hash] if !ok { - return nil, storage.ErrNotFound{Hash: hash} + return nil, storage.NotFoundError{Hash: hash} } return node, nil } // Has checks if a node exists by its hash. -func (s *InMemoryDriver) Has(ctx context.Context, hash string) (bool, error) { +func (s *Driver) Has(_ context.Context, hash string) (bool, error) { s.mu.RLock() defer s.mu.RUnlock() @@ -70,7 +71,7 @@ func (s *InMemoryDriver) Has(ctx context.Context, hash string) (bool, error) { // GetByParent retrieves all nodes that have the provided parent. // This is useful for determining where branching occurs. -func (s *InMemoryDriver) GetByParent(ctx context.Context, parentHash *string) ([]*merkle.Node, error) { +func (s *Driver) GetByParent(_ context.Context, parentHash *string) ([]*merkle.Node, error) { s.mu.RLock() defer s.mu.RUnlock() @@ -90,7 +91,7 @@ func (s *InMemoryDriver) GetByParent(ctx context.Context, parentHash *string) ([ } // List returns all nodes in the store. -func (s *InMemoryDriver) List(ctx context.Context) ([]*merkle.Node, error) { +func (s *Driver) List(_ context.Context) ([]*merkle.Node, error) { s.mu.RLock() defer s.mu.RUnlock() @@ -103,12 +104,12 @@ func (s *InMemoryDriver) List(ctx context.Context) ([]*merkle.Node, error) { } // Roots returns all root nodes -func (s *InMemoryDriver) Roots(ctx context.Context) ([]*merkle.Node, error) { +func (s *Driver) Roots(ctx context.Context) ([]*merkle.Node, error) { return s.GetByParent(ctx, nil) } // Leaves returns all leaf nodes -func (s *InMemoryDriver) Leaves(ctx context.Context) ([]*merkle.Node, error) { +func (s *Driver) Leaves(_ context.Context) ([]*merkle.Node, error) { s.mu.RLock() defer s.mu.RUnlock() @@ -132,7 +133,7 @@ func (s *InMemoryDriver) Leaves(ctx context.Context) ([]*merkle.Node, error) { } // Ancestry returns the path from a node back to its root (node first, root last). -func (s *InMemoryDriver) Ancestry(ctx context.Context, hash string) ([]*merkle.Node, error) { +func (s *Driver) Ancestry(ctx context.Context, hash string) ([]*merkle.Node, error) { var path []*merkle.Node current := hash @@ -153,7 +154,7 @@ func (s *InMemoryDriver) Ancestry(ctx context.Context, hash string) ([]*merkle.N } // Depth returns the depth of a node (0 for roots). -func (s *InMemoryDriver) Depth(ctx context.Context, hash string) (int, error) { +func (s *Driver) Depth(ctx context.Context, hash string) (int, error) { depth := 0 current := hash @@ -173,13 +174,13 @@ func (s *InMemoryDriver) Depth(ctx context.Context, hash string) (int, error) { } // Count returns the number of nodes in the in-memory store. -func (s *InMemoryDriver) Count() int { +func (s *Driver) Count() int { s.mu.RLock() defer s.mu.RUnlock() return len(s.nodes) } // Close is a no-op for the in-memory storer. -func (s *InMemoryDriver) Close() error { +func (s *Driver) Close() error { return nil } diff --git a/pkg/storage/sqlite/sqlite.go b/pkg/storage/sqlite/sqlite.go index bb8aae5..504e16c 100644 --- a/pkg/storage/sqlite/sqlite.go +++ b/pkg/storage/sqlite/sqlite.go @@ -8,20 +8,20 @@ import ( "entgo.io/ent/dialect" entsql "entgo.io/ent/dialect/sql" - _ "github.com/mattn/go-sqlite3" + _ "github.com/mattn/go-sqlite3" // Register sqlite3 driver "github.com/papercomputeco/tapes/pkg/storage/ent" entdriver "github.com/papercomputeco/tapes/pkg/storage/ent/driver" ) -// SQLiteDriver implements storage.Driver using SQLite via the ent driver -type SQLiteDriver struct { +// Driver implements storage.Driver using SQLite via the ent driver +type Driver struct { *entdriver.EntDriver } -// NewSQLiteDriver creates a new SQLite-backed storer. +// NewDriver creates a new SQLite-backed storer. // The dbPath can be a file path or ":memory:" for an in-memory database. -func NewSQLiteDriver(dbPath string) (*SQLiteDriver, error) { +func NewDriver(dbPath string) (*Driver, error) { // Open the database using the github.com/mattn/go-sqlite3 driver (registered as "sqlite3") db, err := sql.Open("sqlite3", dbPath) if err != nil { @@ -29,7 +29,7 @@ func NewSQLiteDriver(dbPath string) (*SQLiteDriver, error) { } // SQLite-specific pragmas - if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { + if _, err := db.ExecContext(context.Background(), "PRAGMA foreign_keys = ON"); err != nil { db.Close() return nil, fmt.Errorf("failed to enable foreign keys: %w", err) } @@ -45,7 +45,7 @@ func NewSQLiteDriver(dbPath string) (*SQLiteDriver, error) { return nil, fmt.Errorf("failed to create schema: %w", err) } - return &SQLiteDriver{ + return &Driver{ EntDriver: &entdriver.EntDriver{ Client: client, }, diff --git a/pkg/storage/sqlite/sqlite_test.go b/pkg/storage/sqlite/sqlite_test.go index cdc12f4..05cbe81 100644 --- a/pkg/storage/sqlite/sqlite_test.go +++ b/pkg/storage/sqlite/sqlite_test.go @@ -25,16 +25,16 @@ func sqliteTestBucket(text string) merkle.Bucket { } } -var _ = Describe("SQLiteDriver", func() { +var _ = Describe("Driver", func() { var ( - driver *sqlite.SQLiteDriver + driver *sqlite.Driver ctx context.Context ) BeforeEach(func() { ctx = context.Background() var err error - driver, err = sqlite.NewSQLiteDriver(":memory:") + driver, err = sqlite.NewDriver(":memory:") Expect(err).NotTo(HaveOccurred()) }) @@ -44,12 +44,12 @@ var _ = Describe("SQLiteDriver", func() { } }) - Describe("NewSQLiteDriver", func() { + Describe("NewDriver", func() { It("creates a driver with file database", func() { tmpDir := GinkgoT().TempDir() dbPath := filepath.Join(tmpDir, "test.db") - s, err := sqlite.NewSQLiteDriver(dbPath) + s, err := sqlite.NewDriver(dbPath) Expect(err).NotTo(HaveOccurred()) defer s.Close() @@ -89,11 +89,11 @@ var _ = Describe("SQLiteDriver", func() { Expect(*retrieved.ParentHash).To(Equal(parent.Hash)) }) - It("returns ErrNotFound for non-existent hash", func() { + It("returns NotFoundError for non-existent hash", func() { _, err := driver.Get(ctx, "nonexistent") Expect(err).To(HaveOccurred()) - var notFoundErr storage.ErrNotFound + var notFoundErr storage.NotFoundError Expect(err).To(BeAssignableToTypeOf(notFoundErr)) }) diff --git a/pkg/utils/test/vector.go b/pkg/utils/test/vector.go index c434bc4..994524b 100644 --- a/pkg/utils/test/vector.go +++ b/pkg/utils/test/vector.go @@ -2,7 +2,7 @@ package testutils import ( "context" - "fmt" + "errors" "github.com/papercomputeco/tapes/pkg/vector" ) @@ -30,7 +30,7 @@ func (m *MockVectorDriver) Add(_ context.Context, docs []vector.Document) error func (m *MockVectorDriver) Query(_ context.Context, _ []float32, topK int) ([]vector.QueryResult, error) { if m.FailQuery { - return nil, fmt.Errorf("mock vector query failure") + return nil, errors.New("mock vector query failure") } if len(m.Results) < topK { diff --git a/pkg/vector/chroma/chroma.go b/pkg/vector/chroma/chroma.go index e8abe3d..dc215e6 100644 --- a/pkg/vector/chroma/chroma.go +++ b/pkg/vector/chroma/chroma.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -30,8 +31,8 @@ const ( defaultMaxRetryDelay = 15 * time.Second ) -// ChromaDriver implements vector.Driver using Chroma's REST API. -type ChromaDriver struct { +// Driver implements vector.Driver using Chroma's REST API. +type Driver struct { baseURL string collectionName string collectionID string @@ -61,11 +62,11 @@ type Config struct { MaxRetryDelay time.Duration } -// NewChromaDriver creates a new Chroma vector driver. +// NewDriver creates a new Chroma vector driver. // It uses exponential delay retries if it cannot connect to Chroma. -func NewChromaDriver(c Config, logger *zap.Logger) (*ChromaDriver, error) { +func NewDriver(c Config, logger *zap.Logger) (*Driver, error) { if c.URL == "" { - return nil, fmt.Errorf("chroma URL is required") + return nil, errors.New("chroma URL is required") } collectionName := c.CollectionName @@ -73,7 +74,7 @@ func NewChromaDriver(c Config, logger *zap.Logger) (*ChromaDriver, error) { collectionName = DefaultCollectionName } - d := &ChromaDriver{ + d := &Driver{ baseURL: c.URL, collectionName: collectionName, httpClient: &http.Client{ @@ -130,11 +131,11 @@ func NewChromaDriver(c Config, logger *zap.Logger) (*ChromaDriver, error) { } // getOrCreateCollection gets an existing collection or creates a new one. -func (d *ChromaDriver) getOrCreateCollection(ctx context.Context) (string, error) { +func (d *Driver) getOrCreateCollection(ctx context.Context) (string, error) { // Try to get existing collection first url := fmt.Sprintf("%s/api/v2/tenants/default_tenant/databases/default_database/collections/%s", d.baseURL, d.collectionName) - req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return "", fmt.Errorf("creating get request: %w", err) } @@ -154,14 +155,14 @@ func (d *ChromaDriver) getOrCreateCollection(ctx context.Context) (string, error } // Collection doesn't exist, create it - createURL := fmt.Sprintf("%s/api/v2/tenants/default_tenant/databases/default_database/collections", d.baseURL) + createURL := d.baseURL + "/api/v2/tenants/default_tenant/databases/default_database/collections" createBody := map[string]string{"name": d.collectionName} jsonBody, err := json.Marshal(createBody) if err != nil { return "", fmt.Errorf("marshaling create request: %w", err) } - req, err = http.NewRequestWithContext(ctx, "POST", createURL, bytes.NewReader(jsonBody)) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, createURL, bytes.NewReader(jsonBody)) if err != nil { return "", fmt.Errorf("creating create request: %w", err) } @@ -187,7 +188,7 @@ func (d *ChromaDriver) getOrCreateCollection(ctx context.Context) (string, error } // Add stores documents with their embeddings. -func (d *ChromaDriver) Add(ctx context.Context, docs []vector.Document) error { +func (d *Driver) Add(ctx context.Context, docs []vector.Document) error { if len(docs) == 0 { return nil } @@ -214,7 +215,7 @@ func (d *ChromaDriver) Add(ctx context.Context, docs []vector.Document) error { } url := fmt.Sprintf("%s/api/v2/tenants/default_tenant/databases/default_database/collections/%s/add", d.baseURL, d.collectionID) - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody)) if err != nil { return fmt.Errorf("creating add request: %w", err) } @@ -239,7 +240,7 @@ func (d *ChromaDriver) Add(ctx context.Context, docs []vector.Document) error { } // Query finds the topK most similar documents to the given embedding. -func (d *ChromaDriver) Query(ctx context.Context, embedding []float32, topK int) ([]vector.QueryResult, error) { +func (d *Driver) Query(ctx context.Context, embedding []float32, topK int) ([]vector.QueryResult, error) { if topK <= 0 { topK = 10 } @@ -256,7 +257,7 @@ func (d *ChromaDriver) Query(ctx context.Context, embedding []float32, topK int) } url := fmt.Sprintf("%s/api/v2/tenants/default_tenant/databases/default_database/collections/%s/query", d.baseURL, d.collectionID) - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody)) if err != nil { return nil, fmt.Errorf("creating query request: %w", err) } @@ -335,7 +336,7 @@ func (d *ChromaDriver) Query(ctx context.Context, embedding []float32, topK int) } // Get retrieves documents by their IDs. -func (d *ChromaDriver) Get(ctx context.Context, ids []string) ([]vector.Document, error) { +func (d *Driver) Get(ctx context.Context, ids []string) ([]vector.Document, error) { if len(ids) == 0 { return nil, nil } @@ -351,7 +352,7 @@ func (d *ChromaDriver) Get(ctx context.Context, ids []string) ([]vector.Document } url := fmt.Sprintf("%s/api/v2/tenants/default_tenant/databases/default_database/collections/%s/get", d.baseURL, d.collectionID) - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody)) if err != nil { return nil, fmt.Errorf("creating get request: %w", err) } @@ -397,7 +398,7 @@ func (d *ChromaDriver) Get(ctx context.Context, ids []string) ([]vector.Document } // Delete removes documents by their IDs. -func (d *ChromaDriver) Delete(ctx context.Context, ids []string) error { +func (d *Driver) Delete(ctx context.Context, ids []string) error { if len(ids) == 0 { return nil } @@ -412,7 +413,7 @@ func (d *ChromaDriver) Delete(ctx context.Context, ids []string) error { } url := fmt.Sprintf("%s/api/v2/tenants/default_tenant/databases/default_database/collections/%s/delete", d.baseURL, d.collectionID) - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody)) if err != nil { return fmt.Errorf("creating delete request: %w", err) } @@ -437,7 +438,7 @@ func (d *ChromaDriver) Delete(ctx context.Context, ids []string) error { } // Close releases resources held by the driver. -func (d *ChromaDriver) Close() error { +func (d *Driver) Close() error { // HTTP client doesn't require explicit cleanup return nil } diff --git a/pkg/vector/chroma/chroma_test.go b/pkg/vector/chroma/chroma_test.go index 4998770..bcd7827 100644 --- a/pkg/vector/chroma/chroma_test.go +++ b/pkg/vector/chroma/chroma_test.go @@ -15,16 +15,16 @@ import ( "github.com/papercomputeco/tapes/pkg/vector/chroma" ) -var _ = Describe("ChromaDriver", func() { +var _ = Describe("Driver", func() { var logger *zap.Logger BeforeEach(func() { logger = zap.NewNop() }) - Describe("NewChromaDriver", func() { + Describe("NewDriver", func() { It("should return an error when URL is empty", func() { - _, err := chroma.NewChromaDriver(chroma.Config{URL: ""}, logger) + _, err := chroma.NewDriver(chroma.Config{URL: ""}, logger) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("chroma URL is required")) }) @@ -61,7 +61,7 @@ var _ = Describe("ChromaDriver", func() { })) defer server.Close() - driver, err := chroma.NewChromaDriver(chroma.Config{ + driver, err := chroma.NewDriver(chroma.Config{ URL: server.URL, MaxRetries: 5, RetryDelay: 10 * time.Millisecond, @@ -78,7 +78,7 @@ var _ = Describe("ChromaDriver", func() { })) defer server.Close() - _, err := chroma.NewChromaDriver(chroma.Config{ + _, err := chroma.NewDriver(chroma.Config{ URL: server.URL, MaxRetries: 3, RetryDelay: 10 * time.Millisecond, @@ -91,8 +91,8 @@ var _ = Describe("ChromaDriver", func() { Describe("Interface compliance", func() { It("should implement vector.Driver interface", func() { - // Compile-time check that ChromaDriver implements vector.Driver - var _ vector.Driver = (*chroma.ChromaDriver)(nil) + // Compile-time check that Driver implements vector.Driver + var _ vector.Driver = (*chroma.Driver)(nil) }) }) }) diff --git a/pkg/vector/sqlitevec/sqlitevec.go b/pkg/vector/sqlitevec/sqlitevec.go index 5f78f75..d1d42d6 100644 --- a/pkg/vector/sqlitevec/sqlitevec.go +++ b/pkg/vector/sqlitevec/sqlitevec.go @@ -5,19 +5,20 @@ import ( "context" "database/sql" "encoding/binary" + "errors" "fmt" "math" "strings" sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo" - _ "github.com/mattn/go-sqlite3" + _ "github.com/mattn/go-sqlite3" // Register sqlite3 driver "go.uber.org/zap" "github.com/papercomputeco/tapes/pkg/vector" ) -// SQLiteVecDriver implements vector.Driver using SQLite with sqlite-vec. -type SQLiteVecDriver struct { +// Driver implements vector.Driver using SQLite with sqlite-vec. +type Driver struct { db *sql.DB logger *zap.Logger } @@ -33,18 +34,18 @@ type Config struct { Dimensions uint } -// NewSQLiteVecDriver creates a new SQLite vector driver backed by sqlite-vec. -func NewSQLiteVecDriver(c Config, logger *zap.Logger) (*SQLiteVecDriver, error) { +// NewDriver creates a new SQLite vector driver backed by sqlite-vec. +func NewDriver(c Config, logger *zap.Logger) (*Driver, error) { // enable connection to have sqlite-vec extension sqlite_vec.Auto() if c.DBPath == "" { - return nil, fmt.Errorf("database path is required") + return nil, errors.New("database path is required") } dimensions := c.Dimensions if dimensions == 0 { - return nil, fmt.Errorf("sqlite-vec embedding dimensions cannot be 0, must be configured") + return nil, errors.New("sqlite-vec embedding dimensions cannot be 0, must be configured") } db, err := sql.Open("sqlite3", c.DBPath) @@ -54,7 +55,7 @@ func NewSQLiteVecDriver(c Config, logger *zap.Logger) (*SQLiteVecDriver, error) // Verify sqlite-vec is loaded var vecVersion string - if err := db.QueryRow("SELECT vec_version()").Scan(&vecVersion); err != nil { + if err := db.QueryRowContext(context.Background(), "SELECT vec_version()").Scan(&vecVersion); err != nil { db.Close() return nil, fmt.Errorf("sqlite-vec not available: %w", err) } @@ -62,7 +63,7 @@ func NewSQLiteVecDriver(c Config, logger *zap.Logger) (*SQLiteVecDriver, error) // Create the document ID mapping table. // vec0 virtual tables use integer rowids, so we need a mapping from // string document IDs to integer rowids. - _, err = db.Exec(` + _, err = db.ExecContext(context.Background(), ` CREATE TABLE IF NOT EXISTS vec_documents ( rowid INTEGER PRIMARY KEY AUTOINCREMENT, doc_id TEXT NOT NULL UNIQUE, @@ -79,7 +80,7 @@ func NewSQLiteVecDriver(c Config, logger *zap.Logger) (*SQLiteVecDriver, error) `CREATE VIRTUAL TABLE IF NOT EXISTS vec_embeddings USING vec0(embedding float[%d])`, dimensions, ) - if _, err := db.Exec(createVec); err != nil { + if _, err := db.ExecContext(context.Background(), createVec); err != nil { db.Close() return nil, fmt.Errorf("creating vec0 table: %w", err) } @@ -90,7 +91,7 @@ func NewSQLiteVecDriver(c Config, logger *zap.Logger) (*SQLiteVecDriver, error) zap.String("vec_version", vecVersion), ) - return &SQLiteVecDriver{ + return &Driver{ db: db, logger: logger, }, nil @@ -98,12 +99,12 @@ func NewSQLiteVecDriver(c Config, logger *zap.Logger) (*SQLiteVecDriver, error) // serializeFloat32 converts a float32 slice to a little-endian byte slice // suitable for sqlite-vec BLOB format. -func serializeFloat32(v []float32) ([]byte, error) { +func serializeFloat32(v []float32) []byte { buf := make([]byte, len(v)*4) for i, f := range v { binary.LittleEndian.PutUint32(buf[i*4:], math.Float32bits(f)) } - return buf, nil + return buf } // deserializeFloat32 converts a little-endian byte slice back to a float32 slice. @@ -120,7 +121,7 @@ func deserializeFloat32(b []byte) ([]float32, error) { // Add stores documents with their embeddings. // If a document with the same ID already exists, it is updated. -func (d *SQLiteVecDriver) Add(ctx context.Context, docs []vector.Document) error { +func (d *Driver) Add(ctx context.Context, docs []vector.Document) error { if len(docs) == 0 { return nil } @@ -129,13 +130,10 @@ func (d *SQLiteVecDriver) Add(ctx context.Context, docs []vector.Document) error if err != nil { return fmt.Errorf("beginning transaction: %w", err) } - defer tx.Rollback() + defer func() { _ = tx.Rollback() }() for _, doc := range docs { - embBlob, err := serializeFloat32(doc.Embedding) - if err != nil { - return fmt.Errorf("serializing embedding for doc %s: %w", doc.ID, err) - } + embBlob := serializeFloat32(doc.Embedding) // Check if document already exists var existingRowID int64 @@ -143,50 +141,27 @@ func (d *SQLiteVecDriver) Add(ctx context.Context, docs []vector.Document) error `SELECT rowid FROM vec_documents WHERE doc_id = ?`, doc.ID, ).Scan(&existingRowID) - switch err { - case nil: - // Document exists — update hash and embedding - if _, err := tx.ExecContext(ctx, - `UPDATE vec_documents SET hash = ? WHERE rowid = ?`, - doc.Hash, existingRowID, - ); err != nil { + switch { + case err == nil: + if _, err := tx.ExecContext(ctx, `UPDATE vec_documents SET hash = ? WHERE rowid = ?`, doc.Hash, existingRowID); err != nil { return fmt.Errorf("updating document %s: %w", doc.ID, err) } - - // Update embedding in vec0 table via DELETE + INSERT - // (vec0 does not support UPDATE) - if _, err := tx.ExecContext(ctx, - `DELETE FROM vec_embeddings WHERE rowid = ?`, existingRowID, - ); err != nil { + if _, err := tx.ExecContext(ctx, `DELETE FROM vec_embeddings WHERE rowid = ?`, existingRowID); err != nil { return fmt.Errorf("deleting old embedding for doc %s: %w", doc.ID, err) } - - if _, err := tx.ExecContext(ctx, - `INSERT INTO vec_embeddings(rowid, embedding) VALUES (?, ?)`, - existingRowID, embBlob, - ); err != nil { + if _, err := tx.ExecContext(ctx, `INSERT INTO vec_embeddings(rowid, embedding) VALUES (?, ?)`, existingRowID, embBlob); err != nil { return fmt.Errorf("re-inserting embedding for doc %s: %w", doc.ID, err) } - case sql.ErrNoRows: - // New document — insert into mapping table first to get the rowid - result, err := tx.ExecContext(ctx, - `INSERT INTO vec_documents(doc_id, hash) VALUES (?, ?)`, - doc.ID, doc.Hash, - ) + case errors.Is(err, sql.ErrNoRows): + result, err := tx.ExecContext(ctx, `INSERT INTO vec_documents(doc_id, hash) VALUES (?, ?)`, doc.ID, doc.Hash) if err != nil { return fmt.Errorf("inserting document %s: %w", doc.ID, err) } - rowID, err := result.LastInsertId() if err != nil { return fmt.Errorf("getting rowid for doc %s: %w", doc.ID, err) } - - // Insert embedding into vec0 table with matching rowid - if _, err := tx.ExecContext(ctx, - `INSERT INTO vec_embeddings(rowid, embedding) VALUES (?, ?)`, - rowID, embBlob, - ); err != nil { + if _, err := tx.ExecContext(ctx, `INSERT INTO vec_embeddings(rowid, embedding) VALUES (?, ?)`, rowID, embBlob); err != nil { return fmt.Errorf("inserting embedding for doc %s: %w", doc.ID, err) } default: @@ -206,15 +181,12 @@ func (d *SQLiteVecDriver) Add(ctx context.Context, docs []vector.Document) error } // Query finds the topK most similar documents to the given embedding. -func (d *SQLiteVecDriver) Query(ctx context.Context, embedding []float32, topK int) ([]vector.QueryResult, error) { +func (d *Driver) Query(ctx context.Context, embedding []float32, topK int) ([]vector.QueryResult, error) { if topK <= 0 { topK = 10 } - queryBlob, err := serializeFloat32(embedding) - if err != nil { - return nil, fmt.Errorf("serializing query embedding: %w", err) - } + queryBlob := serializeFloat32(embedding) // Use KNN query via vec0 MATCH, then JOIN back to get doc_id and hash. rows, err := d.db.QueryContext(ctx, ` @@ -263,7 +235,7 @@ func (d *SQLiteVecDriver) Query(ctx context.Context, embedding []float32, topK i } // Get retrieves documents by their IDs. -func (d *SQLiteVecDriver) Get(ctx context.Context, ids []string) ([]vector.Document, error) { +func (d *Driver) Get(ctx context.Context, ids []string) ([]vector.Document, error) { if len(ids) == 0 { return nil, nil } @@ -276,6 +248,7 @@ func (d *SQLiteVecDriver) Get(ctx context.Context, ids []string) ([]vector.Docum args[i] = id } + //nolint:gosec // @jpmcb - TODO: refactor to avoid SQL string formatting; placeholders are safe ?-marks query := fmt.Sprintf(` SELECT d.doc_id, d.hash, d.rowid FROM vec_documents d @@ -333,7 +306,7 @@ func (d *SQLiteVecDriver) Get(ctx context.Context, ids []string) ([]vector.Docum } // Delete removes documents by their IDs. -func (d *SQLiteVecDriver) Delete(ctx context.Context, ids []string) error { +func (d *Driver) Delete(ctx context.Context, ids []string) error { if len(ids) == 0 { return nil } @@ -342,7 +315,7 @@ func (d *SQLiteVecDriver) Delete(ctx context.Context, ids []string) error { if err != nil { return fmt.Errorf("beginning transaction: %w", err) } - defer tx.Rollback() + defer func() { _ = tx.Rollback() }() // Build placeholders for IN clause placeholders := make([]string, len(ids)) @@ -354,6 +327,7 @@ func (d *SQLiteVecDriver) Delete(ctx context.Context, ids []string) error { inClause := strings.Join(placeholders, ",") // First, get the rowids for the documents to delete from vec0 + //nolint:gosec // @jpmcb - TODO: refactor to avoid SQL string formatting; placeholders are safe ?-marks query := fmt.Sprintf( `SELECT rowid FROM vec_documents WHERE doc_id IN (%s)`, inClause, ) @@ -386,6 +360,7 @@ func (d *SQLiteVecDriver) Delete(ctx context.Context, ids []string) error { } // Delete from mapping table + //nolint:gosec // @jpmcb - TODO: refactor to avoid SQL string formatting; placeholders are safe ?-marks deleteQuery := fmt.Sprintf( `DELETE FROM vec_documents WHERE doc_id IN (%s)`, inClause, ) @@ -405,6 +380,6 @@ func (d *SQLiteVecDriver) Delete(ctx context.Context, ids []string) error { } // Close releases resources held by the driver. -func (d *SQLiteVecDriver) Close() error { +func (d *Driver) Close() error { return d.db.Close() } diff --git a/pkg/vector/sqlitevec/sqlitevec_test.go b/pkg/vector/sqlitevec/sqlitevec_test.go index d04f233..31905d2 100644 --- a/pkg/vector/sqlitevec/sqlitevec_test.go +++ b/pkg/vector/sqlitevec/sqlitevec_test.go @@ -11,22 +11,22 @@ import ( "github.com/papercomputeco/tapes/pkg/vector/sqlitevec" ) -var _ = Describe("SQLiteVecDriver", func() { +var _ = Describe("Driver", func() { var logger *zap.Logger BeforeEach(func() { logger = zap.NewNop() }) - Describe("NewSQLiteVecDriver", func() { + Describe("NewDriver", func() { It("should return an error when DBPath is empty", func() { - _, err := sqlitevec.NewSQLiteVecDriver(sqlitevec.Config{DBPath: ""}, logger) + _, err := sqlitevec.NewDriver(sqlitevec.Config{DBPath: ""}, logger) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("database path is required")) }) It("should create a driver with an in-memory database", func() { - driver, err := sqlitevec.NewSQLiteVecDriver(sqlitevec.Config{ + driver, err := sqlitevec.NewDriver(sqlitevec.Config{ DBPath: ":memory:", Dimensions: 4, }, logger) @@ -36,7 +36,7 @@ var _ = Describe("SQLiteVecDriver", func() { }) It("should error when dimension not specified", func() { - _, err := sqlitevec.NewSQLiteVecDriver(sqlitevec.Config{ + _, err := sqlitevec.NewDriver(sqlitevec.Config{ DBPath: ":memory:", }, logger) Expect(err).To(HaveOccurred()) @@ -45,16 +45,16 @@ var _ = Describe("SQLiteVecDriver", func() { Describe("Interface compliance", func() { It("should implement vector.Driver interface", func() { - var _ vector.Driver = (*sqlitevec.SQLiteVecDriver)(nil) + var _ vector.Driver = (*sqlitevec.Driver)(nil) }) }) Describe("Add", func() { - var driver *sqlitevec.SQLiteVecDriver + var driver *sqlitevec.Driver BeforeEach(func() { var err error - driver, err = sqlitevec.NewSQLiteVecDriver(sqlitevec.Config{ + driver, err = sqlitevec.NewDriver(sqlitevec.Config{ DBPath: ":memory:", Dimensions: 4, }, logger) @@ -127,11 +127,11 @@ var _ = Describe("SQLiteVecDriver", func() { }) Describe("Query", func() { - var driver *sqlitevec.SQLiteVecDriver + var driver *sqlitevec.Driver BeforeEach(func() { var err error - driver, err = sqlitevec.NewSQLiteVecDriver(sqlitevec.Config{ + driver, err = sqlitevec.NewDriver(sqlitevec.Config{ DBPath: ":memory:", Dimensions: 4, }, logger) @@ -193,11 +193,11 @@ var _ = Describe("SQLiteVecDriver", func() { }) Describe("Get", func() { - var driver *sqlitevec.SQLiteVecDriver + var driver *sqlitevec.Driver BeforeEach(func() { var err error - driver, err = sqlitevec.NewSQLiteVecDriver(sqlitevec.Config{ + driver, err = sqlitevec.NewDriver(sqlitevec.Config{ DBPath: ":memory:", Dimensions: 4, }, logger) @@ -247,11 +247,11 @@ var _ = Describe("SQLiteVecDriver", func() { }) Describe("Delete", func() { - var driver *sqlitevec.SQLiteVecDriver + var driver *sqlitevec.Driver BeforeEach(func() { var err error - driver, err = sqlitevec.NewSQLiteVecDriver(sqlitevec.Config{ + driver, err = sqlitevec.NewDriver(sqlitevec.Config{ DBPath: ":memory:", Dimensions: 4, }, logger) @@ -321,7 +321,7 @@ var _ = Describe("SQLiteVecDriver", func() { Describe("Close", func() { It("should close the database connection", func() { - driver, err := sqlitevec.NewSQLiteVecDriver(sqlitevec.Config{ + driver, err := sqlitevec.NewDriver(sqlitevec.Config{ DBPath: ":memory:", Dimensions: 4, }, logger) diff --git a/pkg/vector/utils/new.go b/pkg/vector/utils/new.go index 1b85361..e97f745 100644 --- a/pkg/vector/utils/new.go +++ b/pkg/vector/utils/new.go @@ -1,6 +1,7 @@ package vectorutils import ( + "errors" "fmt" "go.uber.org/zap" @@ -30,16 +31,16 @@ func NewVectorDriver(o *NewVectorDriverOpts) (vector.Driver, error) { func newChromaDriver(o *NewVectorDriverOpts) (vector.Driver, error) { if o.Target == "" { - return nil, fmt.Errorf("chroma target URL must be provided") + return nil, errors.New("chroma target URL must be provided") } - return chroma.NewChromaDriver(chroma.Config{ + return chroma.NewDriver(chroma.Config{ URL: o.Target, }, o.Logger) } func newSqliteVecDriver(o *NewVectorDriverOpts) (vector.Driver, error) { - return sqlitevec.NewSQLiteVecDriver(sqlitevec.Config{ + return sqlitevec.NewDriver(sqlitevec.Config{ DBPath: o.Target, Dimensions: o.Dimensions, }, o.Logger) diff --git a/proxy/header/header_test.go b/proxy/header/header_test.go index 709590f..26aef35 100644 --- a/proxy/header/header_test.go +++ b/proxy/header/header_test.go @@ -28,13 +28,13 @@ var _ = Describe("SetUpstreamRequestHeaders", func() { var got http.Header app.Post("/test", func(c *fiber.Ctx) error { - req, _ := http.NewRequest("POST", "http://upstream/test", nil) + req, _ := http.NewRequest(http.MethodPost, "http://upstream/test", nil) hh.SetUpstreamRequestHeaders(c, req) got = req.Header return c.SendStatus(fiber.StatusOK) }) - req := httptest.NewRequest("POST", "/test", nil) + req := httptest.NewRequest(http.MethodPost, "/test", nil) req.Header.Set("Authorization", "Bearer token123") req.Header.Set("Content-Type", "application/json") req.Header.Set("X-Api-Key", "secret") @@ -52,13 +52,13 @@ var _ = Describe("SetUpstreamRequestHeaders", func() { var got http.Header app.Post("/test", func(c *fiber.Ctx) error { - req, _ := http.NewRequest("POST", "http://upstream/test", nil) + req, _ := http.NewRequest(http.MethodPost, "http://upstream/test", nil) hh.SetUpstreamRequestHeaders(c, req) got = req.Header return c.SendStatus(fiber.StatusOK) }) - req := httptest.NewRequest("POST", "/test", nil) + req := httptest.NewRequest(http.MethodPost, "/test", nil) req.Header.Set("Connection", "keep-alive") resp, err := app.Test(req) @@ -72,13 +72,13 @@ var _ = Describe("SetUpstreamRequestHeaders", func() { var got http.Header app.Post("/test", func(c *fiber.Ctx) error { - req, _ := http.NewRequest("POST", "http://upstream/test", nil) + req, _ := http.NewRequest(http.MethodPost, "http://upstream/test", nil) hh.SetUpstreamRequestHeaders(c, req) got = req.Header return c.SendStatus(fiber.StatusOK) }) - req := httptest.NewRequest("POST", "/test", nil) + req := httptest.NewRequest(http.MethodPost, "/test", nil) req.Header.Set("Host", "client.example.com") resp, err := app.Test(req) @@ -92,13 +92,13 @@ var _ = Describe("SetUpstreamRequestHeaders", func() { var got http.Header app.Post("/test", func(c *fiber.Ctx) error { - req, _ := http.NewRequest("POST", "http://upstream/test", nil) + req, _ := http.NewRequest(http.MethodPost, "http://upstream/test", nil) hh.SetUpstreamRequestHeaders(c, req) got = req.Header return c.SendStatus(fiber.StatusOK) }) - req := httptest.NewRequest("POST", "/test", nil) + req := httptest.NewRequest(http.MethodPost, "/test", nil) req.Header.Set("Accept-Encoding", "gzip, deflate, br") req.Header.Set("Authorization", "Bearer token123") @@ -140,7 +140,7 @@ var _ = Describe("SetClientResponseHeaders", func() { return c.SendStatus(fiber.StatusOK) }) - req := httptest.NewRequest("GET", "/test", nil) + req := httptest.NewRequest(http.MethodGet, "/test", nil) resp, err := app.Test(req) Expect(err).NotTo(HaveOccurred()) resp.Body.Close() @@ -161,7 +161,7 @@ var _ = Describe("SetClientResponseHeaders", func() { return c.SendStatus(fiber.StatusOK) }) - req := httptest.NewRequest("GET", "/test", nil) + req := httptest.NewRequest(http.MethodGet, "/test", nil) resp, err := app.Test(req) Expect(err).NotTo(HaveOccurred()) resp.Body.Close() @@ -180,7 +180,7 @@ var _ = Describe("SetClientResponseHeaders", func() { return c.SendStatus(fiber.StatusOK) }) - req := httptest.NewRequest("GET", "/test", nil) + req := httptest.NewRequest(http.MethodGet, "/test", nil) resp, err := app.Test(req) Expect(err).NotTo(HaveOccurred()) resp.Body.Close() @@ -200,7 +200,7 @@ var _ = Describe("SetClientResponseHeaders", func() { return c.SendStatus(fiber.StatusOK) }) - req := httptest.NewRequest("GET", "/test", nil) + req := httptest.NewRequest(http.MethodGet, "/test", nil) resp, err := app.Test(req) Expect(err).NotTo(HaveOccurred()) resp.Body.Close() @@ -222,7 +222,7 @@ var _ = Describe("SetClientResponseHeaders", func() { return c.SendStatus(fiber.StatusOK) }) - req := httptest.NewRequest("GET", "/test", nil) + req := httptest.NewRequest(http.MethodGet, "/test", nil) resp, err := app.Test(req) Expect(err).NotTo(HaveOccurred()) resp.Body.Close() @@ -245,7 +245,7 @@ var _ = Describe("SetClientResponseHeaders", func() { return c.SendStatus(fiber.StatusOK) }) - req := httptest.NewRequest("GET", "/test", nil) + req := httptest.NewRequest(http.MethodGet, "/test", nil) resp, err := app.Test(req) Expect(err).NotTo(HaveOccurred()) resp.Body.Close() diff --git a/proxy/proxy.go b/proxy/proxy.go index e4ac273..52fa2cd 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -56,12 +56,15 @@ func New(config Config, driver storage.Driver, logger *zap.Logger) (*Proxy, erro // Add compression middleware to handle responses app.Use(compress.New()) - wp := worker.NewPool(&worker.Config{ + wp, err := worker.NewPool(&worker.Config{ Driver: driver, VectorDriver: config.VectorDriver, Embedder: config.Embedder, Logger: logger, }) + if err != nil { + return nil, fmt.Errorf("could not create worker pool: %w", err) + } p := &Proxy{ config: config, @@ -234,7 +237,7 @@ func (p *Proxy) handleStreamingProxy(c *fiber.Ctx, path string, body []byte, par // its RequestCtx after the handler returns, but the streaming callback runs // asynchronously in a separate goroutine and needs the upstream connection // to remain open. - httpReq, err := http.NewRequestWithContext(context.Background(), "POST", upstreamURL, bytes.NewReader(body)) + httpReq, err := http.NewRequestWithContext(context.Background(), http.MethodPost, upstreamURL, bytes.NewReader(body)) if err != nil { p.logger.Error("failed to create upstream request", zap.Error(err)) return c.Status(fiber.StatusInternalServerError).JSON(llm.ErrorResponse{Error: "internal error"}) diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 6e3fffc..4e11891 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -45,9 +45,9 @@ func boolPtr(b bool) *bool { return &b } // newTestProxy creates a Proxy pointed at the given upstream URL, // using an in-memory storage driver and the ollama provider. -func newTestProxy(upstreamURL string) (*Proxy, *inmemory.InMemoryDriver) { +func newTestProxy(upstreamURL string) (*Proxy, *inmemory.Driver) { logger, _ := zap.NewDevelopment() - driver := inmemory.NewInMemoryDriver() + driver := inmemory.NewDriver() p, err := New( Config{ @@ -92,7 +92,7 @@ func makeOllamaResponseBody(model, role, content string) []byte { var _ = Describe("Non-Streaming Proxy", func() { var ( p *Proxy - driver *inmemory.InMemoryDriver + driver *inmemory.Driver upstream *httptest.Server ) @@ -121,7 +121,7 @@ var _ = Describe("Non-Streaming Proxy", func() { {Role: "user", Content: "What is 2+2?"}, }, boolPtr(false)) - resp, err := p.server.Test(httptest.NewRequest("POST", "/api/chat", strings.NewReader(string(reqBody)))) + resp, err := p.server.Test(httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(string(reqBody)))) Expect(err).NotTo(HaveOccurred()) defer resp.Body.Close() @@ -137,7 +137,7 @@ var _ = Describe("Non-Streaming Proxy", func() { {Role: "user", Content: "What is 2+2?"}, }, boolPtr(false)) - resp, err := p.server.Test(httptest.NewRequest("POST", "/api/chat", strings.NewReader(string(reqBody)))) + resp, err := p.server.Test(httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(string(reqBody)))) Expect(err).NotTo(HaveOccurred()) defer resp.Body.Close() @@ -150,7 +150,7 @@ var _ = Describe("Non-Streaming Proxy", func() { {Role: "user", Content: "What is 2+2?"}, }, boolPtr(false)) - resp, err := p.server.Test(httptest.NewRequest("POST", "/api/chat", strings.NewReader(string(reqBody)))) + resp, err := p.server.Test(httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(string(reqBody)))) Expect(err).NotTo(HaveOccurred()) resp.Body.Close() @@ -178,7 +178,7 @@ var _ = Describe("Non-Streaming Proxy", func() { {Role: "user", Content: "What is 2+2?"}, }, boolPtr(false)) - resp, err := p.server.Test(httptest.NewRequest("POST", "/api/chat", strings.NewReader(string(reqBody)))) + resp, err := p.server.Test(httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(string(reqBody)))) Expect(err).NotTo(HaveOccurred()) resp.Body.Close() @@ -214,7 +214,7 @@ var _ = Describe("Non-Streaming Proxy", func() { {Role: "user", Content: "hello"}, }, boolPtr(false)) - resp, err := p.server.Test(httptest.NewRequest("POST", "/api/chat", strings.NewReader(string(reqBody)))) + resp, err := p.server.Test(httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(string(reqBody)))) Expect(err).NotTo(HaveOccurred()) defer resp.Body.Close() @@ -230,7 +230,7 @@ var _ = Describe("Non-Streaming Proxy", func() { {Role: "user", Content: "hello"}, }, boolPtr(false)) - resp, err := p.server.Test(httptest.NewRequest("POST", "/api/chat", strings.NewReader(string(reqBody)))) + resp, err := p.server.Test(httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(string(reqBody)))) Expect(err).NotTo(HaveOccurred()) resp.Body.Close() @@ -255,7 +255,7 @@ var _ = Describe("Non-Streaming Proxy", func() { }) It("forwards GET requests transparently without storing", func() { - resp, err := p.server.Test(httptest.NewRequest("GET", "/api/tags", nil)) + resp, err := p.server.Test(httptest.NewRequest(http.MethodGet, "/api/tags", nil)) Expect(err).NotTo(HaveOccurred()) defer resp.Body.Close() @@ -293,7 +293,7 @@ var _ = Describe("Non-Streaming Proxy", func() { {Role: "user", Content: "hello"}, }, boolPtr(false)) - req := httptest.NewRequest("POST", "/api/chat", strings.NewReader(string(reqBody))) + req := httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(string(reqBody))) req.Header.Set("X-Api-Key", "secret-token") req.Header.Set("Content-Type", "application/json") @@ -312,7 +312,7 @@ var _ = Describe("Non-Streaming Proxy", func() { {Role: "user", Content: "hello"}, }, boolPtr(false)) - req := httptest.NewRequest("POST", "/api/chat", strings.NewReader(string(reqBody))) + req := httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(string(reqBody))) req.Header.Set("Accept-Encoding", "gzip, deflate, br") req.Header.Set("Authorization", "Bearer token123") @@ -335,7 +335,7 @@ var _ = Describe("Non-Streaming Proxy", func() { var _ = Describe("Streaming Proxy", func() { var ( p *Proxy - driver *inmemory.InMemoryDriver + driver *inmemory.Driver upstream *httptest.Server ) @@ -375,7 +375,7 @@ var _ = Describe("Streaming Proxy", func() { {Role: "user", Content: "What is 2+2?"}, }, boolPtr(true)) - resp, err := p.server.Test(httptest.NewRequest("POST", "/api/chat", strings.NewReader(string(reqBody))), -1) + resp, err := p.server.Test(httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(string(reqBody))), -1) Expect(err).NotTo(HaveOccurred()) defer resp.Body.Close() @@ -397,7 +397,7 @@ var _ = Describe("Streaming Proxy", func() { {Role: "user", Content: "What is 2+2?"}, }, boolPtr(true)) - resp, err := p.server.Test(httptest.NewRequest("POST", "/api/chat", strings.NewReader(string(reqBody))), -1) + resp, err := p.server.Test(httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(string(reqBody))), -1) Expect(err).NotTo(HaveOccurred()) resp.Body.Close() @@ -434,7 +434,7 @@ var _ = Describe("Streaming Proxy", func() { {Role: "user", Content: "hello"}, }, boolPtr(true)) - resp, err := p.server.Test(httptest.NewRequest("POST", "/api/chat", strings.NewReader(string(reqBody))), -1) + resp, err := p.server.Test(httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(string(reqBody))), -1) Expect(err).NotTo(HaveOccurred()) defer resp.Body.Close() @@ -480,7 +480,7 @@ var _ = Describe("Streaming Proxy", func() { {Role: "user", Content: "What is 2+2?"}, }, boolPtr(true)) - resp, err := p.server.Test(httptest.NewRequest("POST", "/api/chat", strings.NewReader(string(reqBody))), -1) + resp, err := p.server.Test(httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(string(reqBody))), -1) Expect(err).NotTo(HaveOccurred()) resp.Body.Close() @@ -553,7 +553,7 @@ var _ = Describe("Streaming Detection", func() { {Role: "user", Content: "hello"}, }, boolPtr(true)) - resp, err := p.server.Test(httptest.NewRequest("POST", "/api/chat", strings.NewReader(string(reqBody))), -1) + resp, err := p.server.Test(httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(string(reqBody))), -1) Expect(err).NotTo(HaveOccurred()) resp.Body.Close() @@ -574,7 +574,7 @@ var _ = Describe("Streaming Detection", func() { {Role: "user", Content: "hello"}, }, boolPtr(false)) - resp, err := p.server.Test(httptest.NewRequest("POST", "/api/chat", strings.NewReader(string(reqBody)))) + resp, err := p.server.Test(httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(string(reqBody)))) Expect(err).NotTo(HaveOccurred()) resp.Body.Close() @@ -595,7 +595,7 @@ var _ = Describe("Streaming Detection", func() { {Role: "user", Content: "hello"}, }, nil) - resp, err := p.server.Test(httptest.NewRequest("POST", "/api/chat", strings.NewReader(string(reqBody))), -1) + resp, err := p.server.Test(httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(string(reqBody))), -1) Expect(err).NotTo(HaveOccurred()) resp.Body.Close() @@ -683,7 +683,7 @@ var _ = Describe("reconstructStreamedResponse", func() { var _ = Describe("New", func() { It("returns an error for unrecognized provider type", func() { logger, _ := zap.NewDevelopment() - driver := inmemory.NewInMemoryDriver() + driver := inmemory.NewDriver() _, err := New(Config{ ListenAddr: ":0", @@ -696,7 +696,7 @@ var _ = Describe("New", func() { It("creates a proxy with a valid provider type", func() { logger, _ := zap.NewDevelopment() - driver := inmemory.NewInMemoryDriver() + driver := inmemory.NewDriver() p, err := New(Config{ ListenAddr: ":0", @@ -713,7 +713,7 @@ var _ = Describe("New", func() { var _ = Describe("End-to-End Multi-Turn Proxy", func() { var ( p *Proxy - driver *inmemory.InMemoryDriver + driver *inmemory.Driver upstream *httptest.Server turnNum int ) @@ -749,7 +749,7 @@ var _ = Describe("End-to-End Multi-Turn Proxy", func() { {Role: "user", Content: "What is 2+2?"}, }, boolPtr(false)) - resp1, err := p.server.Test(httptest.NewRequest("POST", "/api/chat", strings.NewReader(string(reqBody1)))) + resp1, err := p.server.Test(httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(string(reqBody1)))) Expect(err).NotTo(HaveOccurred()) resp1.Body.Close() @@ -761,7 +761,7 @@ var _ = Describe("End-to-End Multi-Turn Proxy", func() { {Role: "user", Content: "And what is 3+3?"}, }, boolPtr(false)) - resp2, err := p.server.Test(httptest.NewRequest("POST", "/api/chat", strings.NewReader(string(reqBody2)))) + resp2, err := p.server.Test(httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(string(reqBody2)))) Expect(err).NotTo(HaveOccurred()) resp2.Body.Close() @@ -801,7 +801,7 @@ var _ = Describe("End-to-End Multi-Turn Proxy", func() { var _ = Describe("Storage Provider Metadata", func() { var ( p *Proxy - driver *inmemory.InMemoryDriver + driver *inmemory.Driver upstream *httptest.Server ) @@ -826,7 +826,7 @@ var _ = Describe("Storage Provider Metadata", func() { {Role: "user", Content: "hi"}, }, boolPtr(false)) - resp, err := p.server.Test(httptest.NewRequest("POST", "/api/chat", strings.NewReader(string(reqBody)))) + resp, err := p.server.Test(httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(string(reqBody)))) Expect(err).NotTo(HaveOccurred()) resp.Body.Close() @@ -848,7 +848,7 @@ var _ = Describe("Storage Provider Metadata", func() { {Role: "user", Content: "hi"}, }, boolPtr(false)) - resp, err := p.server.Test(httptest.NewRequest("POST", "/api/chat", strings.NewReader(string(reqBody)))) + resp, err := p.server.Test(httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(string(reqBody)))) Expect(err).NotTo(HaveOccurred()) resp.Body.Close() @@ -870,7 +870,7 @@ var _ = Describe("Storage Provider Metadata", func() { {Role: "user", Content: "hi"}, }, boolPtr(false)) - resp, err := p.server.Test(httptest.NewRequest("POST", "/api/chat", strings.NewReader(string(reqBody)))) + resp, err := p.server.Test(httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(string(reqBody)))) Expect(err).NotTo(HaveOccurred()) resp.Body.Close() diff --git a/proxy/worker/pool.go b/proxy/worker/pool.go index 6f42c21..9545bd8 100644 --- a/proxy/worker/pool.go +++ b/proxy/worker/pool.go @@ -9,6 +9,7 @@ package worker import ( "context" "fmt" + "math" "sync" "go.uber.org/zap" @@ -63,7 +64,7 @@ type Pool struct { } // NewPool creates a new Storer and starts its worker goroutines. -func NewPool(c *Config) *Pool { +func NewPool(c *Config) (*Pool, error) { if c.NumWorkers == 0 { c.NumWorkers = defaultNumWorkers } @@ -72,6 +73,10 @@ func NewPool(c *Config) *Pool { c.QueueSize = defaultJobQueueSize } + if c.NumWorkers > uint(math.MaxInt) { + return nil, fmt.Errorf("NumWorkers %d exceeds max int", c.NumWorkers) + } + wp := &Pool{ config: c, queue: make(chan Job, c.QueueSize), @@ -83,7 +88,7 @@ func NewPool(c *Config) *Pool { go wp.worker(i) } - return wp + return wp, nil } // Enqueue submits a job for processing by the worker pool. diff --git a/proxy/worker/pool_test.go b/proxy/worker/pool_test.go index 5a0cdcd..cd493fa 100644 --- a/proxy/worker/pool_test.go +++ b/proxy/worker/pool_test.go @@ -13,14 +13,15 @@ import ( // newTestPool creates a worker pool backed by an in-memory driver. // Callers should "wp.Close()" to drain enqueued jobs before asserting storage state. -func newTestPool() (*Pool, *inmemory.InMemoryDriver) { +func newTestPool() (*Pool, *inmemory.Driver) { logger, _ := zap.NewDevelopment() - driver := inmemory.NewInMemoryDriver() + driver := inmemory.NewDriver() - wp := NewPool(&Config{ + wp, err := NewPool(&Config{ Driver: driver, Logger: logger, }) + Expect(err).NotTo(HaveOccurred()) return wp, driver } @@ -28,7 +29,7 @@ func newTestPool() (*Pool, *inmemory.InMemoryDriver) { var _ = Describe("Worker Pool", func() { var ( wp *Pool - driver *inmemory.InMemoryDriver + driver *inmemory.Driver ctx context.Context ) From f9c9c4ebd3bced85939d191526e2a7fb835418e7 Mon Sep 17 00:00:00 2001 From: John McBride Date: Mon, 2 Feb 2026 22:39:55 -0500 Subject: [PATCH 2/2] fix: Lint checks for new "deck" TUI command Signed-off-by: John McBride --- .github/workflows/ci.yaml | 12 +- .golangci.yml | 5 +- api/mcp/search.go | 5 +- api/search/search.go | 54 +++++--- api/search/search_test.go | 20 +-- api/search_handler.go | 5 +- cmd/tapes/deck/deck.go | 11 +- cmd/tapes/deck/tui.go | 198 ++++++++---------------------- cmd/tapes/deck/tui_test.go | 4 +- cmd/tapes/deck/web.go | 16 ++- cmd/tapes/search/search.go | 3 +- cmd/tapes/serve/api/api.go | 5 +- cmd/tapes/serve/proxy/proxy.go | 3 +- cmd/tapes/serve/serve.go | 5 +- makefile | 12 +- pkg/deck/pricing.go | 5 +- pkg/deck/query.go | 32 +++-- pkg/deck/types.go | 2 +- pkg/dotdir/checkout.go | 2 +- pkg/storage/sqlite/sqlite.go | 8 +- pkg/storage/sqlite/sqlite_test.go | 4 +- pkg/vector/sqlitevec/sqlitevec.go | 2 +- 22 files changed, 179 insertions(+), 234 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 68fc34a..549d52b 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -41,15 +41,17 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 - - name: Daggerverse golangcilint + - name: Install Dagger uses: dagger/dagger-for-github@v8.2.0 - env: - DAGGER_CLOUD_TOKEN: ${{ secrets.DAGGER_CLOUD_TOKEN }} with: - verb: call - args: "-m github.com/papercomputeco/daggerverse/golangcilint check" + verb: version version: ${{ env.DAGGER_VERSION }} + - name: Run lint check + env: + DAGGER_CLOUD_TOKEN: ${{ secrets.DAGGER_CLOUD_TOKEN }} + run: make check + build: name: Build runs-on: ubuntu-latest diff --git a/.golangci.yml b/.golangci.yml index 3130ac2..ea7b44c 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -57,8 +57,9 @@ linters: - nosprintfhostport # host:port via net.JoinHostPort disable: - - gochecknoglobals # forbid all globals - - gochecknoinits # forbid all init() + # forbid all init() - we do this when the tuis come up to force color alignments. + # TODO @jpmcb - we should refactor this to avoid using an "init()" + - gochecknoinits exclusions: presets: diff --git a/api/mcp/search.go b/api/mcp/search.go index 0138eeb..28f4460 100644 --- a/api/mcp/search.go +++ b/api/mcp/search.go @@ -25,15 +25,14 @@ type SearchInput struct { // handleSearch processes a search request via MCP. // It delegates to the shared search package for the core search logic. func (s *Server) handleSearch(ctx context.Context, _ *mcp.CallToolRequest, input SearchInput) (*mcp.CallToolResult, apisearch.Output, error) { - output, err := apisearch.Search( + searcher := apisearch.NewSearcher( ctx, - input.Query, - input.TopK, s.config.Embedder, s.config.VectorDriver, s.config.DagLoader, s.config.Logger, ) + output, err := searcher.Search(input.Query, input.TopK) if err != nil { return &mcp.CallToolResult{ IsError: true, diff --git a/api/search/search.go b/api/search/search.go index 9d97be6..9f2035c 100644 --- a/api/search/search.go +++ b/api/search/search.go @@ -45,35 +45,55 @@ type Output struct { Count int `json:"count"` } -// Search performs a semantic search over stored LLM sessions. -// It embeds the query text, queries the vector store for similar documents, -// then loads the full conversation branch from the Merkle DAG for each result. -func Search( +type Searcher struct { + ctx context.Context + + embedder embeddings.Embedder + vectorDriver vector.Driver + dagLoader merkle.DagLoader + logger *zap.Logger +} + +func NewSearcher( ctx context.Context, - query string, - topK int, embedder embeddings.Embedder, vectorDriver vector.Driver, dagLoader merkle.DagLoader, logger *zap.Logger, +) *Searcher { + return &Searcher{ + ctx, + embedder, + vectorDriver, + dagLoader, + logger, + } +} + +// Search performs a semantic search over stored LLM sessions. +// It embeds the query text, queries the vector store for similar documents, +// then loads the full conversation branch from the Merkle DAG for each result. +func (s *Searcher) Search( + query string, + topK int, ) (*Output, error) { if topK <= 0 { topK = 5 } - logger.Debug("search request", + s.logger.Debug("search request", zap.String("query", query), zap.Int("topK", topK), ) // Embed the query - queryEmbedding, err := embedder.Embed(ctx, query) + queryEmbedding, err := s.embedder.Embed(s.ctx, query) if err != nil { return nil, fmt.Errorf("failed to embed query: %w", err) } // Query the vector store - results, err := vectorDriver.Query(ctx, queryEmbedding, topK) + results, err := s.vectorDriver.Query(s.ctx, queryEmbedding, topK) if err != nil { return nil, fmt.Errorf("failed to query vector store: %w", err) } @@ -81,16 +101,16 @@ func Search( // Build search results with full branch using merkle.LoadDag searchResults := make([]Result, 0, len(results)) for _, result := range results { - dag, err := merkle.LoadDag(ctx, dagLoader, result.Hash) + dag, err := merkle.LoadDag(s.ctx, s.dagLoader, result.Hash) if err != nil { - logger.Warn("failed to load branch for result", + s.logger.Warn("failed to load branch for result", zap.String("hash", result.Hash), zap.Error(err), ) continue } - searchResult := BuildResult(result, dag) + searchResult := s.BuildResult(result, dag) searchResults = append(searchResults, searchResult) } @@ -102,13 +122,13 @@ func Search( } // BuildResult converts a vector query result and DAG into a Result. -func BuildResult(result vector.QueryResult, dag *merkle.Dag) Result { +func (s *Searcher) BuildResult(result vector.QueryResult, dag *merkle.Dag) Result { turns := []Turn{} preview := "" role := "" // Build turns from the DAG using Walk (depth-first from root to leaves) - _ = dag.Walk(func(node *merkle.DagNode) (bool, error) { + err := dag.Walk(func(node *merkle.DagNode) (bool, error) { isMatched := node.Hash == result.Hash turns = append(turns, Turn{ Hash: node.Hash, @@ -124,6 +144,12 @@ func BuildResult(result vector.QueryResult, dag *merkle.Dag) Result { } return true, nil }) + if err != nil { + s.logger.Error( + "could not walk graph during search", + zap.Error(err), + ) + } return Result{ Hash: result.Hash, diff --git a/api/search/search_test.go b/api/search/search_test.go index e1140f1..013d564 100644 --- a/api/search/search_test.go +++ b/api/search/search_test.go @@ -27,6 +27,7 @@ var _ = Describe("Search", func() { embedder *testutils.MockEmbedder logger *zap.Logger ctx context.Context + searcher *search.Searcher ) BeforeEach(func() { @@ -35,11 +36,12 @@ var _ = Describe("Search", func() { vectorDriver = testutils.NewMockVectorDriver() embedder = testutils.NewMockEmbedder() ctx = context.Background() + searcher = search.NewSearcher(ctx, embedder, vectorDriver, driver, logger) }) Describe("Search function", func() { It("returns empty results when vector store has no matches", func() { - output, err := search.Search(ctx, "hello", 5, embedder, vectorDriver, driver, logger) + output, err := searcher.Search("hello", 5) Expect(err).NotTo(HaveOccurred()) Expect(output.Query).To(Equal("hello")) Expect(output.Count).To(Equal(0)) @@ -65,7 +67,7 @@ var _ = Describe("Search", func() { }, } - output, err := search.Search(ctx, "greeting", 5, embedder, vectorDriver, driver, logger) + output, err := searcher.Search("greeting", 5) Expect(err).NotTo(HaveOccurred()) Expect(output.Query).To(Equal("greeting")) Expect(output.Count).To(Equal(1)) @@ -76,21 +78,21 @@ var _ = Describe("Search", func() { }) It("defaults topK to 5 when zero", func() { - output, err := search.Search(ctx, "test", 0, embedder, vectorDriver, driver, logger) + output, err := searcher.Search("test", 0) Expect(err).NotTo(HaveOccurred()) Expect(output).NotTo(BeNil()) }) It("returns an error when embedding fails", func() { embedder.FailOn = "fail-query" - _, err := search.Search(ctx, "fail-query", 5, embedder, vectorDriver, driver, logger) + _, err := searcher.Search("fail-query", 5) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("failed to embed query")) }) It("returns an error when vector query fails", func() { vectorDriver.FailQuery = true - _, err := search.Search(ctx, "test", 5, embedder, vectorDriver, driver, logger) + _, err := searcher.Search("test", 5) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("failed to query vector store")) }) @@ -107,7 +109,7 @@ var _ = Describe("Search", func() { }, } - output, err := search.Search(ctx, "test", 5, embedder, vectorDriver, driver, logger) + output, err := searcher.Search("test", 5) Expect(err).NotTo(HaveOccurred()) Expect(output.Count).To(Equal(0)) }) @@ -130,7 +132,7 @@ var _ = Describe("Search", func() { dag, err := merkle.LoadDag(ctx, driver, node.Hash) Expect(err).NotTo(HaveOccurred()) - searchResult := search.BuildResult(result, dag) + searchResult := searcher.BuildResult(result, dag) Expect(searchResult.Hash).To(Equal(node.Hash)) Expect(searchResult.Score).To(Equal(float32(0.95))) @@ -163,7 +165,7 @@ var _ = Describe("Search", func() { dag, err := merkle.LoadDag(ctx, driver, node3.Hash) Expect(err).NotTo(HaveOccurred()) - searchResult := search.BuildResult(result, dag) + searchResult := searcher.BuildResult(result, dag) Expect(searchResult.Hash).To(Equal(node3.Hash)) Expect(searchResult.Turns).To(Equal(3)) @@ -190,7 +192,7 @@ var _ = Describe("Search", func() { } dag := merkle.NewDag() - searchResult := search.BuildResult(result, dag) + searchResult := searcher.BuildResult(result, dag) Expect(searchResult.Hash).To(Equal("empty-hash")) Expect(searchResult.Turns).To(Equal(0)) diff --git a/api/search_handler.go b/api/search_handler.go index 5f886f5..268494e 100644 --- a/api/search_handler.go +++ b/api/search_handler.go @@ -39,15 +39,14 @@ func (s *Server) handleSearchEndpoint(c *fiber.Ctx) error { topK = parsed } - output, err := apisearch.Search( + searcher := apisearch.NewSearcher( c.Context(), - query, - topK, s.config.Embedder, s.config.VectorDriver, s.dagLoader, s.logger, ) + output, err := searcher.Search(query, topK) if err != nil { return c.Status(fiber.StatusInternalServerError).JSON(llm.ErrorResponse{ Error: err.Error(), diff --git a/cmd/tapes/deck/deck.go b/cmd/tapes/deck/deck.go index 955b5b0..b261efb 100644 --- a/cmd/tapes/deck/deck.go +++ b/cmd/tapes/deck/deck.go @@ -3,6 +3,7 @@ package deckcmder import ( "context" + "errors" "fmt" "os" "path/filepath" @@ -84,11 +85,11 @@ func (c *deckCommander) run(ctx context.Context) error { return err } - query, closeFn, err := deck.NewQuery(sqlitePath, pricing) + query, closeFn, err := deck.NewQuery(ctx, sqlitePath, pricing) if err != nil { return err } - defer closeFn() + defer func() { _ = closeFn() }() filters, err := c.parseFilters() if err != nil { @@ -140,7 +141,7 @@ func (c *deckCommander) parseFilters() (deck.Filters, error) { func parseTime(value string) (time.Time, error) { value = strings.TrimSpace(value) if value == "" { - return time.Time{}, fmt.Errorf("empty time") + return time.Time{}, errors.New("empty time") } if parsed, err := time.Parse(time.RFC3339, value); err == nil { @@ -151,7 +152,7 @@ func parseTime(value string) (time.Time, error) { return parsed, nil } - return time.Time{}, fmt.Errorf("expected RFC3339 or YYYY-MM-DD") + return time.Time{}, errors.New("expected RFC3339 or YYYY-MM-DD") } func resolveSQLitePath(override string) (string, error) { @@ -194,5 +195,5 @@ func resolveSQLitePath(override string) (string, error) { } } - return "", fmt.Errorf("could not find tapes SQLite database; pass --sqlite") + return "", errors.New("could not find tapes SQLite database; pass --sqlite") } diff --git a/cmd/tapes/deck/tui.go b/cmd/tapes/deck/tui.go index d411653..5fb3275 100644 --- a/cmd/tapes/deck/tui.go +++ b/cmd/tapes/deck/tui.go @@ -36,7 +36,7 @@ const ( type deckModel struct { query *deck.Query filters deck.Filters - overview *deck.DeckOverview + overview *deck.Overview detail *deck.SessionDetail view deckView cursor int @@ -70,9 +70,11 @@ var ( deckRoleAsstStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("220")) ) -var sortOrder = []string{"cost", "time", "tokens", "duration"} -var messageSortOrder = []string{"time", "tokens", "cost", "delta"} -var statusFilters = []string{"", deck.StatusCompleted, deck.StatusFailed, deck.StatusAbandoned} +var ( + sortOrder = []string{"cost", "time", "tokens", "duration"} + messageSortOrder = []string{"time", "tokens", "cost", "delta"} + statusFilters = []string{"", deck.StatusCompleted, deck.StatusFailed, deck.StatusAbandoned} +) type deckKeyMap struct { Up key.Binding @@ -114,7 +116,7 @@ type sessionLoadedMsg struct { } type overviewLoadedMsg struct { - overview *deck.DeckOverview + overview *deck.Overview err error } @@ -145,9 +147,9 @@ func runDeckTUI(ctx context.Context, query *deck.Query, filters deck.Filters) er return err } -func newDeckModel(query *deck.Query, filters deck.Filters, overview *deck.DeckOverview) deckModel { +func newDeckModel(query *deck.Query, filters deck.Filters, overview *deck.Overview) deckModel { toggles := map[int]bool{} - for i := 0; i < 8; i++ { + for i := range 8 { toggles[i] = true } @@ -195,7 +197,7 @@ func (m deckModel) Update(msg bubbletea.Msg) (bubbletea.Model, bubbletea.Cmd) { } m.overview = msg.overview if m.cursor >= len(m.overview.Sessions) { - m.cursor = clamp(m.cursor, 0, len(m.overview.Sessions)-1) + m.cursor = clamp(m.cursor, len(m.overview.Sessions)-1) } return m, nil case sessionLoadedMsg: @@ -231,11 +233,12 @@ func (m deckModel) Update(msg bubbletea.Msg) (bubbletea.Model, bubbletea.Cmd) { func (m deckModel) View() string { switch m.view { + case viewOverview: + return m.viewOverview() case viewSession: return m.viewSession() - default: - return m.viewOverview() } + return m.viewOverview() } func (m deckModel) handleKey(msg bubbletea.KeyMsg) (bubbletea.Model, bubbletea.Cmd) { @@ -301,14 +304,14 @@ func (m deckModel) moveCursor(delta int) (bubbletea.Model, bubbletea.Cmd) { } // Limit cursor to first 8 sessions maxIdx := min(7, len(m.overview.Sessions)-1) - m.cursor = clamp(m.cursor+delta, 0, maxIdx) + m.cursor = clamp(m.cursor+delta, maxIdx) return m, nil } if m.detail == nil || len(m.detail.Messages) == 0 { return m, nil } - m.messageCursor = clamp(m.messageCursor+delta, 0, len(m.detail.Messages)-1) + m.messageCursor = clamp(m.messageCursor+delta, len(m.detail.Messages)-1) return m, nil } @@ -339,7 +342,7 @@ func (m deckModel) cycleMessageSort() (bubbletea.Model, bubbletea.Cmd) { m.messageCursor = 0 return m, nil } - m.messageCursor = clamp(m.messageCursor, 0, len(m.sortedMessages())-1) + m.messageCursor = clamp(m.messageCursor, len(m.sortedMessages())-1) return m, nil } @@ -362,7 +365,8 @@ func (m deckModel) viewOverview() string { headerLeft := deckTitleStyle.Render("tapes deck") headerRight := deckMutedStyle.Render(m.headerSessionCount(lastWindow, len(selected), len(m.overview.Sessions), filtered)) header := renderHeaderLine(m.width, headerLeft, headerRight) - lines := []string{header, renderRule(m.width), ""} + lines := make([]string, 0, 10) + lines = append(lines, header, renderRule(m.width), "") lines = append(lines, m.viewMetrics(stats)) lines = append(lines, "", m.viewCostByModel(stats), "", m.viewSessionList(), "", m.viewFooter()) @@ -380,13 +384,13 @@ func (m deckModel) viewMetrics(stats deckOverviewStats) string { formatCost(stats.TotalCost), fmt.Sprintf("%s in %s out", formatTokens(stats.InputTokens), formatTokens(stats.OutputTokens)), formatDuration(stats.TotalDuration), - fmt.Sprintf("%d", stats.TotalToolCalls), + strconv.Itoa(stats.TotalToolCalls), formatPercent(stats.SuccessRate), } avgValues := []string{ - fmt.Sprintf("%s avg", formatCost(avgCost)), + formatCost(avgCost) + " avg", fmt.Sprintf("%s in %s out", formatTokens(avgTokenCount(stats.InputTokens, stats.TotalSessions)), formatTokens(avgTokenCount(stats.OutputTokens, stats.TotalSessions))), - fmt.Sprintf("%s avg", formatDuration(avgTime)), + formatDuration(avgTime) + " avg", fmt.Sprintf("%d avg", avgTools), fmt.Sprintf("%d/%d", stats.Completed, stats.TotalSessions), } @@ -428,10 +432,7 @@ func (m deckModel) viewSessionList() string { } // Limit to 8 visible sessions - maxVisible := 8 - if len(m.overview.Sessions) < maxVisible { - maxVisible = len(m.overview.Sessions) - } + maxVisible := min(len(m.overview.Sessions), 8) status := m.filters.Status if status == "" { @@ -439,16 +440,16 @@ func (m deckModel) viewSessionList() string { } lines := []string{deckSectionStyle.Render(fmt.Sprintf("sessions (sort: %s, status: %s)", m.filters.Sort, status)), renderRule(m.width)} lines = append(lines, deckMutedStyle.Render(" label model dur tokens cost tools msgs status")) - for i := 0; i < maxVisible; i++ { + for i := range maxVisible { session := m.overview.Sessions[i] cursor := " " if i == m.cursor { cursor = ">" } - toggle := " " + var toggle string if m.trackToggles[i] { - toggle = fmt.Sprintf("%d", i+1) + toggle = strconv.Itoa(i + 1) } else { toggle = "-" } @@ -488,10 +489,11 @@ func (m deckModel) viewSession() string { statusStyle := statusStyleFor(m.detail.Summary.Status) statusDot := statusStyle.Render("●") - headerLeft := deckTitleStyle.Render(fmt.Sprintf("⏏ tapes deck › %s", m.detail.Summary.Label)) + headerLeft := deckTitleStyle.Render("⏏ tapes deck › " + m.detail.Summary.Label) headerRight := deckMutedStyle.Render(fmt.Sprintf("%s · %s %s", m.detail.Summary.ID, statusDot, m.detail.Summary.Status)) header := renderHeaderLine(m.width, headerLeft, headerRight) - lines := []string{header, renderRule(m.width), ""} + lines := make([]string, 0, 20) + lines = append(lines, header, renderRule(m.width), "") lines = append(lines, deckSectionStyle.Render("session"), renderRule(m.width)) lines = append(lines, deckMutedStyle.Render("MODEL DURATION INPUT COST OUTPUT COST TOTAL")) @@ -506,8 +508,8 @@ func (m deckModel) viewSession() string { lines = append(lines, deckMutedStyle.Render(fmt.Sprintf("%-15s %-15s %-14s %-14s", "", "", - fmt.Sprintf("%s tokens", formatTokens(m.detail.Summary.InputTokens)), - fmt.Sprintf("%s tokens", formatTokens(m.detail.Summary.OutputTokens)), + formatTokens(m.detail.Summary.InputTokens)+" tokens", + formatTokens(m.detail.Summary.OutputTokens)+" tokens", ))) inputRate, outputRate := costPerMTok(m.detail.Summary.InputCost, m.detail.Summary.InputTokens), costPerMTok(m.detail.Summary.OutputCost, m.detail.Summary.OutputTokens) @@ -538,15 +540,9 @@ func (m deckModel) viewSession() string { screenHeight = 40 } footerHeight := 2 - remaining := screenHeight - len(lines) - footerHeight - if remaining < 8 { - remaining = 8 - } + remaining := max(screenHeight-len(lines)-footerHeight, 8) gap := 3 - leftWidth := (m.width - gap) * 2 / 3 - if leftWidth < 30 { - leftWidth = 30 - } + leftWidth := max((m.width-gap)*2/3, 30) rightWidth := m.width - gap - leftWidth if rightWidth < 24 { rightWidth = 24 @@ -612,12 +608,12 @@ func numberKey(key string) (int, bool) { } } -func clamp(value, min, max int) int { - if value < min { - return min +func clamp(value, upper int) int { + if value < 0 { + return 0 } - if value > max { - return max + if value > upper { + return upper } return value } @@ -633,7 +629,7 @@ func formatTokens(value int64) string { if value >= 1_000 { return fmt.Sprintf("%.1fK", float64(value)/1_000.0) } - return fmt.Sprintf("%d", value) + return strconv.FormatInt(value, 10) } func formatDuration(value time.Duration) string { @@ -644,7 +640,7 @@ func formatDuration(value time.Duration) string { minutes := int(value.Minutes()) seconds := int(value.Seconds()) % 60 hours := minutes / 60 - minutes = minutes % 60 + minutes %= 60 if hours > 0 { return fmt.Sprintf("%dh%dm", hours, minutes) } @@ -668,18 +664,12 @@ func truncateText(value string, limit int) string { return value[:limit-3] + "..." } -func renderBar(value, max float64, width int) string { - if max <= 0 { +func renderBar(value, ceiling float64, width int) string { + if ceiling <= 0 { return strings.Repeat("░", width) } - ratio := value / max - filled := int(ratio * float64(width)) - if filled < 0 { - filled = 0 - } - if filled > width { - filled = width - } + ratio := value / ceiling + filled := min(max(int(ratio*float64(width)), 0), width) return strings.Repeat("█", filled) + strings.Repeat("░", width-filled) } @@ -715,10 +705,7 @@ func renderMetricRow(width int, items []string, style lipgloss.Style) string { } cols := len(items) spaceWidth := (cols - 1) * 2 - colWidth := (lineWidth - spaceWidth) / cols - if colWidth < 12 { - colWidth = 12 - } + colWidth := max((lineWidth-spaceWidth)/cols, 12) parts := make([]string, 0, len(items)) for _, item := range items { parts = append(parts, style.Render(fitCell(item, colWidth))) @@ -768,13 +755,7 @@ func renderSplitBar(label string, percent float64, width int) string { if width <= 0 { width = 24 } - filled := int((percent / 100) * float64(width)) - if filled < 0 { - filled = 0 - } - if filled > width { - filled = width - } + filled := min(max(int((percent/100)*float64(width)), 0), width) bar := strings.Repeat("█", filled) + strings.Repeat("░", width-filled) return fmt.Sprintf("%s %s %2.0f%%", label, bar, percent) } @@ -801,10 +782,7 @@ func (m deckModel) renderTimelineBlock(width, height int) []string { if height < 3 { height = 3 } - maxVisible := height - 1 - if maxVisible < 1 { - maxVisible = 1 - } + maxVisible := max(height-1, 1) messages := m.sortedMessages() start, end := visibleRange(len(messages), m.messageCursor, maxVisible) @@ -856,7 +834,7 @@ func (m deckModel) renderDetailBlock(width, height int) []string { role := roleLabel(msg.Role) lines = append(lines, - fmt.Sprintf("role: %s", role), + "role: "+role, fmt.Sprintf("time: %s delta: %s", msg.Timestamp.Format("15:04:05"), formatDelta(msg.Delta)), fmt.Sprintf("tokens: in %s out %s total %s", formatTokensDetail(msg.InputTokens), formatTokensDetail(msg.OutputTokens), formatTokensDetail(msg.TotalTokens)), fmt.Sprintf("cost: in %s out %s total %s", formatCost(msg.InputCost), formatCost(msg.OutputCost), formatCost(msg.TotalCost)), @@ -892,10 +870,7 @@ func renderTimelineLine(width int, cursor, timestamp, role, tokens, cost, tools, costWidth := 7 deltaWidth := 6 baseWidth := cursorWidth + timeWidth + roleWidth + tokensWidth + costWidth + deltaWidth + gap*6 - toolWidth := lineWidth - baseWidth - if toolWidth < 0 { - toolWidth = 0 - } + toolWidth := max(lineWidth-baseWidth, 0) columns := []string{ fitCell(cursor, cursorWidth), @@ -1060,13 +1035,10 @@ func padRight(value string, width int) string { } func joinColumns(left, right []string, gap int) []string { - maxLines := len(left) - if len(right) > maxLines { - maxLines = len(right) - } + maxLines := max(len(right), len(left)) lines := make([]string, 0, maxLines) gapSpace := strings.Repeat(" ", gap) - for i := 0; i < maxLines; i++ { + for i := range maxLines { leftLine := "" if i < len(left) { leftLine = left[i] @@ -1093,17 +1065,11 @@ func visibleRange(total, cursor, size int) (int, int) { if cursor >= total { cursor = total - 1 } - start := cursor - (size / 2) - if start < 0 { - start = 0 - } + start := max(cursor-(size/2), 0) end := start + size if end > total { end = total - start = end - size - if start < 0 { - start = 0 - } + start = max(end-size, 0) } return start, end } @@ -1139,9 +1105,9 @@ func formatDelta(value time.Duration) string { func formatTokensDetail(value int64) string { if value < 10_000 { - return fmt.Sprintf("%s tok", formatInt(value)) + return formatInt(value) + " tok" } - return fmt.Sprintf("%s tok", formatTokens(value)) + return formatTokens(value) + " tok" } func formatInt(value int64) string { @@ -1216,59 +1182,3 @@ func wrapText(text string, width int) []string { } return lines } - -func (m deckModel) viewToolFrequency() string { - if m.detail == nil || len(m.detail.ToolFrequency) == 0 { - return deckMutedStyle.Render("no tools recorded") - } - - items := make([]toolCount, 0, len(m.detail.ToolFrequency)) - maxCount := 0 - for tool, count := range m.detail.ToolFrequency { - items = append(items, toolCount{name: tool, count: count}) - if count > maxCount { - maxCount = count - } - } - - sort.Slice(items, func(i, j int) bool { - if items[i].count == items[j].count { - return items[i].name < items[j].name - } - return items[i].count > items[j].count - }) - - maxVisible := 6 - if len(items) < maxVisible { - maxVisible = len(items) - } - - lines := make([]string, 0, maxVisible) - for i := 0; i < maxVisible; i++ { - item := items[i] - bar := renderBar(float64(item.count), float64(maxCount), 24) - line := fmt.Sprintf("* %-16s %s %d", item.name, deckAccentStyle.Render(bar), item.count) - lines = append(lines, line) - } - - return strings.Join(lines, "\n") -} - -type toolCount struct { - name string - count int -} - -func max(a, b int) int { - if a > b { - return a - } - return b -} - -func min(a, b int) int { - if a < b { - return a - } - return b -} diff --git a/cmd/tapes/deck/tui_test.go b/cmd/tapes/deck/tui_test.go index 8222b1d..8019eb4 100644 --- a/cmd/tapes/deck/tui_test.go +++ b/cmd/tapes/deck/tui_test.go @@ -75,7 +75,7 @@ var _ = Describe("Deck TUI helpers", func() { It("returns all sessions when nothing is toggled off", func() { sessions := []deck.SessionSummary{{ID: "s1"}, {ID: "s2"}, {ID: "s3"}} model := deckModel{ - overview: &deck.DeckOverview{Sessions: sessions}, + overview: &deck.Overview{Sessions: sessions}, trackToggles: map[int]bool{ 0: true, 1: true, @@ -96,7 +96,7 @@ var _ = Describe("Deck TUI helpers", func() { It("excludes deselected sessions", func() { sessions := []deck.SessionSummary{{ID: "s1"}, {ID: "s2"}, {ID: "s3"}} model := deckModel{ - overview: &deck.DeckOverview{Sessions: sessions}, + overview: &deck.Overview{Sessions: sessions}, trackToggles: map[int]bool{ 0: true, 1: false, diff --git a/cmd/tapes/deck/web.go b/cmd/tapes/deck/web.go index 971e774..a86ca30 100644 --- a/cmd/tapes/deck/web.go +++ b/cmd/tapes/deck/web.go @@ -50,7 +50,8 @@ func runDeckWeb(ctx context.Context, query *deck.Query, filters deck.Filters, po ReadHeaderTimeout: 5 * time.Second, } - listener, err := net.Listen("tcp", address) + lc := net.ListenConfig{} + listener, err := lc.Listen(ctx, "tcp", address) if err != nil { return err } @@ -59,7 +60,9 @@ func runDeckWeb(ctx context.Context, query *deck.Query, filters deck.Filters, po go func() { <-ctx.Done() - _ = server.Shutdown(context.Background()) + shutdownCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) + defer cancel() + _ = server.Shutdown(shutdownCtx) }() return server.Serve(listener) @@ -67,10 +70,15 @@ func runDeckWeb(ctx context.Context, query *deck.Query, filters deck.Filters, po func writeJSON(w http.ResponseWriter, payload any) { w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(payload) + if err := json.NewEncoder(w).Encode(payload); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } } func writeJSONError(w http.ResponseWriter, err error) { w.WriteHeader(http.StatusInternalServerError) - _ = json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + resp := map[string]string{"error": err.Error()} + if encErr := json.NewEncoder(w).Encode(resp); encErr != nil { + http.Error(w, encErr.Error(), http.StatusInternalServerError) + } } diff --git a/cmd/tapes/search/search.go b/cmd/tapes/search/search.go index 227cc6f..475e4a9 100644 --- a/cmd/tapes/search/search.go +++ b/cmd/tapes/search/search.go @@ -74,6 +74,7 @@ func NewSearchCmd() *cobra.Command { func (c *searchCommander) run() error { c.logger = logger.NewLogger(c.debug) defer func() { _ = c.logger.Sync() }() + client := &http.Client{} c.logger.Debug("searching via API", zap.String("api_target", c.apiTarget), @@ -98,7 +99,7 @@ func (c *searchCommander) run() error { if err != nil { return fmt.Errorf("creating search request: %w", err) } - resp, err := (&http.Client{}).Do(req) + resp, err := client.Do(req) if err != nil { return fmt.Errorf("failed to connect to Tapes API at %s: %w", c.apiTarget, err) } diff --git a/cmd/tapes/serve/api/api.go b/cmd/tapes/serve/api/api.go index f699557..a10f6eb 100644 --- a/cmd/tapes/serve/api/api.go +++ b/cmd/tapes/serve/api/api.go @@ -2,6 +2,7 @@ package apicmder import ( + "context" "fmt" "github.com/spf13/cobra" @@ -84,7 +85,7 @@ func (c *apiCommander) run() error { func (c *apiCommander) newStorageDriver() (storage.Driver, error) { if c.sqlitePath != "" { - driver, err := sqlite.NewDriver(c.sqlitePath) + driver, err := sqlite.NewDriver(context.Background(), c.sqlitePath) if err != nil { return nil, fmt.Errorf("failed to create SQLite storer: %w", err) } @@ -98,7 +99,7 @@ func (c *apiCommander) newStorageDriver() (storage.Driver, error) { func (c *apiCommander) newDagLoader() (merkle.DagLoader, error) { if c.sqlitePath != "" { - driver, err := sqlite.NewDriver(c.sqlitePath) + driver, err := sqlite.NewDriver(context.Background(), c.sqlitePath) if err != nil { return nil, fmt.Errorf("failed to create SQLite storer: %w", err) } diff --git a/cmd/tapes/serve/proxy/proxy.go b/cmd/tapes/serve/proxy/proxy.go index e327961..ff74ece 100644 --- a/cmd/tapes/serve/proxy/proxy.go +++ b/cmd/tapes/serve/proxy/proxy.go @@ -2,6 +2,7 @@ package proxycmder import ( + "context" "fmt" "github.com/spf13/cobra" @@ -142,7 +143,7 @@ func (c *proxyCommander) run() error { func (c *proxyCommander) newStorageDriver() (storage.Driver, error) { if c.sqlitePath != "" { - driver, err := sqlite.NewDriver(c.sqlitePath) + driver, err := sqlite.NewDriver(context.Background(), c.sqlitePath) if err != nil { return nil, fmt.Errorf("failed to create SQLite storer: %w", err) } diff --git a/cmd/tapes/serve/serve.go b/cmd/tapes/serve/serve.go index 3c540b5..3738e8c 100644 --- a/cmd/tapes/serve/serve.go +++ b/cmd/tapes/serve/serve.go @@ -2,6 +2,7 @@ package servecmder import ( + "context" "fmt" "os" "os/signal" @@ -211,7 +212,7 @@ func (c *ServeCommander) run() error { func (c *ServeCommander) newStorageDriver() (storage.Driver, error) { if c.sqlitePath != "" { - driver, err := sqlite.NewDriver(c.sqlitePath) + driver, err := sqlite.NewDriver(context.Background(), c.sqlitePath) if err != nil { return nil, fmt.Errorf("failed to create SQLite storer: %w", err) } @@ -225,7 +226,7 @@ func (c *ServeCommander) newStorageDriver() (storage.Driver, error) { func (c *ServeCommander) newDagLoader() (merkle.DagLoader, error) { if c.sqlitePath != "" { - driver, err := sqlite.NewDriver(c.sqlitePath) + driver, err := sqlite.NewDriver(context.Background(), c.sqlitePath) if err != nil { return nil, fmt.Errorf("failed to create SQLite storer: %w", err) } diff --git a/makefile b/makefile index b715537..884ce6f 100644 --- a/makefile +++ b/makefile @@ -10,17 +10,13 @@ LDFLAGS := -s -w \ -X 'github.com/papercomputeco/tapes/pkg/utils.Sha=$(COMMIT)' \ -X 'github.com/papercomputeco/tapes/pkg/utils.Buildtime=$(BUILDTIME)' -.PHONY: format -format: - find . -type f -name "*.go" -exec goimports -local github.com/papercomputeco/tapes -w {} \; - -.PHONY: check-lint -check-lint: ## Runs golangci-lint check. Auto-fixes are not automatically applied. +.PHONY: check +check: ## Runs golangci-lint check. Auto-fixes are not automatically applied. $(call print-target) dagger call check-lint -.PHONY: fix-lint -fix-lint: ## Runs golangci-lint lint with auto-fixes applied. +.PHONY: format +format: ## Runs golangci-lint linters and formatters with auto-fixes applied. $(call print-target) dagger call fix-lint export --path . diff --git a/pkg/deck/pricing.go b/pkg/deck/pricing.go index 172e70d..8b23033 100644 --- a/pkg/deck/pricing.go +++ b/pkg/deck/pricing.go @@ -3,6 +3,7 @@ package deck import ( "encoding/json" "fmt" + "maps" "os" "strings" ) @@ -45,9 +46,7 @@ func LoadPricing(path string) (PricingTable, error) { return nil, fmt.Errorf("parse pricing file: %w", err) } - for model, price := range overrides { - pricing[model] = price - } + maps.Copy(pricing, overrides) return pricing, nil } diff --git a/pkg/deck/query.go b/pkg/deck/query.go index 9abe1a5..ef1536f 100644 --- a/pkg/deck/query.go +++ b/pkg/deck/query.go @@ -3,6 +3,7 @@ package deck import ( "context" "encoding/json" + "errors" "fmt" "sort" "strings" @@ -14,13 +15,15 @@ import ( "github.com/papercomputeco/tapes/pkg/storage/sqlite" ) +const blockTypeToolUse = "tool_use" + type Query struct { client *ent.Client pricing PricingTable } -func NewQuery(dbPath string, pricing PricingTable) (*Query, func() error, error) { - driver, err := sqlite.NewSQLiteDriver(dbPath) +func NewQuery(ctx context.Context, dbPath string, pricing PricingTable) (*Query, func() error, error) { + driver, err := sqlite.NewDriver(ctx, dbPath) if err != nil { return nil, nil, err } @@ -32,13 +35,13 @@ func NewQuery(dbPath string, pricing PricingTable) (*Query, func() error, error) return &Query{client: driver.Client, pricing: pricing}, closeFn, nil } -func (q *Query) Overview(ctx context.Context, filters Filters) (*DeckOverview, error) { +func (q *Query) Overview(ctx context.Context, filters Filters) (*Overview, error) { leaves, err := q.client.Node.Query().Where(node.Not(node.HasChildren())).All(ctx) if err != nil { return nil, fmt.Errorf("list leaves: %w", err) } - overview := &DeckOverview{ + overview := &Overview{ Sessions: make([]SessionSummary, 0, len(leaves)), CostByModel: map[string]ModelCost{}, } @@ -169,15 +172,12 @@ func (q *Query) buildSessionSummary(ctx context.Context, leaf *ent.Node) (Sessio func (q *Query) buildSessionSummaryFromNodes(nodes []*ent.Node) (SessionSummary, map[string]ModelCost, string, error) { if len(nodes) == 0 { - return SessionSummary{}, nil, "", fmt.Errorf("empty session nodes") + return SessionSummary{}, nil, "", errors.New("empty session nodes") } start := nodes[0].CreatedAt end := nodes[len(nodes)-1].CreatedAt - duration := end.Sub(start) - if duration < 0 { - duration = 0 - } + duration := max(end.Sub(start), 0) label := buildLabel(nodes) toolCalls := 0 @@ -284,19 +284,17 @@ func (q *Query) costForNode(node *ent.Node, inputTokens, outputTokens int64) (fl } func tokenCounts(node *ent.Node) (int64, int64, int64) { - inputTokens := int64(0) - outputTokens := int64(0) - totalTokens := int64(0) + var inputTokens, outputTokens int64 if node.PromptTokens != nil { inputTokens = int64(*node.PromptTokens) } if node.CompletionTokens != nil { outputTokens = int64(*node.CompletionTokens) } + + totalTokens := inputTokens + outputTokens if node.TotalTokens != nil { totalTokens = int64(*node.TotalTokens) - } else { - totalTokens = inputTokens + outputTokens } return inputTokens, outputTokens, totalTokens @@ -323,7 +321,7 @@ func parseContentBlocks(raw []map[string]any) ([]llm.ContentBlock, error) { func extractToolCalls(blocks []llm.ContentBlock) []string { tools := []string{} for _, block := range blocks { - if block.Type == "tool_use" && block.ToolName != "" { + if block.Type == blockTypeToolUse && block.ToolName != "" { tools = append(tools, block.ToolName) } } @@ -333,7 +331,7 @@ func extractToolCalls(blocks []llm.ContentBlock) []string { func countToolCalls(blocks []llm.ContentBlock) int { count := 0 for _, block := range blocks { - if block.Type == "tool_use" { + if block.Type == blockTypeToolUse { count++ } } @@ -358,7 +356,7 @@ func extractText(blocks []llm.ContentBlock) string { case block.ToolOutput != "": texts = append(texts, block.ToolOutput) case block.ToolName != "": - texts = append(texts, fmt.Sprintf("tool call: %s", block.ToolName)) + texts = append(texts, "tool call: "+block.ToolName) } } return strings.Join(texts, "\n") diff --git a/pkg/deck/types.go b/pkg/deck/types.go index 592cd6a..29ed6d8 100644 --- a/pkg/deck/types.go +++ b/pkg/deck/types.go @@ -56,7 +56,7 @@ type ModelCost struct { SessionCount int `json:"session_count"` } -type DeckOverview struct { +type Overview struct { Sessions []SessionSummary `json:"sessions"` TotalCost float64 `json:"total_cost"` TotalTokens int64 `json:"total_tokens"` diff --git a/pkg/dotdir/checkout.go b/pkg/dotdir/checkout.go index eb475ed..9aca861 100644 --- a/pkg/dotdir/checkout.go +++ b/pkg/dotdir/checkout.go @@ -73,7 +73,7 @@ func (m *Manager) SaveCheckout(state *CheckoutState, overrideDir string) error { } path := filepath.Join(dir, checkoutFile) - if err := os.WriteFile(path, data, 0o644); err != nil { //nolint:gosec // @jpmcb - TODO: refactor file permissions + if err := os.WriteFile(path, data, 0o600); err != nil { return fmt.Errorf("writing checkout state: %w", err) } diff --git a/pkg/storage/sqlite/sqlite.go b/pkg/storage/sqlite/sqlite.go index 504e16c..16a528f 100644 --- a/pkg/storage/sqlite/sqlite.go +++ b/pkg/storage/sqlite/sqlite.go @@ -8,7 +8,7 @@ import ( "entgo.io/ent/dialect" entsql "entgo.io/ent/dialect/sql" - _ "github.com/mattn/go-sqlite3" // Register sqlite3 driver + _ "github.com/mattn/go-sqlite3" // load up the sqlite3 CGO libs "github.com/papercomputeco/tapes/pkg/storage/ent" entdriver "github.com/papercomputeco/tapes/pkg/storage/ent/driver" @@ -21,7 +21,7 @@ type Driver struct { // NewDriver creates a new SQLite-backed storer. // The dbPath can be a file path or ":memory:" for an in-memory database. -func NewDriver(dbPath string) (*Driver, error) { +func NewDriver(ctx context.Context, dbPath string) (*Driver, error) { // Open the database using the github.com/mattn/go-sqlite3 driver (registered as "sqlite3") db, err := sql.Open("sqlite3", dbPath) if err != nil { @@ -29,7 +29,7 @@ func NewDriver(dbPath string) (*Driver, error) { } // SQLite-specific pragmas - if _, err := db.ExecContext(context.Background(), "PRAGMA foreign_keys = ON"); err != nil { + if _, err := db.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil { db.Close() return nil, fmt.Errorf("failed to enable foreign keys: %w", err) } @@ -40,7 +40,7 @@ func NewDriver(dbPath string) (*Driver, error) { // Run ent's auto-migration to create/update the schema // This handles append-only schema changes (new tables, columns, indexes) - if err := client.Schema.Create(context.Background()); err != nil { + if err := client.Schema.Create(ctx); err != nil { client.Close() return nil, fmt.Errorf("failed to create schema: %w", err) } diff --git a/pkg/storage/sqlite/sqlite_test.go b/pkg/storage/sqlite/sqlite_test.go index 05cbe81..d4262da 100644 --- a/pkg/storage/sqlite/sqlite_test.go +++ b/pkg/storage/sqlite/sqlite_test.go @@ -34,7 +34,7 @@ var _ = Describe("Driver", func() { BeforeEach(func() { ctx = context.Background() var err error - driver, err = sqlite.NewDriver(":memory:") + driver, err = sqlite.NewDriver(ctx, ":memory:") Expect(err).NotTo(HaveOccurred()) }) @@ -49,7 +49,7 @@ var _ = Describe("Driver", func() { tmpDir := GinkgoT().TempDir() dbPath := filepath.Join(tmpDir, "test.db") - s, err := sqlite.NewDriver(dbPath) + s, err := sqlite.NewDriver(context.Background(), dbPath) Expect(err).NotTo(HaveOccurred()) defer s.Close() diff --git a/pkg/vector/sqlitevec/sqlitevec.go b/pkg/vector/sqlitevec/sqlitevec.go index d1d42d6..13af6f9 100644 --- a/pkg/vector/sqlitevec/sqlitevec.go +++ b/pkg/vector/sqlitevec/sqlitevec.go @@ -11,7 +11,7 @@ import ( "strings" sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo" - _ "github.com/mattn/go-sqlite3" // Register sqlite3 driver + _ "github.com/mattn/go-sqlite3" // load up the sqlite3 CGO libs "go.uber.org/zap" "github.com/papercomputeco/tapes/pkg/vector"