diff --git a/.env.example b/.env.example index cd79533..c3edf26 100644 --- a/.env.example +++ b/.env.example @@ -8,6 +8,11 @@ SMTP_HEALTH_PORT= SMTP_HEALTH_DISABLE=false SMTP_QUEUE_PATH=./data/spool SMTP_QUEUE_WORKERS= +SMTP_RETENTION_DAYS=7 + +# Authentication +SMTP_AUTH_USERS= +SMTP_AUTH_INSECURE=false # Access control SMTP_ALLOW_NETWORKS=127.0.0.1/32 diff --git a/CHANGELOG.md b/CHANGELOG.md index 2dd44fc..b492a06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,9 @@ - Config: Document and surface the new `SMTP_QUEUE_WORKERS` environment variable across README, `.env.example`, install tooling, and the marketing site. - Brand: Refresh GopherPost logo and favicon with updated gopher-and-envelope concept; refine site header hover styling. - Brand: Refresh GopherPost logo and favicon with updated gopher-and-envelope concept; add site styling for ringed logo hover state. +- Queue: Persist delivery queue state to disk so pending deliveries survive restarts; restored entries are logged on startup. +- Storage: Add configurable retention policy via `SMTP_RETENTION_DAYS` (default 7 days) with automatic cleanup of expired spool directories. +- Auth: Implement SMTP AUTH (PLAIN and LOGIN mechanisms) via `SMTP_AUTH_USERS` configuration; AUTH is only advertised over TLS unless `SMTP_AUTH_INSECURE` is set. Authenticated users bypass domain restrictions. ## v0.4.0 - Added subscription-based audit fan-out so `/healthz` can stream live debug logs when `SMTP_DEBUG=true`. diff --git a/README.md b/README.md index d8e5b8f..63056a3 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,16 @@ SMTP_HEALTH_PORT # Override only the port component of the health address (e.g. SMTP_HEALTH_DISABLE # Disable the health endpoint when `true` (default `false`). SMTP_QUEUE_PATH # Directory used to persist inbound messages (default ./data/spool). SMTP_QUEUE_WORKERS # Number of concurrent delivery workers processing the outbound queue (default logical CPU count). +SMTP_RETENTION_DAYS # Number of days to retain stored messages before automatic cleanup (default 7). ``` +#### Authentication + +```yml +SMTP_AUTH_USERS # Comma-separated list of user:password pairs for SMTP AUTH (e.g. alice:secret,bob:pass123). +SMTP_AUTH_INSECURE # Allow AUTH without TLS when `true` (default `false`, strongly discouraged). +``` +When `SMTP_AUTH_USERS` is configured, the server advertises AUTH PLAIN and AUTH LOGIN in EHLO responses (only over TLS unless `SMTP_AUTH_INSECURE` is set). Authenticated users bypass the `SMTP_REQUIRE_LOCAL_DOMAIN` restriction, allowing them to send from any address. + #### Access control ```yml @@ -73,7 +82,7 @@ SMTP_DKIM_KEY_PATH # Filesystem path to the DKIM private key (e.g. /etc/dkim/mai SMTP_DKIM_PRIVATE_KEY # Inline PEM-formatted DKIM private key (e.g. -----BEGIN RSA PRIVATE KEY-----). SMTP_DKIM_DOMAIN # Domain to sign messages as when overriding the sender domain (e.g. example.com). ``` -**Security note:** configure `SMTP_ALLOW_NETWORKS`, `SMTP_ALLOW_HOSTS`, and `SMTP_REQUIRE_LOCAL_DOMAIN` to enforce ingress and sender restrictions. The server lacks authentication, so deploy behind firewalls or proxies and run as a non-root service account. +**Security note:** configure `SMTP_ALLOW_NETWORKS`, `SMTP_ALLOW_HOSTS`, and `SMTP_REQUIRE_LOCAL_DOMAIN` to enforce ingress and sender restrictions. For authenticated access, configure `SMTP_AUTH_USERS` and ensure TLS is enabled. Deploy behind firewalls or proxies and run as a non-root service account. Use an absolute path for `SMTP_QUEUE_PATH` when running the daemon under systemd so that the service `ReadWritePaths` setting can be aligned. diff --git a/internal/auth/auth.go b/internal/auth/auth.go new file mode 100644 index 0000000..1d511c3 --- /dev/null +++ b/internal/auth/auth.go @@ -0,0 +1,141 @@ +// Package auth provides SMTP authentication support. +package auth + +import ( + "crypto/subtle" + "encoding/base64" + "errors" + "os" + "strings" + "sync" +) + +var ( + ErrInvalidCredentials = errors.New("invalid credentials") + ErrAuthNotConfigured = errors.New("authentication not configured") + ErrMalformedAuth = errors.New("malformed authentication data") +) + +// Credentials holds configured authentication credentials. +type Credentials struct { + mu sync.RWMutex + users map[string]string // username -> password + enabled bool + insecure bool // allow AUTH without TLS +} + +var creds = &Credentials{ + users: make(map[string]string), +} + +// LoadFromEnv loads authentication configuration from environment variables. +// SMTP_AUTH_USERS: comma-separated list of user:password pairs +// SMTP_AUTH_INSECURE: allow AUTH without TLS (default false) +func LoadFromEnv() { + creds.mu.Lock() + defer creds.mu.Unlock() + + creds.users = make(map[string]string) + creds.enabled = false + + usersEnv := strings.TrimSpace(os.Getenv("SMTP_AUTH_USERS")) + if usersEnv == "" { + return + } + + pairs := strings.Split(usersEnv, ",") + for _, pair := range pairs { + pair = strings.TrimSpace(pair) + if pair == "" { + continue + } + parts := strings.SplitN(pair, ":", 2) + if len(parts) != 2 { + continue + } + username := strings.TrimSpace(parts[0]) + password := parts[1] // Don't trim password - spaces may be intentional + if username != "" && password != "" { + creds.users[username] = password + } + } + + creds.enabled = len(creds.users) > 0 + creds.insecure = strings.EqualFold(os.Getenv("SMTP_AUTH_INSECURE"), "true") +} + +// Enabled returns true if authentication is configured. +func Enabled() bool { + creds.mu.RLock() + defer creds.mu.RUnlock() + return creds.enabled +} + +// AllowInsecure returns true if AUTH is permitted without TLS. +func AllowInsecure() bool { + creds.mu.RLock() + defer creds.mu.RUnlock() + return creds.insecure +} + +// Validate checks if the given credentials are valid. +func Validate(username, password string) error { + creds.mu.RLock() + defer creds.mu.RUnlock() + + if !creds.enabled { + return ErrAuthNotConfigured + } + + expected, ok := creds.users[username] + if !ok { + return ErrInvalidCredentials + } + + if subtle.ConstantTimeCompare([]byte(password), []byte(expected)) != 1 { + return ErrInvalidCredentials + } + + return nil +} + +// DecodePlain decodes SASL PLAIN authentication data. +// Format: base64(authzid\0authcid\0password) or base64(\0username\0password) +func DecodePlain(encoded string) (username, password string, err error) { + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return "", "", ErrMalformedAuth + } + + // PLAIN format: authzid\0authcid\0password + // authzid is optional, authcid is the username + parts := strings.Split(string(decoded), "\x00") + if len(parts) != 3 { + return "", "", ErrMalformedAuth + } + + // Use authcid (second part) as username + username = parts[1] + password = parts[2] + + if username == "" { + return "", "", ErrMalformedAuth + } + + return username, password, nil +} + +// DecodeLogin decodes SASL LOGIN username or password. +// Each is simply base64 encoded. +func DecodeLogin(encoded string) (string, error) { + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return "", ErrMalformedAuth + } + return string(decoded), nil +} + +// EncodeChallenge encodes a LOGIN challenge (Username: or Password:). +func EncodeChallenge(text string) string { + return base64.StdEncoding.EncodeToString([]byte(text)) +} diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go new file mode 100644 index 0000000..036378d --- /dev/null +++ b/internal/auth/auth_test.go @@ -0,0 +1,143 @@ +package auth + +import ( + "encoding/base64" + "os" + "testing" +) + +func TestLoadFromEnv(t *testing.T) { + // Save and restore env + origUsers := os.Getenv("SMTP_AUTH_USERS") + origInsecure := os.Getenv("SMTP_AUTH_INSECURE") + defer func() { + os.Setenv("SMTP_AUTH_USERS", origUsers) + os.Setenv("SMTP_AUTH_INSECURE", origInsecure) + }() + + // Test with users configured + os.Setenv("SMTP_AUTH_USERS", "alice:secret123,bob:password456") + os.Setenv("SMTP_AUTH_INSECURE", "false") + LoadFromEnv() + + if !Enabled() { + t.Error("expected auth to be enabled") + } + if AllowInsecure() { + t.Error("expected insecure to be false") + } + + // Validate correct credentials + if err := Validate("alice", "secret123"); err != nil { + t.Errorf("valid credentials rejected: %v", err) + } + if err := Validate("bob", "password456"); err != nil { + t.Errorf("valid credentials rejected: %v", err) + } + + // Validate incorrect credentials + if err := Validate("alice", "wrong"); err != ErrInvalidCredentials { + t.Errorf("expected ErrInvalidCredentials, got %v", err) + } + if err := Validate("unknown", "secret123"); err != ErrInvalidCredentials { + t.Errorf("expected ErrInvalidCredentials, got %v", err) + } + + // Test insecure mode + os.Setenv("SMTP_AUTH_INSECURE", "true") + LoadFromEnv() + if !AllowInsecure() { + t.Error("expected insecure to be true") + } + + // Test empty config + os.Setenv("SMTP_AUTH_USERS", "") + LoadFromEnv() + if Enabled() { + t.Error("expected auth to be disabled with empty users") + } +} + +func TestDecodePlain(t *testing.T) { + tests := []struct { + name string + input string + wantUser string + wantPass string + wantErr bool + }{ + { + name: "valid with authzid", + input: base64.StdEncoding.EncodeToString([]byte("authzid\x00username\x00password")), + wantUser: "username", + wantPass: "password", + }, + { + name: "valid without authzid", + input: base64.StdEncoding.EncodeToString([]byte("\x00username\x00password")), + wantUser: "username", + wantPass: "password", + }, + { + name: "invalid base64", + input: "not-valid-base64!!!", + wantErr: true, + }, + { + name: "wrong format", + input: base64.StdEncoding.EncodeToString([]byte("just-one-part")), + wantErr: true, + }, + { + name: "empty username", + input: base64.StdEncoding.EncodeToString([]byte("authzid\x00\x00password")), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + user, pass, err := DecodePlain(tt.input) + if tt.wantErr { + if err == nil { + t.Error("expected error") + } + return + } + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + if user != tt.wantUser { + t.Errorf("user = %q, want %q", user, tt.wantUser) + } + if pass != tt.wantPass { + t.Errorf("pass = %q, want %q", pass, tt.wantPass) + } + }) + } +} + +func TestDecodeLogin(t *testing.T) { + encoded := base64.StdEncoding.EncodeToString([]byte("testuser")) + decoded, err := DecodeLogin(encoded) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if decoded != "testuser" { + t.Errorf("got %q, want %q", decoded, "testuser") + } + + _, err = DecodeLogin("invalid!!!") + if err == nil { + t.Error("expected error for invalid base64") + } +} + +func TestEncodeChallenge(t *testing.T) { + result := EncodeChallenge("Username:") + expected := base64.StdEncoding.EncodeToString([]byte("Username:")) + if result != expected { + t.Errorf("got %q, want %q", result, expected) + } +} diff --git a/main.go b/main.go index 7198cac..03030cb 100644 --- a/main.go +++ b/main.go @@ -21,6 +21,7 @@ import ( health "gopherpost/health" audit "gopherpost/internal/audit" + "gopherpost/internal/auth" "gopherpost/internal/config" "gopherpost/internal/dkim" "gopherpost/internal/email" @@ -42,7 +43,16 @@ const ( func main() { _ = godotenv.Load() audit.RefreshFromEnv() + auth.LoadFromEnv() log.Printf("GopherPost version %s starting", version.Number) + if auth.Enabled() { + if auth.AllowInsecure() { + log.Printf("SMTP AUTH enabled (warning: insecure mode allows AUTH without TLS)") + } else { + log.Printf("SMTP AUTH enabled (requires TLS)") + } + audit.Log("auth enabled insecure=%v", auth.AllowInsecure()) + } audit.Log("version %s boot", version.Number) port := defaultSMTPPort @@ -93,8 +103,19 @@ func main() { q := queue.NewManager(queue.WithWorkers(workerCount)) if dir := strings.TrimSpace(os.Getenv("SMTP_QUEUE_PATH")); dir != "" { storage.SetBaseDir(dir) + queue.SetPersistDir(dir) log.Printf("Queue storage path set to %s", dir) } + storage.LoadRetentionFromEnv() + retentionStop := storage.StartRetentionCleanup() + defer close(retentionStop) + log.Printf("Message retention set to %d days", storage.RetentionDays()) + if err := q.LoadQueue(); err != nil { + log.Printf("Warning: failed to load persisted queue: %v", err) + } else if depth := q.Depth(); depth > 0 { + log.Printf("Restored %d pending deliveries from disk", depth) + audit.Log("queue restored %d entries", depth) + } log.Printf("Queue workers configured: %d", workerCount) audit.Log("queue workers %d", workerCount) q.Start() @@ -127,11 +148,12 @@ func main() { log.Printf("Accept error: %v", err) continue } - go handleSession(conn, q, greeting, hostname, dkimSigner) + tlsActive := tlsConf != nil + go handleSession(conn, q, greeting, hostname, dkimSigner, tlsActive) } } -func handleSession(conn net.Conn, q *queue.Manager, greeting string, hostname string, signer *dkim.Signer) { +func handleSession(conn net.Conn, q *queue.Manager, greeting string, hostname string, signer *dkim.Signer, tlsActive bool) { defer conn.Close() tp := textproto.NewConn(conn) defer tp.Close() @@ -179,6 +201,11 @@ func handleSession(conn net.Conn, q *queue.Manager, greeting string, hostname st var from string var to []string var data bytes.Buffer + var authenticated bool + var authUser string + + // Determine if AUTH should be offered + authAvailable := auth.Enabled() && (tlsActive || auth.AllowInsecure()) reset := func() { from = "" @@ -202,11 +229,145 @@ func handleSession(conn net.Conn, q *queue.Manager, greeting string, hostname st alog("recv %s", summarizeCommand(line)) cmd := strings.ToUpper(line) switch { - case strings.HasPrefix(cmd, "HELO") || strings.HasPrefix(cmd, "EHLO"): + case strings.HasPrefix(cmd, "HELO"): if !send(250, hostname) { return } - alog("handshake %s", cmd[:4]) + alog("handshake HELO") + case strings.HasPrefix(cmd, "EHLO"): + // Send multi-line EHLO response + lines := []string{hostname} + if authAvailable { + lines = append(lines, "AUTH PLAIN LOGIN") + } + for i, l := range lines { + var prefix string + if i == len(lines)-1 { + prefix = "250 " + } else { + prefix = "250-" + } + if err := tp.PrintfLine("%s%s", prefix, l); err != nil { + log.Printf("send error to %s: %v", remote, err) + alog("send error: %v", err) + return + } + } + alog("handshake EHLO (auth_available=%v)", authAvailable) + case strings.HasPrefix(cmd, "AUTH "): + if !authAvailable { + if !send(503, "5.5.1 AUTH not available") { + return + } + alog("AUTH rejected: not available") + continue + } + if authenticated { + if !send(503, "5.5.1 Already authenticated") { + return + } + alog("AUTH rejected: already authenticated") + continue + } + authMethod := strings.TrimPrefix(cmd, "AUTH ") + switch { + case strings.HasPrefix(authMethod, "PLAIN"): + // AUTH PLAIN may have credentials inline or require continuation + parts := strings.SplitN(authMethod, " ", 2) + var encodedCreds string + if len(parts) == 2 && parts[1] != "" { + encodedCreds = parts[1] + } else { + // Request credentials + if err := tp.PrintfLine("334 "); err != nil { + alog("send error: %v", err) + return + } + creds, err := tp.ReadLine() + if err != nil { + alog("read error during AUTH PLAIN: %v", err) + return + } + encodedCreds = creds + } + user, pass, err := auth.DecodePlain(encodedCreds) + if err != nil { + if !send(501, "5.5.4 Malformed authentication data") { + return + } + alog("AUTH PLAIN decode error: %v", err) + continue + } + if err := auth.Validate(user, pass); err != nil { + if !send(535, "5.7.8 Authentication failed") { + return + } + alog("AUTH PLAIN failed for user %s", user) + continue + } + authenticated = true + authUser = user + if !send(235, "2.7.0 Authentication successful") { + return + } + alog("AUTH PLAIN success user=%s", user) + case strings.HasPrefix(authMethod, "LOGIN"): + // AUTH LOGIN uses challenge-response + // Send Username: challenge + if err := tp.PrintfLine("334 %s", auth.EncodeChallenge("Username:")); err != nil { + alog("send error: %v", err) + return + } + encodedUser, err := tp.ReadLine() + if err != nil { + alog("read error during AUTH LOGIN: %v", err) + return + } + user, err := auth.DecodeLogin(encodedUser) + if err != nil { + if !send(501, "5.5.4 Malformed authentication data") { + return + } + alog("AUTH LOGIN decode error: %v", err) + continue + } + // Send Password: challenge + if err := tp.PrintfLine("334 %s", auth.EncodeChallenge("Password:")); err != nil { + alog("send error: %v", err) + return + } + encodedPass, err := tp.ReadLine() + if err != nil { + alog("read error during AUTH LOGIN: %v", err) + return + } + pass, err := auth.DecodeLogin(encodedPass) + if err != nil { + if !send(501, "5.5.4 Malformed authentication data") { + return + } + alog("AUTH LOGIN decode error: %v", err) + continue + } + if err := auth.Validate(user, pass); err != nil { + if !send(535, "5.7.8 Authentication failed") { + return + } + alog("AUTH LOGIN failed for user %s", user) + continue + } + authenticated = true + authUser = user + if !send(235, "2.7.0 Authentication successful") { + return + } + alog("AUTH LOGIN success user=%s", user) + default: + if !send(504, "5.5.4 Unrecognized authentication mechanism") { + return + } + alog("AUTH rejected: unknown mechanism") + } case strings.HasPrefix(cmd, "MAIL FROM:"): addr, err := email.ParseCommandAddress(line) if err != nil { @@ -216,7 +377,8 @@ func handleSession(conn net.Conn, q *queue.Manager, greeting string, hostname st alog("invalid MAIL FROM: %v", err) continue } - if requireLocalDomain { + // Skip domain restriction for authenticated users + if requireLocalDomain && !authenticated { domain, derr := email.Domain(addr) if derr != nil { if !send(501, "Invalid sender domain") { @@ -335,10 +497,11 @@ func handleSession(conn net.Conn, q *queue.Manager, greeting string, hostname st } persistedPaths = append(persistedPaths, path) queued = append(queued, queue.QueuedMessage{ - ID: messageID, - From: from, - To: rcpt, - Payload: payload, + ID: messageID, + From: from, + To: rcpt, + FilePath: path, + Payload: payload, }) } if persistErr != nil { @@ -361,7 +524,11 @@ func handleSession(conn net.Conn, q *queue.Manager, greeting string, hostname st if !send(250, fmt.Sprintf("Message queued as %s", messageID)) { return } - alog("message %s queued (size=%d bytes, recipients=%d)", messageID, len(messageBytes), len(to)) + if authenticated { + alog("message %s queued (size=%d bytes, recipients=%d, auth_user=%s)", messageID, len(messageBytes), len(to), authUser) + } else { + alog("message %s queued (size=%d bytes, recipients=%d)", messageID, len(messageBytes), len(to)) + } reset() case strings.HasPrefix(cmd, "QUIT"): if !send(221, "Bye") { diff --git a/queue/manager.go b/queue/manager.go index b08866a..2b912c4 100644 --- a/queue/manager.go +++ b/queue/manager.go @@ -63,7 +63,6 @@ func (m *Manager) Enqueue(msg QueuedMessage) { return } m.mu.Lock() - defer m.mu.Unlock() if msg.Attempts == 0 && msg.NextRetry.IsZero() { msg.NextRetry = time.Now() } @@ -72,6 +71,11 @@ func (m *Manager) Enqueue(msg QueuedMessage) { audit.Log("queue enqueue %s -> %s attempt %d next %s", msg.ID, msg.To, msg.Attempts, msg.NextRetry.Format(time.RFC3339)) metrics.MessagesQueued.Add(1) metrics.SetQueueDepth(len(m.queue)) + m.mu.Unlock() + + if err := m.SaveQueue(); err != nil { + log.Printf("Failed to persist queue: %v", err) + } } // Start starts the queue processor in a background goroutine. @@ -168,6 +172,10 @@ func (m *Manager) processQueue() { } wg.Wait() + + if err := m.SaveQueue(); err != nil { + log.Printf("Failed to persist queue after processing: %v", err) + } } // Depth returns the current queue length. diff --git a/queue/persist.go b/queue/persist.go new file mode 100644 index 0000000..a9879b3 --- /dev/null +++ b/queue/persist.go @@ -0,0 +1,136 @@ +package queue + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" + "time" +) + +// persistedEntry represents a queue entry stored on disk. +type persistedEntry struct { + ID string `json:"id"` + From string `json:"from"` + To string `json:"to"` + FilePath string `json:"file_path"` + Attempts int `json:"attempts"` + NextRetry time.Time `json:"next_retry"` + LastError string `json:"last_error,omitempty"` +} + +// persistedQueue represents the entire queue state on disk. +type persistedQueue struct { + Entries []persistedEntry `json:"entries"` +} + +var ( + persistDir = "./data/spool" + persistFile = "queue.json" + persistMu sync.Mutex +) + +// SetPersistDir overrides the directory used for queue persistence. +func SetPersistDir(dir string) { + persistMu.Lock() + defer persistMu.Unlock() + persistDir = dir +} + +// persistPath returns the full path to the queue state file. +func persistPath() string { + persistMu.Lock() + defer persistMu.Unlock() + return filepath.Join(persistDir, persistFile) +} + +// SaveQueue persists the current queue state to disk. +func (m *Manager) SaveQueue() error { + m.mu.Lock() + entries := make([]persistedEntry, 0, len(m.queue)) + for _, msg := range m.queue { + entries = append(entries, persistedEntry{ + ID: msg.ID, + From: msg.From, + To: msg.To, + FilePath: msg.FilePath, + Attempts: msg.Attempts, + NextRetry: msg.NextRetry, + LastError: msg.LastError, + }) + } + m.mu.Unlock() + + pq := persistedQueue{Entries: entries} + data, err := json.MarshalIndent(pq, "", " ") + if err != nil { + return err + } + + path := persistPath() + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + + // Write atomically via temp file + tmp := path + ".tmp" + if err := os.WriteFile(tmp, data, 0o600); err != nil { + return err + } + return os.Rename(tmp, path) +} + +// LoadQueue restores the queue state from disk. +// Messages are rehydrated from their stored .eml files. +func (m *Manager) LoadQueue() error { + path := persistPath() + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil // No persisted state + } + return err + } + + var pq persistedQueue + if err := json.Unmarshal(data, &pq); err != nil { + return err + } + + m.mu.Lock() + defer m.mu.Unlock() + + for _, entry := range pq.Entries { + // Read message content from stored file + content, err := os.ReadFile(entry.FilePath) + if err != nil { + // File missing - skip this entry (already delivered or cleaned up) + continue + } + + msg := QueuedMessage{ + ID: entry.ID, + From: entry.From, + To: entry.To, + FilePath: entry.FilePath, + Payload: NewPayload(content), + Attempts: entry.Attempts, + NextRetry: entry.NextRetry, + LastError: entry.LastError, + } + m.queue = append(m.queue, msg) + } + + return nil +} + +// ClearPersistedQueue removes the persisted queue file. +func ClearPersistedQueue() error { + path := persistPath() + err := os.Remove(path) + if os.IsNotExist(err) { + return nil + } + return err +} diff --git a/queue/persist_test.go b/queue/persist_test.go new file mode 100644 index 0000000..973310d --- /dev/null +++ b/queue/persist_test.go @@ -0,0 +1,124 @@ +package queue + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestQueuePersistence(t *testing.T) { + // Create temp directory for test + tmpDir := t.TempDir() + SetPersistDir(tmpDir) + + // Create a test message file + msgPath := filepath.Join(tmpDir, "test_msg.eml") + if err := os.WriteFile(msgPath, []byte("test message body"), 0o600); err != nil { + t.Fatalf("failed to create test message: %v", err) + } + + // Create manager and enqueue a message + m1 := NewManager() + msg := QueuedMessage{ + ID: "test-123", + From: "sender@example.com", + To: "rcpt@example.net", + FilePath: msgPath, + Payload: NewPayload([]byte("test message body")), + Attempts: 2, + NextRetry: time.Now().Add(time.Hour), + LastError: "temporary failure", + } + m1.Enqueue(msg) + + if m1.Depth() != 1 { + t.Fatalf("expected depth 1, got %d", m1.Depth()) + } + + // Verify queue file exists + queueFile := filepath.Join(tmpDir, "queue.json") + if _, err := os.Stat(queueFile); os.IsNotExist(err) { + t.Fatalf("queue.json was not created") + } + + // Create new manager and load queue + m2 := NewManager() + if err := m2.LoadQueue(); err != nil { + t.Fatalf("LoadQueue failed: %v", err) + } + + if m2.Depth() != 1 { + t.Fatalf("expected restored depth 1, got %d", m2.Depth()) + } + + // Verify message fields were restored + m2.mu.Lock() + restored := m2.queue[0] + m2.mu.Unlock() + + if restored.ID != "test-123" { + t.Errorf("ID mismatch: got %s", restored.ID) + } + if restored.From != "sender@example.com" { + t.Errorf("From mismatch: got %s", restored.From) + } + if restored.To != "rcpt@example.net" { + t.Errorf("To mismatch: got %s", restored.To) + } + if restored.Attempts != 2 { + t.Errorf("Attempts mismatch: got %d", restored.Attempts) + } + if restored.LastError != "temporary failure" { + t.Errorf("LastError mismatch: got %s", restored.LastError) + } + if string(restored.Payload.Bytes()) != "test message body" { + t.Errorf("Payload mismatch: got %s", string(restored.Payload.Bytes())) + } +} + +func TestQueuePersistenceSkipsMissingFiles(t *testing.T) { + tmpDir := t.TempDir() + SetPersistDir(tmpDir) + + // Create a manager and enqueue a message + m1 := NewManager() + msg := QueuedMessage{ + ID: "missing-file", + From: "sender@example.com", + To: "rcpt@example.net", + FilePath: filepath.Join(tmpDir, "nonexistent.eml"), + Payload: NewPayload([]byte("body")), + } + + // Manually save queue state without the actual file + m1.mu.Lock() + m1.queue = append(m1.queue, msg) + m1.mu.Unlock() + if err := m1.SaveQueue(); err != nil { + t.Fatalf("SaveQueue failed: %v", err) + } + + // Load into new manager - should skip the missing file entry + m2 := NewManager() + if err := m2.LoadQueue(); err != nil { + t.Fatalf("LoadQueue failed: %v", err) + } + + if m2.Depth() != 0 { + t.Errorf("expected depth 0 (missing file skipped), got %d", m2.Depth()) + } +} + +func TestLoadQueueNoFile(t *testing.T) { + tmpDir := t.TempDir() + SetPersistDir(tmpDir) + + m := NewManager() + if err := m.LoadQueue(); err != nil { + t.Errorf("LoadQueue should not error on missing file: %v", err) + } + if m.Depth() != 0 { + t.Errorf("expected empty queue, got depth %d", m.Depth()) + } +} diff --git a/queue/types.go b/queue/types.go index ba273ab..5268769 100644 --- a/queue/types.go +++ b/queue/types.go @@ -26,6 +26,7 @@ type QueuedMessage struct { ID string From string To string + FilePath string Payload *Payload Attempts int NextRetry time.Time diff --git a/storage/retention.go b/storage/retention.go new file mode 100644 index 0000000..46bf859 --- /dev/null +++ b/storage/retention.go @@ -0,0 +1,250 @@ +package storage + +import ( + "log" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "time" +) + +var ( + retentionDays = 7 // default 7 days + retentionMu sync.RWMutex + cleanupInterval = 1 * time.Hour +) + +// SetRetentionDays configures how many days to retain stored messages. +// Messages older than this are eligible for cleanup. +func SetRetentionDays(days int) { + retentionMu.Lock() + defer retentionMu.Unlock() + if days > 0 { + retentionDays = days + } +} + +// RetentionDays returns the current retention period in days. +func RetentionDays() int { + retentionMu.RLock() + defer retentionMu.RUnlock() + return retentionDays +} + +// LoadRetentionFromEnv loads retention settings from environment variables. +func LoadRetentionFromEnv() { + if val := strings.TrimSpace(os.Getenv("SMTP_RETENTION_DAYS")); val != "" { + if days, err := strconv.Atoi(val); err == nil && days > 0 { + SetRetentionDays(days) + } + } +} + +// CleanupOldMessages removes message files older than the retention period. +// Returns the number of files removed and any error encountered. +func CleanupOldMessages() (int, error) { + retentionMu.RLock() + days := retentionDays + retentionMu.RUnlock() + + // Use start of today to ensure consistent date comparison + now := time.Now().UTC() + today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) + cutoff := today.AddDate(0, 0, -days) + removed := 0 + + entries, err := os.ReadDir(baseDir) + if err != nil { + if os.IsNotExist(err) { + return 0, nil + } + return 0, err + } + + for _, entry := range entries { + if !entry.IsDir() { + continue + } + + // Directory names are YYYY-MM-DD + dirDate, err := time.Parse("2006-01-02", entry.Name()) + if err != nil { + continue // Skip non-date directories + } + + if dirDate.Before(cutoff) { + dirPath := filepath.Join(baseDir, entry.Name()) + count, err := removeDir(dirPath) + if err != nil { + log.Printf("Failed to remove old spool directory %s: %v", dirPath, err) + continue + } + removed += count + } + } + + return removed, nil +} + +// removeDir removes a directory and returns the count of files removed. +func removeDir(path string) (int, error) { + entries, err := os.ReadDir(path) + if err != nil { + return 0, err + } + + count := len(entries) + if err := os.RemoveAll(path); err != nil { + return 0, err + } + + return count, nil +} + +// SpoolStats returns statistics about the current spool. +type SpoolStats struct { + TotalFiles int + TotalBytes int64 + OldestDate string + NewestDate string + DirCount int + RetentionDay int +} + +// GetSpoolStats returns statistics about the message spool. +func GetSpoolStats() (SpoolStats, error) { + stats := SpoolStats{ + RetentionDay: RetentionDays(), + } + + entries, err := os.ReadDir(baseDir) + if err != nil { + if os.IsNotExist(err) { + return stats, nil + } + return stats, err + } + + for _, entry := range entries { + if !entry.IsDir() { + continue + } + + // Validate it's a date directory + if _, err := time.Parse("2006-01-02", entry.Name()); err != nil { + continue + } + + stats.DirCount++ + if stats.OldestDate == "" || entry.Name() < stats.OldestDate { + stats.OldestDate = entry.Name() + } + if entry.Name() > stats.NewestDate { + stats.NewestDate = entry.Name() + } + + dirPath := filepath.Join(baseDir, entry.Name()) + files, err := os.ReadDir(dirPath) + if err != nil { + continue + } + + for _, f := range files { + if f.IsDir() { + continue + } + stats.TotalFiles++ + if info, err := f.Info(); err == nil { + stats.TotalBytes += info.Size() + } + } + } + + return stats, nil +} + +// StartRetentionCleanup starts a background goroutine that periodically +// cleans up old messages. Returns a channel that can be closed to stop cleanup. +func StartRetentionCleanup() chan struct{} { + stop := make(chan struct{}) + go func() { + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + + // Run initial cleanup + if count, err := CleanupOldMessages(); err != nil { + log.Printf("Retention cleanup error: %v", err) + } else if count > 0 { + log.Printf("Retention cleanup removed %d old messages", count) + } + + for { + select { + case <-stop: + return + case <-ticker.C: + if count, err := CleanupOldMessages(); err != nil { + log.Printf("Retention cleanup error: %v", err) + } else if count > 0 { + log.Printf("Retention cleanup removed %d old messages", count) + } + } + } + }() + return stop +} + +// ListMessages returns message files in the spool directory. +// If date is empty, lists all dates. Otherwise filters to specific date. +func ListMessages(date string) ([]string, error) { + var messages []string + + if date != "" { + // List specific date + dirPath := filepath.Join(baseDir, date) + entries, err := os.ReadDir(dirPath) + if err != nil { + if os.IsNotExist(err) { + return messages, nil + } + return nil, err + } + for _, entry := range entries { + if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".eml") { + messages = append(messages, filepath.Join(date, entry.Name())) + } + } + return messages, nil + } + + // List all dates + entries, err := os.ReadDir(baseDir) + if err != nil { + if os.IsNotExist(err) { + return messages, nil + } + return nil, err + } + + for _, entry := range entries { + if !entry.IsDir() { + continue + } + if _, err := time.Parse("2006-01-02", entry.Name()); err != nil { + continue + } + subEntries, err := os.ReadDir(filepath.Join(baseDir, entry.Name())) + if err != nil { + continue + } + for _, f := range subEntries { + if !f.IsDir() && strings.HasSuffix(f.Name(), ".eml") { + messages = append(messages, filepath.Join(entry.Name(), f.Name())) + } + } + } + + return messages, nil +} diff --git a/storage/retention_test.go b/storage/retention_test.go new file mode 100644 index 0000000..b26ec3b --- /dev/null +++ b/storage/retention_test.go @@ -0,0 +1,170 @@ +package storage + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestCleanupOldMessages(t *testing.T) { + tmpDir := t.TempDir() + originalBaseDir := baseDir + baseDir = tmpDir + defer func() { baseDir = originalBaseDir }() + + SetRetentionDays(3) + + // Create directories for different dates + today := time.Now().UTC() + dirs := []struct { + date time.Time + files int + deleted bool + }{ + {today, 2, false}, // today - keep + {today.AddDate(0, 0, -1), 1, false}, // yesterday - keep + {today.AddDate(0, 0, -2), 1, false}, // 2 days ago - keep + {today.AddDate(0, 0, -3), 3, false}, // 3 days ago - keep (boundary) + {today.AddDate(0, 0, -4), 2, true}, // 4 days ago - delete + {today.AddDate(0, 0, -10), 5, true}, // 10 days ago - delete + } + + for _, d := range dirs { + dirPath := filepath.Join(tmpDir, d.date.Format("2006-01-02")) + if err := os.MkdirAll(dirPath, 0o755); err != nil { + t.Fatalf("failed to create dir: %v", err) + } + for i := 0; i < d.files; i++ { + fpath := filepath.Join(dirPath, "msg"+string(rune('a'+i))+".eml") + if err := os.WriteFile(fpath, []byte("test"), 0o600); err != nil { + t.Fatalf("failed to create file: %v", err) + } + } + } + + removed, err := CleanupOldMessages() + if err != nil { + t.Fatalf("CleanupOldMessages failed: %v", err) + } + + expectedRemoved := 2 + 5 // 4 days ago + 10 days ago files + if removed != expectedRemoved { + t.Errorf("expected %d removed, got %d", expectedRemoved, removed) + } + + // Verify directories + for _, d := range dirs { + dirPath := filepath.Join(tmpDir, d.date.Format("2006-01-02")) + _, err := os.Stat(dirPath) + if d.deleted { + if !os.IsNotExist(err) { + t.Errorf("directory %s should have been deleted", d.date.Format("2006-01-02")) + } + } else { + if err != nil { + t.Errorf("directory %s should exist: %v", d.date.Format("2006-01-02"), err) + } + } + } +} + +func TestGetSpoolStats(t *testing.T) { + tmpDir := t.TempDir() + originalBaseDir := baseDir + baseDir = tmpDir + defer func() { baseDir = originalBaseDir }() + + // Create some test data + dates := []string{"2024-01-01", "2024-01-02", "2024-01-03"} + for _, date := range dates { + dirPath := filepath.Join(tmpDir, date) + if err := os.MkdirAll(dirPath, 0o755); err != nil { + t.Fatalf("failed to create dir: %v", err) + } + for i := 0; i < 3; i++ { + fpath := filepath.Join(dirPath, "msg"+string(rune('a'+i))+".eml") + if err := os.WriteFile(fpath, []byte("test content"), 0o600); err != nil { + t.Fatalf("failed to create file: %v", err) + } + } + } + + stats, err := GetSpoolStats() + if err != nil { + t.Fatalf("GetSpoolStats failed: %v", err) + } + + if stats.DirCount != 3 { + t.Errorf("expected 3 dirs, got %d", stats.DirCount) + } + if stats.TotalFiles != 9 { + t.Errorf("expected 9 files, got %d", stats.TotalFiles) + } + if stats.OldestDate != "2024-01-01" { + t.Errorf("expected oldest 2024-01-01, got %s", stats.OldestDate) + } + if stats.NewestDate != "2024-01-03" { + t.Errorf("expected newest 2024-01-03, got %s", stats.NewestDate) + } +} + +func TestListMessages(t *testing.T) { + tmpDir := t.TempDir() + originalBaseDir := baseDir + baseDir = tmpDir + defer func() { baseDir = originalBaseDir }() + + // Create test messages + dirPath := filepath.Join(tmpDir, "2024-01-15") + if err := os.MkdirAll(dirPath, 0o755); err != nil { + t.Fatalf("failed to create dir: %v", err) + } + for _, name := range []string{"msg1.eml", "msg2.eml"} { + if err := os.WriteFile(filepath.Join(dirPath, name), []byte("test"), 0o600); err != nil { + t.Fatalf("failed to create file: %v", err) + } + } + + // List specific date + msgs, err := ListMessages("2024-01-15") + if err != nil { + t.Fatalf("ListMessages failed: %v", err) + } + if len(msgs) != 2 { + t.Errorf("expected 2 messages, got %d", len(msgs)) + } + + // List all + msgs, err = ListMessages("") + if err != nil { + t.Fatalf("ListMessages failed: %v", err) + } + if len(msgs) != 2 { + t.Errorf("expected 2 messages, got %d", len(msgs)) + } + + // List nonexistent date + msgs, err = ListMessages("2024-12-31") + if err != nil { + t.Fatalf("ListMessages failed: %v", err) + } + if len(msgs) != 0 { + t.Errorf("expected 0 messages, got %d", len(msgs)) + } +} + +func TestRetentionDays(t *testing.T) { + original := RetentionDays() + defer SetRetentionDays(original) + + SetRetentionDays(14) + if got := RetentionDays(); got != 14 { + t.Errorf("expected 14, got %d", got) + } + + SetRetentionDays(0) // Should not change + if got := RetentionDays(); got != 14 { + t.Errorf("expected 14 (unchanged), got %d", got) + } +}