diff --git a/cagent-schema.json b/cagent-schema.json index 677be8384..0322603c6 100644 --- a/cagent-schema.json +++ b/cagent-schema.json @@ -51,6 +51,13 @@ "$ref": "#/definitions/RAGConfig" } }, + "memory": { + "type": "object", + "description": "Map of memory scope configurations for pluggable memory backends", + "additionalProperties": { + "$ref": "#/definitions/MemoryConfig" + } + }, "metadata": { "$ref": "#/definitions/Metadata", "description": "Configuration metadata" @@ -277,6 +284,13 @@ "type": "string" } }, + "memory": { + "type": "array", + "description": "List of memory scopes to use for this agent", + "items": { + "type": "string" + } + }, "hooks": { "$ref": "#/definitions/HooksConfig", "description": "Lifecycle hooks for executing shell commands at various points in the agent's execution" @@ -1297,6 +1311,80 @@ "strategies" ], "additionalProperties": false + }, + "MemoryConfig": { + "type": "object", + "description": "Memory scope configuration for pluggable memory backends supporting long-term (RAG-style) and short-term (whiteboard) strategies", + "required": ["kind"], + "properties": { + "kind": { + "type": "string", + "description": "Memory backend type", + "enum": ["sqlite", "neo4j", "qdrant", "redis", "whiteboard"] + }, + "strategy": { + "type": "string", + "description": "Memory strategy: long_term (persistent RAG-style) or short_term (ephemeral whiteboard)", + "enum": ["long_term", "short_term"] + }, + "description": { + "type": "string", + "description": "Human-readable description of this memory scope" + }, + "path": { + "type": "string", + "description": "File path for file-based backends (sqlite)" + }, + "ttl": { + "type": "integer", + "description": "Time-to-live in seconds for ephemeral memory (whiteboard, redis)", + "minimum": 0 + }, + "mode": { + "type": "string", + "description": "Access mode for the memory", + "enum": ["read_write", "read_only", "append_only"] + }, + "connection": { + "type": "object", + "description": "Connection details for remote backends", + "properties": { + "url": { + "type": "string", + "description": "Connection URL for the memory backend" + }, + "database": { + "type": "string", + "description": "Database name (for backends supporting multiple databases)" + }, + "collection": { + "type": "string", + "description": "Collection/table name (for vector stores)" + }, + "auth": { + "type": "object", + "description": "Authentication credentials", + "properties": { + "username": { + "type": "string", + "description": "Username for basic authentication" + }, + "password": { + "type": "string", + "description": "Password for basic authentication" + }, + "token": { + "type": "string", + "description": "Token for token-based authentication" + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false } } } diff --git a/examples/memory_demo.yaml b/examples/memory_demo.yaml new file mode 100644 index 000000000..6c1f54785 --- /dev/null +++ b/examples/memory_demo.yaml @@ -0,0 +1,79 @@ +#!/usr/bin/env cagent run + +metadata: + author: Memory System Demo + readme: | + Demonstrates the pluggable memory system with both long-term and short-term memory. + + Memory strategies: + - **Long-term (SQLite)**: Persistent memory for user facts + - **Short-term (Whiteboard)**: Ephemeral shared context for multi-agent collaboration + + Future backends (config-ready, implementation pending): + - Neo4j GraphRAG for knowledge graphs + - Qdrant for vector-based semantic search + - Redis for distributed whiteboard + +memory: + # Long-term persistent memory (current implementation) + user_facts: + kind: sqlite + strategy: long_term + path: ./memory/user_facts.db + description: "Persistent memory about the user" + + # Short-term shared whiteboard (implementation pending) + # team_whiteboard: + # kind: whiteboard + # strategy: short_term + # ttl: 3600 # 1 hour expiry + # description: "Shared context for agent collaboration" + + # GraphRAG with Neo4j (implementation pending) + # knowledge_graph: + # kind: neo4j + # strategy: long_term + # connection: + # url: bolt://localhost:7687 + # database: cagent + # auth: + # username: neo4j + # password: ${NEO4J_PASSWORD} + # description: "Knowledge graph for semantic relationships" + +agents: + root: + model: anthropic/claude-sonnet-4-5 + description: "Assistant with long-term memory" + instruction: | + You are a helpful assistant with memory capabilities. + + Use the memory tool to remember things about the user. + Before responding, always check memories to personalize your responses. + memory: + - user_facts + toolsets: + - type: think + + # Multi-agent example with shared whiteboard (pending whiteboard implementation) + # coordinator: + # model: anthropic/claude-sonnet-4-5 + # description: "Coordinates team using shared whiteboard" + # memory: + # - team_whiteboard # Shared with sub-agents + # - user_facts # Personal long-term memory + # sub_agents: + # - researcher + # - writer + # + # researcher: + # model: openai/gpt-4o + # description: "Research specialist" + # memory: + # - team_whiteboard # Shared whiteboard + # + # writer: + # model: anthropic/claude-sonnet-4-5 + # description: "Writing specialist" + # memory: + # - team_whiteboard # Shared whiteboard diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index aab024847..e0921ec57 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -18,6 +18,7 @@ type Config struct { Providers map[string]ProviderConfig `json:"providers,omitempty"` Models map[string]ModelConfig `json:"models,omitempty"` RAG map[string]RAGConfig `json:"rag,omitempty"` + Memory map[string]MemoryConfig `json:"memory,omitempty"` Metadata Metadata `json:"metadata,omitempty"` Permissions *PermissionsConfig `json:"permissions,omitempty"` } @@ -119,6 +120,7 @@ type AgentConfig struct { SubAgents []string `json:"sub_agents,omitempty"` Handoffs []string `json:"handoffs,omitempty"` RAG []string `json:"rag,omitempty"` + Memory []string `json:"memory,omitempty"` AddDate bool `json:"add_date,omitempty"` AddEnvironmentInfo bool `json:"add_environment_info,omitempty"` CodeModeTools bool `json:"code_mode_tools,omitempty"` @@ -1003,3 +1005,45 @@ func (h *HookDefinition) validate(prefix string, index int) error { return nil } + +// MemoryConfig represents a named memory scope configuration. +// Memory scopes define how agents store and retrieve information. +type MemoryConfig struct { + // Kind specifies the memory backend type: sqlite, whiteboard, neo4j, qdrant, redis. + Kind string `json:"kind"` + // Strategy specifies the memory strategy: "long_term" (persistent RAG-style) or "short_term" (ephemeral whiteboard). + // Default is "long_term" for sqlite/neo4j/qdrant, "short_term" for whiteboard/redis. + Strategy string `json:"strategy,omitempty"` + // Description is an optional human-readable description of this memory scope. + Description string `json:"description,omitempty"` + // Connection holds connection details for remote backends. + Connection *MemoryConnectionConfig `json:"connection,omitempty"` + // Path is the file path for file-based backends like sqlite. + Path string `json:"path,omitempty"` + // TTL is the time-to-live in seconds for ephemeral memory (e.g., whiteboard). 0 means no expiry. + TTL int `json:"ttl,omitempty"` + // Mode specifies the access mode: "read_write" (default), "read_only", or "append_only" (event-log style). + Mode string `json:"mode,omitempty"` +} + +// MemoryConnectionConfig holds connection details for remote memory backends. +type MemoryConnectionConfig struct { + // URL is the connection URL for the memory backend. + URL string `json:"url"` + // Database is the database name (for backends that support multiple databases). + Database string `json:"database,omitempty"` + // Collection is the collection/table name (for vector stores). + Collection string `json:"collection,omitempty"` + // Auth holds authentication credentials. + Auth *MemoryAuthConfig `json:"auth,omitempty"` +} + +// MemoryAuthConfig holds authentication credentials for memory backends. +type MemoryAuthConfig struct { + // Username for basic authentication. + Username string `json:"username,omitempty"` + // Password for basic authentication. + Password string `json:"password,omitempty"` + // Token for token-based authentication. + Token string `json:"token,omitempty"` +} diff --git a/pkg/config/latest/validate.go b/pkg/config/latest/validate.go index 752c29987..372420650 100644 --- a/pkg/config/latest/validate.go +++ b/pkg/config/latest/validate.go @@ -2,6 +2,7 @@ package latest import ( "errors" + "fmt" "strings" ) @@ -16,8 +17,7 @@ func (t *Config) UnmarshalYAML(unmarshal func(any) error) error { } func (t *Config) validate() error { - for i := range t.Agents { - agent := t.Agents[i] + for agentName, agent := range t.Agents { for j := range agent.Toolsets { if err := agent.Toolsets[j].validate(); err != nil { return err @@ -28,6 +28,18 @@ func (t *Config) validate() error { return err } } + // Validate agent memory references exist in top-level memory map + for _, memRef := range agent.Memory { + if _, exists := t.Memory[memRef]; !exists { + return fmt.Errorf("agent %q: references undefined memory %q", agentName, memRef) + } + } + } + + for name, mem := range t.Memory { + if err := mem.validate(name); err != nil { + return err + } } return nil @@ -123,3 +135,89 @@ func (t *Toolset) validate() error { return nil } + +func (m *MemoryConfig) validate(name string) error { + if m.Kind == "" { + return fmt.Errorf("memory %q: kind is required", name) + } + + validKinds := map[string]bool{ + "sqlite": true, + "neo4j": true, + "qdrant": true, + "redis": true, + "whiteboard": true, + } + if !validKinds[m.Kind] { + return fmt.Errorf("memory %q: invalid kind %q, must be one of: sqlite, neo4j, qdrant, redis, whiteboard", name, m.Kind) + } + + // Validate strategy if provided + if m.Strategy != "" { + validStrategies := map[string]bool{ + "long_term": true, // Persistent RAG-style memory (sqlite, neo4j, qdrant) + "short_term": true, // Ephemeral whiteboard-style memory (whiteboard, redis) + } + if !validStrategies[m.Strategy] { + return fmt.Errorf("memory %q: invalid strategy %q, must be one of: long_term, short_term", name, m.Strategy) + } + + // Validate strategy matches kind semantics + longTermKinds := map[string]bool{"sqlite": true, "neo4j": true, "qdrant": true} + shortTermKinds := map[string]bool{"whiteboard": true, "redis": true} + + if m.Strategy == "long_term" && shortTermKinds[m.Kind] { + return fmt.Errorf("memory %q: kind %q is not suitable for long_term strategy (use sqlite, neo4j, or qdrant)", name, m.Kind) + } + if m.Strategy == "short_term" && longTermKinds[m.Kind] && m.Kind != "redis" { + // Note: redis can be used for both strategies; sqlite/neo4j/qdrant are long-term only + if m.Kind != "redis" { + return fmt.Errorf("memory %q: kind %q is not suitable for short_term strategy (use whiteboard or redis)", name, m.Kind) + } + } + } + + // Validate mode if provided + if m.Mode != "" { + validModes := map[string]bool{ + "read_write": true, // Default: full read/write access + "read_only": true, // Read-only access (useful for shared knowledge bases) + "append_only": true, // Append-only (event-log style, no updates/deletes) + } + if !validModes[m.Mode] { + return fmt.Errorf("memory %q: invalid mode %q, must be one of: read_write, read_only, append_only", name, m.Mode) + } + } + + // Validate auth completeness if present (check Connection is not nil first) + if m.Connection != nil && m.Connection.Auth != nil { + auth := m.Connection.Auth + hasUserPass := auth.Username != "" || auth.Password != "" + hasToken := auth.Token != "" + + if hasUserPass && hasToken { + return fmt.Errorf("memory %q: auth must use either username/password or token, not both", name) + } + if hasUserPass && (auth.Username == "" || auth.Password == "") { + return fmt.Errorf("memory %q: auth requires both username and password when using user/password auth", name) + } + } + + // For sqlite, path is required + if m.Kind == "sqlite" && m.Path == "" { + return fmt.Errorf("memory %q: sqlite requires a path", name) + } + + // For remote backends, connection URL is typically required + remoteKinds := map[string]bool{"neo4j": true, "qdrant": true, "redis": true} + if remoteKinds[m.Kind] && (m.Connection == nil || m.Connection.URL == "") { + return fmt.Errorf("memory %q: %s requires connection.url", name, m.Kind) + } + + // TTL validation: only meaningful for short-term/ephemeral memory + if m.TTL > 0 && m.Kind != "whiteboard" && m.Kind != "redis" { + return fmt.Errorf("memory %q: ttl is only supported for whiteboard and redis kinds", name) + } + + return nil +} diff --git a/pkg/memory/adapter.go b/pkg/memory/adapter.go new file mode 100644 index 000000000..c8140a34f --- /dev/null +++ b/pkg/memory/adapter.go @@ -0,0 +1,50 @@ +package memory + +import ( + "context" + + "github.com/google/uuid" + + "github.com/docker/cagent/pkg/memory/database" +) + +// DatabaseAdapter adapts the new Driver interface to the legacy database.Database interface +type DatabaseAdapter struct { + driver Driver +} + +var _ database.Database = (*DatabaseAdapter)(nil) + +// NewDatabaseAdapter creates an adapter that wraps a Driver +func NewDatabaseAdapter(driver Driver) *DatabaseAdapter { + return &DatabaseAdapter{driver: driver} +} + +func (a *DatabaseAdapter) AddMemory(ctx context.Context, memory database.UserMemory) error { + key := memory.ID + if key == "" { + key = uuid.New().String() + } + return a.driver.Store(ctx, key, memory.Memory) +} + +func (a *DatabaseAdapter) GetMemories(ctx context.Context) ([]database.UserMemory, error) { + entries, err := a.driver.Retrieve(ctx, Query{}) + if err != nil { + return nil, err + } + + memories := make([]database.UserMemory, len(entries)) + for i, e := range entries { + memories[i] = database.UserMemory{ + ID: e.ID, + CreatedAt: e.CreatedAt, + Memory: e.Content, + } + } + return memories, nil +} + +func (a *DatabaseAdapter) DeleteMemory(ctx context.Context, memory database.UserMemory) error { + return a.driver.Delete(ctx, memory.ID) +} diff --git a/pkg/memory/driver.go b/pkg/memory/driver.go new file mode 100644 index 000000000..8b012e03b --- /dev/null +++ b/pkg/memory/driver.go @@ -0,0 +1,88 @@ +package memory + +import ( + "context" + "io" + + "github.com/docker/cagent/pkg/config/latest" +) + +// Driver defines the interface for memory backends. +// Implementations support different strategies (long-term RAG, short-term whiteboard). +type Driver interface { + // Store saves a memory entry with the given key and value + Store(ctx context.Context, key, value string) error + + // Retrieve fetches memory entries matching the query + Retrieve(ctx context.Context, query Query) ([]Entry, error) + + // Delete removes a memory entry by key + Delete(ctx context.Context, key string) error + + // Close releases resources held by the driver + io.Closer +} + +// Query represents different types of memory queries +type Query struct { + // ID for exact match retrieval + ID string + + // Semantic for natural language queries (GraphRAG, vector search) + Semantic string + + // Limit on number of results + Limit int + + // Filters for metadata-based filtering + Filters map[string]any +} + +// Entry represents a memory item returned from a query +type Entry struct { + ID string + CreatedAt string + Content string + Metadata map[string]any + Score float64 // Relevance score for semantic queries +} + +// Factory creates memory drivers from configuration +type Factory interface { + CreateDriver(ctx context.Context, cfg latest.MemoryConfig) (Driver, error) +} + +// Registry holds registered driver factories +type Registry struct { + factories map[string]Factory +} + +// NewRegistry creates a new driver registry +func NewRegistry() *Registry { + return &Registry{ + factories: make(map[string]Factory), + } +} + +// Register adds a factory for a specific backend kind +func (r *Registry) Register(kind string, factory Factory) { + r.factories[kind] = factory +} + +// CreateDriver instantiates a driver from config +func (r *Registry) CreateDriver(ctx context.Context, cfg latest.MemoryConfig) (Driver, error) { + factory, ok := r.factories[cfg.Kind] + if !ok { + return nil, &UnsupportedKindError{Kind: cfg.Kind} + } + return factory.CreateDriver(ctx, cfg) +} + +// UnsupportedKindError indicates an unknown backend kind +type UnsupportedKindError struct { + Kind string +} + +func (e *UnsupportedKindError) Error() string { + return "unsupported memory kind: " + e.Kind +} diff --git a/pkg/memory/memory_test.go b/pkg/memory/memory_test.go new file mode 100644 index 000000000..a9eed971c --- /dev/null +++ b/pkg/memory/memory_test.go @@ -0,0 +1,203 @@ +package memory_test + +import ( + "context" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/config/latest" + "github.com/docker/cagent/pkg/memory" + "github.com/docker/cagent/pkg/memory/database" + "github.com/docker/cagent/pkg/memory/sqlite" +) + +// MockDriver implements memory.Driver for testing +type MockDriver struct { + stored map[string]string + entries []memory.Entry + closeErr error +} + +func NewMockDriver() *MockDriver { + return &MockDriver{ + stored: make(map[string]string), + entries: []memory.Entry{}, + } +} + +func (m *MockDriver) Store(ctx context.Context, key, value string) error { + m.stored[key] = value + m.entries = append(m.entries, memory.Entry{ + ID: key, + Content: value, + CreatedAt: "2026-01-17T00:00:00Z", + }) + return nil +} + +func (m *MockDriver) Retrieve(ctx context.Context, query memory.Query) ([]memory.Entry, error) { + if query.ID != "" { + for _, e := range m.entries { + if e.ID == query.ID { + return []memory.Entry{e}, nil + } + } + return []memory.Entry{}, nil + } + return m.entries, nil +} + +func (m *MockDriver) Delete(ctx context.Context, key string) error { + delete(m.stored, key) + var newEntries []memory.Entry + for _, e := range m.entries { + if e.ID != key { + newEntries = append(newEntries, e) + } + } + m.entries = newEntries + return nil +} + +func (m *MockDriver) Close() error { + return m.closeErr +} + +func TestRegistry(t *testing.T) { + t.Parallel() + + t.Run("register and create driver", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + registry := memory.NewRegistry() + + // Register a mock factory + mockFactory := &mockFactory{} + registry.Register("mock", mockFactory) + + // Create driver + cfg := latest.MemoryConfig{Kind: "mock"} + driver, err := registry.CreateDriver(ctx, cfg) + require.NoError(t, err) + require.NotNil(t, driver) + }) + + t.Run("error on unknown kind", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + registry := memory.NewRegistry() + + cfg := latest.MemoryConfig{Kind: "unknown"} + driver, err := registry.CreateDriver(ctx, cfg) + require.Error(t, err) + assert.Nil(t, driver) + + var unsupportedErr *memory.UnsupportedKindError + require.ErrorAs(t, err, &unsupportedErr) + assert.Equal(t, "unknown", unsupportedErr.Kind) + }) +} + +type mockFactory struct{} + +func (f *mockFactory) CreateDriver(ctx context.Context, cfg latest.MemoryConfig) (memory.Driver, error) { + return NewMockDriver(), nil +} + +func TestDatabaseAdapter(t *testing.T) { + t.Parallel() + ctx := t.Context() + + mockDriver := NewMockDriver() + adapter := memory.NewDatabaseAdapter(mockDriver) + + t.Run("add memory", func(t *testing.T) { + mem := database.UserMemory{ + ID: "test-1", + Memory: "test content", + } + err := adapter.AddMemory(ctx, mem) + require.NoError(t, err) + assert.Equal(t, "test content", mockDriver.stored["test-1"]) + }) + + t.Run("add memory with auto ID", func(t *testing.T) { + mem := database.UserMemory{ + Memory: "auto id content", + } + err := adapter.AddMemory(ctx, mem) + require.NoError(t, err) + // Should have stored with a UUID + assert.Len(t, mockDriver.stored, 2) + }) + + t.Run("get memories", func(t *testing.T) { + memories, err := adapter.GetMemories(ctx) + require.NoError(t, err) + assert.GreaterOrEqual(t, len(memories), 2) + }) + + t.Run("delete memory", func(t *testing.T) { + mem := database.UserMemory{ID: "test-1"} + err := adapter.DeleteMemory(ctx, mem) + require.NoError(t, err) + _, exists := mockDriver.stored["test-1"] + assert.False(t, exists) + }) +} + +func TestDefaultRegistry(t *testing.T) { + t.Run("default registry is singleton", func(t *testing.T) { + reg1 := memory.DefaultRegistry() + reg2 := memory.DefaultRegistry() + assert.Same(t, reg1, reg2) + }) +} + +// Integration test with real SQLite driver +func TestSQLiteDriverIntegration(t *testing.T) { + t.Parallel() + ctx := t.Context() + + dbPath := filepath.Join(t.TempDir(), "integration_test.db") + cfg := latest.MemoryConfig{ + Kind: "sqlite", + Path: dbPath, + } + + // Use the actual sqlite factory + sqliteFactory := &sqlite.Factory{} + registry := memory.NewRegistry() + registry.Register("sqlite", sqliteFactory) + + driver, err := registry.CreateDriver(ctx, cfg) + require.NoError(t, err) + defer driver.Close() + + // Test full workflow through adapter + adapter := memory.NewDatabaseAdapter(driver) + + // Add + err = adapter.AddMemory(ctx, database.UserMemory{ + ID: "integration-1", + Memory: "Integration test memory", + }) + require.NoError(t, err) + + // Get + memories, err := adapter.GetMemories(ctx) + require.NoError(t, err) + require.Len(t, memories, 1) + assert.Equal(t, "Integration test memory", memories[0].Memory) + + // Delete + err = adapter.DeleteMemory(ctx, database.UserMemory{ID: "integration-1"}) + require.NoError(t, err) + + memories, err = adapter.GetMemories(ctx) + require.NoError(t, err) + assert.Empty(t, memories) +} diff --git a/pkg/memory/registry.go b/pkg/memory/registry.go new file mode 100644 index 000000000..85134942f --- /dev/null +++ b/pkg/memory/registry.go @@ -0,0 +1,31 @@ +package memory + +import ( + "context" + "sync" + + "github.com/docker/cagent/pkg/config/latest" +) + +var ( + globalRegistry *Registry + globalRegistryOnce sync.Once +) + +// DefaultRegistry returns the global driver registry +func DefaultRegistry() *Registry { + globalRegistryOnce.Do(func() { + globalRegistry = NewRegistry() + }) + return globalRegistry +} + +// RegisterFactory registers a driver factory for a backend kind +func RegisterFactory(kind string, factory Factory) { + DefaultRegistry().Register(kind, factory) +} + +// CreateDriver creates a driver from config using the default registry +func CreateDriver(ctx context.Context, cfg latest.MemoryConfig) (Driver, error) { + return DefaultRegistry().CreateDriver(ctx, cfg) +} diff --git a/pkg/memory/sqlite/driver.go b/pkg/memory/sqlite/driver.go new file mode 100644 index 000000000..ee680a956 --- /dev/null +++ b/pkg/memory/sqlite/driver.go @@ -0,0 +1,121 @@ +package sqlite + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/google/uuid" + + "github.com/docker/cagent/pkg/config/latest" + "github.com/docker/cagent/pkg/memory" + "github.com/docker/cagent/pkg/sqliteutil" +) + +// Driver implements the memory.Driver interface using SQLite +type Driver struct { + db *sql.DB +} + +// Factory creates SQLite drivers +type Factory struct{} + +var _ memory.Factory = (*Factory)(nil) + +func (f *Factory) CreateDriver(ctx context.Context, cfg latest.MemoryConfig) (memory.Driver, error) { + if cfg.Path == "" { + return nil, fmt.Errorf("sqlite driver requires a path") + } + + db, err := sqliteutil.OpenDB(cfg.Path) + if err != nil { + return nil, fmt.Errorf("failed to open sqlite database: %w", err) + } + + _, err = db.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS memories ( + id TEXT PRIMARY KEY, + created_at TEXT, + content TEXT, + metadata TEXT + )`) + if err != nil { + db.Close() + return nil, fmt.Errorf("failed to create memories table: %w", err) + } + + return &Driver{db: db}, nil +} + +func (d *Driver) Store(ctx context.Context, key, value string) error { + if key == "" { + key = uuid.New().String() + } + + createdAt := time.Now().UTC().Format(time.RFC3339) + _, err := d.db.ExecContext(ctx, + "INSERT OR REPLACE INTO memories (id, created_at, content, metadata) VALUES (?, ?, ?, ?)", + key, createdAt, value, "{}") + if err != nil { + return fmt.Errorf("failed to store memory: %w", err) + } + return nil +} + +func (d *Driver) Retrieve(ctx context.Context, query memory.Query) ([]memory.Entry, error) { + var ( + rows *sql.Rows + err error + ) + + switch { + case query.ID != "": + rows, err = d.db.QueryContext( + ctx, + "SELECT id, created_at, content FROM memories WHERE id = ?", + query.ID, + ) + default: + // Semantic search is not yet implemented for SQLite. + sqlQuery := "SELECT id, created_at, content FROM memories ORDER BY created_at DESC" + if query.Limit > 0 { + sqlQuery = fmt.Sprintf("%s LIMIT %d", sqlQuery, query.Limit) + } + rows, err = d.db.QueryContext(ctx, sqlQuery) + } + + if err != nil { + return nil, fmt.Errorf("failed to retrieve memories: %w", err) + } + defer rows.Close() + + var entries []memory.Entry + for rows.Next() { + var e memory.Entry + if err := rows.Scan(&e.ID, &e.CreatedAt, &e.Content); err != nil { + return nil, fmt.Errorf("failed to scan memory row: %w", err) + } + entries = append(entries, e) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating memory rows: %w", err) + } + + return entries, nil +} + +func (d *Driver) Delete(ctx context.Context, key string) error { + _, err := d.db.ExecContext(ctx, "DELETE FROM memories WHERE id = ?", key) + if err != nil { + return fmt.Errorf("failed to delete memory: %w", err) + } + return nil +} + +func (d *Driver) Close() error { + if d.db != nil { + return d.db.Close() + } + return nil +} diff --git a/pkg/memory/sqlite/driver_test.go b/pkg/memory/sqlite/driver_test.go new file mode 100644 index 000000000..63f15835c --- /dev/null +++ b/pkg/memory/sqlite/driver_test.go @@ -0,0 +1,157 @@ +package sqlite + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/config/latest" + "github.com/docker/cagent/pkg/memory" +) + +func TestFactory_CreateDriver(t *testing.T) { + t.Parallel() + + ctx := t.Context() + factory := &Factory{} + + t.Run("creates driver with valid path", func(t *testing.T) { + t.Parallel() + dbPath := filepath.Join(t.TempDir(), "test.db") + cfg := latest.MemoryConfig{ + Kind: "sqlite", + Path: dbPath, + } + + driver, err := factory.CreateDriver(ctx, cfg) + require.NoError(t, err) + require.NotNil(t, driver) + defer driver.Close() + }) + + t.Run("fails without path", func(t *testing.T) { + t.Parallel() + cfg := latest.MemoryConfig{ + Kind: "sqlite", + Path: "", + } + + driver, err := factory.CreateDriver(ctx, cfg) + require.Error(t, err) + assert.Nil(t, driver) + assert.Contains(t, err.Error(), "requires a path") + }) +} + +func TestDriver_StoreAndRetrieve(t *testing.T) { + t.Parallel() + + ctx := t.Context() + driver := createTestDriver(t) + defer driver.Close() + + t.Run("store with explicit key", func(t *testing.T) { + err := driver.Store(ctx, "test-key-1", "test value 1") + require.NoError(t, err) + + entries, err := driver.Retrieve(ctx, memory.Query{ID: "test-key-1"}) + require.NoError(t, err) + require.Len(t, entries, 1) + assert.Equal(t, "test-key-1", entries[0].ID) + assert.Equal(t, "test value 1", entries[0].Content) + }) + + t.Run("store with auto-generated key", func(t *testing.T) { + err := driver.Store(ctx, "", "auto key value") + require.NoError(t, err) + + entries, err := driver.Retrieve(ctx, memory.Query{}) + require.NoError(t, err) + require.GreaterOrEqual(t, len(entries), 1) + }) + + t.Run("retrieve all with limit", func(t *testing.T) { + // Add more entries + for range 5 { + err := driver.Store(ctx, "", "bulk value") + require.NoError(t, err) + } + + entries, err := driver.Retrieve(ctx, memory.Query{Limit: 3}) + require.NoError(t, err) + assert.Len(t, entries, 3) + }) + + t.Run("retrieve with semantic query falls back to all", func(t *testing.T) { + entries, err := driver.Retrieve(ctx, memory.Query{ + Semantic: "some semantic query", + Limit: 2, + }) + require.NoError(t, err) + assert.LessOrEqual(t, len(entries), 2) + }) +} + +func TestDriver_Delete(t *testing.T) { + t.Parallel() + + ctx := t.Context() + driver := createTestDriver(t) + defer driver.Close() + + // Store a memory + err := driver.Store(ctx, "delete-test", "to be deleted") + require.NoError(t, err) + + // Verify it exists + entries, err := driver.Retrieve(ctx, memory.Query{ID: "delete-test"}) + require.NoError(t, err) + require.Len(t, entries, 1) + + // Delete it + err = driver.Delete(ctx, "delete-test") + require.NoError(t, err) + + // Verify it's gone + entries, err = driver.Retrieve(ctx, memory.Query{ID: "delete-test"}) + require.NoError(t, err) + assert.Empty(t, entries) +} + +func TestDriver_UpdateExisting(t *testing.T) { + t.Parallel() + + ctx := t.Context() + driver := createTestDriver(t) + defer driver.Close() + + // Store initial value + err := driver.Store(ctx, "update-key", "initial value") + require.NoError(t, err) + + // Update with same key + err = driver.Store(ctx, "update-key", "updated value") + require.NoError(t, err) + + // Retrieve and verify updated + entries, err := driver.Retrieve(ctx, memory.Query{ID: "update-key"}) + require.NoError(t, err) + require.Len(t, entries, 1) + assert.Equal(t, "updated value", entries[0].Content) +} + +func createTestDriver(t *testing.T) *Driver { + t.Helper() + ctx := t.Context() + factory := &Factory{} + dbPath := filepath.Join(t.TempDir(), "test.db") + cfg := latest.MemoryConfig{ + Kind: "sqlite", + Path: dbPath, + } + driver, err := factory.CreateDriver(ctx, cfg) + require.NoError(t, err) + return driver.(*Driver) +} diff --git a/pkg/memory/sqlite/init.go b/pkg/memory/sqlite/init.go new file mode 100644 index 000000000..14adbf4e0 --- /dev/null +++ b/pkg/memory/sqlite/init.go @@ -0,0 +1,13 @@ +package sqlite + +import "github.com/docker/cagent/pkg/memory" + +// registerSQLite registers the sqlite driver factory via package side-effects. +// +//nolint:unparam // Return value exists only to allow calling from a var initializer. +func registerSQLite() struct{} { + memory.RegisterFactory("sqlite", &Factory{}) + return struct{}{} +} + +var _ = registerSQLite() diff --git a/pkg/teamloader/registry.go b/pkg/teamloader/registry.go index 7374605cc..dd59e9ecf 100644 --- a/pkg/teamloader/registry.go +++ b/pkg/teamloader/registry.go @@ -12,7 +12,8 @@ import ( "github.com/docker/cagent/pkg/environment" "github.com/docker/cagent/pkg/gateway" "github.com/docker/cagent/pkg/js" - "github.com/docker/cagent/pkg/memory/database/sqlite" + "github.com/docker/cagent/pkg/memory" + _ "github.com/docker/cagent/pkg/memory/sqlite" // Register sqlite driver "github.com/docker/cagent/pkg/path" "github.com/docker/cagent/pkg/tools" "github.com/docker/cagent/pkg/tools/a2a" @@ -80,7 +81,7 @@ func createTodoTool(_ context.Context, toolset latest.Toolset, _ string, _ *conf return builtin.NewTodoTool(), nil } -func createMemoryTool(_ context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { +func createMemoryTool(ctx context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { var memoryPath string if filepath.IsAbs(toolset.Path) { memoryPath = "" @@ -98,11 +99,19 @@ func createMemoryTool(_ context.Context, toolset latest.Toolset, parentDir strin return nil, fmt.Errorf("failed to create memory database directory: %w", err) } - db, err := sqlite.NewMemoryDatabase(validatedMemoryPath) + // Use new driver-based approach + cfg := latest.MemoryConfig{ + Kind: "sqlite", + Path: validatedMemoryPath, + } + + driver, err := memory.CreateDriver(ctx, cfg) if err != nil { - return nil, fmt.Errorf("failed to create memory database: %w", err) + return nil, fmt.Errorf("failed to create memory driver: %w", err) } + // Adapt new driver to legacy database interface + db := memory.NewDatabaseAdapter(driver) return builtin.NewMemoryTool(db), nil } diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 527d17378..45143d770 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -12,6 +12,8 @@ import ( "github.com/docker/cagent/pkg/config" "github.com/docker/cagent/pkg/config/latest" "github.com/docker/cagent/pkg/js" + "github.com/docker/cagent/pkg/memory" + _ "github.com/docker/cagent/pkg/memory/sqlite" // Register sqlite driver "github.com/docker/cagent/pkg/model/provider" "github.com/docker/cagent/pkg/model/provider/options" "github.com/docker/cagent/pkg/modelsdev" @@ -113,6 +115,16 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c return nil, err } + // Create memory drivers from top-level memory configs + memoryDrivers := make(map[string]memory.Driver) + for name, memCfg := range cfg.Memory { + driver, err := memory.CreateDriver(ctx, memCfg) + if err != nil { + return nil, fmt.Errorf("failed to create memory driver %q: %w", name, err) + } + memoryDrivers[name] = driver + } + // Create RAG managers parentDir := cmp.Or(agentSource.ParentDir(), runConfig.WorkingDir) ragManagers, err := rag.NewManagers(ctx, cfg, rag.ManagersBuildConfig{ @@ -175,6 +187,12 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c agentTools = append(agentTools, ragTools...) } + // Add memory tools if agent has memory scopes + if len(agentConfig.Memory) > 0 { + memoryTools := createMemoryToolsForAgent(&agentConfig, memoryDrivers) + agentTools = append(agentTools, memoryTools...) + } + opts = append(opts, agent.WithToolSets(agentTools...)) ag := agent.New(agentConfig.Name, agentConfig.Instruction, opts...) @@ -392,3 +410,31 @@ func createRAGToolsForAgent(agentConfig *latest.AgentConfig, allManagers map[str return ragTools } + +// createMemoryToolsForAgent creates memory tools for an agent, one for each referenced memory scope +func createMemoryToolsForAgent(agentConfig *latest.AgentConfig, allDrivers map[string]memory.Driver) []tools.ToolSet { + if len(agentConfig.Memory) == 0 { + return nil + } + + var memoryTools []tools.ToolSet + + for _, memName := range agentConfig.Memory { + driver, exists := allDrivers[memName] + if !exists { + slog.Error("Memory scope not found", "memory_scope", memName) + continue + } + + // Adapt driver to legacy database interface + db := memory.NewDatabaseAdapter(driver) + memTool := builtin.NewMemoryTool(db) + memoryTools = append(memoryTools, memTool) + + slog.Debug("Created memory tool for agent", + "memory_scope", memName, + ) + } + + return memoryTools +}