diff --git a/database/query/client.go b/database/query/client.go index f5f51c12..2b34dc4f 100644 --- a/database/query/client.go +++ b/database/query/client.go @@ -18,8 +18,10 @@ import ( ) type ( - dbClient struct { + dbOperationReporter func(context.Context, operationType, error) + dbClient struct { db *connector.DB + operationReporter dbOperationReporter rollbackableEvents *xsync.Map[eventHash, *databaseRollbackRequest] relayPrivateKey string relayURL string @@ -53,6 +55,8 @@ func openDatabase(ctx context.Context, writeURLs []string, readURLs []string, ru client := &dbClient{ rollbackableEvents: xsync.NewMap[eventHash, *databaseRollbackRequest](), hasReadURLs: len(readURLs) > 0, + operationReporter: func(context.Context, operationType, error) { + }, } options := []connector.Option{ connector.WithFieldNameMapper(func(in string) string { @@ -83,6 +87,11 @@ func openDatabase(ctx context.Context, writeURLs []string, readURLs []string, ru return client } +func (client *dbClient) WithOperationReporter(reporter dbOperationReporter) *dbClient { + client.operationReporter = reporter + return client +} + func (client *dbClient) Close() (err error) { if client.db != nil { err = errors.Join(err, client.db.Close()) diff --git a/database/query/global.go b/database/query/global.go index 51a8df77..35737f84 100644 --- a/database/query/global.go +++ b/database/query/global.go @@ -20,8 +20,9 @@ import ( var ( globalDB struct { - Client *dbClient - Once sync.Once + Client *dbClient + Tracker *statusTracker + Once sync.Once } UsedDatabaseStorage atomic.Uint64 ) @@ -134,8 +135,10 @@ func MustInit(ctx context.Context, opts ...Option) { log.Warn().Msg("database DDL execution is disabled") } + globalDB.Tracker = newStatusTracker() globalDB.Client = openDatabase(ctx, conf.WriteURLs, conf.ReadURLs, conf.RunDDL). WithPrivateKey(conf.PrivateKey). + WithOperationReporter(globalDB.Tracker.Submit). WithRelayURL(conf.RelayURL) if !conf.DisableSelfTest { @@ -147,6 +150,7 @@ func MustInit(ctx context.Context, opts ...Option) { } } + globalDB.Tracker.Start(ctx) if !globalDB.Client.hasReadURLs { go globalDB.Client.StartExpiredEventsCleanup(ctx) } else { @@ -275,3 +279,10 @@ func (db *dbClient) StartCollectingUsedDatabaseStorage(ctx context.Context) { func CollectDeviceRegistrationEvents(ctx context.Context) EventIterator { return globalDB.Client.collectDeviceRegistrationEvents(ctx) } + +func GetStatusReport(ctx context.Context) (*Status, error) { + if globalDB.Tracker == nil { + return nil, errors.New("database status tracker is not initialized") + } + return globalDB.Tracker.Get(ctx) +} diff --git a/database/query/query.go b/database/query/query.go index 2e832f55..423083dc 100644 --- a/database/query/query.go +++ b/database/query/query.go @@ -1013,6 +1013,7 @@ func (db *dbClient) executeBatch(ctx context.Context, req *databaseBatchRequest) if err == nil && (!eventsToRollback.Empty() || len(eventsToRollback.ReplaceableEvents) > 0) { db.rollbackableEvents.Store(*req.EventsHash, &eventsToRollback) } + db.trackEventWriteError(ctx, err) return err } @@ -1238,6 +1239,7 @@ func (db *dbClient) SelectEvents(ctx context.Context, filters ...model.Filter) E if errors.Is(err, connector.ErrNotFound) { err = nil } + db.trackEventReadError(ctx, err) yield(nil, errors.Wrap(err, "failed to select events")) return } @@ -1697,3 +1699,14 @@ func startPeriodicSelfTest(ctx context.Context, writeURLs []string, readURLs []s } }() } + +func (db *dbClient) trackEventReadError(ctx context.Context, err error) { + if errors.Is(err, ErrRaceCondition) { + err = nil + } + db.operationReporter(ctx, operationTypeRead, err) +} + +func (db *dbClient) trackEventWriteError(ctx context.Context, err error) { + db.operationReporter(ctx, operationTypeWrite, err) +} diff --git a/database/query/status.go b/database/query/status.go new file mode 100644 index 00000000..9eb5d8d4 --- /dev/null +++ b/database/query/status.go @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: ice License 1.0 + +package query + +import ( + "context" + "time" + + "github.com/rs/zerolog/log" + + "github.com/ice-blockchain/subzero/tracing/statefsm" +) + +const ( + // Threshold for consecutive failed operations before marking the database as unhealthy. + consecutiveOperationThreshold = 3 + // Time window to consider operations as consecutive. + consecutiveWindow = time.Minute +) + +type ( + Status struct { + LastWrite time.Time // Timestamp of the last event write operation. + LastRead time.Time // Timestamp of the last event read operation. + InReadErrorState bool // Indicates if the last N read operations failed. + InWriteErrorState bool // Indicates if the last N write operations failed. + } + + operationType int + + operationStatus struct { + Timestamp time.Time + Err error + Type operationType + Ack chan struct{} + } + statusTracker struct { + In chan operationStatus + Req chan chan Status + } +) + +const ( + operationTypeRead operationType = iota + 1 + operationTypeWrite +) + +func newStatusTracker() *statusTracker { + return &statusTracker{ + In: make(chan operationStatus, 100), + Req: make(chan chan Status, 10), + } +} + +func (st *statusTracker) Start(ctx context.Context) { + go st.worker(ctx) +} + +func (st *statusTracker) worker(ctx context.Context) { + var lastReadTime, lastWriteTime time.Time + + readFSM := statefsm.New(consecutiveWindow, consecutiveOperationThreshold) + writeFSM := statefsm.New(consecutiveWindow, consecutiveOperationThreshold) + + for ctx.Err() == nil { + select { + case <-ctx.Done(): + return + + case status := <-st.In: + switch status.Type { + case operationTypeRead: + lastReadTime = status.Timestamp + readFSM.Push(status.Err != nil, status.Timestamp) + + case operationTypeWrite: + lastWriteTime = status.Timestamp + writeFSM.Push(status.Err != nil, status.Timestamp) + + default: + log.Warn().Str("context", "db-tracker").Int("operation_type", int(status.Type)).Msg("unknown operation type") + } + if status.Ack != nil { + close(status.Ack) + } + + case respChan := <-st.Req: + select { + case respChan <- Status{ + LastRead: lastReadTime, + LastWrite: lastWriteTime, + InReadErrorState: readFSM.InError(), + InWriteErrorState: writeFSM.InError(), + }: + case <-ctx.Done(): + return + } + } + } +} + +func (st *statusTracker) Submit(ctx context.Context, opType operationType, err error) { + st.submitOp(ctx, opType, true, err) +} + +func (st *statusTracker) SubmitSync(ctx context.Context, opType operationType, err error) bool { + return st.submitOp(ctx, opType, false, err) +} + +func (st *statusTracker) submitOp(ctx context.Context, opType operationType, async bool, err error) bool { + data := operationStatus{ + Timestamp: time.Now().UTC(), + Err: err, + Type: opType, + } + + if async { + select { + case st.In <- data: + return true + case <-ctx.Done(): + default: + // Drop the status update if the channel is full, meaning we already have plenty of data to process. + } + return false + } + + // For synchronous submission, use an acknowledgment channel. + data.Ack = make(chan struct{}) + + select { + case st.In <- data: + select { + case <-data.Ack: + return true + case <-ctx.Done(): + } + case <-ctx.Done(): + } + + return false +} + +func (st *statusTracker) Get(ctx context.Context) (*Status, error) { + respChan := make(chan Status, 1) // Allow buffer to avoid workers blocking. + + select { + case st.Req <- respChan: + case <-ctx.Done(): + return nil, ctx.Err() + } + + select { + case status := <-respChan: + return &status, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} diff --git a/database/query/status_test.go b/database/query/status_test.go new file mode 100644 index 00000000..258a4408 --- /dev/null +++ b/database/query/status_test.go @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: ice License 1.0 + +package query + +import ( + "context" + "testing" + + "github.com/cockroachdb/errors" + "github.com/nbd-wtf/go-nostr" + "github.com/stretchr/testify/require" + + "github.com/ice-blockchain/subzero/model" +) + +func TestStatusTracker(t *testing.T) { + t.Parallel() + + tracker := newStatusTracker() + db := helperNewDatabase(t). + WithOperationReporter(tracker.Submit) + defer db.Close() + + workerCtx, cancel := context.WithCancel(t.Context()) + tracker.Start(workerCtx) + + isHealthy := func(t *testing.T) (readOK, writeOK bool) { + status, err := tracker.Get(t.Context()) + require.NoError(t, err) + + return !status.InReadErrorState, !status.InWriteErrorState + } + + t.Run("Initial status is healthy", func(t *testing.T) { + readOK, writeOK := isHealthy(t) + require.True(t, readOK) + require.True(t, writeOK) + }) + t.Run("Insert events successfully", func(t *testing.T) { + for range consecutiveOperationThreshold * 10 { + var ev model.Event + + ev.Kind = nostr.KindTextNote + ev.CreatedAt = nostr.Now() + + require.NoError(t, ev.SignWithAlg(model.GeneratePrivateKey(), model.SignAlgEDDSA, model.KeyAlgCurve25519)) + require.NoError(t, db.AcceptEvents(t.Context(), &ev)) + } + }) + t.Run("Single error does not trip FSM", func(t *testing.T) { + require.True(t, tracker.SubmitSync(t.Context(), operationTypeRead, errors.New("simulated read error"))) + require.True(t, tracker.SubmitSync(t.Context(), operationTypeWrite, errors.New("simulated write error"))) + + readOK, writeOK := isHealthy(t) + require.True(t, readOK) + require.True(t, writeOK) + }) + t.Run("Simulate read errors to trip FSM", func(t *testing.T) { + for range consecutiveOperationThreshold { + require.True(t, tracker.SubmitSync(t.Context(), operationTypeRead, errors.New("simulated read error"))) + } + readOK, writeOK := isHealthy(t) + require.False(t, readOK) + require.True(t, writeOK) + + t.Run("Recover from read errors", func(t *testing.T) { + for range consecutiveOperationThreshold { + require.True(t, tracker.SubmitSync(t.Context(), operationTypeRead, nil)) + } + + readOK, writeOK := isHealthy(t) + require.True(t, readOK) + require.True(t, writeOK) + }) + }) + t.Run("Simulate write errors to trip FSM", func(t *testing.T) { + for range consecutiveOperationThreshold { + require.True(t, tracker.SubmitSync(t.Context(), operationTypeWrite, errors.New("simulated write error"))) + } + readOK, writeOK := isHealthy(t) + require.True(t, readOK) + require.False(t, writeOK) + + t.Run("Recover from write errors", func(t *testing.T) { + for range consecutiveOperationThreshold { + require.True(t, tracker.SubmitSync(t.Context(), operationTypeWrite, nil)) + } + readOK, writeOK := isHealthy(t) + require.True(t, readOK) + require.True(t, writeOK) + }) + }) + + cancel() +} diff --git a/server/http/nip11/nip11.go b/server/http/nip11/nip11.go index 93bc0671..6b82993f 100644 --- a/server/http/nip11/nip11.go +++ b/server/http/nip11/nip11.go @@ -22,6 +22,7 @@ import ( "github.com/ice-blockchain/subzero/appcontext" "github.com/ice-blockchain/subzero/database/query" "github.com/ice-blockchain/subzero/model" + "github.com/ice-blockchain/subzero/storage" ) type ( @@ -39,22 +40,13 @@ type ( UsedCPU uint16 `json:"used_cpu"` UsedBandwidth uint64 `json:"used_bandwidth"` } - SystemStatusState string - SystemStatus struct { - EventsWrite SystemStatusState `json:"publishing_events"` - EventsRead SystemStatusState `json:"subscribing_for_events"` - DVM SystemStatusState `json:"dvm"` - FilesWrite SystemStatusState `json:"uploading_files"` - FilesRead SystemStatusState `json:"reading_files"` - PushesSend SystemStatusState `json:"sending_push_notifications"` - } RelayInformationDocument struct { + SystemStatus *SystemStatus `json:"system_status,omitzero"` SystemMetrics *SystemMetrics `json:"system_metrics,omitempty"` nip11.RelayInformationDocument `json:",inline"` - FCMAndroidConfigs []FCMConfig `json:"fcm_android_configs"` - FCMIOSConfigs []FCMConfig `json:"fcm_ios_configs"` - FCMWebConfigs []FCMConfig `json:"fcm_web_configs"` - SystemStatus SystemStatus `json:"system_status,omitzero"` + FCMAndroidConfigs []FCMConfig `json:"fcm_android_configs"` + FCMIOSConfigs []FCMConfig `json:"fcm_ios_configs"` + FCMWebConfigs []FCMConfig `json:"fcm_web_configs"` } Config struct { PrivateKey string @@ -64,8 +56,11 @@ type ( MinLeadingZeroBits int } nip11handler struct { + storageClient storage.StorageClient cfg *Config systemMetrics *atomic.Pointer[SystemMetrics] + systemStatus *atomic.Pointer[SystemStatus] + databaseReportGetter func(context.Context) (*query.Status, error) storagePath string commandPath string lastBandwidthBytes uint64 @@ -73,22 +68,31 @@ type ( } ) -const ( - SystemStatusStateOK SystemStatusState = "UP" - SystemStatusStateError SystemStatusState = "DOWN" - SystemStatusStateMaintenance SystemStatusState = "MAINTENANCE" -) - const systemMetricsCollectionTime = 30 * time.Second func NewNIP11Handler(ctx context.Context, cfg *Config, storagePath, commandPath string) http.Handler { h := &nip11handler{ - cfg: cfg, - storagePath: storagePath, - commandPath: commandPath, - systemMetrics: new(atomic.Pointer[SystemMetrics]), + cfg: cfg, + storagePath: storagePath, + commandPath: commandPath, + systemMetrics: new(atomic.Pointer[SystemMetrics]), + systemStatus: new(atomic.Pointer[SystemStatus]), + databaseReportGetter: query.GetStatusReport, + storageClient: storage.Client(), } + + // Assume healthy at start. + h.systemStatus.Store(&SystemStatus{ + EventsWrite: SystemStatusStateOK, + EventsRead: SystemStatusStateOK, + DVM: SystemStatusStateOK, + FilesWrite: SystemStatusStateOK, + FilesRead: SystemStatusStateOK, + PushesSend: SystemStatusStateOK, + }) + go h.startSystemMetricsCollector(ctx) + go h.startSystemStatusCollector(ctx, nil) return h } @@ -98,15 +102,18 @@ func (n *nip11handler) ServeHTTP(writer http.ResponseWriter, req *http.Request) return } writer.Header().Add("Content-Type", "application/json") - info := n.info() - bytes, err := json.Marshal(info) + info := n.info(req.Context()) + + encoder := json.NewEncoder(writer) + encoder.SetEscapeHTML(true) + + err := encoder.Encode(info) if err != nil { log.Error().Err(err).Interface("info", info).Msg("failed to serialize NIP11 json") } - writer.Write(bytes) } -func (n *nip11handler) info() RelayInformationDocument { +func (n *nip11handler) info(context.Context) RelayInformationDocument { var androidConfigs []FCMConfig var iosConfigs []FCMConfig var webConfigs []FCMConfig @@ -173,18 +180,8 @@ func (n *nip11handler) info() RelayInformationDocument { FCMAndroidConfigs: androidConfigs, FCMIOSConfigs: iosConfigs, FCMWebConfigs: webConfigs, - SystemStatus: SystemStatus{ - EventsRead: SystemStatusStateOK, - EventsWrite: SystemStatusStateOK, - - DVM: SystemStatusStateMaintenance, - - FilesRead: SystemStatusStateOK, - FilesWrite: SystemStatusStateMaintenance, - - PushesSend: SystemStatusStateError, - }, - SystemMetrics: n.systemMetrics.Load(), + SystemStatus: n.systemStatus.Load(), + SystemMetrics: n.systemMetrics.Load(), } } diff --git a/server/http/nip11/nip11_test.go b/server/http/nip11/nip11_test.go index f8916ba5..7f6cf8c2 100644 --- a/server/http/nip11/nip11_test.go +++ b/server/http/nip11/nip11_test.go @@ -10,6 +10,7 @@ import ( "slices" "sync/atomic" "testing" + "testing/synctest" "time" "github.com/gin-gonic/gin" @@ -22,6 +23,7 @@ import ( "github.com/ice-blockchain/subzero/server/cert" wsserver "github.com/ice-blockchain/subzero/server/ws" "github.com/ice-blockchain/subzero/server/ws/fixture" + "github.com/ice-blockchain/subzero/storage" ) const ( @@ -62,11 +64,35 @@ func initServer(serverCtx context.Context, port uint16) { time.Sleep(100 * time.Millisecond) } +func helperNewHandler(t testing.TB) *nip11handler { + t.Helper() + + handler := &nip11handler{ + cfg: &Config{MinLeadingZeroBits: minLeadingZeroBits}, + systemMetrics: new(atomic.Pointer[SystemMetrics]), + systemStatus: new(atomic.Pointer[SystemStatus]), + databaseReportGetter: query.GetStatusReport, + storageClient: storage.Client(), + } + + handler.systemMetrics.Store(&SystemMetrics{}) + handler.systemStatus.Store(&SystemStatus{ + EventsRead: SystemStatusStateOK, + EventsWrite: SystemStatusStateOK, + DVM: SystemStatusStateOK, + FilesRead: SystemStatusStateOK, + FilesWrite: SystemStatusStateOK, + PushesSend: SystemStatusStateOK, + }) + + return handler +} + func TestNIP11(t *testing.T) { t.Parallel() - handler := nip11handler{cfg: &Config{MinLeadingZeroBits: minLeadingZeroBits}, systemMetrics: new(atomic.Pointer[SystemMetrics])} - expected := handler.info() + handler := helperNewHandler(t) + expected := handler.info(t.Context()) require.NotNil(t, expected) t.Run("Fetch via standard nip11 fetcher", func(t *testing.T) { @@ -118,17 +144,14 @@ func TestFCMConfigParsing(t *testing.T) { iosConfig := `{"apiKey":"ios-key","appId":"ios-app-id","messagingSenderId":"ios-messaging-sender","projectId":"ios-project"}` webConfig := `{"apiKey":"web-key","appId":"web-app-id","messagingSenderId":"web-messaging-sender","projectId":"web-project"}` - handler := nip11handler{ - cfg: &Config{ - MinLeadingZeroBits: minLeadingZeroBits, - FCMAndroidConfigs: []string{androidConfig}, - FCMIOSConfigs: []string{iosConfig}, - FCMWebConfigs: []string{webConfig}, - }, - systemMetrics: new(atomic.Pointer[SystemMetrics]), + handler := helperNewHandler(t) + handler.cfg = &Config{ + MinLeadingZeroBits: minLeadingZeroBits, + FCMAndroidConfigs: []string{androidConfig}, + FCMIOSConfigs: []string{iosConfig}, + FCMWebConfigs: []string{webConfig}, } - handler.systemMetrics.Store(&SystemMetrics{}) - info := handler.info() + info := handler.info(t.Context()) require.Len(t, info.FCMAndroidConfigs, 1) androidCfg := info.FCMAndroidConfigs[0] @@ -151,14 +174,168 @@ func TestFCMConfigParsing(t *testing.T) { require.Equal(t, "web-messaging-sender", webCfg.MessagingSenderID) require.Equal(t, "web-project", webCfg.ProjectID) - handlerWithInvalidJSON := nip11handler{ - cfg: &Config{ - MinLeadingZeroBits: minLeadingZeroBits, - FCMAndroidConfigs: []string{`invalid json`}, - }, - systemMetrics: new(atomic.Pointer[SystemMetrics]), + handlerWithInvalidJSON := helperNewHandler(t) + handlerWithInvalidJSON.cfg = &Config{ + MinLeadingZeroBits: minLeadingZeroBits, + FCMAndroidConfigs: []string{`invalid json`}, } - handlerWithInvalidJSON.systemMetrics.Store(&SystemMetrics{}) - infoWithInvalidJSON := handlerWithInvalidJSON.info() + infoWithInvalidJSON := handlerWithInvalidJSON.info(t.Context()) require.Empty(t, infoWithInvalidJSON.FCMAndroidConfigs) } + +func TestStorageStatusCheck(t *testing.T) { + t.Parallel() + + ctx := appcontext.TestContext(t) + storageClient := storage.NewClient(ctx, nil, storage.WithConfig(&storage.Config{ + PrivateKey: model.GeneratePrivateKey(), + AbsoluteRootStoragePath: t.TempDir(), + RelayURL: testRelayURL, + IONLibertyDisabled: true, + ExternalADNLPort: 12345, + IONStorageConfigURL: "https://ton.org/testnet-global.config.json", + })) + require.NotNil(t, storageClient) + + h := helperNewHandler(t) + h.storageClient = storageClient + + var report SystemStatus + h.RunStorageStatusCheck(ctx, &report) + require.Equal(t, SystemStatusStateOK, report.FilesRead) + require.Equal(t, SystemStatusStateOK, report.FilesWrite) + + internalReport := storageClient.Health() + require.False(t, internalReport.InReadErrorState) + require.False(t, internalReport.InWriteErrorState) + + storageClient.Close() +} + +func TestSystemStatusCollectorDatabase(t *testing.T) { + t.Parallel() + + t.Run("System status collector sets correct statuses", func(t *testing.T) { + var cases = []struct { + Name string + query.Status + }{ + { + Name: "Healthy database", + Status: query.Status{ + LastWrite: time.Now(), + LastRead: time.Now(), + }, + }, + { + Name: "Broken reads", + Status: query.Status{ + LastWrite: time.Now(), + LastRead: time.Now(), + InReadErrorState: true, + }, + }, + { + Name: "Broken writes", + Status: query.Status{ + LastWrite: time.Now(), + LastRead: time.Now(), + InWriteErrorState: true, + }, + }, + } + synctest.Test(t, func(t *testing.T) { + for _, tc := range cases { + t.Logf("Running case: %s", tc.Name) // t.Run() is not available inside synctest.Test. + handler := helperNewHandler(t) + require.NotNil(t, handler) + + handler.storageClient = nil // Disable storage checks. + handler.databaseReportGetter = func(context.Context) (*query.Status, error) { + return &tc.Status, nil + } + + ticker := make(chan struct{}, 1) + + workerCtx, workerCancel := context.WithCancel(appcontext.TestContext(t)) + go handler.startSystemStatusCollector(workerCtx, ticker) + + ticker <- struct{}{} + synctest.Wait() + + data := handler.systemStatus.Load() + require.NotNil(t, data) + + if tc.InReadErrorState { + require.Equal(t, SystemStatusStateError, data.EventsRead) + } else { + require.Equal(t, SystemStatusStateOK, data.EventsRead) + } + + if tc.InWriteErrorState { + require.Equal(t, SystemStatusStateError, data.EventsWrite) + } else { + require.Equal(t, SystemStatusStateOK, data.EventsWrite) + } + + if tc.InReadErrorState || tc.InWriteErrorState { + require.Equal(t, SystemStatusStateError, data.DVM) + } else { + require.Equal(t, SystemStatusStateOK, data.DVM) + } + workerCancel() + } + }) + }) + t.Run("Forced check scheduling", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + handler := helperNewHandler(t) + handler.storageClient = nil // Disable storage checks. + require.NotNil(t, handler) + + handler.databaseReportGetter = func(context.Context) (*query.Status, error) { + return &query.Status{ + LastWrite: time.Now().Add(-2 * forceDatabaseCheckInterval), + LastRead: time.Now().Add(-2 * forceDatabaseCheckInterval), + InReadErrorState: true, + InWriteErrorState: true, + }, nil + } + + ticker := make(chan struct{}, 1) + + workerCtx, workerCancel := context.WithCancel(appcontext.TestContext(t)) + go handler.startSystemStatusCollector(workerCtx, ticker) + + ticker <- struct{}{} + time.Sleep(forceDatabaseCheckInterval + time.Second) // Ensure that next forced check would be due. + synctest.Wait() + + events := helperSelectEvents(t, model.Filter{Kinds: []int{9998}}) + require.Len(t, events, 3) // Manual check should have created 3 test events. + + // Force check should fix the status. + data := handler.systemStatus.Load() + require.NotNil(t, data) + require.Equal(t, SystemStatusStateOK, data.EventsRead) + require.Equal(t, SystemStatusStateOK, data.EventsWrite) + require.Equal(t, SystemStatusStateOK, data.DVM) + + workerCancel() + }) + }) +} + +func helperSelectEvents(t *testing.T, filters ...model.Filter) (events []*model.Event) { + t.Helper() + + t.Logf("selecting events: %s", model.Filters(filters).String()) + + for ev, err := range query.GetStoredEvents(t.Context(), filters...) { + require.NoError(t, err) + require.NotNil(t, ev) + events = append(events, ev) + } + + return events +} diff --git a/server/http/nip11/system_status.go b/server/http/nip11/system_status.go new file mode 100644 index 00000000..93a5758f --- /dev/null +++ b/server/http/nip11/system_status.go @@ -0,0 +1,368 @@ +// SPDX-License-Identifier: ice License 1.0 + +package nip11 + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "io" + "mime/multipart" + "net/http" + "os" + "path/filepath" + "strconv" + "time" + + "github.com/cockroachdb/errors" + "github.com/nbd-wtf/go-nostr" + "github.com/rs/zerolog/log" + + "github.com/ice-blockchain/subzero/appcontext" + "github.com/ice-blockchain/subzero/database/query" + "github.com/ice-blockchain/subzero/model" + "github.com/ice-blockchain/subzero/storage" +) + +type ( + SystemStatusState string + + SystemStatus struct { + EventsWrite SystemStatusState `json:"publishing_events"` + EventsRead SystemStatusState `json:"subscribing_for_events"` + DVM SystemStatusState `json:"dvm"` + FilesWrite SystemStatusState `json:"uploading_files"` + FilesRead SystemStatusState `json:"reading_files"` + PushesSend SystemStatusState `json:"sending_push_notifications"` + } +) + +const ( + SystemStatusStateOK SystemStatusState = "UP" + SystemStatusStateError SystemStatusState = "DOWN" + SystemStatusStateMaintenance SystemStatusState = "MAINTENANCE" + + forceDatabaseCheckInterval = 5 * time.Minute + forceStorageCheckInterval = 10 * time.Minute +) + +func (n *nip11handler) collectStorageStatus(_ context.Context, systemReport *SystemStatus) (scheduleCheck bool) { + if n.storageClient == nil { + log.Error().Msg("storage client is nil, cannot collect storage status") + return false + } + + report := n.storageClient.Health() + + systemReport.FilesRead = SystemStatusStateOK + systemReport.FilesWrite = SystemStatusStateOK + + if report.InReadErrorState { + systemReport.FilesRead = SystemStatusStateError + } + + if report.InWriteErrorState { + systemReport.FilesWrite = SystemStatusStateError + } + + scheduleCheck = time.Since(report.LastWrite) > forceStorageCheckInterval || + time.Since(report.LastRead) > forceStorageCheckInterval + + return scheduleCheck +} + +func (n *nip11handler) collectDatabaseStatus(ctx context.Context, systemReport *SystemStatus) (scheduleCheck bool) { + report, err := n.databaseReportGetter(ctx) + if err != nil { + log.Error().Err(err).Msg("failed to collect database status") + return false + } + + systemReport.EventsRead = SystemStatusStateOK + systemReport.EventsWrite = SystemStatusStateOK + systemReport.DVM = SystemStatusStateOK + + if report.InReadErrorState { + systemReport.EventsRead = SystemStatusStateError + } + + if report.InWriteErrorState { + systemReport.EventsWrite = SystemStatusStateError + } + + if report.InReadErrorState || report.InWriteErrorState { + systemReport.DVM = SystemStatusStateError + } + + // Schedule a manual check if the last read or write was long ago. + scheduleCheck = time.Since(report.LastWrite) > forceDatabaseCheckInterval || + time.Since(report.LastRead) > forceDatabaseCheckInterval + + return scheduleCheck +} + +func createStorageUploadRequest() (req *http.Request, filename string, hash []byte, err error) { + var body bytes.Buffer + + testFileName := "subzero_storage_healthcheck_" + strconv.FormatInt(time.Now().UnixNano(), 16) + ".txt" + testFileContent := rand.Text() + testFileHash := sha256.Sum256([]byte(testFileContent)) + + writer := multipart.NewWriter(&body) + + part, err := writer.CreateFormFile("file", testFileName) + if err != nil { + return nil, "", nil, errors.Wrap(err, "failed to create form file") + } + + _, wErr := io.WriteString(part, testFileContent) + cErr := writer.Close() + if wErr != nil || cErr != nil { + return nil, "", nil, errors.Wrap(errors.Join(wErr, cErr), "failed to write to form file or close writer") + } + + req = &http.Request{ + Header: make(http.Header), + Body: io.NopCloser(&body), + } + req.Header.Set("Content-Type", writer.FormDataContentType()) + + return req, testFileName, testFileHash[:], nil +} + +func (n *nip11handler) RunStorageStatusCheck(ctx context.Context, systemReport *SystemStatus) { + systemReport.FilesWrite = SystemStatusStateError + systemReport.FilesRead = SystemStatusStateError + + if n.storageClient == nil { + log.Error().Msg("storage client is nil, cannot run storage status check") + return + } + + req, filename, expectedHash, err := createStorageUploadRequest() + if err != nil { + log.Error().Err(err).Str("context", "storage health check").Msg("failed to create upload request") + return + } + + now := time.Now() + _, masterPubKey := model.GenerateKeyPair() + + uploadPath, metaInput, hash, err := n.storageClient.SaveFile(ctx, now, masterPubKey, req, 1<<20) + if err != nil { + log.Error().Err(err).Str("context", "storage health check").Msg("failed to save file to storage") + return + } else if !bytes.Equal(hash, expectedHash) { + log.Error(). + Hex("expected_hash", expectedHash). + Hex("actual_hash", hash). + Str("context", "storage health check"). + Msg("hash mismatch after saving file to storage") + return + } + hashHex := hex.EncodeToString(hash) + + defer func() { + deleteErr := n.storageClient.Delete(ctx, masterPubKey, masterPubKey, hashHex) + if deleteErr != nil && !errors.Is(deleteErr, storage.ErrNotFound) { + log.Error().Err(deleteErr).Str("context", "storage health check").Msg("failed to delete file from storage") + } + }() + + log.Trace(). + Str("file_path", uploadPath). + Str("context", "storage health check"). + Hex("hash", hash). + Msg("file saved to storage successfully") + + bagID, targetURL, _, err := n.storageClient.StartUpload(ctx, now, masterPubKey, masterPubKey, uploadPath, hashHex, metaInput) + if err != nil { + log.Error().Err(err).Str("context", "storage health check").Msg("failed to start upload to storage") + return + } + + log.Trace(). + Str("bag_id", bagID). + Str("target_url", targetURL). + Str("context", "storage health check"). + Msg("upload started successfully") + + systemReport.FilesWrite = SystemStatusStateOK + + fullPath, err := n.storageClient.FilePath(masterPubKey, hashHex, filepath.Ext(filename)) + if err != nil { + log.Error().Err(err).Str("context", "storage health check").Msg("failed to get file path from storage client") + return + } + + data, err := os.ReadFile(fullPath) + if err != nil { + log.Error().Err(err).Str("context", "storage health check").Msg("failed to read file from storage") + return + } + + actualHash := sha256.Sum256(data) + if !bytes.Equal(actualHash[:], expectedHash) { + log.Error(). + Hex("expected_hash", expectedHash). + Hex("actual_hash", actualHash[:]). + Str("context", "storage health check"). + Msg("hash mismatch after reading file from storage") + return + } else { + log.Trace(). + Str("file_path", fullPath). + Str("context", "storage health check"). + Hex("hash", actualHash[:]). + Msg("file read from storage successfully with matching hash") + } + + systemReport.FilesRead = SystemStatusStateOK +} + +func (*nip11handler) RunDatabaseStatusCheck(ctx context.Context, systemReport *SystemStatus) { + const numEvents = 3 + const testEventKind = 9998 + + var events model.Events + for i := range numEvents { + var ev model.Event + + ev.Kind = testEventKind + ev.CreatedAt = nostr.Now() + ev.Content = "status check event " + strconv.Itoa(i) + ev.Tags = model.Tags{ + {"expiration", ev.CreatedAt.Add(time.Minute * 2).String()}, + } + ev.SignWithAlg(model.GeneratePrivateKey(), model.SignAlgEDDSA, model.KeyAlgCurve25519) + + events = append(events, &ev) + } + + systemReport.EventsWrite = SystemStatusStateError + systemReport.DVM = SystemStatusStateError + systemReport.EventsRead = SystemStatusStateError + + err := query.AcceptEvents(ctx, events...) + if err != nil { + log.Error().Err(err).Msg("failed to write test events for database status check") + return + } + + systemReport.EventsWrite = SystemStatusStateOK + + var received model.Events + for attempt := range numEvents { + received = nil + for ev, err := range query.GetStoredEvents(ctx, model.Filter{IDs: events.IDs()}) { + if err != nil { + log.Error().Err(err).Msg("failed to read test events for database status check") + break + } + received = append(received, ev) + } + if len(received) == len(events) { + break + } + + // Trying again if we didn't get all events. + log.Info().Msgf("database status check: attempt %d: expected %d events, got %d", attempt+1, len(events), len(received)) + if attempt < numEvents-1 { // No wait on last attempt. + select { + case <-ctx.Done(): + return + case <-time.After(time.Second * 10): + } + } + } + + if len(received) == len(events) { + systemReport.EventsRead = SystemStatusStateOK + systemReport.DVM = SystemStatusStateOK + } +} + +func (n *nip11handler) startSystemStatusCollector(ctx context.Context, regularTicks chan struct{}) { + var status = SystemStatus{ + EventsWrite: SystemStatusStateOK, + EventsRead: SystemStatusStateOK, + DVM: SystemStatusStateOK, + FilesWrite: SystemStatusStateOK, + FilesRead: SystemStatusStateOK, + PushesSend: SystemStatusStateOK, + } + + defer appcontext.GetAppContext(ctx).Recover() + + if regularTicks == nil { + // Use internal ticker if none provided. + regularTicks = make(chan struct{}, 1) + go func() { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + + for ctx.Err() == nil { + select { + case <-ticker.C: + select { + case <-ctx.Done(): + return + case regularTicks <- struct{}{}: + default: + log.Debug().Msg("skipping system status collection tick, previous tick still being processed") + } + case <-ctx.Done(): + return + } + } + }() + } + + databaseCheck := make(chan struct{}, 1) + storageCheck := make(chan struct{}, 1) + + for ctx.Err() == nil { + workingStatus := status // Working copy. + + select { + case <-ctx.Done(): + return + + case <-regularTicks: + scheduleDatabaseCheck := n.collectDatabaseStatus(ctx, &workingStatus) + if scheduleDatabaseCheck { + select { + case databaseCheck <- struct{}{}: + default: + log.Debug().Msg("database status check already scheduled, skipping") + } + } + + scheduleStorageCheck := n.collectStorageStatus(ctx, &workingStatus) + if scheduleStorageCheck { + select { + case storageCheck <- struct{}{}: + default: + log.Debug().Msg("storage status check already scheduled, skipping") + } + } + + case <-databaseCheck: + log.Info().Msg("starting scheduled database status check") + testCtx, testCancel := context.WithTimeout(ctx, time.Minute*3) + n.RunDatabaseStatusCheck(testCtx, &workingStatus) + testCancel() + + case <-storageCheck: + log.Info().Msg("starting scheduled storage status check") + testCtx, testCancel := context.WithTimeout(ctx, time.Minute*3) + n.RunStorageStatusCheck(testCtx, &workingStatus) + testCancel() + } + + status = workingStatus + n.systemStatus.Store(&status) + } +} diff --git a/server/http/nip96/storage_nip96.go b/server/http/nip96/storage_nip96.go index 7e7fb60d..8a247b3b 100644 --- a/server/http/nip96/storage_nip96.go +++ b/server/http/nip96/storage_nip96.go @@ -149,7 +149,7 @@ func (s *storageHandler) Upload() gin.HandlerFunc { return } ctx = storage.WithSyncCdnUpload(ctx) - bagID, url, existed, err := s.storageClient.StartUpload(ctx, now, token.PubKey(), token.MasterPubKey(), input.Filename, hex.EncodeToString(hash), input) + bagID, url, existed, err := s.storageClient.StartUpload(ctx, now, token.PubKey(), token.MasterPubKey(), input.Filename, hashHex, input) if err != nil { log.Error().Err(err).Msg("failed to upload file to ion storage") diff --git a/storage/download.go b/storage/download.go index c700b259..f1578bfe 100644 --- a/storage/download.go +++ b/storage/download.go @@ -29,27 +29,27 @@ import ( func (c *client) DownloadUrl(masterPubkey string, fileHash string) (string, error) { bag, _, err := c.bagByUser(masterPubkey) - if err != nil { + if c.RecordReadOperation(err) != nil { return "", errors.Wrapf(err, "failed to get bagID for the user %v", masterPubkey) } if bag == nil { return "", ErrNotFound } bs, err := c.buildBootstrapNodeInfo(bag) - if err != nil { + if c.RecordReadOperation(err) != nil { return "", errors.Wrapf(err, "failed to build bootstap for bag %v", hex.EncodeToString(bag.BagID)) } file, err := c.detectFile(bag, fileHash) - if err != nil { + if c.RecordReadOperation(err) != nil { return "", errors.Wrapf(err, "failed to detect file %v in bag %v", fileHash, hex.EncodeToString(bag.BagID)) } b, err := json.Marshal([]*Bootstrap{bs}) - if err != nil { + if c.RecordReadOperation(err) != nil { return "", errors.Wrapf(err, "failed to marshal %#v", bs) } bootstrap := base64.StdEncoding.EncodeToString(b) u, _, err := c.buildUrl(hex.EncodeToString(bag.BagID), file, masterPubkey, fileHash, bootstrap) - return u, err + return u, c.RecordReadOperation(err) } func acceptNewBag(ctx context.Context, event *model.Event, acceptor func(ctx context.Context, fh, master, infohash string) error) error { @@ -70,20 +70,25 @@ func acceptNewBag(ctx context.Context, event *model.Event, acceptor func(ctx con return acceptor(ctx, fileHash, event.GetMasterPublicKey(), infohash) } -func (c *client) StartDownloadNewBag(ctx context.Context, fileHash, userMasterKey, infohash string) error { +func (c *client) StartDownloadNewBag(ctx context.Context, fileHash, userMasterKey, infohash string) (err error) { log.Info().Str("context", "STORAGE"). Str("user", userMasterKey). Str("infohash", infohash). Msg("accepting NIP-94 infohash with new files") + + defer func() { + err = c.RecordReadOperation(err) + }() + spl := strings.Split(infohash, ":") if len(spl) != 3 { - return errors.Newf("malformed i tag %v, cannot detect bootstrap and version", infohash) + return errors.Errorf("malformed i tag %v, cannot detect bootstrap and version", infohash) } infohash = spl[0] bootstrap := spl[1] version, cErr := strconv.ParseInt(spl[2], 10, 64) if cErr != nil { - return errors.Wrapf(cErr, "malformed i tag %v, cannot version", infohash) + return errors.Wrapf(cErr, "i tag parse error for tag %v", infohash) } if err := c.newBagIDPromoted(ctx, userMasterKey, infohash, &bootstrap, version); err != nil { @@ -132,6 +137,9 @@ func (c *client) newBagIDPromoted(ctx context.Context, user, bagID string, boots } func (c *client) download(ctx context.Context, bagID, user string, bootstrap *string, newVersion int64) (err error) { + defer func() { + err = c.RecordReadOperation(err) + }() bag, err := hex.DecodeString(bagID) if err != nil { return errors.Wrapf(err, "invalid bagID %v", bagID) @@ -198,7 +206,7 @@ func (c *client) torrentStateCallback(tor *storage.Torrent, user *string) func(e Uint64("file_size", tor.Info.FileSize). Strs("files", files). Msg("bag downloaded, disabling download") - if pErr := tor.Start(true, false, false); pErr != nil { + if pErr := tor.Start(true, false, false); c.RecordReadOperation(pErr) != nil { log.Error().Err(pErr).Hex("bag_id", tor.BagID).Str("user", usr).Msg("failed to stop torrent download after downloading data") } c.activeDownloadsMx.Lock() @@ -218,7 +226,7 @@ func (c *client) torrentStateCallback(tor *storage.Torrent, user *string) func(e Uint32("files_count", tor.Header.FilesCount). Uint64("file_size", uint64(tor.Info.FileSize)). Msg("bag header resolved, enabling upload to serve clients with chunks we own") - if pErr := tor.StartWithCallback(true, true, false, c.torrentStateCallback(tor, user)); pErr != nil { + if pErr := tor.StartWithCallback(true, true, false, c.torrentStateCallback(tor, user)); c.RecordReadOperation(pErr) != nil { log.Error().Err(pErr).Hex("bag_id", tor.BagID).Msg("failed to start torrent upload after downloading header") } if user != nil { @@ -231,7 +239,7 @@ func (c *client) torrentStateCallback(tor *storage.Torrent, user *string) func(e } } ver := int64(tor.Header.FilesCount) - if pErr := c.saveTorrent(tor, user, nil, false, &ver); pErr != nil { + if pErr := c.saveTorrent(tor, user, nil, false, &ver); c.RecordReadOperation(pErr) != nil { log.Error().Err(pErr).Hex("bag_id", tor.BagID).Msg("failed save torrent with stopped download after downloading") } } @@ -267,7 +275,10 @@ func (c *client) connectToBootstrap(ctx context.Context, torrent *storage.Torren return nil } -func (c *client) saveTorrent(tr *storage.Torrent, userPubKey *string, bs *string, deletion bool, newVersion *int64) error { +func (c *client) saveTorrent(tr *storage.Torrent, userPubKey *string, bs *string, deletion bool, newVersion *int64) (err error) { + defer func() { + err = c.RecordWriteOperation(err) + }() if err := c.progressStorage.SetTorrent(tr); err != nil { return errors.Wrap(err, "failed to save torrent into storage") } diff --git a/storage/global.go b/storage/global.go index 83c87bee..6279a5ed 100644 --- a/storage/global.go +++ b/storage/global.go @@ -37,6 +37,7 @@ import ( "github.com/ice-blockchain/subzero/rq" "github.com/ice-blockchain/subzero/storage/internal" "github.com/ice-blockchain/subzero/storage/statistics" + "github.com/ice-blockchain/subzero/tracing/statefsm" ) var ( @@ -260,7 +261,7 @@ func WithConfig(cfg *Config) Option { func MustInit(ctx context.Context, rqClient rq.Client, opts ...Option) { globalClient.Once.Do(func() { - globalClient.Client = mustInit(ctx, rqClient, opts...) + globalClient.Client = mustCreateClient(ctx, rqClient, opts...) }) appcontext.GetAppContext(ctx).OnShutdown(func() error { if globalClient.Client == nil { @@ -273,13 +274,15 @@ func MustInit(ctx context.Context, rqClient rq.Client, opts ...Option) { }) } -func mustInit(ctx context.Context, rqClient rq.Client, opts ...Option) *client { +func mustCreateClient(ctx context.Context, rqClient rq.Client, opts ...Option) *client { var cl = &client{ newFiles: make(map[string]map[string]*FileMetaInput), newFilesMx: &sync.RWMutex{}, downloadQueue: make(chan queueItem, 1000000), activeDownloads: make(map[string]bool), activeDownloadsMx: &sync.RWMutex{}, + healthReadFSM: statefsm.New(time.Minute*3, 10, statefsm.WithThreadSafety()), + healthWriteFSM: statefsm.New(time.Minute*3, 10, statefsm.WithThreadSafety()), } for _, opt := range opts { diff --git a/storage/storage.go b/storage/storage.go index 1b8744fe..6f6d5d7c 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -31,6 +31,7 @@ import ( "github.com/ice-blockchain/subzero/appcontext" "github.com/ice-blockchain/subzero/storage/internal" "github.com/ice-blockchain/subzero/storage/statistics" + "github.com/ice-blockchain/subzero/tracing/statefsm" ) type ( @@ -46,6 +47,13 @@ type ( Delete(ctx context.Context, userPubkey, masterKey string, fileSha256 string) error DeleteUser(masterKey string) error StartDownloadNewBag(ctx context.Context, fileHash, masterKey, infohash string) error + Health() HealthStatus + } + HealthStatus struct { + LastRead time.Time + LastWrite time.Time + InReadErrorState bool + InWriteErrorState bool } Bootstrap struct { Overlay *overlay.Node @@ -86,6 +94,10 @@ type ( config *Config rootStoragePath string closed atomic.Bool + healthReadFSM statefsm.FSM + healthWriteFSM statefsm.FSM + lastRead atomic.Pointer[time.Time] + lastWrite atomic.Pointer[time.Time] cdn internal.CDNClient } queueItem struct { @@ -373,3 +385,31 @@ func (c *client) report(ctx context.Context) { } } } + +func (c *client) Health() (status HealthStatus) { + if lastRead := c.lastRead.Load(); lastRead != nil { + status.LastRead = *lastRead + } + if lastWrite := c.lastWrite.Load(); lastWrite != nil { + status.LastWrite = *lastWrite + } + status.InReadErrorState = c.healthReadFSM.InError() + status.InWriteErrorState = c.healthWriteFSM.InError() + return status +} + +// RecordReadOperation records the result of a read operation to update health status and returns the given error as is. +func (c *client) RecordReadOperation(err error) error { + now := time.Now() + c.lastRead.Store(&now) + c.healthReadFSM.Push(err != nil, now) + return err +} + +// RecordWriteOperation records the result of a write operation to update health status and returns the given error as is. +func (c *client) RecordWriteOperation(err error) error { + now := time.Now() + c.lastWrite.Store(&now) + c.healthWriteFSM.Push(err != nil, now) + return err +} diff --git a/storage/storage_fixture.go b/storage/storage_fixture.go index 75fb56db..ece078d6 100644 --- a/storage/storage_fixture.go +++ b/storage/storage_fixture.go @@ -17,10 +17,11 @@ import ( "github.com/fsnotify/fsnotify" "github.com/stretchr/testify/require" + "github.com/ice-blockchain/subzero/rq" "github.com/ice-blockchain/subzero/storage/internal" ) -func calcFileHash(t *testing.T, path string) (string, error) { +func calcFileHash(t testing.TB, path string) (string, error) { t.Helper() f, err := os.Open(path) @@ -147,9 +148,14 @@ func Reset() { globalClient.Once = sync.Once{} } -func VerifyFileOnCdn(tb *testing.T, ctx context.Context, fileName string) { +func VerifyFileOnCdn(tb testing.TB, ctx context.Context, fileName string) { internal.VerifyFileOnCdn(tb, ctx, globalClient.Client.cdn, fileName) } -func VerifyFileDeletedOnCdn(tb *testing.T, ctx context.Context, fileName string) { + +func VerifyFileDeletedOnCdn(tb testing.TB, ctx context.Context, fileName string) { internal.VerifyFileDeletedOnCdn(tb, ctx, globalClient.Client.cdn, fileName) } + +func NewClient(ctx context.Context, rqClient rq.Client, options ...Option) StorageClient { + return mustCreateClient(ctx, rqClient, options...) +} diff --git a/storage/upload.go b/storage/upload.go index 2c95f65c..72bd344e 100644 --- a/storage/upload.go +++ b/storage/upload.go @@ -33,7 +33,7 @@ import ( func (c *client) StartUpload(ctx context.Context, now time.Time, userPubKey, masterPubKey, relativePathToFileForUrl, hash string, newFile *FileMetaInput) (bagID, url string, existed bool, err error) { existingBagForUser, _, err := c.bagByUser(masterPubKey) - if err != nil { + if c.RecordWriteOperation(err) != nil { return "", "", false, errors.Wrapf(err, "failed to find existing bag for user %s", masterPubKey) } var existingHDData []byte @@ -42,12 +42,12 @@ func (c *client) StartUpload(ctx context.Context, now time.Time, userPubKey, mas if existingBagForUser.Header != nil && len(existingBagForUser.Header.Data) > 0 { existingHDData = existingBagForUser.Header.Data } else { - if existingHDData, err = c.latestHeaderForBag(existingBagForUser.BagID); err != nil { + if existingHDData, err = c.latestHeaderForBag(existingBagForUser.BagID); c.RecordWriteOperation(err) != nil { return "", "", false, errors.Wrapf(err, "failed to get header for bag %v", hex.EncodeToString(existingBagForUser.BagID)) } } if len(existingHDData) > 0 { - if err = json.Unmarshal(existingHDData, &existingHD); err != nil { + if err = json.Unmarshal(existingHDData, &existingHD); c.RecordWriteOperation(err) != nil { return "", "", false, errors.Wrapf(err, "corrupted header metadata for bag %v", hex.EncodeToString(existingBagForUser.BagID)) } } @@ -61,7 +61,7 @@ func (c *client) StartUpload(ctx context.Context, now time.Time, userPubKey, mas existed = false } else { return "", "", false, - errors.Wrapf(err, "failed to build download url for already existing file %v/%v(%v)", masterPubKey, relativePathToFileForUrl, hash) + errors.Wrapf(c.RecordWriteOperation(err), "failed to build download url for already existing file %v/%v(%v)", masterPubKey, relativePathToFileForUrl, hash) } } if existed { @@ -72,7 +72,7 @@ func (c *client) StartUpload(ctx context.Context, now time.Time, userPubKey, mas } bs := []*Bootstrap{bootstrapNode} b, err := json.Marshal(bs) - if err != nil { + if c.RecordWriteOperation(err) != nil { return "", "", false, errors.Wrapf(err, "failed to marshal %#v", bs) } bootstrap := base64.StdEncoding.EncodeToString(b) @@ -89,7 +89,7 @@ func (c *client) StartUpload(ctx context.Context, now time.Time, userPubKey, mas var bs []*Bootstrap var bag *storage.Torrent bag, bs, err = c.upload(ctx, now, userPubKey, masterPubKey, relativePathToFileForUrl, hash, newFile, &existingHD) - if err != nil { + if c.RecordWriteOperation(err) != nil { return "", "", false, errors.Wrapf(err, "failed to start upload of %v", relativePathToFileForUrl) } bagID = hex.EncodeToString(bag.BagID) @@ -110,7 +110,7 @@ func (c *client) StartUpload(ctx context.Context, now time.Time, userPubKey, mas go c.stats.ProcessFile(ctx, fullFilePath, gomime.TypeByExtension(filepath.Ext(fullFilePath)), uplFile.Size) } b, err := json.Marshal(bs) - if err != nil { + if c.RecordWriteOperation(err) != nil { return "", "", false, errors.Wrapf(err, "failed to marshal %#v", bs) } bootstrap := base64.StdEncoding.EncodeToString(b) @@ -122,8 +122,8 @@ func (c *client) StartUpload(ctx context.Context, now time.Time, userPubKey, mas if !c.cdnEnabled() || newFile == nil { return bagID + ":" + bootstrap + ":" + strconv.FormatInt(int64(bag.Header.FilesCount), 10), url, existed, nil } - if err = c.cdnUpload(ctx, masterPubKey, relativePathToFileForUrl, fileNameForCdn, newFile); err != nil { - return "", "", false, errors.Wrapf(c.cdnUpload(ctx, masterPubKey, relativePathToFileForUrl, fileNameForCdn, newFile), "failed to upload file to cdn") + if err = c.cdnUpload(ctx, masterPubKey, relativePathToFileForUrl, fileNameForCdn, newFile); c.RecordWriteOperation(err) != nil { + return "", "", false, errors.Wrapf(err, "failed to upload file to cdn") } return bagID + ":" + bootstrap + ":" + strconv.FormatInt(int64(bag.Header.FilesCount), 10), url, existed, nil } @@ -140,11 +140,11 @@ func (c *client) cdnUpload(ctx context.Context, masterPubKey, relativePathToFile return nil } f, ferr := os.Open(fullFilePath) - if ferr != nil { + if c.RecordWriteOperation(ferr) != nil { return errors.Wrapf(ferr, "failed to open %v", fullFilePath) } defer f.Close() - if err := c.cdn.FileUpload(ctx, f, newFile.ContentType, fileNameForCdn); err != nil { + if err := c.cdn.FileUpload(ctx, f, newFile.ContentType, fileNameForCdn); c.RecordWriteOperation(err) != nil { if err = c.cdn.FileUploadAsync(ctx, strings.TrimPrefix(fullFilePath, c.rootStoragePath), newFile.ContentType, fileNameForCdn); err != nil { return errors.Wrapf(err, "failed to enqueue file upload %v to cdn", fileNameForCdn) } @@ -351,16 +351,17 @@ func (c *client) SaveFile(ctx context.Context, now time.Time, masterPubKey strin contentType = gomime.TypeByExtension(filepath.Ext(fileName)) } uploadingFilePath := filepath.Join(storagePath, fileName) - if err = os.MkdirAll(filepath.Dir(uploadingFilePath), 0o744); err != nil { + err = os.MkdirAll(filepath.Dir(uploadingFilePath), 0o744) + if c.RecordWriteOperation(err) != nil { log.Error().Str("context", "STORAGE").Err(err).Msg("failed to open temp file while processing upload") return "", nil, nil, errors.Wrapf(err, "failed to create tmp dir") } userDir, err := os.OpenRoot(storagePath) - if err != nil { + if c.RecordWriteOperation(err) != nil { return "", nil, nil, errors.Wrap(err, "failed to open user folder while processing upload") } fileUploadTo, err := userDir.OpenFile(fileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o644) - if err != nil { + if c.RecordWriteOperation(err) != nil { return "", nil, nil, errors.Wrap(err, "failed to open temp file while processing upload") } defer func() { @@ -368,7 +369,7 @@ func (c *client) SaveFile(ctx context.Context, now time.Time, masterPubKey strin fileUploadTo.Close() }() written, err := io.Copy(io.MultiWriter(fileUploadTo, hashCalc), part) - if err != nil { + if c.RecordWriteOperation(err) != nil { log.Error(). Str("context", "STORAGE"). Err(err). @@ -380,6 +381,7 @@ func (c *client) SaveFile(ctx context.Context, now time.Time, masterPubKey strin if fileSize > maxSize { part.Close() defer os.Remove(uploadingFilePath) + c.RecordWriteOperation(ErrFileTooBig) return "", &FileMetaInput{FileSize: fileSize}, nil, ErrFileTooBig } log.Trace().Str("context", "STORAGE"). @@ -426,7 +428,7 @@ func (c *client) SaveFile(ctx context.Context, now time.Time, masterPubKey strin Msg("hash") hexHash := hex.EncodeToString(hash) newName = hexHash + filepath.Ext(fileName) - if err = os.Rename(filepath.Join(storagePath, fileName), filepath.Join(storagePath, newName)); err != nil { + if err = os.Rename(filepath.Join(storagePath, fileName), filepath.Join(storagePath, newName)); c.RecordWriteOperation(err) != nil { log.Error(). Str("context", "STORAGE"). Err(err). @@ -450,16 +452,12 @@ func (c *client) SaveFile(ctx context.Context, now time.Time, masterPubKey strin } func readString(part *multipart.Part, name string) (string, error) { - bufSize := 1024 - b := make([]byte, bufSize) - read, err := part.Read(b) + const sizeCap = 1024 + data, err := io.ReadAll(io.LimitReader(part, sizeCap)) if err != nil { - if err == io.EOF { - return string(b[:read]), nil - } return "", errors.Wrapf(err, "failed to read %v", name) } - return string(b[:read]), nil + return string(data), nil } func (c *client) forceUploadExistingFiles(ctx context.Context) error { @@ -477,7 +475,7 @@ func (c *client) forceUploadExistingFiles(ctx context.Context) error { masterKey := userDir.Name() userPath, _ := c.BuildUserPath(masterKey, "") userFiles, uploadErr := os.ReadDir(userPath) - if err != nil { + if uploadErr != nil { log.Error().Err(uploadErr).Str("user", masterKey).Msg("failed to list files in user folder") return } @@ -486,8 +484,8 @@ func (c *client) forceUploadExistingFiles(ctx context.Context) error { contentType := c.detectContentType(masterKey, uf.Name()) uploadErr = errors.Join(uploadErr, errors.Wrapf(c.cdn.FileUploadAsync(ctx, strings.TrimPrefix(filepath.Join(userPath, uf.Name()), c.rootStoragePath), contentType, fName), "failed to upload file %v for usr %v", uf.Name(), masterKey)) } - if err != nil { - log.Error().Err(err).Str("user", masterKey).Msg("failed to upload files for user") + if uploadErr != nil { + log.Error().Err(uploadErr).Str("user", masterKey).Msg("failed to upload files for user") } }() } diff --git a/tracing/statefsm/fsm.go b/tracing/statefsm/fsm.go new file mode 100644 index 00000000..fa28913e --- /dev/null +++ b/tracing/statefsm/fsm.go @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: ice License 1.0 + +package statefsm + +import ( + "fmt" + "math" + "sync" + "time" +) + +type ( + // FSM is a finite state machine to track consecutive successes and failures. + FSM struct { + startTime time.Time // Timestamp of the first error/success in the current sequence. + mu *sync.RWMutex // Optional mutex for thread safety. + windowSize time.Duration // Consecutive time window to consider operations. + threshold uint32 // Number of consecutive operations to trigger state change. + okCounter uint32 // Count of consecutive successful operations. + errCounter uint32 // Count of consecutive failed operations. + currentState state // Current state of the FSM. + } + Option func(*FSM) + + state uint8 +) + +const ( + stateHealthy state = iota + stateUnhealthy +) + +func WithThreadSafety() Option { + return func(fsm *FSM) { + fsm.mu = new(sync.RWMutex) + } +} + +// New creates a new FSM with the specified window size and operation threshold. +func New(consecutiveWindow time.Duration, consecutiveOperationThreshold uint32, opts ...Option) (fsm FSM) { + if consecutiveWindow <= 0 { + panic("consecutiveWindow must be greater than zero") + } + if consecutiveOperationThreshold == 0 { + panic("consecutiveOperationThreshold must be greater than zero") + } + + fsm = FSM{ + windowSize: consecutiveWindow, + threshold: consecutiveOperationThreshold, + currentState: stateHealthy, + } + for _, opt := range opts { + opt(&fsm) + } + + return fsm +} + +// Push updates the FSM state based on the result of an operation and returns new state. +func (s *FSM) Push(hasError bool, currentTime time.Time) bool { + if s.mu != nil { + s.mu.Lock() + defer s.mu.Unlock() + } + + if hasError { + s.okCounter = 0 + startTime := s.startTime + if s.errCounter == 0 { + s.errCounter++ + startTime = currentTime + } else { + if currentTime.Sub(startTime) <= s.windowSize { + if s.errCounter < math.MaxUint32 { + s.errCounter++ + } + } else { + // Reset error count if the time window has expired. + if s.inError() { + if s.errCounter < math.MaxUint32 { + s.errCounter++ + } + } else { + // Reset to 1 as this is a new error after the window. + s.errCounter = 1 + } + startTime = currentTime + } + } + + s.startTime = startTime + if s.errCounter >= s.OperationThreshold() { + s.currentState = stateUnhealthy + } + } else { + if s.okCounter < math.MaxUint32 { + s.okCounter++ + } + if s.inError() && s.okCounter >= s.OperationThreshold() { + s.currentState = stateHealthy + s.errCounter = 0 + s.startTime = time.Time{} + } + } + return s.inError() +} + +func (s *FSM) inError() bool { + return s.currentState == stateUnhealthy +} + +// InError returns true if the FSM is in an unhealthy state. +func (s *FSM) InError() bool { + if s.mu != nil { + s.mu.RLock() + defer s.mu.RUnlock() + } + return s.inError() +} + +// String returns a string representation of the FSM state. +func (s *FSM) String() string { + if s.mu != nil { + s.mu.RLock() + defer s.mu.RUnlock() + } + return fmt.Sprintf("StateFSM{window=%v, threshold=%d, state=%v, errs=%d, oks=%d, ts=%s}", + s.windowSize, + s.threshold, + s.currentState, + s.errCounter, + s.okCounter, + s.startTime.Format(time.RFC3339Nano), + ) +} + +// CurrentErrors returns the current count of consecutive errors. +func (s *FSM) CurrentErrors() uint32 { + if s.mu != nil { + s.mu.RLock() + defer s.mu.RUnlock() + } + return s.errCounter +} + +// CurrentSuccesses returns the current count of consecutive successes. +func (s *FSM) CurrentSuccesses() uint32 { + if s.mu != nil { + s.mu.RLock() + defer s.mu.RUnlock() + } + return s.okCounter +} + +// OperationThreshold returns the configured threshold for consecutive operations to trigger state change. +func (s *FSM) OperationThreshold() uint32 { + return s.threshold +} diff --git a/tracing/statefsm/fsm_test.go b/tracing/statefsm/fsm_test.go new file mode 100644 index 00000000..c72592a7 --- /dev/null +++ b/tracing/statefsm/fsm_test.go @@ -0,0 +1,242 @@ +// SPDX-License-Identifier: ice License 1.0 + +package statefsm + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +const ( + testWindowSize = time.Duration(30 * time.Second) + testOperationThreshold = 3 +) + +func helperNewFSM(tb testing.TB, opts ...Option) FSM { + tb.Helper() + return New(testWindowSize, testOperationThreshold, opts...) +} + +func TestStateFSM(t *testing.T) { + t.Parallel() + + t.Run("Enters error state after threshold errors", func(t *testing.T) { + fsm := helperNewFSM(t) + + now := time.Now() + fsm.Push(true, now) + fsm.Push(true, now) + + t.Logf("FSM state: %s", fsm.String()) + require.False(t, fsm.InError()) + + fsm.Push(true, now) + t.Logf("FSM state: %s", fsm.String()) + require.True(t, fsm.InError()) + require.EqualValues(t, testOperationThreshold, fsm.CurrentErrors()) + }) + t.Run("Recovers from error state after threshold successes", func(t *testing.T) { + fsm := helperNewFSM(t) + + now := time.Now() + for range 2 * fsm.OperationThreshold() { + fsm.Push(true, now) + } + + t.Logf("FSM state: %s", fsm.String()) + require.True(t, fsm.InError()) + + for range fsm.OperationThreshold() { + fsm.Push(false, now) + } + + t.Logf("FSM state: %s", fsm.String()) + require.False(t, fsm.InError()) + }) + t.Run("Does not enter error state if errors are below threshold", func(t *testing.T) { + fsm := helperNewFSM(t) + + now := time.Now() + for range fsm.OperationThreshold() - 1 { + fsm.Push(true, now) + } + + t.Logf("FSM state: %s", fsm.String()) + require.False(t, fsm.InError()) + }) + t.Run("Does not recover from error state if successes are below threshold", func(t *testing.T) { + fsm := helperNewFSM(t) + + now := time.Now() + for range fsm.OperationThreshold() { + fsm.Push(true, now) + } + + t.Logf("FSM state: %s", fsm.String()) + require.True(t, fsm.InError()) + + for range fsm.OperationThreshold() - 1 { + fsm.Push(false, now) + } + + t.Logf("FSM state: %s", fsm.String()) + require.True(t, fsm.InError()) + }) + t.Run("Trips correctly within fast burst", func(t *testing.T) { + fsm := helperNewFSM(t) + now := time.Now() + + fsm.Push(true, now) + fsm.Push(true, now.Add(10*time.Second)) + fsm.Push(true, now.Add(20*time.Second)) // 3rd error within 20s of first + + t.Logf("FSM state: %s", fsm.String()) + require.True(t, fsm.InError()) + require.Zero(t, fsm.CurrentSuccesses()) + }) + t.Run("Does not trip if errors are spaced out", func(t *testing.T) { + fsm := helperNewFSM(t) + now := time.Now() + + fsm.Push(true, now) + fsm.Push(true, now.Add(10*time.Second)) + fsm.Push(true, now.Add(90*time.Second)) + + t.Logf("FSM state: %s", fsm.String()) + require.False(t, fsm.InError()) + }) + t.Run("Sustainability: error while down + window expired = maintain threshold", func(t *testing.T) { + fsm := helperNewFSM(t) + start := time.Now() + + fsm.Push(true, start) + fsm.Push(true, start.Add(10*time.Second)) + fsm.Push(true, start.Add(20*time.Second)) + + t.Logf("FSM state: %s", fsm.String()) + require.True(t, fsm.InError()) + + // The original 'startTime' was T+0. T+90 is clearly expired. + future := start.Add(90 * time.Second) + fsm.Push(true, future) + + // State should still be DOWN. + t.Logf("FSM state: %s", fsm.String()) + require.True(t, fsm.InError()) + }) +} + +func TestStateFSMWithThreads(t *testing.T) { + t.Parallel() + + t.Run("Concurrent Push calls are safe", func(t *testing.T) { + const n = 10 + fsm := helperNewFSM(t, WithThreadSafety()) + now := time.Now() + + done := make(chan bool, n) + for i := range n { + go func(idx int) { + defer func() { done <- true }() + + for j := range n * 10 { + hasError := idx%2 == 0 + fsm.Push(hasError, now.Add(time.Duration(j)*time.Millisecond)) + } + }(i) + } + + for range n { + <-done + } + + require.NotNil(t, fsm) + t.Logf("FSM state: %s", fsm.String()) + }) + + t.Run("Concurrent reads are safe", func(t *testing.T) { + fsm := helperNewFSM(t, WithThreadSafety()) + now := time.Now() + + for i := range testOperationThreshold { + fsm.Push(true, now.Add(time.Duration(i)*time.Second)) + } + require.True(t, fsm.InError()) + + const n = 20 + done := make(chan bool, n) + for range n { + go func() { + defer func() { done <- true }() + + for range n * 5 { + _ = fsm.InError() + _ = fsm.CurrentErrors() + _ = fsm.CurrentSuccesses() + _ = fsm.String() + } + }() + } + + for range n { + <-done + } + + require.True(t, fsm.InError()) + }) + + t.Run("Concurrent reads and writes are safe", func(t *testing.T) { + fsm := helperNewFSM(t, WithThreadSafety()) + now := time.Now() + + const n = 22 + const r = n / 3 + const w = n - r + + done := make(chan bool, n) + + for i := range w { + go func(idx int) { + defer func() { done <- true }() + + for j := range 50 { + hasError := idx%2 == 0 + fsm.Push(hasError, now.Add(time.Duration(j)*time.Millisecond)) + } + }(i) + } + + for range r { + go func() { + defer func() { done <- true }() + + for range 50 { + _ = fsm.InError() + _ = fsm.CurrentErrors() + _ = fsm.CurrentSuccesses() + _ = fsm.String() + } + }() + } + + for range n { + <-done + } + + require.NotNil(t, fsm) + t.Logf("FSM state: %s", fsm.String()) + }) +} + +func TestFSMPanicWithInvalidValues(t *testing.T) { + t.Parallel() + + require.Panics(t, func() { + _ = New(0, 1) + }) + require.Panics(t, func() { + _ = New(1, 0) + }) +}