diff --git a/cl/README.md b/cl/README.md index cbf88ce14..106edba73 100644 --- a/cl/README.md +++ b/cl/README.md @@ -105,24 +105,6 @@ geth --verbosity 5 \ - `--syncmode full`: Synchronization mode. - `--miner.recommit`: Frequency of miner recommit. -### Obtaining the Genesis Block Hash - -You can obtain the genesis block hash by querying the latest block after initializing your node: - -```bash -cast block latest -r http://localhost:8545 -``` - -Look for the `hash` field in the output, which represents the latest block hash. Since the chain is just initialized, this will be the genesis block hash. - -Alternatively, you can use `curl` to get the genesis block hash: - -```bash -curl -X POST --data '{"jsonrpc":"2.0","method":"eth_getBlockByNumber","params":["0x0", false],"id":1}' -H "Content-Type: application/json" http://localhost:8545 -``` - -Extract the `hash` value from the response and use it without `0x`. - ## Running Redis We will use Docker Compose to run Redis. @@ -163,29 +145,27 @@ Ensure all dependencies are installed and build the application: ```bash go mod tidy -go build -o consensus-client main.go +go build -o consensus-client cmd/redisapp/main.go ``` -### Configuration +### Consensus Client Configuration The consensus client can be configured via command-line flags, environment variables, or a YAML configuration file. -#### Command-Line Flags +#### Command-Line Flags for Streamer - `--instance-id`: **(Required)** Unique instance ID for this node. - `--eth-client-url`: Ethereum client URL (default: `http://localhost:8551`). - `--jwt-secret`: JWT secret for Ethereum client. -- `--genesis-block-hash`: Genesis block hash. - `--redis-addr`: Redis address (default: `127.0.0.1:7001`). - `--evm-build-delay`: EVM build delay (default: `1s`). - `--config`: Path to a YAML configuration file. -#### Environment Variables +#### Environment Variables for Consensus Client - `RAPP_INSTANCE_ID` - `RAPP_ETH_CLIENT_URL` - `RAPP_JWT_SECRET` -- `RAPP_GENESIS_BLOCK_HASH` - `RAPP_REDIS_ADDR` - `RAPP_EVM_BUILD_DELAY` - `RAPP_CONFIG` @@ -199,7 +179,6 @@ Run the client using command-line flags: --instance-id "node1" \ --eth-client-url "http://localhost:8551" \ --jwt-secret "your_jwt_secret" \ - --genesis-block-hash "your_genesis_block_hash" \ --redis-addr "127.0.0.1:7001" \ --evm-build-delay "1s" ``` @@ -207,9 +186,8 @@ Run the client using command-line flags: **Note**: - Replace `"your_jwt_secret"` with the actual JWT secret you used earlier. -- Replace `"your_genesis_block_hash"` with the genesis block hash obtained earlier. -### Using a Configuration File +### Using a Configuration File for Consensus Client Create a `config.yaml` file: @@ -217,7 +195,6 @@ Create a `config.yaml` file: instance-id: "node1" eth-client-url: "http://localhost:8551" jwt-secret: "your_jwt_secret" -genesis-block-hash: "your_genesis_block_hash" redis-addr: "127.0.0.1:7001" evm-build-delay: "1s" ``` @@ -232,6 +209,143 @@ Run the client with the configuration file: - **Multiple Instances**: You can run multiple instances of the consensus client by changing the `--instance-id` and `--eth-client-url` parameters. +## Running the Streamer + +The Streamer is responsible for streaming payloads to member nodes, allowing them to apply these payloads to their respective Geth instances. + +### Build the Streamer + +Ensure all dependencies are installed and build the Streamer application: + +```bash +go mod tidy +go build -o streamer cmd/streamer/main.go +``` + +### Streamer Configuration + +The Streamer can be configured via command-line flags, environment variables, or a YAML configuration file. + +#### Command-Line Flags + +- `--config`: Path to config file. +- `--redis-addr`: Redis address (default: 127.0.0.1:7001). +- `--listen-addr`: Streamer listen address (default: :50051). +- `--log-fmt`: Log format to use, options are text or json (default: text). +- `--log-level`: Log level to use, options are debug, info, warn, error (default: info). + +#### Environment Variables + +- `STREAMER_CONFIG` +- `STREAMER_REDIS_ADDR` +- `STREAMER_LISTEN_ADDR` +- `STREAMER_LOG_FMT` +- `STREAMER_LOG_LEVEL` + +#### Run the Streamer + +Run the Streamer using command-line flags: + +```bash +./streamer start \ + --config "config.yaml" \ + --redis-addr "127.0.0.1:7001" \ + --listen-addr ":50051" \ + --log-fmt "json" \ + --log-level "info" +``` + +#### Using a Configuration File for Streamer + +Create a `streamer_config.yaml` file: + +```yaml +redis-addr: "127.0.0.1:7001" +listen-addr: ":50051" +log-fmt: "json" +log-level: "info" +``` + +Run the Streamer with the configuration file: + +```bash +./streamer start --config streamer_config.yaml +``` + +## Running member nodes + +Member nodes connect to the Streamer to receive payloads from the stream and apply them to their Geth instances. + +### Build the Member Client + +Ensure all dependencies are installed and build the Member Client application: + +```bash +go mod tidy +go build -o memberclient cmd/member/main.go +``` + +### Configuration + +The Member Client can be configured via command-line flags, environment variables, or a YAML configuration file. + +### Command-Line Flags for Member Client + +- `--config`: Path to config file. +- `--client-id`: (Required) Unique client ID for this member. +- `--streamer-addr`: (Required) Streamer address. +- `--eth-client-url`: Ethereum client URL (default: ). +- `--jwt-secret`: JWT secret for Ethereum client. +- `--log-fmt`: Log format to use, options are text or json (default: text). +- `--log-level`: Log level to use, options are debug, info, warn, error (default: info). + +### Environment Variables for Member Client + +- `MEMBER_CONFIG` +- `MEMBER_CLIENT_ID` +- `MEMBER_STREAMER_ADDR` +- `MEMBER_ETH_CLIENT_URL` +- `MEMBER_JWT_SECRET` +- `MEMBER_LOG_FMT` +- `MEMBER_LOG_LEVEL` + +### Run the Member Client + +Run the Member Client using command-line flags: + +```bash +./memberclient start \ + --client-id "member1" \ + --streamer-addr "http://localhost:50051" \ + --eth-client-url "http://localhost:8551" \ + --jwt-secret "your_jwt_secret" \ + --log-fmt "json" \ + --log-level "info" +``` + +Note: + +Replace "your_jwt_secret" with the actual JWT secret you used earlier. + +### Using a Configuration File + +Create a member_config.yaml file: + +```yaml +client-id: "member1" +streamer-addr: "http://localhost:50051" +eth-client-url: "http://localhost:8551" +jwt-secret: "your_jwt_secret" +log-fmt: "json" +log-level: "info" +``` + +Run the Member Client with the configuration file: + +```bash +./memberclient start --config member_config.yaml +``` + ## Conclusion You now have a local Ethereum environment with Geth nodes, Redis, and a consensus client. diff --git a/cl/cmd/member/main.go b/cl/cmd/member/main.go new file mode 100644 index 000000000..19daaef57 --- /dev/null +++ b/cl/cmd/member/main.go @@ -0,0 +1,170 @@ +package main + +import ( + "encoding/hex" + "fmt" + "net/url" + "os" + "os/signal" + "syscall" + + "github.com/primev/mev-commit/cl/member" + "github.com/primev/mev-commit/x/util" + "github.com/urfave/cli/v2" + "github.com/urfave/cli/v2/altsrc" +) + +var ( + configFlag = &cli.StringFlag{ + Name: "config", + Usage: "Path to config file", + EnvVars: []string{"MEMBER_CONFIG"}, + } + + clientIDFlag = altsrc.NewStringFlag(&cli.StringFlag{ + Name: "client-id", + Usage: "Unique client ID for this member", + EnvVars: []string{"MEMBER_CLIENT_ID"}, + Required: true, + }) + + streamerAddrFlag = altsrc.NewStringFlag(&cli.StringFlag{ + Name: "streamer-addr", + Usage: "Streamer address", + EnvVars: []string{"MEMBER_STREAMER_ADDR"}, + Required: true, + Action: func(_ *cli.Context, s string) error { + if _, err := url.Parse(s); err != nil { + return fmt.Errorf("invalid streamer-addr: %v", err) + } + return nil + }, + }) + + ethClientURLFlag = altsrc.NewStringFlag(&cli.StringFlag{ + Name: "eth-client-url", + Usage: "Ethereum client URL", + EnvVars: []string{"MEMBER_ETH_CLIENT_URL"}, + Value: "http://localhost:8551", + Action: func(_ *cli.Context, s string) error { + if _, err := url.Parse(s); err != nil { + return fmt.Errorf("invalid eth-client-url: %v", err) + } + return nil + }, + }) + + jwtSecretFlag = altsrc.NewStringFlag(&cli.StringFlag{ + Name: "jwt-secret", + Usage: "JWT secret for Ethereum client", + EnvVars: []string{"MEMBER_JWT_SECRET"}, + Value: "13373d9a0257983ad150392d7ddb2f9172c9396b4c450e26af469d123c7aaa5c", + Action: func(_ *cli.Context, s string) error { + if len(s) != 64 { + return fmt.Errorf("invalid jwt-secret: must be 64 hex characters") + } + if _, err := hex.DecodeString(s); err != nil { + return fmt.Errorf("invalid jwt-secret: %v", err) + } + return nil + }, + }) + + logFmtFlag = altsrc.NewStringFlag(&cli.StringFlag{ + Name: "log-fmt", + Usage: "Log format to use, options are 'text' or 'json'", + EnvVars: []string{"MEMBER_LOG_FMT"}, + Value: "text", + }) + + logLevelFlag = altsrc.NewStringFlag(&cli.StringFlag{ + Name: "log-level", + Usage: "Log level to use, options are 'debug', 'info', 'warn', 'error'", + EnvVars: []string{"MEMBER_LOG_LEVEL"}, + Value: "info", + }) +) + +type Config struct { + ClientID string + StreamerAddr string + EthClientURL string + JWTSecret string +} + +func main() { + flags := []cli.Flag{ + configFlag, + clientIDFlag, + streamerAddrFlag, + ethClientURLFlag, + jwtSecretFlag, + logFmtFlag, + logLevelFlag, + } + + app := &cli.App{ + Name: "memberclient", + Usage: "Start the member client", + Flags: flags, + Before: altsrc.InitInputSourceWithContext(flags, + func(c *cli.Context) (altsrc.InputSourceContext, error) { + configFile := c.String("config") + if configFile != "" { + return altsrc.NewYamlSourceFromFile(configFile) + } + return &altsrc.MapInputSource{}, nil + }), + Action: func(c *cli.Context) error { + return startMemberClient(c) + }, + } + + if err := app.Run(os.Args); err != nil { + fmt.Println("Error running member client:", err) + } +} + +func startMemberClient(c *cli.Context) error { + log, err := util.NewLogger( + c.String(logLevelFlag.Name), + c.String(logFmtFlag.Name), + "", // No log tags + c.App.Writer, + ) + if err != nil { + return fmt.Errorf("failed to create logger: %w", err) + } + + cfg := Config{ + ClientID: c.String(clientIDFlag.Name), + StreamerAddr: c.String(streamerAddrFlag.Name), + EthClientURL: c.String(ethClientURLFlag.Name), + JWTSecret: c.String(jwtSecretFlag.Name), + } + + log.Info("Starting member client with configuration", "config", cfg) + + // Initialize the MemberClient + memberClient, err := member.NewMemberClient(cfg.ClientID, cfg.StreamerAddr, cfg.EthClientURL, cfg.JWTSecret, log) + if err != nil { + log.Error("Failed to initialize MemberClient", "error", err) + return err + } + + ctx, stop := signal.NotifyContext(c.Context, os.Interrupt, syscall.SIGTERM) + defer stop() + + // Start the member client + go func() { + if err := memberClient.Run(ctx); err != nil { + log.Error("Member client exited with error", "error", err) + stop() + } + }() + + <-ctx.Done() + + log.Info("Member client shutdown completed") + return nil +} diff --git a/cl/cmd/redisapp/main.go b/cl/cmd/redisapp/main.go index b70118d5b..d760b68a4 100644 --- a/cl/cmd/redisapp/main.go +++ b/cl/cmd/redisapp/main.go @@ -83,22 +83,6 @@ var ( }, }) - genesisBlockHashFlag = altsrc.NewStringFlag(&cli.StringFlag{ - Name: "genesis-block-hash", - Usage: "Genesis block hash", - EnvVars: []string{"RAPP_GENESIS_BLOCK_HASH"}, - Value: "dfc7fa546e1268f5bb65b9ec67759307d2435ad1bf609307c7c306e9bb0edcde", - Action: func(_ *cli.Context, s string) error { - if len(s) != 64 { - return fmt.Errorf("invalid genesis-block-hash: must be 64 hex characters") - } - if _, err := hex.DecodeString(s); err != nil { - return fmt.Errorf("invalid genesis-block-hash: %v", err) - } - return nil - }, - }) - redisAddrFlag = altsrc.NewStringFlag(&cli.StringFlag{ Name: "redis-addr", Usage: "Redis address", @@ -177,7 +161,6 @@ type Config struct { InstanceID string EthClientURL string JWTSecret string - GenesisBlockHash string RedisAddr string EVMBuildDelay time.Duration EVMBuildDelayEmptyBlocks time.Duration @@ -190,7 +173,6 @@ func main() { instanceIDFlag, ethClientURLFlag, jwtSecretFlag, - genesisBlockHashFlag, redisAddrFlag, logFmtFlag, logLevelFlag, @@ -245,7 +227,6 @@ func startApplication(c *cli.Context) error { InstanceID: c.String(instanceIDFlag.Name), EthClientURL: c.String(ethClientURLFlag.Name), JWTSecret: c.String(jwtSecretFlag.Name), - GenesisBlockHash: c.String(genesisBlockHashFlag.Name), RedisAddr: c.String(redisAddrFlag.Name), EVMBuildDelay: c.Duration(evmBuildDelayFlag.Name), EVMBuildDelayEmptyBlocks: c.Duration(evmBuildDelayEmptyBlockFlag.Name), @@ -259,7 +240,6 @@ func startApplication(c *cli.Context) error { cfg.InstanceID, cfg.EthClientURL, cfg.JWTSecret, - cfg.GenesisBlockHash, cfg.RedisAddr, cfg.PriorityFeeReceipt, log, diff --git a/cl/cmd/streamer/main.go b/cl/cmd/streamer/main.go new file mode 100644 index 000000000..a955796a0 --- /dev/null +++ b/cl/cmd/streamer/main.go @@ -0,0 +1,145 @@ +package main + +import ( + "fmt" + "net" + "os" + "os/signal" + "strconv" + "syscall" + + "github.com/primev/mev-commit/cl/streamer" + "github.com/primev/mev-commit/x/util" + "github.com/urfave/cli/v2" + "github.com/urfave/cli/v2/altsrc" +) + +var ( + configFlag = &cli.StringFlag{ + Name: "config", + Usage: "Path to config file", + EnvVars: []string{"STREAMER_CONFIG"}, + } + + redisAddrFlag = altsrc.NewStringFlag(&cli.StringFlag{ + Name: "redis-addr", + Usage: "Redis address", + EnvVars: []string{"STREAMER_REDIS_ADDR"}, + Value: "127.0.0.1:7001", + Action: func(_ *cli.Context, s string) error { + host, port, err := net.SplitHostPort(s) + if err != nil { + return fmt.Errorf("invalid redis-addr: %v", err) + } + if host == "" { + return fmt.Errorf("invalid redis-addr: missing host") + } + if p, err := strconv.Atoi(port); err != nil || p <= 0 || p > 65535 { + return fmt.Errorf("invalid redis-addr: invalid port number") + } + return nil + }, + }) + + listenAddrFlag = altsrc.NewStringFlag(&cli.StringFlag{ + Name: "listen-addr", + Usage: "Streamer listen address", + EnvVars: []string{"STREAMER_LISTEN_ADDR"}, + Value: ":50051", + }) + + logFmtFlag = altsrc.NewStringFlag(&cli.StringFlag{ + Name: "log-fmt", + Usage: "Log format to use, options are 'text' or 'json'", + EnvVars: []string{"STREAMER_LOG_FMT"}, + Value: "text", + }) + + logLevelFlag = altsrc.NewStringFlag(&cli.StringFlag{ + Name: "log-level", + Usage: "Log level to use, options are 'debug', 'info', 'warn', 'error'", + EnvVars: []string{"STREAMER_LOG_LEVEL"}, + Value: "info", + }) +) + +type Config struct { + RedisAddr string + ListenAddr string +} + +func main() { + flags := []cli.Flag{ + configFlag, + redisAddrFlag, + listenAddrFlag, + logFmtFlag, + logLevelFlag, + } + + app := &cli.App{ + Name: "streamer", + Usage: "Start the streamer", + Flags: flags, + Before: altsrc.InitInputSourceWithContext(flags, + func(c *cli.Context) (altsrc.InputSourceContext, error) { + configFile := c.String("config") + if configFile != "" { + return altsrc.NewYamlSourceFromFile(configFile) + } + return &altsrc.MapInputSource{}, nil + }), + Action: func(c *cli.Context) error { + return startStreamer(c) + }, + } + + if err := app.Run(os.Args); err != nil { + fmt.Println("Error running streamer:", err) + } +} + +func startStreamer(c *cli.Context) error { + log, err := util.NewLogger( + c.String(logLevelFlag.Name), + c.String(logFmtFlag.Name), + "", // No log tags + c.App.Writer, + ) + if err != nil { + return fmt.Errorf("failed to create logger: %w", err) + } + + // Load configuration + cfg := Config{ + RedisAddr: c.String(redisAddrFlag.Name), + ListenAddr: c.String(listenAddrFlag.Name), + } + + log.Info("Starting streamer with configuration", "config", cfg) + + // Initialize the Streamer + streamer, err := streamer.NewPayloadStreamer(cfg.RedisAddr, log) + if err != nil { + log.Error("Failed to initialize Streamer", "error", err) + return err + } + + ctx, stop := signal.NotifyContext(c.Context, os.Interrupt, syscall.SIGTERM) + defer stop() + + // Start the streamer + go func() { + if err := streamer.Start(cfg.ListenAddr); err != nil { + log.Error("Streamer exited with error", "error", err) + stop() + } + }() + + <-ctx.Done() + + streamer.Stop() + + log.Info("Streamer shutdown completed") + return nil +} diff --git a/cl/go.mod b/cl/go.mod index bf74747d5..55c4ae25b 100644 --- a/cl/go.mod +++ b/cl/go.mod @@ -9,6 +9,7 @@ require ( github.com/redis/go-redis/v9 v9.6.1 github.com/urfave/cli/v2 v2.27.4 golang.org/x/tools v0.23.0 + google.golang.org/grpc v1.67.1 ) require ( @@ -54,7 +55,9 @@ require ( github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect golang.org/x/mod v0.19.0 // indirect + golang.org/x/net v0.28.0 // indirect golang.org/x/sync v0.8.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 // indirect rsc.io/tmplfunc v0.0.3 // indirect ) @@ -88,6 +91,6 @@ require ( golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect golang.org/x/sys v0.26.0 // indirect golang.org/x/text v0.17.0 // indirect - google.golang.org/protobuf v1.34.2 // indirect + google.golang.org/protobuf v1.34.2 gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/cl/go.sum b/cl/go.sum index 881839bb6..63a090407 100644 --- a/cl/go.sum +++ b/cl/go.sum @@ -290,6 +290,10 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 h1:e7S5W7MGGLaSu8j3YjdezkZ+m1/Nm0uRVRMEMGk26Xs= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E= +google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/cl/member/member.go b/cl/member/member.go new file mode 100644 index 000000000..5ce0abb33 --- /dev/null +++ b/cl/member/member.go @@ -0,0 +1,129 @@ +package member + +import ( + "context" + "encoding/hex" + "errors" + "fmt" + "math/big" + + "log/slog" + + "github.com/ethereum/go-ethereum/beacon/engine" + "github.com/ethereum/go-ethereum/common" + gtypes "github.com/ethereum/go-ethereum/core/types" + "github.com/primev/mev-commit/cl/ethclient" + "github.com/primev/mev-commit/cl/pb/pb" + "github.com/primev/mev-commit/cl/redisapp/blockbuilder" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +type MemberClient struct { + clientID string + streamerAddr string + conn *grpc.ClientConn + client pb.PayloadStreamerClient + logger *slog.Logger + engineCl EngineClient + bb BlockBuilder +} + +type EngineClient interface { + NewPayloadV3(ctx context.Context, params engine.ExecutableData, versionedHashes []common.Hash, beaconRoot *common.Hash) (engine.PayloadStatusV1, error) + ForkchoiceUpdatedV3(ctx context.Context, update engine.ForkchoiceStateV1, payloadAttributes *engine.PayloadAttributes) (engine.ForkChoiceResponse, error) + HeaderByNumber(ctx context.Context, number *big.Int) (*gtypes.Header, error) +} + +type BlockBuilder interface { + FinalizeBlock(ctx context.Context, payloadIDStr, executionPayloadStr, msgID string) error +} + +func NewMemberClient(clientID, streamerAddr, ecURL, jwtSecret string, logger *slog.Logger) (*MemberClient, error) { + conn, err := grpc.NewClient(streamerAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + client := pb.NewPayloadStreamerClient(conn) + + bytes, err := hex.DecodeString(jwtSecret) + if err != nil { + return nil, fmt.Errorf("error decoding JWT secret: %v", err) + } + + engineCL, err := ethclient.NewAuthClient(context.Background(), ecURL, bytes) + if err != nil { + return nil, fmt.Errorf("error creating engine client: %v", err) + } + + bb := blockbuilder.NewMemberBlockBuilder(engineCL, logger) + + return &MemberClient{ + clientID: clientID, + streamerAddr: streamerAddr, + conn: conn, + client: client, + engineCl: engineCL, + logger: logger, + bb: bb, + }, nil +} + +func (mc *MemberClient) Run(ctx context.Context) error { + stream, err := mc.client.Subscribe(ctx) + if err != nil { + return err + } + + err = stream.Send(&pb.ClientMessage{ + Message: &pb.ClientMessage_SubscribeRequest{ + SubscribeRequest: &pb.SubscribeRequest{ + ClientId: mc.clientID, + }, + }, + }) + if err != nil { + mc.logger.Error("Failed to send SubscribeRequest", "error", err) + return err + } + + mc.logger.Info("Member client started", "clientID", mc.clientID) + + for { + select { + case <-ctx.Done(): + mc.logger.Info("Member client context done", "clientID", mc.clientID) + return nil + default: + msg, err := stream.Recv() + if err != nil { + if errors.Is(err, context.Canceled) { + mc.logger.Info("Member client context canceled", "clientID", mc.clientID) + return nil + } + mc.logger.Error("Error receiving message", "error", err) + continue + } + err = mc.bb.FinalizeBlock(ctx, msg.PayloadId, msg.ExecutionPayload, msg.MessageId) + if err != nil { + mc.logger.Error("Error processing payload", "error", err) + continue + } + + err = stream.Send(&pb.ClientMessage{ + Message: &pb.ClientMessage_AckPayload{ + AckPayload: &pb.AckPayloadRequest{ + ClientId: mc.clientID, + PayloadId: msg.PayloadId, + MessageId: msg.MessageId, + }, + }, + }) + if err != nil { + mc.logger.Error("Failed to send acknowledgment", "error", err) + continue + } + mc.logger.Info("Acknowledged message", "payloadID", msg.PayloadId) + } + } +} diff --git a/cl/member/member_test.go b/cl/member/member_test.go new file mode 100644 index 000000000..386ae80e8 --- /dev/null +++ b/cl/member/member_test.go @@ -0,0 +1,213 @@ +package member + +import ( + "context" + "encoding/hex" + "io" + "log/slog" + "math/big" + "net" + "sync" + "testing" + "time" + + "github.com/ethereum/go-ethereum/beacon/engine" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/primev/mev-commit/cl/pb/pb" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +type mockEngineClient struct{} + +func (m *mockEngineClient) NewPayloadV3(ctx context.Context, params engine.ExecutableData, versionedHashes []common.Hash, beaconRoot *common.Hash) (engine.PayloadStatusV1, error) { + return engine.PayloadStatusV1{}, nil +} + +func (m *mockEngineClient) ForkchoiceUpdatedV3(context.Context, engine.ForkchoiceStateV1, *engine.PayloadAttributes) (engine.ForkChoiceResponse, error) { + return engine.ForkChoiceResponse{}, nil +} + +func (m *mockEngineClient) HeaderByNumber(ctx context.Context, number *big.Int) (*types.Header, error) { + return &types.Header{}, nil +} + +type mockBlockBuilder struct { + mu sync.Mutex + finalizeCalls []finalizeCall +} + +type finalizeCall struct { + payloadID string + executionPayload string + messageID string +} + +func (m *mockBlockBuilder) FinalizeBlock(ctx context.Context, payloadIDStr, executionPayloadStr, msgID string) error { + m.mu.Lock() + defer m.mu.Unlock() + m.finalizeCalls = append(m.finalizeCalls, finalizeCall{ + payloadID: payloadIDStr, + executionPayload: executionPayloadStr, + messageID: msgID, + }) + return nil +} + +func (f *mockBlockBuilder) Calls() []finalizeCall { + f.mu.Lock() + defer f.mu.Unlock() + return append([]finalizeCall(nil), f.finalizeCalls...) +} + +// fakePayloadStreamerServer simulates the PayloadStreamer gRPC service for testing. +type fakePayloadStreamerServer struct { + pb.UnimplementedPayloadStreamerServer + + mu sync.Mutex + subscribed bool + sentPayload bool + clientID string + serverStopped bool +} + +func (s *fakePayloadStreamerServer) Subscribe(stream pb.PayloadStreamer_SubscribeServer) error { + for { + msg, err := stream.Recv() + if err == io.EOF || s.serverStopped { + return nil + } + if err != nil { + return err + } + + if req := msg.GetSubscribeRequest(); req != nil { + // Acknowledge subscription + s.mu.Lock() + s.subscribed = true + s.clientID = req.GetClientId() + s.mu.Unlock() + + // After subscribing, send a single payload message, then close the stream. + resp := &pb.PayloadMessage{ + PayloadId: "test-payload-id", + ExecutionPayload: "test-exec-payload", + SenderInstanceId: "sender-123", + MessageId: "test-msg-id", + } + if err := stream.SendMsg(resp); err != nil { + return err + } + s.mu.Lock() + s.sentPayload = true + s.mu.Unlock() + + // Wait a moment and then return EOF to stop the stream + time.Sleep(200 * time.Millisecond) + return nil + } else if ack := msg.GetAckPayload(); ack != nil { + continue + } + } +} + +// TestMemberClientRun tests the Run method end-to-end with a fake server and fake dependencies. +func TestMemberClientRun(t *testing.T) { + lis, err := net.Listen("tcp", "127.0.0.1:0") // ephemeral port + if err != nil { + t.Fatalf("failed to listen: %v", err) + } + s := grpc.NewServer() + defer s.Stop() + + streamerServer := &fakePayloadStreamerServer{} + pb.RegisterPayloadStreamerServer(s, streamerServer) + + errChan := make(chan error, 1) + go func() { + errChan <- s.Serve(lis) + }() + + select { + case err := <-errChan: + if err != nil { + t.Fatalf("failed to serve: %v", err) + } + case <-time.After(time.Millisecond * 100): + // Server started successfully + } + + clientID := "test-client-id" + streamerAddr := lis.Addr().String() + logger := slog.Default() + + engineClient := &mockEngineClient{} + blockBuilder := &mockBlockBuilder{} + + conn, err := grpc.NewClient(streamerAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("failed to dial test server: %v", err) + } + streamerClient := pb.NewPayloadStreamerClient(conn) + + mc := &MemberClient{ + clientID: clientID, + streamerAddr: streamerAddr, + conn: conn, + client: streamerClient, + logger: logger, + engineCl: engineClient, + bb: blockBuilder, + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err = mc.Run(ctx) + if err != nil { + t.Errorf("MemberClient.Run returned an error: %v", err) + } + + streamerServer.mu.Lock() + subscribed := streamerServer.subscribed + sentPayload := streamerServer.sentPayload + streamerServer.mu.Unlock() + + if !subscribed { + t.Errorf("Server did not receive subscription from client") + } + if !sentPayload { + t.Errorf("Server did not send a payload message") + } + + calls := blockBuilder.Calls() + if len(calls) != 1 { + t.Fatalf("Expected 1 FinalizeBlock call, got %d", len(calls)) + } + call := calls[0] + if call.payloadID != "test-payload-id" { + t.Errorf("Expected payloadID 'test-payload-id', got '%s'", call.payloadID) + } + if call.executionPayload != "test-exec-payload" { + t.Errorf("Expected executionPayload 'test-exec-payload', got '%s'", call.executionPayload) + } + if call.messageID != "test-msg-id" { + t.Errorf("Expected messageID 'test-msg-id', got '%s'", call.messageID) + } +} + +func TestJWTSecretDecodingNoMocks(t *testing.T) { + validSecret := "deadbeef" + invalidSecret := "zzzz" + + _, err := hex.DecodeString(validSecret) + if err != nil { + t.Errorf("Failed to decode valid secret: %v", err) + } + + _, err = hex.DecodeString(invalidSecret) + if err == nil { + t.Error("Expected error decoding invalid secret, got none") + } +} diff --git a/cl/mocks/mock_state.go b/cl/mocks/mock_state.go index f3803b80f..5eaaa36dd 100644 --- a/cl/mocks/mock_state.go +++ b/cl/mocks/mock_state.go @@ -10,6 +10,7 @@ import ( time "time" gomock "github.com/golang/mock/gomock" + state "github.com/primev/mev-commit/cl/redisapp/state" types "github.com/primev/mev-commit/cl/redisapp/types" redis "github.com/redis/go-redis/v9" ) @@ -7345,8 +7346,147 @@ func (m *MockStateManager) EXPECT() *MockStateManagerMockRecorder { return m.recorder } +// ExecuteTransaction mocks base method. +func (m *MockStateManager) ExecuteTransaction(ctx context.Context, ops ...state.PipelineOperation) error { + m.ctrl.T.Helper() + varargs := []interface{}{ctx} + for _, a := range ops { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ExecuteTransaction", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// ExecuteTransaction indicates an expected call of ExecuteTransaction. +func (mr *MockStateManagerMockRecorder) ExecuteTransaction(ctx interface{}, ops ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx}, ops...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecuteTransaction", reflect.TypeOf((*MockStateManager)(nil).ExecuteTransaction), varargs...) +} + +// GetBlockBuildState mocks base method. +func (m *MockStateManager) GetBlockBuildState(ctx context.Context) types.BlockBuildState { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetBlockBuildState", ctx) + ret0, _ := ret[0].(types.BlockBuildState) + return ret0 +} + +// GetBlockBuildState indicates an expected call of GetBlockBuildState. +func (mr *MockStateManagerMockRecorder) GetBlockBuildState(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBlockBuildState", reflect.TypeOf((*MockStateManager)(nil).GetBlockBuildState), ctx) +} + +// LoadExecutionHead mocks base method. +func (m *MockStateManager) LoadExecutionHead(ctx context.Context) (*types.ExecutionHead, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadExecutionHead", ctx) + ret0, _ := ret[0].(*types.ExecutionHead) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LoadExecutionHead indicates an expected call of LoadExecutionHead. +func (mr *MockStateManagerMockRecorder) LoadExecutionHead(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadExecutionHead", reflect.TypeOf((*MockStateManager)(nil).LoadExecutionHead), ctx) +} + +// LoadOrInitializeBlockState mocks base method. +func (m *MockStateManager) LoadOrInitializeBlockState(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadOrInitializeBlockState", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// LoadOrInitializeBlockState indicates an expected call of LoadOrInitializeBlockState. +func (mr *MockStateManagerMockRecorder) LoadOrInitializeBlockState(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOrInitializeBlockState", reflect.TypeOf((*MockStateManager)(nil).LoadOrInitializeBlockState), ctx) +} + +// ResetBlockState mocks base method. +func (m *MockStateManager) ResetBlockState(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResetBlockState", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// ResetBlockState indicates an expected call of ResetBlockState. +func (mr *MockStateManagerMockRecorder) ResetBlockState(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetBlockState", reflect.TypeOf((*MockStateManager)(nil).ResetBlockState), ctx) +} + +// SaveBlockState mocks base method. +func (m *MockStateManager) SaveBlockState(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveBlockState", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveBlockState indicates an expected call of SaveBlockState. +func (mr *MockStateManagerMockRecorder) SaveBlockState(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveBlockState", reflect.TypeOf((*MockStateManager)(nil).SaveBlockState), ctx) +} + +// SaveExecutionHead mocks base method. +func (m *MockStateManager) SaveExecutionHead(ctx context.Context, head *types.ExecutionHead) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveExecutionHead", ctx, head) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveExecutionHead indicates an expected call of SaveExecutionHead. +func (mr *MockStateManagerMockRecorder) SaveExecutionHead(ctx, head interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveExecutionHead", reflect.TypeOf((*MockStateManager)(nil).SaveExecutionHead), ctx, head) +} + +// Stop mocks base method. +func (m *MockStateManager) Stop() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Stop") +} + +// Stop indicates an expected call of Stop. +func (mr *MockStateManagerMockRecorder) Stop() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockStateManager)(nil).Stop)) +} + +// MockStreamManager is a mock of StreamManager interface. +type MockStreamManager struct { + ctrl *gomock.Controller + recorder *MockStreamManagerMockRecorder +} + +// MockStreamManagerMockRecorder is the mock recorder for MockStreamManager. +type MockStreamManagerMockRecorder struct { + mock *MockStreamManager +} + +// NewMockStreamManager creates a new mock instance. +func NewMockStreamManager(ctrl *gomock.Controller) *MockStreamManager { + mock := &MockStreamManager{ctrl: ctrl} + mock.recorder = &MockStreamManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStreamManager) EXPECT() *MockStreamManagerMockRecorder { + return m.recorder +} + // AckMessage mocks base method. -func (m *MockStateManager) AckMessage(ctx context.Context, messageID string) error { +func (m *MockStreamManager) AckMessage(ctx context.Context, messageID string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AckMessage", ctx, messageID) ret0, _ := ret[0].(error) @@ -7354,13 +7494,13 @@ func (m *MockStateManager) AckMessage(ctx context.Context, messageID string) err } // AckMessage indicates an expected call of AckMessage. -func (mr *MockStateManagerMockRecorder) AckMessage(ctx, messageID interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) AckMessage(ctx, messageID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AckMessage", reflect.TypeOf((*MockStateManager)(nil).AckMessage), ctx, messageID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AckMessage", reflect.TypeOf((*MockStreamManager)(nil).AckMessage), ctx, messageID) } // CreateConsumerGroup mocks base method. -func (m *MockStateManager) CreateConsumerGroup(ctx context.Context) error { +func (m *MockStreamManager) CreateConsumerGroup(ctx context.Context) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CreateConsumerGroup", ctx) ret0, _ := ret[0].(error) @@ -7368,13 +7508,143 @@ func (m *MockStateManager) CreateConsumerGroup(ctx context.Context) error { } // CreateConsumerGroup indicates an expected call of CreateConsumerGroup. -func (mr *MockStateManagerMockRecorder) CreateConsumerGroup(ctx interface{}) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) CreateConsumerGroup(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateConsumerGroup", reflect.TypeOf((*MockStreamManager)(nil).CreateConsumerGroup), ctx) +} + +// ExecuteTransaction mocks base method. +func (m *MockStreamManager) ExecuteTransaction(ctx context.Context, ops ...state.PipelineOperation) error { + m.ctrl.T.Helper() + varargs := []interface{}{ctx} + for _, a := range ops { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ExecuteTransaction", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// ExecuteTransaction indicates an expected call of ExecuteTransaction. +func (mr *MockStreamManagerMockRecorder) ExecuteTransaction(ctx interface{}, ops ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx}, ops...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecuteTransaction", reflect.TypeOf((*MockStreamManager)(nil).ExecuteTransaction), varargs...) +} + +// PublishToStream mocks base method. +func (m *MockStreamManager) PublishToStream(ctx context.Context, bsState *types.BlockBuildState) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PublishToStream", ctx, bsState) + ret0, _ := ret[0].(error) + return ret0 +} + +// PublishToStream indicates an expected call of PublishToStream. +func (mr *MockStreamManagerMockRecorder) PublishToStream(ctx, bsState interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateConsumerGroup", reflect.TypeOf((*MockStateManager)(nil).CreateConsumerGroup), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishToStream", reflect.TypeOf((*MockStreamManager)(nil).PublishToStream), ctx, bsState) +} + +// ReadMessagesFromStream mocks base method. +func (m *MockStreamManager) ReadMessagesFromStream(ctx context.Context, msgType types.RedisMsgType) ([]redis.XStream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadMessagesFromStream", ctx, msgType) + ret0, _ := ret[0].([]redis.XStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadMessagesFromStream indicates an expected call of ReadMessagesFromStream. +func (mr *MockStreamManagerMockRecorder) ReadMessagesFromStream(ctx, msgType interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadMessagesFromStream", reflect.TypeOf((*MockStreamManager)(nil).ReadMessagesFromStream), ctx, msgType) +} + +// Stop mocks base method. +func (m *MockStreamManager) Stop() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Stop") +} + +// Stop indicates an expected call of Stop. +func (mr *MockStreamManagerMockRecorder) Stop() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockStreamManager)(nil).Stop)) +} + +// MockCoordinator is a mock of Coordinator interface. +type MockCoordinator struct { + ctrl *gomock.Controller + recorder *MockCoordinatorMockRecorder +} + +// MockCoordinatorMockRecorder is the mock recorder for MockCoordinator. +type MockCoordinatorMockRecorder struct { + mock *MockCoordinator +} + +// NewMockCoordinator creates a new mock instance. +func NewMockCoordinator(ctrl *gomock.Controller) *MockCoordinator { + mock := &MockCoordinator{ctrl: ctrl} + mock.recorder = &MockCoordinatorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockCoordinator) EXPECT() *MockCoordinatorMockRecorder { + return m.recorder +} + +// AckMessage mocks base method. +func (m *MockCoordinator) AckMessage(ctx context.Context, messageID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AckMessage", ctx, messageID) + ret0, _ := ret[0].(error) + return ret0 +} + +// AckMessage indicates an expected call of AckMessage. +func (mr *MockCoordinatorMockRecorder) AckMessage(ctx, messageID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AckMessage", reflect.TypeOf((*MockCoordinator)(nil).AckMessage), ctx, messageID) +} + +// CreateConsumerGroup mocks base method. +func (m *MockCoordinator) CreateConsumerGroup(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateConsumerGroup", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateConsumerGroup indicates an expected call of CreateConsumerGroup. +func (mr *MockCoordinatorMockRecorder) CreateConsumerGroup(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateConsumerGroup", reflect.TypeOf((*MockCoordinator)(nil).CreateConsumerGroup), ctx) +} + +// ExecuteTransaction mocks base method. +func (m *MockCoordinator) ExecuteTransaction(ctx context.Context, ops ...state.PipelineOperation) error { + m.ctrl.T.Helper() + varargs := []interface{}{ctx} + for _, a := range ops { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ExecuteTransaction", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// ExecuteTransaction indicates an expected call of ExecuteTransaction. +func (mr *MockCoordinatorMockRecorder) ExecuteTransaction(ctx interface{}, ops ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx}, ops...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecuteTransaction", reflect.TypeOf((*MockCoordinator)(nil).ExecuteTransaction), varargs...) } // GetBlockBuildState mocks base method. -func (m *MockStateManager) GetBlockBuildState(ctx context.Context) types.BlockBuildState { +func (m *MockCoordinator) GetBlockBuildState(ctx context.Context) types.BlockBuildState { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetBlockBuildState", ctx) ret0, _ := ret[0].(types.BlockBuildState) @@ -7382,13 +7652,13 @@ func (m *MockStateManager) GetBlockBuildState(ctx context.Context) types.BlockBu } // GetBlockBuildState indicates an expected call of GetBlockBuildState. -func (mr *MockStateManagerMockRecorder) GetBlockBuildState(ctx interface{}) *gomock.Call { +func (mr *MockCoordinatorMockRecorder) GetBlockBuildState(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBlockBuildState", reflect.TypeOf((*MockStateManager)(nil).GetBlockBuildState), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBlockBuildState", reflect.TypeOf((*MockCoordinator)(nil).GetBlockBuildState), ctx) } // LoadExecutionHead mocks base method. -func (m *MockStateManager) LoadExecutionHead(ctx context.Context) (*types.ExecutionHead, error) { +func (m *MockCoordinator) LoadExecutionHead(ctx context.Context) (*types.ExecutionHead, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LoadExecutionHead", ctx) ret0, _ := ret[0].(*types.ExecutionHead) @@ -7397,13 +7667,13 @@ func (m *MockStateManager) LoadExecutionHead(ctx context.Context) (*types.Execut } // LoadExecutionHead indicates an expected call of LoadExecutionHead. -func (mr *MockStateManagerMockRecorder) LoadExecutionHead(ctx interface{}) *gomock.Call { +func (mr *MockCoordinatorMockRecorder) LoadExecutionHead(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadExecutionHead", reflect.TypeOf((*MockStateManager)(nil).LoadExecutionHead), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadExecutionHead", reflect.TypeOf((*MockCoordinator)(nil).LoadExecutionHead), ctx) } // LoadOrInitializeBlockState mocks base method. -func (m *MockStateManager) LoadOrInitializeBlockState(ctx context.Context) error { +func (m *MockCoordinator) LoadOrInitializeBlockState(ctx context.Context) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LoadOrInitializeBlockState", ctx) ret0, _ := ret[0].(error) @@ -7411,13 +7681,27 @@ func (m *MockStateManager) LoadOrInitializeBlockState(ctx context.Context) error } // LoadOrInitializeBlockState indicates an expected call of LoadOrInitializeBlockState. -func (mr *MockStateManagerMockRecorder) LoadOrInitializeBlockState(ctx interface{}) *gomock.Call { +func (mr *MockCoordinatorMockRecorder) LoadOrInitializeBlockState(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOrInitializeBlockState", reflect.TypeOf((*MockStateManager)(nil).LoadOrInitializeBlockState), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOrInitializeBlockState", reflect.TypeOf((*MockCoordinator)(nil).LoadOrInitializeBlockState), ctx) +} + +// PublishToStream mocks base method. +func (m *MockCoordinator) PublishToStream(ctx context.Context, bsState *types.BlockBuildState) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PublishToStream", ctx, bsState) + ret0, _ := ret[0].(error) + return ret0 +} + +// PublishToStream indicates an expected call of PublishToStream. +func (mr *MockCoordinatorMockRecorder) PublishToStream(ctx, bsState interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishToStream", reflect.TypeOf((*MockCoordinator)(nil).PublishToStream), ctx, bsState) } // ReadMessagesFromStream mocks base method. -func (m *MockStateManager) ReadMessagesFromStream(ctx context.Context, msgType types.RedisMsgType) ([]redis.XStream, error) { +func (m *MockCoordinator) ReadMessagesFromStream(ctx context.Context, msgType types.RedisMsgType) ([]redis.XStream, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ReadMessagesFromStream", ctx, msgType) ret0, _ := ret[0].([]redis.XStream) @@ -7426,13 +7710,13 @@ func (m *MockStateManager) ReadMessagesFromStream(ctx context.Context, msgType t } // ReadMessagesFromStream indicates an expected call of ReadMessagesFromStream. -func (mr *MockStateManagerMockRecorder) ReadMessagesFromStream(ctx, msgType interface{}) *gomock.Call { +func (mr *MockCoordinatorMockRecorder) ReadMessagesFromStream(ctx, msgType interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadMessagesFromStream", reflect.TypeOf((*MockStateManager)(nil).ReadMessagesFromStream), ctx, msgType) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadMessagesFromStream", reflect.TypeOf((*MockCoordinator)(nil).ReadMessagesFromStream), ctx, msgType) } // ResetBlockState mocks base method. -func (m *MockStateManager) ResetBlockState(ctx context.Context) error { +func (m *MockCoordinator) ResetBlockState(ctx context.Context) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ResetBlockState", ctx) ret0, _ := ret[0].(error) @@ -7440,13 +7724,13 @@ func (m *MockStateManager) ResetBlockState(ctx context.Context) error { } // ResetBlockState indicates an expected call of ResetBlockState. -func (mr *MockStateManagerMockRecorder) ResetBlockState(ctx interface{}) *gomock.Call { +func (mr *MockCoordinatorMockRecorder) ResetBlockState(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetBlockState", reflect.TypeOf((*MockStateManager)(nil).ResetBlockState), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetBlockState", reflect.TypeOf((*MockCoordinator)(nil).ResetBlockState), ctx) } // SaveBlockState mocks base method. -func (m *MockStateManager) SaveBlockState(ctx context.Context) error { +func (m *MockCoordinator) SaveBlockState(ctx context.Context) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SaveBlockState", ctx) ret0, _ := ret[0].(error) @@ -7454,13 +7738,13 @@ func (m *MockStateManager) SaveBlockState(ctx context.Context) error { } // SaveBlockState indicates an expected call of SaveBlockState. -func (mr *MockStateManagerMockRecorder) SaveBlockState(ctx interface{}) *gomock.Call { +func (mr *MockCoordinatorMockRecorder) SaveBlockState(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveBlockState", reflect.TypeOf((*MockStateManager)(nil).SaveBlockState), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveBlockState", reflect.TypeOf((*MockCoordinator)(nil).SaveBlockState), ctx) } // SaveBlockStateAndPublishToStream mocks base method. -func (m *MockStateManager) SaveBlockStateAndPublishToStream(ctx context.Context, bsState *types.BlockBuildState) error { +func (m *MockCoordinator) SaveBlockStateAndPublishToStream(ctx context.Context, bsState *types.BlockBuildState) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SaveBlockStateAndPublishToStream", ctx, bsState) ret0, _ := ret[0].(error) @@ -7468,13 +7752,13 @@ func (m *MockStateManager) SaveBlockStateAndPublishToStream(ctx context.Context, } // SaveBlockStateAndPublishToStream indicates an expected call of SaveBlockStateAndPublishToStream. -func (mr *MockStateManagerMockRecorder) SaveBlockStateAndPublishToStream(ctx, bsState interface{}) *gomock.Call { +func (mr *MockCoordinatorMockRecorder) SaveBlockStateAndPublishToStream(ctx, bsState interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveBlockStateAndPublishToStream", reflect.TypeOf((*MockStateManager)(nil).SaveBlockStateAndPublishToStream), ctx, bsState) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveBlockStateAndPublishToStream", reflect.TypeOf((*MockCoordinator)(nil).SaveBlockStateAndPublishToStream), ctx, bsState) } // SaveExecutionHead mocks base method. -func (m *MockStateManager) SaveExecutionHead(ctx context.Context, head *types.ExecutionHead) error { +func (m *MockCoordinator) SaveExecutionHead(ctx context.Context, head *types.ExecutionHead) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SaveExecutionHead", ctx, head) ret0, _ := ret[0].(error) @@ -7482,13 +7766,13 @@ func (m *MockStateManager) SaveExecutionHead(ctx context.Context, head *types.Ex } // SaveExecutionHead indicates an expected call of SaveExecutionHead. -func (mr *MockStateManagerMockRecorder) SaveExecutionHead(ctx, head interface{}) *gomock.Call { +func (mr *MockCoordinatorMockRecorder) SaveExecutionHead(ctx, head interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveExecutionHead", reflect.TypeOf((*MockStateManager)(nil).SaveExecutionHead), ctx, head) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveExecutionHead", reflect.TypeOf((*MockCoordinator)(nil).SaveExecutionHead), ctx, head) } // SaveExecutionHeadAndAck mocks base method. -func (m *MockStateManager) SaveExecutionHeadAndAck(ctx context.Context, head *types.ExecutionHead, messageID string) error { +func (m *MockCoordinator) SaveExecutionHeadAndAck(ctx context.Context, head *types.ExecutionHead, messageID string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SaveExecutionHeadAndAck", ctx, head, messageID) ret0, _ := ret[0].(error) @@ -7496,19 +7780,19 @@ func (m *MockStateManager) SaveExecutionHeadAndAck(ctx context.Context, head *ty } // SaveExecutionHeadAndAck indicates an expected call of SaveExecutionHeadAndAck. -func (mr *MockStateManagerMockRecorder) SaveExecutionHeadAndAck(ctx, head, messageID interface{}) *gomock.Call { +func (mr *MockCoordinatorMockRecorder) SaveExecutionHeadAndAck(ctx, head, messageID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveExecutionHeadAndAck", reflect.TypeOf((*MockStateManager)(nil).SaveExecutionHeadAndAck), ctx, head, messageID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveExecutionHeadAndAck", reflect.TypeOf((*MockCoordinator)(nil).SaveExecutionHeadAndAck), ctx, head, messageID) } // Stop mocks base method. -func (m *MockStateManager) Stop() { +func (m *MockCoordinator) Stop() { m.ctrl.T.Helper() m.ctrl.Call(m, "Stop") } // Stop indicates an expected call of Stop. -func (mr *MockStateManagerMockRecorder) Stop() *gomock.Call { +func (mr *MockCoordinatorMockRecorder) Stop() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockStateManager)(nil).Stop)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockCoordinator)(nil).Stop)) } diff --git a/cl/pb/pb/streamer.pb.go b/cl/pb/pb/streamer.pb.go new file mode 100644 index 000000000..c5fe2ad62 --- /dev/null +++ b/cl/pb/pb/streamer.pb.go @@ -0,0 +1,436 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.34.2 +// protoc v3.19.1 +// source: streamer.proto + +package pb + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type ClientMessage struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Types that are assignable to Message: + // + // *ClientMessage_SubscribeRequest + // *ClientMessage_AckPayload + Message isClientMessage_Message `protobuf_oneof:"message"` +} + +func (x *ClientMessage) Reset() { + *x = ClientMessage{} + if protoimpl.UnsafeEnabled { + mi := &file_streamer_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ClientMessage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ClientMessage) ProtoMessage() {} + +func (x *ClientMessage) ProtoReflect() protoreflect.Message { + mi := &file_streamer_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ClientMessage.ProtoReflect.Descriptor instead. +func (*ClientMessage) Descriptor() ([]byte, []int) { + return file_streamer_proto_rawDescGZIP(), []int{0} +} + +func (m *ClientMessage) GetMessage() isClientMessage_Message { + if m != nil { + return m.Message + } + return nil +} + +func (x *ClientMessage) GetSubscribeRequest() *SubscribeRequest { + if x, ok := x.GetMessage().(*ClientMessage_SubscribeRequest); ok { + return x.SubscribeRequest + } + return nil +} + +func (x *ClientMessage) GetAckPayload() *AckPayloadRequest { + if x, ok := x.GetMessage().(*ClientMessage_AckPayload); ok { + return x.AckPayload + } + return nil +} + +type isClientMessage_Message interface { + isClientMessage_Message() +} + +type ClientMessage_SubscribeRequest struct { + SubscribeRequest *SubscribeRequest `protobuf:"bytes,1,opt,name=subscribe_request,json=subscribeRequest,proto3,oneof"` +} + +type ClientMessage_AckPayload struct { + AckPayload *AckPayloadRequest `protobuf:"bytes,2,opt,name=ack_payload,json=ackPayload,proto3,oneof"` +} + +func (*ClientMessage_SubscribeRequest) isClientMessage_Message() {} + +func (*ClientMessage_AckPayload) isClientMessage_Message() {} + +type SubscribeRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ClientId string `protobuf:"bytes,1,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"` +} + +func (x *SubscribeRequest) Reset() { + *x = SubscribeRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_streamer_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SubscribeRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SubscribeRequest) ProtoMessage() {} + +func (x *SubscribeRequest) ProtoReflect() protoreflect.Message { + mi := &file_streamer_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SubscribeRequest.ProtoReflect.Descriptor instead. +func (*SubscribeRequest) Descriptor() ([]byte, []int) { + return file_streamer_proto_rawDescGZIP(), []int{1} +} + +func (x *SubscribeRequest) GetClientId() string { + if x != nil { + return x.ClientId + } + return "" +} + +type PayloadMessage struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + PayloadId string `protobuf:"bytes,1,opt,name=payload_id,json=payloadId,proto3" json:"payload_id,omitempty"` + ExecutionPayload string `protobuf:"bytes,2,opt,name=execution_payload,json=executionPayload,proto3" json:"execution_payload,omitempty"` + SenderInstanceId string `protobuf:"bytes,3,opt,name=sender_instance_id,json=senderInstanceId,proto3" json:"sender_instance_id,omitempty"` + MessageId string `protobuf:"bytes,4,opt,name=message_id,json=messageId,proto3" json:"message_id,omitempty"` +} + +func (x *PayloadMessage) Reset() { + *x = PayloadMessage{} + if protoimpl.UnsafeEnabled { + mi := &file_streamer_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *PayloadMessage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PayloadMessage) ProtoMessage() {} + +func (x *PayloadMessage) ProtoReflect() protoreflect.Message { + mi := &file_streamer_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PayloadMessage.ProtoReflect.Descriptor instead. +func (*PayloadMessage) Descriptor() ([]byte, []int) { + return file_streamer_proto_rawDescGZIP(), []int{2} +} + +func (x *PayloadMessage) GetPayloadId() string { + if x != nil { + return x.PayloadId + } + return "" +} + +func (x *PayloadMessage) GetExecutionPayload() string { + if x != nil { + return x.ExecutionPayload + } + return "" +} + +func (x *PayloadMessage) GetSenderInstanceId() string { + if x != nil { + return x.SenderInstanceId + } + return "" +} + +func (x *PayloadMessage) GetMessageId() string { + if x != nil { + return x.MessageId + } + return "" +} + +type AckPayloadRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ClientId string `protobuf:"bytes,1,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"` + PayloadId string `protobuf:"bytes,2,opt,name=payload_id,json=payloadId,proto3" json:"payload_id,omitempty"` + MessageId string `protobuf:"bytes,3,opt,name=message_id,json=messageId,proto3" json:"message_id,omitempty"` +} + +func (x *AckPayloadRequest) Reset() { + *x = AckPayloadRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_streamer_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *AckPayloadRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AckPayloadRequest) ProtoMessage() {} + +func (x *AckPayloadRequest) ProtoReflect() protoreflect.Message { + mi := &file_streamer_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AckPayloadRequest.ProtoReflect.Descriptor instead. +func (*AckPayloadRequest) Descriptor() ([]byte, []int) { + return file_streamer_proto_rawDescGZIP(), []int{3} +} + +func (x *AckPayloadRequest) GetClientId() string { + if x != nil { + return x.ClientId + } + return "" +} + +func (x *AckPayloadRequest) GetPayloadId() string { + if x != nil { + return x.PayloadId + } + return "" +} + +func (x *AckPayloadRequest) GetMessageId() string { + if x != nil { + return x.MessageId + } + return "" +} + +var File_streamer_proto protoreflect.FileDescriptor + +var file_streamer_proto_rawDesc = []byte{ + 0x0a, 0x0e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x12, 0x02, 0x70, 0x62, 0x22, 0x99, 0x01, 0x0a, 0x0d, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x4d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x43, 0x0a, 0x11, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, + 0x69, 0x62, 0x65, 0x5f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x14, 0x2e, 0x70, 0x62, 0x2e, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x10, 0x73, 0x75, 0x62, 0x73, 0x63, + 0x72, 0x69, 0x62, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x38, 0x0a, 0x0b, 0x61, + 0x63, 0x6b, 0x5f, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x15, 0x2e, 0x70, 0x62, 0x2e, 0x41, 0x63, 0x6b, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x0a, 0x61, 0x63, 0x6b, 0x50, 0x61, + 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x42, 0x09, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x22, 0x2f, 0x0a, 0x10, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, + 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, + 0x64, 0x22, 0xa9, 0x01, 0x0a, 0x0e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x4d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x5f, + 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, + 0x64, 0x49, 0x64, 0x12, 0x2b, 0x0a, 0x11, 0x65, 0x78, 0x65, 0x63, 0x75, 0x74, 0x69, 0x6f, 0x6e, + 0x5f, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x10, + 0x65, 0x78, 0x65, 0x63, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, + 0x12, 0x2c, 0x0a, 0x12, 0x73, 0x65, 0x6e, 0x64, 0x65, 0x72, 0x5f, 0x69, 0x6e, 0x73, 0x74, 0x61, + 0x6e, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x10, 0x73, 0x65, + 0x6e, 0x64, 0x65, 0x72, 0x49, 0x6e, 0x73, 0x74, 0x61, 0x6e, 0x63, 0x65, 0x49, 0x64, 0x12, 0x1d, + 0x0a, 0x0a, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x09, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x49, 0x64, 0x22, 0x6e, 0x0a, + 0x11, 0x41, 0x63, 0x6b, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, + 0x1d, 0x0a, 0x0a, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x09, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x12, 0x1d, + 0x0a, 0x0a, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x09, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x49, 0x64, 0x32, 0x49, 0x0a, + 0x0f, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x65, 0x72, + 0x12, 0x36, 0x0a, 0x09, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x12, 0x11, 0x2e, + 0x70, 0x62, 0x2e, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x1a, 0x12, 0x2e, 0x70, 0x62, 0x2e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x4d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x28, 0x01, 0x30, 0x01, 0x42, 0x06, 0x5a, 0x04, 0x2e, 0x2f, 0x70, 0x62, + 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_streamer_proto_rawDescOnce sync.Once + file_streamer_proto_rawDescData = file_streamer_proto_rawDesc +) + +func file_streamer_proto_rawDescGZIP() []byte { + file_streamer_proto_rawDescOnce.Do(func() { + file_streamer_proto_rawDescData = protoimpl.X.CompressGZIP(file_streamer_proto_rawDescData) + }) + return file_streamer_proto_rawDescData +} + +var file_streamer_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_streamer_proto_goTypes = []any{ + (*ClientMessage)(nil), // 0: pb.ClientMessage + (*SubscribeRequest)(nil), // 1: pb.SubscribeRequest + (*PayloadMessage)(nil), // 2: pb.PayloadMessage + (*AckPayloadRequest)(nil), // 3: pb.AckPayloadRequest +} +var file_streamer_proto_depIdxs = []int32{ + 1, // 0: pb.ClientMessage.subscribe_request:type_name -> pb.SubscribeRequest + 3, // 1: pb.ClientMessage.ack_payload:type_name -> pb.AckPayloadRequest + 0, // 2: pb.PayloadStreamer.Subscribe:input_type -> pb.ClientMessage + 2, // 3: pb.PayloadStreamer.Subscribe:output_type -> pb.PayloadMessage + 3, // [3:4] is the sub-list for method output_type + 2, // [2:3] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_streamer_proto_init() } +func file_streamer_proto_init() { + if File_streamer_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_streamer_proto_msgTypes[0].Exporter = func(v any, i int) any { + switch v := v.(*ClientMessage); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_streamer_proto_msgTypes[1].Exporter = func(v any, i int) any { + switch v := v.(*SubscribeRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_streamer_proto_msgTypes[2].Exporter = func(v any, i int) any { + switch v := v.(*PayloadMessage); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_streamer_proto_msgTypes[3].Exporter = func(v any, i int) any { + switch v := v.(*AckPayloadRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + file_streamer_proto_msgTypes[0].OneofWrappers = []any{ + (*ClientMessage_SubscribeRequest)(nil), + (*ClientMessage_AckPayload)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_streamer_proto_rawDesc, + NumEnums: 0, + NumMessages: 4, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_streamer_proto_goTypes, + DependencyIndexes: file_streamer_proto_depIdxs, + MessageInfos: file_streamer_proto_msgTypes, + }.Build() + File_streamer_proto = out.File + file_streamer_proto_rawDesc = nil + file_streamer_proto_goTypes = nil + file_streamer_proto_depIdxs = nil +} diff --git a/cl/pb/pb/streamer_grpc.pb.go b/cl/pb/pb/streamer_grpc.pb.go new file mode 100644 index 000000000..9b43768d8 --- /dev/null +++ b/cl/pb/pb/streamer_grpc.pb.go @@ -0,0 +1,141 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.3.0 +// - protoc v3.19.1 +// source: streamer.proto + +package pb + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +const ( + PayloadStreamer_Subscribe_FullMethodName = "/pb.PayloadStreamer/Subscribe" +) + +// PayloadStreamerClient is the client API for PayloadStreamer service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type PayloadStreamerClient interface { + Subscribe(ctx context.Context, opts ...grpc.CallOption) (PayloadStreamer_SubscribeClient, error) +} + +type payloadStreamerClient struct { + cc grpc.ClientConnInterface +} + +func NewPayloadStreamerClient(cc grpc.ClientConnInterface) PayloadStreamerClient { + return &payloadStreamerClient{cc} +} + +func (c *payloadStreamerClient) Subscribe(ctx context.Context, opts ...grpc.CallOption) (PayloadStreamer_SubscribeClient, error) { + stream, err := c.cc.NewStream(ctx, &PayloadStreamer_ServiceDesc.Streams[0], PayloadStreamer_Subscribe_FullMethodName, opts...) + if err != nil { + return nil, err + } + x := &payloadStreamerSubscribeClient{stream} + return x, nil +} + +type PayloadStreamer_SubscribeClient interface { + Send(*ClientMessage) error + Recv() (*PayloadMessage, error) + grpc.ClientStream +} + +type payloadStreamerSubscribeClient struct { + grpc.ClientStream +} + +func (x *payloadStreamerSubscribeClient) Send(m *ClientMessage) error { + return x.ClientStream.SendMsg(m) +} + +func (x *payloadStreamerSubscribeClient) Recv() (*PayloadMessage, error) { + m := new(PayloadMessage) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// PayloadStreamerServer is the server API for PayloadStreamer service. +// All implementations must embed UnimplementedPayloadStreamerServer +// for forward compatibility +type PayloadStreamerServer interface { + Subscribe(PayloadStreamer_SubscribeServer) error + mustEmbedUnimplementedPayloadStreamerServer() +} + +// UnimplementedPayloadStreamerServer must be embedded to have forward compatible implementations. +type UnimplementedPayloadStreamerServer struct { +} + +func (UnimplementedPayloadStreamerServer) Subscribe(PayloadStreamer_SubscribeServer) error { + return status.Errorf(codes.Unimplemented, "method Subscribe not implemented") +} +func (UnimplementedPayloadStreamerServer) mustEmbedUnimplementedPayloadStreamerServer() {} + +// UnsafePayloadStreamerServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to PayloadStreamerServer will +// result in compilation errors. +type UnsafePayloadStreamerServer interface { + mustEmbedUnimplementedPayloadStreamerServer() +} + +func RegisterPayloadStreamerServer(s grpc.ServiceRegistrar, srv PayloadStreamerServer) { + s.RegisterService(&PayloadStreamer_ServiceDesc, srv) +} + +func _PayloadStreamer_Subscribe_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(PayloadStreamerServer).Subscribe(&payloadStreamerSubscribeServer{stream}) +} + +type PayloadStreamer_SubscribeServer interface { + Send(*PayloadMessage) error + Recv() (*ClientMessage, error) + grpc.ServerStream +} + +type payloadStreamerSubscribeServer struct { + grpc.ServerStream +} + +func (x *payloadStreamerSubscribeServer) Send(m *PayloadMessage) error { + return x.ServerStream.SendMsg(m) +} + +func (x *payloadStreamerSubscribeServer) Recv() (*ClientMessage, error) { + m := new(ClientMessage) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// PayloadStreamer_ServiceDesc is the grpc.ServiceDesc for PayloadStreamer service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var PayloadStreamer_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "pb.PayloadStreamer", + HandlerType: (*PayloadStreamerServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "Subscribe", + Handler: _PayloadStreamer_Subscribe_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: "streamer.proto", +} diff --git a/cl/pb/streamer.proto b/cl/pb/streamer.proto new file mode 100644 index 000000000..09ba27dcf --- /dev/null +++ b/cl/pb/streamer.proto @@ -0,0 +1,33 @@ +syntax = "proto3"; + +package pb; + +option go_package = "./pb"; + +service PayloadStreamer { + rpc Subscribe(stream ClientMessage) returns (stream PayloadMessage); +} + +message ClientMessage { + oneof message { + SubscribeRequest subscribe_request = 1; + AckPayloadRequest ack_payload = 2; + } +} + +message SubscribeRequest { + string client_id = 1; +} + +message PayloadMessage { + string payload_id = 1; + string execution_payload = 2; + string sender_instance_id = 3; + string message_id = 4; +} + +message AckPayloadRequest { + string client_id = 1; + string payload_id = 2; + string message_id = 3; +} diff --git a/cl/redis-cluster/docker-compose.yml b/cl/redis-cluster/docker-compose.yml index a5d9ba217..356888c33 100644 --- a/cl/redis-cluster/docker-compose.yml +++ b/cl/redis-cluster/docker-compose.yml @@ -1,5 +1,3 @@ -version: '3.8' - services: redis-master: image: redis:7.0 diff --git a/cl/redisapp/blockbuilder/blockbuilder.go b/cl/redisapp/blockbuilder/blockbuilder.go index 27356e112..e4280a40e 100644 --- a/cl/redisapp/blockbuilder/blockbuilder.go +++ b/cl/redisapp/blockbuilder/blockbuilder.go @@ -2,9 +2,11 @@ package blockbuilder import ( "context" + "encoding/base64" "errors" "fmt" "log/slog" + "math/big" "regexp" "strconv" "strings" @@ -14,7 +16,6 @@ import ( "github.com/ethereum/go-ethereum/beacon/engine" "github.com/ethereum/go-ethereum/common" etypes "github.com/ethereum/go-ethereum/core/types" - "github.com/primev/mev-commit/cl/redisapp/state" "github.com/primev/mev-commit/cl/redisapp/types" "github.com/primev/mev-commit/cl/redisapp/util" "github.com/vmihailenco/msgpack/v5" @@ -30,10 +31,18 @@ type EngineClient interface { payloadAttributes *engine.PayloadAttributes) (engine.ForkChoiceResponse, error) GetPayloadV3(ctx context.Context, payloadID engine.PayloadID) (*engine.ExecutionPayloadEnvelope, error) + + HeaderByNumber(ctx context.Context, number *big.Int) (*etypes.Header, error) +} + +type stateManager interface { + SaveBlockStateAndPublishToStream(ctx context.Context, state *types.BlockBuildState) error + GetBlockBuildState(ctx context.Context) types.BlockBuildState + ResetBlockState(ctx context.Context) error } type BlockBuilder struct { - stateManager state.StateManager + stateManager stateManager engineCl EngineClient logger *slog.Logger buildDelay time.Duration @@ -42,10 +51,10 @@ type BlockBuilder struct { LastCallTime time.Time lastBlockTime time.Time feeRecipient common.Address - ctx context.Context + executionHead *types.ExecutionHead } -func NewBlockBuilder(stateManager state.StateManager, engineCl EngineClient, logger *slog.Logger, buildDelay, buildDelayEmptyBlocks time.Duration, feeReceipt string) *BlockBuilder { +func NewBlockBuilder(stateManager stateManager, engineCl EngineClient, logger *slog.Logger, buildDelay, buildDelayEmptyBlocks time.Duration, feeReceipt string) *BlockBuilder { return &BlockBuilder{ stateManager: stateManager, engineCl: engineCl, @@ -58,6 +67,13 @@ func NewBlockBuilder(stateManager state.StateManager, engineCl EngineClient, log } } +func NewMemberBlockBuilder(engineCL EngineClient, logger *slog.Logger) *BlockBuilder { + return &BlockBuilder{ + engineCl: engineCL, + logger: logger, + } +} + func (bb *BlockBuilder) SetLastCallTimeToZero() { bb.LastCallTime = time.Time{} } @@ -95,7 +111,7 @@ func (bb *BlockBuilder) GetPayload(ctx context.Context) error { currentCallTime := time.Now() // Load execution head to get previous block timestamp - head, err := bb.stateManager.LoadExecutionHead(ctx) + head, err := bb.loadExecutionHead(ctx) if err != nil { return fmt.Errorf("latest execution block: %w", err) } @@ -194,12 +210,14 @@ func (bb *BlockBuilder) GetPayload(ctx context.Context) error { return fmt.Errorf("failed to marshal payload: %w", err) } + encodedPayload := base64.StdEncoding.EncodeToString(payloadData) + payloadIDStr := payloadID.String() err = bb.stateManager.SaveBlockStateAndPublishToStream(ctx, &types.BlockBuildState{ CurrentStep: types.StepFinalizeBlock, PayloadID: payloadIDStr, - ExecutionPayload: string(payloadData), + ExecutionPayload: encodedPayload, }) if err != nil { return fmt.Errorf("failed to save state after GetPayload: %w", err) @@ -325,12 +343,16 @@ func (bb *BlockBuilder) FinalizeBlock(ctx context.Context, payloadIDStr, executi return errors.New("PayloadID or ExecutionPayload is missing in build state") } + executionPayloadBytes, err := base64.StdEncoding.DecodeString(executionPayloadStr) + if err != nil { + return fmt.Errorf("failed to decode ExecutionPayload: %w", err) + } + var executionPayload engine.ExecutableData - if err := msgpack.Unmarshal([]byte(executionPayloadStr), &executionPayload); err != nil { + if err := msgpack.Unmarshal(executionPayloadBytes, &executionPayload); err != nil { return fmt.Errorf("failed to deserialize ExecutionPayload: %w", err) } - - head, err := bb.stateManager.LoadExecutionHead(ctx) + head, err := bb.loadExecutionHead(ctx) if err != nil { return fmt.Errorf("failed to load execution head: %w", err) } @@ -347,25 +369,18 @@ func (bb *BlockBuilder) FinalizeBlock(ctx context.Context, payloadIDStr, executi } fcs := engine.ForkchoiceStateV1{ - HeadBlockHash: hash, - SafeBlockHash: hash, - FinalizedBlockHash: hash, + HeadBlockHash: executionPayload.BlockHash, + SafeBlockHash: executionPayload.BlockHash, + FinalizedBlockHash: executionPayload.BlockHash, } if err := bb.updateForkChoice(ctx, fcs, retryFunc); err != nil { return fmt.Errorf("failed to finalize fork choice update: %w", err) } - executionHead := &types.ExecutionHead{ - BlockHeight: executionPayload.Number, - BlockHash: executionPayload.BlockHash[:], - BlockTime: executionPayload.Timestamp, - } - - if err := bb.saveExecutionHead(ctx, executionHead, msgID); err != nil { + if err := bb.saveExecutionHead(executionPayload); err != nil { return fmt.Errorf("failed to save execution head: %w", err) } - return nil } @@ -435,9 +450,31 @@ func (bb *BlockBuilder) updateForkChoice(ctx context.Context, fcs engine.Forkcho }) } -func (bb *BlockBuilder) saveExecutionHead(ctx context.Context, executionHead *types.ExecutionHead, msgID string) error { - if msgID == "" { - return bb.stateManager.SaveExecutionHead(ctx, executionHead) +func (bb *BlockBuilder) loadExecutionHead(ctx context.Context) (*types.ExecutionHead, error) { + if bb.executionHead != nil { + return bb.executionHead, nil + } + + header, err := bb.engineCl.HeaderByNumber(ctx, nil) // nil for the latest block + if err != nil { + return nil, fmt.Errorf("failed to get the latest block header: %w", err) + } + + bb.executionHead = &types.ExecutionHead{ + BlockHeight: header.Number.Uint64(), + BlockHash: header.Hash().Bytes(), + BlockTime: header.Time, } - return bb.stateManager.SaveExecutionHeadAndAck(ctx, executionHead, msgID) + + return bb.executionHead, nil +} + +func (bb *BlockBuilder) saveExecutionHead(executionPayload engine.ExecutableData) error { + bb.executionHead = &types.ExecutionHead{ + BlockHeight: executionPayload.Number, + BlockHash: executionPayload.BlockHash[:], + BlockTime: executionPayload.Timestamp, + } + + return nil } diff --git a/cl/redisapp/blockbuilder/blockbuilder_test.go b/cl/redisapp/blockbuilder/blockbuilder_test.go index e40c1919f..28349256c 100644 --- a/cl/redisapp/blockbuilder/blockbuilder_test.go +++ b/cl/redisapp/blockbuilder/blockbuilder_test.go @@ -2,6 +2,7 @@ package blockbuilder import ( "context" + "encoding/base64" "encoding/json" "errors" "log/slog" @@ -52,6 +53,11 @@ func (m *MockEngineClient) NewPayloadV3(ctx context.Context, executionPayload en return args.Get(0).(engine.PayloadStatusV1), args.Error(1) } +func (m *MockEngineClient) HeaderByNumber(ctx context.Context, number *big.Int) (*etypes.Header, error) { + args := m.Called(ctx, number) + return args.Get(0).(*etypes.Header), args.Error(1) +} + func TestBlockBuilder_startBuild(t *testing.T) { ctx := context.Background() @@ -64,7 +70,7 @@ func TestBlockBuilder_startBuild(t *testing.T) { BlockTime: uint64(time.Now().UnixMilli()) - 10, } - stateManager, err := state.NewRedisStateManager("instanceID123", redisClient, nil, "010203") + stateManager, err := state.NewRedisCoordinator("instanceID123", redisClient, nil) require.NoError(t, err) mockEngineClient := new(MockEngineClient) @@ -74,7 +80,6 @@ func TestBlockBuilder_startBuild(t *testing.T) { buildDelay: buildDelay, buildDelayMs: uint64(buildDelay.Milliseconds()), logger: stLog, - ctx: ctx, } timestamp := time.Now() @@ -117,25 +122,18 @@ func TestBlockBuilder_getPayload(t *testing.T) { BlockHeight: 100, BlockTime: uint64(timestamp.UnixMilli()), } - executionHeadKey := "executionHead:instanceID123" - executionHeadData, _ := msgpack.Marshal(executionHead) mockRedisClient.EXPECT(). XGroupCreateMkStream(gomock.Any(), "mevcommit_block_stream", "mevcommit_consumer_group:instanceID123", "0").Return(redis.NewStatusCmd(ctx)) - mockRedisClient.EXPECT(). - Get(gomock.Any(), executionHeadKey). - Return(redis.NewStringResult(string(executionHeadData), nil)). - Times(1) - - mockRedisClient.EXPECT().Pipeline().Return(mockPipeliner) + mockRedisClient.EXPECT().TxPipeline().Return(mockPipeliner) mockPipeliner.EXPECT().Set(ctx, "blockBuildState:instanceID123", gomock.Any(), time.Duration(0)).Return(redis.NewStatusCmd(ctx)) mockPipeliner.EXPECT().XAdd(ctx, gomock.Any()).Return(redis.NewStringCmd(ctx, "result")) mockPipeliner.EXPECT().Exec(ctx).Return([]redis.Cmder{}, nil) - stateManager, err := state.NewRedisStateManager("instanceID123", mockRedisClient, nil, "010203") + stateManager, err := state.NewRedisCoordinator("instanceID123", mockRedisClient, nil) require.NoError(t, err) mockEngineClient := new(MockEngineClient) @@ -146,7 +144,6 @@ func TestBlockBuilder_getPayload(t *testing.T) { buildDelay: buildDelay, buildDelayMs: uint64(buildDelay.Milliseconds()), logger: stLog, - ctx: ctx, } hash := common.BytesToHash(executionHead.BlockHash) @@ -178,6 +175,7 @@ func TestBlockBuilder_getPayload(t *testing.T) { } mockEngineClient.On("GetPayloadV3", mock.Anything, *payloadID).Return(executionPayload, nil) + blockBuilder.executionHead = executionHead err = blockBuilder.GetPayload(ctx) require.NoError(t, err) @@ -188,8 +186,12 @@ func TestBlockBuilder_getPayload(t *testing.T) { func TestBlockBuilder_FinalizeBlock(t *testing.T) { ctx := context.Background() - redisClient, redisMock := redismock.NewClientMock() - redisMock.ExpectXGroupCreateMkStream("mevcommit_block_stream", "mevcommit_consumer_group:instanceID123", "0").SetVal("OK") + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRedisClient := mocks.NewMockRedisClient(ctrl) + + mockRedisClient.EXPECT().XGroupCreateMkStream(ctx, "mevcommit_block_stream", "mevcommit_consumer_group:instanceID123", "0").Return(redis.NewStatusCmd(ctx)) timestamp := uint64(1728051707) // 0x66fff9fb executionHead := &types.ExecutionHead{ @@ -197,11 +199,8 @@ func TestBlockBuilder_FinalizeBlock(t *testing.T) { BlockHeight: 2, BlockTime: timestamp - 10, } - executionHeadKey := "executionHead:instanceID123" - executionHeadData, _ := msgpack.Marshal(executionHead) - redisMock.ExpectGet(executionHeadKey).SetVal(string(executionHeadData)) - stateManager, err := state.NewRedisStateManager("instanceID123", redisClient, nil, "010203") + stateManager, err := state.NewRedisCoordinator("instanceID123", mockRedisClient, nil) require.NoError(t, err) mockEngineClient := new(MockEngineClient) @@ -211,7 +210,6 @@ func TestBlockBuilder_FinalizeBlock(t *testing.T) { buildDelay: buildDelay, buildDelayMs: uint64(buildDelay.Milliseconds()), logger: stLog, - ctx: ctx, } payloadIDStr := "payloadID123" @@ -246,12 +244,15 @@ func TestBlockBuilder_FinalizeBlock(t *testing.T) { msgpackData, err := msgpack.Marshal(executionPayload) require.NoError(t, err) + encodedPayload := base64.StdEncoding.EncodeToString(msgpackData) + payloadStatus := engine.PayloadStatusV1{ Status: engine.VALID, } + mockEngineClient.On("NewPayloadV3", mock.Anything, executionPayload, []common.Hash{}, mock.Anything).Return(payloadStatus, nil) - hash := common.BytesToHash(executionHead.BlockHash) + hash := executionPayload.BlockHash fcs := engine.ForkchoiceStateV1{ HeadBlockHash: hash, SafeBlockHash: hash, @@ -262,20 +263,12 @@ func TestBlockBuilder_FinalizeBlock(t *testing.T) { } mockEngineClient.On("ForkchoiceUpdatedV3", mock.Anything, fcs, (*engine.PayloadAttributes)(nil)).Return(forkChoiceResponse, nil) - executionHeadUpdate := &types.ExecutionHead{ - BlockHash: executionPayload.BlockHash.Bytes(), - BlockHeight: executionPayload.Number, - BlockTime: executionPayload.Timestamp, - } - executionHeadDataUpdated, _ := msgpack.Marshal(executionHeadUpdate) - redisMock.ExpectSet(executionHeadKey, executionHeadDataUpdated, 0).SetVal("OK") - - err = blockBuilder.FinalizeBlock(ctx, payloadIDStr, string(msgpackData), msgID) + blockBuilder.executionHead = executionHead + err = blockBuilder.FinalizeBlock(ctx, payloadIDStr, encodedPayload, msgID) require.NoError(t, err) mockEngineClient.AssertExpectations(t) - require.NoError(t, redisMock.ExpectationsWereMet()) } func TestBlockBuilder_startBuild_ForkchoiceUpdatedError(t *testing.T) { @@ -289,7 +282,7 @@ func TestBlockBuilder_startBuild_ForkchoiceUpdatedError(t *testing.T) { BlockTime: uint64(time.Now().UnixMilli()) - 10, } - stateManager, err := state.NewRedisStateManager("instanceID123", redisClient, nil, "010203") + stateManager, err := state.NewRedisCoordinator("instanceID123", redisClient, nil) require.NoError(t, err) mockEngineClient := new(MockEngineClient) @@ -299,7 +292,6 @@ func TestBlockBuilder_startBuild_ForkchoiceUpdatedError(t *testing.T) { buildDelay: buildDelay, buildDelayMs: uint64(buildDelay.Milliseconds()), logger: stLog, - ctx: ctx, } timestamp := time.Now() @@ -334,7 +326,7 @@ func TestBlockBuilder_startBuild_InvalidPayloadStatus(t *testing.T) { BlockTime: uint64(time.Now().UnixMilli()) - 10, } - stateManager, err := state.NewRedisStateManager("instanceID123", redisClient, nil, "010203") + stateManager, err := state.NewRedisCoordinator("instanceID123", redisClient, nil) require.NoError(t, err) mockEngineClient := new(MockEngineClient) @@ -344,7 +336,6 @@ func TestBlockBuilder_startBuild_InvalidPayloadStatus(t *testing.T) { buildDelay: buildDelay, buildDelayMs: uint64(buildDelay.Milliseconds()), logger: stLog, - ctx: ctx, } timestamp := time.Now() @@ -373,34 +364,6 @@ func TestBlockBuilder_startBuild_InvalidPayloadStatus(t *testing.T) { require.NoError(t, redisMock.ExpectationsWereMet()) } -func TestBlockBuilder_getPayload_startBuildFails(t *testing.T) { - ctx := context.Background() - redisClient, redisMock := redismock.NewClientMock() - redisMock.ExpectXGroupCreateMkStream("mevcommit_block_stream", "mevcommit_consumer_group:instanceID123", "0").SetVal("OK") - - stateManager, err := state.NewRedisStateManager("instanceID123", redisClient, nil, "010203") - require.NoError(t, err) - mockEngineClient := new(MockEngineClient) - blockBuilder := &BlockBuilder{ - stateManager: stateManager, - engineCl: mockEngineClient, - buildDelay: buildDelay, - buildDelayMs: uint64(buildDelay.Milliseconds()), - logger: stLog, - ctx: ctx, - } - - executionHeadKey := "executionHead:instanceID123" - redisMock.ExpectGet(executionHeadKey).SetErr(errors.New("redis error")) - - err = blockBuilder.GetPayload(ctx) - - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to retrieve") - - require.NoError(t, redisMock.ExpectationsWereMet()) -} - func TestBlockBuilder_getPayload_GetPayloadUnknownPayload(t *testing.T) { ctx := context.Background() redisClient, redisMock := redismock.NewClientMock() @@ -412,11 +375,8 @@ func TestBlockBuilder_getPayload_GetPayloadUnknownPayload(t *testing.T) { BlockHeight: 100, BlockTime: uint64(timestamp.UnixMilli()) - 10, } - executionHeadKey := "executionHead:instanceID123" - executionHeadData, _ := msgpack.Marshal(executionHead) - redisMock.ExpectGet(executionHeadKey).SetVal(string(executionHeadData)) - stateManager, err := state.NewRedisStateManager("instanceID123", redisClient, nil, "010203") + stateManager, err := state.NewRedisCoordinator("instanceID123", redisClient, nil) require.NoError(t, err) mockEngineClient := new(MockEngineClient) blockBuilder := &BlockBuilder{ @@ -424,7 +384,6 @@ func TestBlockBuilder_getPayload_GetPayloadUnknownPayload(t *testing.T) { engineCl: mockEngineClient, buildDelay: time.Duration(1 * time.Second), logger: stLog, - ctx: ctx, } hash := common.BytesToHash(executionHead.BlockHash) @@ -446,6 +405,7 @@ func TestBlockBuilder_getPayload_GetPayloadUnknownPayload(t *testing.T) { mockEngineClient.On("GetPayloadV3", mock.Anything, *payloadID).Return(&engine.ExecutionPayloadEnvelope{}, errors.New("Unknown payload")) + blockBuilder.executionHead = executionHead err = blockBuilder.GetPayload(ctx) require.Error(t, err) @@ -466,11 +426,8 @@ func TestBlockBuilder_FinalizeBlock_InvalidBlockHeight(t *testing.T) { BlockHeight: 100, BlockTime: uint64(timestamp.UnixMilli()) - 10, } - executionHeadKey := "executionHead:instanceID123" - executionHeadData, _ := msgpack.Marshal(executionHead) - redisMock.ExpectGet(executionHeadKey).SetVal(string(executionHeadData)) - stateManager, err := state.NewRedisStateManager("instanceID123", redisClient, nil, "000000") + stateManager, err := state.NewRedisCoordinator("instanceID123", redisClient, nil) require.NoError(t, err) mockEngineClient := new(MockEngineClient) blockBuilder := &BlockBuilder{ @@ -479,7 +436,6 @@ func TestBlockBuilder_FinalizeBlock_InvalidBlockHeight(t *testing.T) { buildDelay: buildDelay, buildDelayMs: uint64(buildDelay.Milliseconds()), logger: stLog, - ctx: ctx, } payloadIDStr := "payloadID123" @@ -504,7 +460,9 @@ func TestBlockBuilder_FinalizeBlock_InvalidBlockHeight(t *testing.T) { } executionPayloadData, _ := msgpack.Marshal(executionPayload) - err = blockBuilder.FinalizeBlock(ctx, payloadIDStr, string(executionPayloadData), "") + executionPayloadEncoded := base64.StdEncoding.EncodeToString(executionPayloadData) + blockBuilder.executionHead = executionHead + err = blockBuilder.FinalizeBlock(ctx, payloadIDStr, executionPayloadEncoded, "") require.Error(t, err) assert.Contains(t, err.Error(), "invalid block height") @@ -523,11 +481,8 @@ func TestBlockBuilder_FinalizeBlock_NewPayloadInvalidStatus(t *testing.T) { BlockHeight: 2, BlockTime: timestamp - 10, } - executionHeadKey := "executionHead:instanceID123" - executionHeadData, _ := msgpack.Marshal(executionHead) - redisMock.ExpectGet(executionHeadKey).SetVal(string(executionHeadData)) - stateManager, err := state.NewRedisStateManager("instanceID123", redisClient, nil, "000000") + stateManager, err := state.NewRedisCoordinator("instanceID123", redisClient, nil) require.NoError(t, err) mockEngineClient := new(MockEngineClient) blockBuilder := &BlockBuilder{ @@ -536,7 +491,6 @@ func TestBlockBuilder_FinalizeBlock_NewPayloadInvalidStatus(t *testing.T) { buildDelay: buildDelay, buildDelayMs: uint64(buildDelay.Milliseconds()), logger: stLog, - ctx: ctx, } payloadIDStr := "payloadID123" @@ -561,13 +515,13 @@ func TestBlockBuilder_FinalizeBlock_NewPayloadInvalidStatus(t *testing.T) { } executionPayloadData, _ := msgpack.Marshal(executionPayload) - + executionPayloadEncoded := base64.StdEncoding.EncodeToString(executionPayloadData) payloadStatus := engine.PayloadStatusV1{ Status: "INVALID", } mockEngineClient.On("NewPayloadV3", mock.Anything, executionPayload, []common.Hash{}, mock.Anything).Return(payloadStatus, nil) - - err = blockBuilder.FinalizeBlock(ctx, payloadIDStr, string(executionPayloadData), "") + blockBuilder.executionHead = executionHead + err = blockBuilder.FinalizeBlock(ctx, payloadIDStr, executionPayloadEncoded, "") require.Error(t, err) assert.Contains(t, err.Error(), "failed to push new payload") @@ -587,11 +541,8 @@ func TestBlockBuilder_FinalizeBlock_ForkchoiceUpdatedInvalidStatus(t *testing.T) BlockHeight: 2, BlockTime: timestamp - 10, } - executionHeadKey := "executionHead:instanceID123" - executionHeadData, _ := msgpack.Marshal(executionHead) - redisMock.ExpectGet(executionHeadKey).SetVal(string(executionHeadData)) - stateManager, err := state.NewRedisStateManager("instanceID123", redisClient, nil, "000000") + stateManager, err := state.NewRedisCoordinator("instanceID123", redisClient, nil) require.NoError(t, err) mockEngineClient := new(MockEngineClient) blockBuilder := &BlockBuilder{ @@ -600,7 +551,6 @@ func TestBlockBuilder_FinalizeBlock_ForkchoiceUpdatedInvalidStatus(t *testing.T) buildDelay: buildDelay, buildDelayMs: uint64(buildDelay.Milliseconds()), logger: stLog, - ctx: ctx, } payloadIDStr := "payloadID123" @@ -624,16 +574,16 @@ func TestBlockBuilder_FinalizeBlock_ForkchoiceUpdatedInvalidStatus(t *testing.T) ExcessBlobGas: new(uint64), } executionPayloadData, _ := msgpack.Marshal(executionPayload) - + executionPayloadEncoded := base64.StdEncoding.EncodeToString(executionPayloadData) payloadStatus := engine.PayloadStatusV1{ Status: engine.VALID, } mockEngineClient.On("NewPayloadV3", mock.Anything, executionPayload, []common.Hash{}, mock.Anything).Return(payloadStatus, nil) fcs := engine.ForkchoiceStateV1{ - HeadBlockHash: executionPayload.ParentHash, - SafeBlockHash: executionPayload.ParentHash, - FinalizedBlockHash: executionPayload.ParentHash, + HeadBlockHash: executionPayload.BlockHash, + SafeBlockHash: executionPayload.BlockHash, + FinalizedBlockHash: executionPayload.BlockHash, } forkChoiceResponse := engine.ForkChoiceResponse{ PayloadStatus: engine.PayloadStatusV1{ @@ -642,7 +592,8 @@ func TestBlockBuilder_FinalizeBlock_ForkchoiceUpdatedInvalidStatus(t *testing.T) } mockEngineClient.On("ForkchoiceUpdatedV3", ctx, fcs, (*engine.PayloadAttributes)(nil)).Return(forkChoiceResponse, nil) - err = blockBuilder.FinalizeBlock(ctx, payloadIDStr, string(executionPayloadData), "") + blockBuilder.executionHead = executionHead + err = blockBuilder.FinalizeBlock(ctx, payloadIDStr, executionPayloadEncoded, "") require.Error(t, err) assert.Contains(t, err.Error(), "failed to finalize fork choice update") @@ -651,85 +602,6 @@ func TestBlockBuilder_FinalizeBlock_ForkchoiceUpdatedInvalidStatus(t *testing.T) require.NoError(t, redisMock.ExpectationsWereMet()) } -func TestBlockBuilder_FinalizeBlock_SaveExecutionHeadError(t *testing.T) { - ctx := context.Background() - redisClient, redisMock := redismock.NewClientMock() - redisMock.ExpectXGroupCreateMkStream("mevcommit_block_stream", "mevcommit_consumer_group:instanceID123", "0").SetVal("OK") - - timestamp := uint64(1728051707) // 0x66fff9fb - executionHead := &types.ExecutionHead{ - BlockHash: []byte{0xb, 0xf3, 0x9b, 0xc1, 0x8b, 0xe0, 0x59, 0xc1, 0xdc, 0xb8, 0x72, 0xac, 0x8c, 0xb, 0xc, 0x84, 0x56, 0x55, 0xa0, 0x1c, 0x2b, 0x7d, 0x8f, 0xd0, 0x1c, 0x4b, 0xec, 0xde, 0x6b, 0x3f, 0x93, 0xd7}, - BlockHeight: 2, - BlockTime: timestamp - 10, - } - executionHeadKey := "executionHead:instanceID123" - executionHeadData, _ := msgpack.Marshal(executionHead) - redisMock.ExpectGet(executionHeadKey).SetVal(string(executionHeadData)) - - stateManager, err := state.NewRedisStateManager("instanceID123", redisClient, nil, "000000") - require.NoError(t, err) - mockEngineClient := new(MockEngineClient) - blockBuilder := &BlockBuilder{ - stateManager: stateManager, - engineCl: mockEngineClient, - buildDelay: buildDelay, - buildDelayMs: uint64(buildDelay.Milliseconds()), - logger: stLog, - ctx: ctx, - } - - payloadIDStr := "payloadID123" - executionPayload := engine.ExecutableData{ - ParentHash: common.HexToHash("0x0bf39bc18be059c1dcb872ac8c0b0c845655a01c2b7d8fd01c4becde6b3f93d7"), - FeeRecipient: common.HexToAddress("0x0000000000000000000000000000000000000000"), - StateRoot: common.HexToHash("0xcdc166a6c2e7f8b873889a7256873144e61121f9fc1f027d79b8fa310b91ff0f"), - ReceiptsRoot: common.HexToHash("0x56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421"), - LogsBloom: common.FromHex("0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), - Random: common.HexToHash("0x0bf39bc18be059c1dcb872ac8c0b0c845655a01c2b7d8fd01c4becde6b3f93d7"), - Number: 3, - GasLimit: 30000000, - GasUsed: 0, - Timestamp: 0x66fff9fb, - ExtraData: common.FromHex("0xd983010e08846765746888676f312e32322e368664617277696e"), - BaseFeePerGas: big.NewInt(0x27ee3253), - BlockHash: common.HexToHash("0x9a9b2f7e98934f8544c22cdcb00526f48886170b15c4e4e96bd43af189b5aac4"), - Transactions: [][]byte{}, - Withdrawals: []*etypes.Withdrawal{}, - BlobGasUsed: new(uint64), - ExcessBlobGas: new(uint64), - } - executionPayloadData, _ := msgpack.Marshal(executionPayload) - - payloadStatus := engine.PayloadStatusV1{ - Status: engine.VALID, - } - mockEngineClient.On("NewPayloadV3", mock.Anything, executionPayload, []common.Hash{}, mock.Anything).Return(payloadStatus, nil) - fcs := engine.ForkchoiceStateV1{ - HeadBlockHash: executionPayload.ParentHash, - SafeBlockHash: executionPayload.ParentHash, - FinalizedBlockHash: executionPayload.ParentHash, - } - mockEngineClient.On("ForkchoiceUpdatedV3", mock.Anything, fcs, (*engine.PayloadAttributes)(nil)).Return(engine.ForkChoiceResponse{ - PayloadStatus: payloadStatus, - }, nil) - - executionHeadUpdate := &types.ExecutionHead{ - BlockHash: executionPayload.BlockHash.Bytes(), - BlockHeight: executionPayload.Number, - BlockTime: executionPayload.Timestamp, - } - executionHeadDataUpdated, _ := msgpack.Marshal(executionHeadUpdate) - redisMock.ExpectSet(executionHeadKey, executionHeadDataUpdated, time.Duration(0)).SetErr(errors.New("redis error")) - - err = blockBuilder.FinalizeBlock(ctx, payloadIDStr, string(executionPayloadData), "") - - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to save execution head") - - mockEngineClient.AssertExpectations(t) - require.NoError(t, redisMock.ExpectationsWereMet()) -} - func matchPayloadAttributes(expectedHash common.Hash, executionHeadTime uint64) func(*engine.PayloadAttributes) bool { return func(attrs *engine.PayloadAttributes) bool { if attrs == nil { diff --git a/cl/redisapp/leaderfollower/leaderfollower.go b/cl/redisapp/leaderfollower/leaderfollower.go index 67c6c16ff..0a5f883a0 100644 --- a/cl/redisapp/leaderfollower/leaderfollower.go +++ b/cl/redisapp/leaderfollower/leaderfollower.go @@ -9,7 +9,6 @@ import ( "time" "github.com/heyvito/go-leader/leader" - "github.com/primev/mev-commit/cl/redisapp/state" "github.com/primev/mev-commit/cl/redisapp/types" "github.com/primev/mev-commit/cl/redisapp/util" "github.com/redis/go-redis/v9" @@ -18,8 +17,8 @@ import ( type LeaderFollowerManager struct { isLeader atomic.Bool isFollowerInitialized atomic.Bool - stateManager state.StateManager - blockBuilder BlockBuilder + stateManager stateManager + blockBuilder blockBuilder leaderProc leader.Leader logger *slog.Logger instanceID string @@ -30,7 +29,7 @@ type LeaderFollowerManager struct { erroredCh <-chan error } -type BlockBuilder interface { +type blockBuilder interface { // Retrieves the latest payload and ensures it meets necessary conditions GetPayload(ctx context.Context) error @@ -44,12 +43,23 @@ type BlockBuilder interface { SetLastCallTimeToZero() } +// todo: work with block state through block builder, not directly +type stateManager interface { + // state related methods + GetBlockBuildState(ctx context.Context) types.BlockBuildState + ResetBlockState(ctx context.Context) error + + // stream related methods + AckMessage(ctx context.Context, messageID string) error + ReadMessagesFromStream(ctx context.Context, msgType types.RedisMsgType) ([]redis.XStream, error) +} + func NewLeaderFollowerManager( instanceID string, logger *slog.Logger, redisClient *redis.Client, - stateManager state.StateManager, - blockBuilder BlockBuilder, + stateManager stateManager, + blockBuilder blockBuilder, ) (*LeaderFollowerManager, error) { // Initialize leader election leaderOpts := leader.Opts{ @@ -283,6 +293,13 @@ func (lfm *LeaderFollowerManager) followerWork(ctx context.Context) error { continue } + err = lfm.stateManager.AckMessage(ctx, field.ID) + if err != nil { + lfm.logger.Error("Follower: Failed to acknowledge message", "error", err) + } else { + lfm.logger.Info("Follower: Successfully acknowledged message", "PayloadID", payloadIDStr) + } + lfm.logger.Info("Follower: Successfully finalized block", "PayloadID", payloadIDStr) } } diff --git a/cl/redisapp/leaderfollower/leaderfollower_test.go b/cl/redisapp/leaderfollower/leaderfollower_test.go index c35703afd..66917fbd8 100644 --- a/cl/redisapp/leaderfollower/leaderfollower_test.go +++ b/cl/redisapp/leaderfollower/leaderfollower_test.go @@ -42,7 +42,7 @@ func TestNewLeaderFollowerManager(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - stateManager := mocks.NewMockStateManager(ctrl) + stateManager := mocks.NewMockCoordinator(ctrl) blockBuilder := mocks.NewMockBlockBuilder(ctrl) // Execute @@ -62,7 +62,7 @@ func TestHaveMessagesToProcess(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockSM := mocks.NewMockStateManager(ctrl) + mockSM := mocks.NewMockCoordinator(ctrl) // Prepare mock state manager to return some messages messages := []redis.XStream{ @@ -104,7 +104,7 @@ func TestHaveMessagesToProcess_NoMessages(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockSM := mocks.NewMockStateManager(ctrl) + mockSM := mocks.NewMockCoordinator(ctrl) // Set up expectations gomock.InOrder( @@ -130,7 +130,7 @@ func TestLeaderWork_StepBuildBlock(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockSM := mocks.NewMockStateManager(ctrl) + mockSM := mocks.NewMockCoordinator(ctrl) mockBB := mocks.NewMockBlockBuilder(ctrl) lfm := &LeaderFollowerManager{ @@ -176,7 +176,7 @@ func TestFollowerWork_NoMessages(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockSM := mocks.NewMockStateManager(ctrl) + mockSM := mocks.NewMockCoordinator(ctrl) mockBB := mocks.NewMockBlockBuilder(ctrl) lfm := &LeaderFollowerManager{ @@ -212,7 +212,7 @@ func TestFollowerWork_WithMessages(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockSM := mocks.NewMockStateManager(ctrl) + mockSM := mocks.NewMockCoordinator(ctrl) mockBB := mocks.NewMockBlockBuilder(ctrl) lfm := &LeaderFollowerManager{ @@ -222,7 +222,6 @@ func TestFollowerWork_WithMessages(t *testing.T) { instanceID: "test-instance", } - // Prepare messages to return messages := []redis.XStream{ { Stream: "test-stream", @@ -239,26 +238,30 @@ func TestFollowerWork_WithMessages(t *testing.T) { }, } - // Set up expectations gomock.InOrder( mockSM.EXPECT().ReadMessagesFromStream(ctx, types.RedisMsgTypePending).Return([]redis.XStream{}, nil), mockSM.EXPECT().ReadMessagesFromStream(ctx, types.RedisMsgTypeNew).Return(messages, nil), mockBB.EXPECT().FinalizeBlock(ctx, "test-payload-id", "test-execution-payload", "1-0").Return(nil), + mockSM.EXPECT().AckMessage(ctx, "1-0").Return(nil).Do(func(ctx context.Context, msgID string) { + cancel() + }), ) mockSM.EXPECT().ReadMessagesFromStream(ctx, gomock.Any()).AnyTimes().Return([]redis.XStream{}, nil) - // Run followerWork in a separate goroutine to allow context cancellation done := make(chan error) go func() { err := lfm.followerWork(ctx) done <- err }() - // Wait for the function to return - err := <-done - if err != nil { - t.Errorf("followerWork returned error: %v", err) + select { + case err := <-done: + if err != nil { + t.Errorf("followerWork returned error: %v", err) + } + case <-time.After(5 * time.Second): + t.Errorf("followerWork timed out") } } diff --git a/cl/redisapp/rapp.go b/cl/redisapp/rapp.go index 8b1327e51..c87917bb7 100644 --- a/cl/redisapp/rapp.go +++ b/cl/redisapp/rapp.go @@ -24,7 +24,7 @@ type MevCommitChain struct { lfm *leaderfollower.LeaderFollowerManager } -func NewMevCommitChain(instanceID, ecURL, jwtSecret, genesisBlockHash, redisAddr, feeReceipt string, +func NewMevCommitChain(instanceID, ecURL, jwtSecret, redisAddr, feeReceipt string, logger *slog.Logger, buildDelay, buildDelayEmptyBlocks time.Duration) (*MevCommitChain, error) { // Create a context for cancellation @@ -56,19 +56,19 @@ func NewMevCommitChain(instanceID, ecURL, jwtSecret, genesisBlockHash, redisAddr return nil, err } - stateManager, err := state.NewRedisStateManager(instanceID, redisClient, logger, genesisBlockHash) + coordinator, err := state.NewRedisCoordinator(instanceID, redisClient, logger) if err != nil { cancel() logger.Error("Error creating state manager", "error", err) return nil, err } - blockBuilder := blockbuilder.NewBlockBuilder(stateManager, engineCL, logger, buildDelay, buildDelayEmptyBlocks, feeReceipt) + blockBuilder := blockbuilder.NewBlockBuilder(coordinator, engineCL, logger, buildDelay, buildDelayEmptyBlocks, feeReceipt) lfm, err := leaderfollower.NewLeaderFollowerManager( instanceID, logger, redisClient, - stateManager, + coordinator, blockBuilder, ) if err != nil { @@ -77,7 +77,7 @@ func NewMevCommitChain(instanceID, ecURL, jwtSecret, genesisBlockHash, redisAddr return nil, err } app := &MevCommitChain{ - stateManager: stateManager, + stateManager: coordinator, blockBuilder: blockBuilder, logger: logger, cancel: cancel, diff --git a/cl/redisapp/state/state.go b/cl/redisapp/state/state.go index 07e34e00e..c125286f2 100644 --- a/cl/redisapp/state/state.go +++ b/cl/redisapp/state/state.go @@ -2,7 +2,6 @@ package state import ( "context" - "encoding/hex" "errors" "fmt" "log/slog" @@ -21,97 +20,118 @@ type RedisClient interface { Close() error } +type PipelineOperation func(redis.Pipeliner) error + type StateManager interface { - SaveExecutionHead(ctx context.Context, head *types.ExecutionHead) error - LoadExecutionHead(ctx context.Context) (*types.ExecutionHead, error) - LoadOrInitializeBlockState(ctx context.Context) error - SaveBlockState(ctx context.Context) error ResetBlockState(ctx context.Context) error - SaveExecutionHeadAndAck(ctx context.Context, head *types.ExecutionHead, messageID string) error - SaveBlockStateAndPublishToStream(ctx context.Context, bsState *types.BlockBuildState) error GetBlockBuildState(ctx context.Context) types.BlockBuildState - CreateConsumerGroup(ctx context.Context) error + Stop() +} + +type StreamManager interface { ReadMessagesFromStream(ctx context.Context, msgType types.RedisMsgType) ([]redis.XStream, error) AckMessage(ctx context.Context, messageID string) error Stop() } +type Coordinator interface { + StreamManager + StateManager + SaveBlockStateAndPublishToStream(ctx context.Context, bsState *types.BlockBuildState) error + Stop() +} + type RedisStateManager struct { - instanceID string - redisClient RedisClient - logger *slog.Logger - genesisBlockHash string - groupName string - consumerName string + instanceID string + redisClient RedisClient + logger *slog.Logger blockStateKey string blockBuildState *types.BlockBuildState } +type RedisStreamManager struct { + instanceID string + redisClient RedisClient + logger *slog.Logger + + groupName string + consumerName string +} + +type RedisCoordinator struct { + stateMgr *RedisStateManager + streamMgr *RedisStreamManager + redisClient RedisClient // added to close the client + logger *slog.Logger +} + func NewRedisStateManager( instanceID string, redisClient RedisClient, logger *slog.Logger, - genesisBlockHash string, -) (StateManager, error) { - rsm := &RedisStateManager{ - instanceID: instanceID, - redisClient: redisClient, - logger: logger, - genesisBlockHash: genesisBlockHash, - blockStateKey: fmt.Sprintf("blockBuildState:%s", instanceID), - groupName: fmt.Sprintf("mevcommit_consumer_group:%s", instanceID), - consumerName: fmt.Sprintf("follower:%s", instanceID), +) *RedisStateManager { + return &RedisStateManager{ + instanceID: instanceID, + redisClient: redisClient, + logger: logger, + blockStateKey: fmt.Sprintf("blockBuildState:%s", instanceID), } - if err := rsm.CreateConsumerGroup(context.Background()); err != nil { - return nil, err +} + +func NewRedisStreamManager( + instanceID string, + redisClient RedisClient, + logger *slog.Logger, +) *RedisStreamManager { + return &RedisStreamManager{ + instanceID: instanceID, + redisClient: redisClient, + logger: logger, + groupName: fmt.Sprintf("mevcommit_consumer_group:%s", instanceID), + consumerName: fmt.Sprintf("follower:%s", instanceID), } - return rsm, nil } -func (s *RedisStateManager) SaveExecutionHead(ctx context.Context, head *types.ExecutionHead) error { - data, err := msgpack.Marshal(head) - if err != nil { - return fmt.Errorf("failed to serialize execution head: %w", err) +func NewRedisCoordinator( + instanceID string, + redisClient RedisClient, + logger *slog.Logger, +) (*RedisCoordinator, error) { + stateMgr := NewRedisStateManager(instanceID, redisClient, logger) + streamMgr := NewRedisStreamManager(instanceID, redisClient, logger) + + coordinator := &RedisCoordinator{ + stateMgr: stateMgr, + streamMgr: streamMgr, + redisClient: redisClient, + logger: logger, } - key := fmt.Sprintf("executionHead:%s", s.instanceID) - if err := s.redisClient.Set(ctx, key, data, 0).Err(); err != nil { - return fmt.Errorf("failed to save execution head to Redis: %w", err) + if err := streamMgr.createConsumerGroup(context.Background()); err != nil { + return nil, fmt.Errorf("failed to create consumer group: %w", err) } - return nil + return coordinator, nil } -func (s *RedisStateManager) LoadExecutionHead(ctx context.Context) (*types.ExecutionHead, error) { - key := fmt.Sprintf("executionHead:%s", s.instanceID) - data, err := s.redisClient.Get(ctx, key).Result() - if err != nil { - if errors.Is(err, redis.Nil) { - s.logger.Info("executionHead not found in Redis, initializing with default values") - hashBytes, decodeErr := hex.DecodeString(s.genesisBlockHash) - if decodeErr != nil { - s.logger.Error("Error decoding genesis block hash", "error", decodeErr) - return nil, decodeErr - } - head := &types.ExecutionHead{BlockHash: hashBytes, BlockTime: uint64(time.Now().UnixMilli())} - if saveErr := s.SaveExecutionHead(ctx, head); saveErr != nil { - return nil, saveErr - } - return head, nil +func (s *RedisStateManager) executeTransaction(ctx context.Context, ops ...PipelineOperation) error { + pipe := s.redisClient.TxPipeline() + + for _, op := range ops { + if err := op(pipe); err != nil { + return fmt.Errorf("failed to execute operation: %w", err) } - return nil, fmt.Errorf("failed to retrieve execution head: %w", err) } - var head types.ExecutionHead - if err := msgpack.Unmarshal([]byte(data), &head); err != nil { - return nil, fmt.Errorf("failed to deserialize execution head: %w", err) + if _, err := pipe.Exec(ctx); err != nil { + return fmt.Errorf("state transaction failed: %w", err) } - return &head, nil + return nil } -func (s *RedisStateManager) LoadOrInitializeBlockState(ctx context.Context) error { +func (s *RedisStateManager) loadOrInitializeBlockState(ctx context.Context) error { data, err := s.redisClient.Get(ctx, s.blockStateKey).Result() if err != nil { if errors.Is(err, redis.Nil) { @@ -119,7 +139,7 @@ func (s *RedisStateManager) LoadOrInitializeBlockState(ctx context.Context) erro CurrentStep: types.StepBuildBlock, } s.logger.Info("Leader block build state not found in Redis, initializing with default values") - return s.SaveBlockState(ctx) + return s.saveBlockState(ctx) } return fmt.Errorf("failed to retrieve leader block build state: %w", err) } @@ -134,17 +154,20 @@ func (s *RedisStateManager) LoadOrInitializeBlockState(ctx context.Context) erro return nil } -func (s *RedisStateManager) SaveBlockState(ctx context.Context) error { - data, err := msgpack.Marshal(s.blockBuildState) - if err != nil { - return fmt.Errorf("failed to serialize leader block build state: %w", err) - } +func (s *RedisStateManager) saveBlockState(ctx context.Context) error { + return s.executeTransaction(ctx, s.saveBlockStateFunc(ctx, s.blockBuildState)) +} - if err := s.redisClient.Set(ctx, s.blockStateKey, data, 0).Err(); err != nil { - return fmt.Errorf("failed to save leader block build state to Redis: %w", err) - } +func (s *RedisStateManager) saveBlockStateFunc(ctx context.Context, bsState *types.BlockBuildState) PipelineOperation { + return func(pipe redis.Pipeliner) error { + data, err := msgpack.Marshal(bsState) + if err != nil { + return fmt.Errorf("failed to serialize block build state: %w", err) + } - return nil + pipe.Set(ctx, s.blockStateKey, data, 0) + return nil + } } func (s *RedisStateManager) ResetBlockState(ctx context.Context) error { @@ -152,66 +175,17 @@ func (s *RedisStateManager) ResetBlockState(ctx context.Context) error { CurrentStep: types.StepBuildBlock, } - if err := s.SaveBlockState(ctx); err != nil { + if err := s.saveBlockState(ctx); err != nil { return fmt.Errorf("failed to reset leader state: %w", err) } return nil } -func (s *RedisStateManager) SaveExecutionHeadAndAck(ctx context.Context, head *types.ExecutionHead, messageID string) error { - data, err := msgpack.Marshal(head) - if err != nil { - return fmt.Errorf("failed to serialize execution head: %w", err) - } - - key := fmt.Sprintf("executionHead:%s", s.instanceID) - pipe := s.redisClient.TxPipeline() - - pipe.Set(ctx, key, data, 0) - pipe.XAck(ctx, blockStreamName, s.groupName, messageID) - - if _, err := pipe.Exec(ctx); err != nil { - return fmt.Errorf("transaction failed: %w", err) - } - - s.logger.Info("Follower: Execution head saved and message acknowledged", "MessageID", messageID) - return nil -} - -func (s *RedisStateManager) SaveBlockStateAndPublishToStream(ctx context.Context, bsState *types.BlockBuildState) error { - s.blockBuildState = bsState - data, err := msgpack.Marshal(bsState) - if err != nil { - return fmt.Errorf("failed to serialize leader block build state: %w", err) - } - - pipe := s.redisClient.Pipeline() - pipe.Set(ctx, s.blockStateKey, data, 0) - - message := map[string]interface{}{ - "payload_id": bsState.PayloadID, - "execution_payload": bsState.ExecutionPayload, - "timestamp": time.Now().UnixNano(), - "sender_instance_id": s.instanceID, - } - - pipe.XAdd(ctx, &redis.XAddArgs{ - Stream: blockStreamName, - Values: message, - }) - - if _, err := pipe.Exec(ctx); err != nil { - return fmt.Errorf("pipeline failed: %w", err) - } - - return nil -} - func (s *RedisStateManager) GetBlockBuildState(ctx context.Context) types.BlockBuildState { if s.blockBuildState == nil { s.logger.Error("Leader blockBuildState is not initialized") - if err := s.LoadOrInitializeBlockState(ctx); err != nil { + if err := s.loadOrInitializeBlockState(ctx); err != nil { s.logger.Warn("Failed to load/init state", "error", err) return types.BlockBuildState{} } @@ -227,7 +201,29 @@ func (s *RedisStateManager) GetBlockBuildState(ctx context.Context) types.BlockB return *s.blockBuildState } -func (s *RedisStateManager) CreateConsumerGroup(ctx context.Context) error { +func (s *RedisStateManager) Stop() { + if err := s.redisClient.Close(); err != nil { + s.logger.Error("Error closing Redis client in StateManager", "error", err) + } +} + +func (s *RedisStreamManager) executeTransaction(ctx context.Context, ops ...PipelineOperation) error { + pipe := s.redisClient.TxPipeline() + + for _, op := range ops { + if err := op(pipe); err != nil { + return fmt.Errorf("failed to execute operation: %w", err) + } + } + + if _, err := pipe.Exec(ctx); err != nil { + return fmt.Errorf("stream transaction failed: %w", err) + } + + return nil +} + +func (s *RedisStreamManager) createConsumerGroup(ctx context.Context) error { if err := s.redisClient.XGroupCreateMkStream(ctx, blockStreamName, s.groupName, "0").Err(); err != nil { if !strings.Contains(err.Error(), "BUSYGROUP") { return fmt.Errorf("failed to create consumer group '%s': %w", s.groupName, err) @@ -236,7 +232,7 @@ func (s *RedisStateManager) CreateConsumerGroup(ctx context.Context) error { return nil } -func (s *RedisStateManager) ReadMessagesFromStream(ctx context.Context, msgType types.RedisMsgType) ([]redis.XStream, error) { +func (s *RedisStreamManager) ReadMessagesFromStream(ctx context.Context, msgType types.RedisMsgType) ([]redis.XStream, error) { args := &redis.XReadGroupArgs{ Group: s.groupName, Consumer: s.consumerName, @@ -253,15 +249,73 @@ func (s *RedisStateManager) ReadMessagesFromStream(ctx context.Context, msgType return messages, nil } -func (s *RedisStateManager) AckMessage(ctx context.Context, messageID string) error { - if err := s.redisClient.XAck(ctx, blockStreamName, s.groupName, messageID).Err(); err != nil { - return fmt.Errorf("failed to acknowledge message: %w", err) +func (s *RedisStreamManager) AckMessage(ctx context.Context, messageID string) error { + return s.executeTransaction(ctx, s.ackMessageFunc(ctx, messageID)) +} + +func (s *RedisStreamManager) ackMessageFunc(ctx context.Context, messageID string) PipelineOperation { + return func(pipe redis.Pipeliner) error { + pipe.XAck(ctx, blockStreamName, s.groupName, messageID) + return nil } - return nil } -func (s *RedisStateManager) Stop() { +func (s *RedisStreamManager) publishToStreamFunc(ctx context.Context, bsState *types.BlockBuildState) PipelineOperation { + return func(pipe redis.Pipeliner) error { + message := map[string]interface{}{ + "payload_id": bsState.PayloadID, + "execution_payload": bsState.ExecutionPayload, + "timestamp": time.Now().UnixNano(), + "sender_instance_id": s.instanceID, + } + + pipe.XAdd(ctx, &redis.XAddArgs{ + Stream: blockStreamName, + Values: message, + }) + return nil + } +} + +func (s *RedisStreamManager) Stop() { if err := s.redisClient.Close(); err != nil { - s.logger.Error("Error closing Redis client", "error", err) + s.logger.Error("Error closing Redis client in StreamManager", "error", err) + } +} + +func (c *RedisCoordinator) SaveBlockStateAndPublishToStream(ctx context.Context, bsState *types.BlockBuildState) error { + c.stateMgr.blockBuildState = bsState + + err := c.stateMgr.executeTransaction( + ctx, + c.stateMgr.saveBlockStateFunc(ctx, bsState), + c.streamMgr.publishToStreamFunc(ctx, bsState), + ) + if err != nil { + return fmt.Errorf("transaction failed: %w", err) + } + + return nil +} + +func (c *RedisCoordinator) ResetBlockState(ctx context.Context) error { + return c.stateMgr.ResetBlockState(ctx) +} + +func (c *RedisCoordinator) GetBlockBuildState(ctx context.Context) types.BlockBuildState { + return c.stateMgr.GetBlockBuildState(ctx) +} + +func (c *RedisCoordinator) ReadMessagesFromStream(ctx context.Context, msgType types.RedisMsgType) ([]redis.XStream, error) { + return c.streamMgr.ReadMessagesFromStream(ctx, msgType) +} + +func (c *RedisCoordinator) AckMessage(ctx context.Context, messageID string) error { + return c.streamMgr.AckMessage(ctx, messageID) +} + +func (c *RedisCoordinator) Stop() { + if err := c.redisClient.Close(); err != nil { + c.logger.Error("Error closing Redis client in StateManager", "error", err) } } diff --git a/cl/streamer/streamer.go b/cl/streamer/streamer.go new file mode 100644 index 000000000..c3a728036 --- /dev/null +++ b/cl/streamer/streamer.go @@ -0,0 +1,215 @@ +package streamer + +import ( + "context" + "errors" + "fmt" + "net" + "strings" + "time" + + "log/slog" + + "github.com/redis/go-redis/v9" + "google.golang.org/grpc" + "google.golang.org/grpc/reflection" + + pb "github.com/primev/mev-commit/cl/pb/pb" +) + +const blockStreamName = "mevcommit_block_stream" + +type RedisMsgType string + +const ( + RedisMsgTypePending RedisMsgType = "0" + RedisMsgTypeNew RedisMsgType = ">" +) + +type PayloadStreamer struct { + pb.UnimplementedPayloadStreamerServer + redisClient *redis.Client + logger *slog.Logger + server *grpc.Server +} + +func NewPayloadStreamer(redisAddr string, logger *slog.Logger) (*PayloadStreamer, error) { + redisClient := redis.NewClient(&redis.Options{ + Addr: redisAddr, + }) + + err := redisClient.ConfigSet(context.Background(), "min-replicas-to-write", "1").Err() + if err != nil { + logger.Error("Error setting min-replicas-to-write", "error", err) + return nil, err + } + + return &PayloadStreamer{ + redisClient: redisClient, + logger: logger, + server: grpc.NewServer(), + }, nil +} + +func (s *PayloadStreamer) Start(address string) error { + lis, err := net.Listen("tcp", address) + if err != nil { + return err + } + + pb.RegisterPayloadStreamerServer(s.server, s) + reflection.Register(s.server) + + s.logger.Info("PayloadStreamer is listening", "address", address) + return s.server.Serve(lis) +} + +func (s *PayloadStreamer) Stop() { + s.server.GracefulStop() + if err := s.redisClient.Close(); err != nil { + s.logger.Error("Error closing Redis client in PayloadStreamer", "error", err) + } +} + +func (s *PayloadStreamer) Subscribe(stream pb.PayloadStreamer_SubscribeServer) error { + ctx := stream.Context() + + var clientID string + firstMessage, err := stream.Recv() + if err != nil { + s.logger.Error("Failed to receive initial message", "error", err) + return err + } + if req := firstMessage.GetSubscribeRequest(); req != nil { + clientID = req.ClientId + } else { + return fmt.Errorf("expected SubscribeRequest, got %v", firstMessage) + } + + groupName := "member_group:" + clientID + consumerName := "member_consumer:" + clientID + + err = s.createConsumerGroup(ctx, groupName) + if err != nil { + s.logger.Error("Failed to create consumer group", "clientID", clientID, "error", err) + return err + } + + s.logger.Info("Subscriber connected", "clientID", clientID) + return s.handleBidirectionalStream(stream, clientID, groupName, consumerName) +} + +func (s *PayloadStreamer) createConsumerGroup(ctx context.Context, groupName string) error { + err := s.redisClient.XGroupCreateMkStream(ctx, blockStreamName, groupName, "0").Err() + if err != nil && !strings.Contains(err.Error(), "BUSYGROUP") { + return err + } + return nil +} + +func (s *PayloadStreamer) handleBidirectionalStream(stream pb.PayloadStreamer_SubscribeServer, clientID, groupName, consumerName string) error { + ctx := stream.Context() + var pendingMessageID string + + for { + if pendingMessageID == "" { + // No pending message, read the next message from Redis + messages, err := s.readMessages(ctx, groupName, consumerName) + if err != nil { + s.logger.Error("Error reading messages", "clientID", clientID, "error", err) + return err + } + if len(messages) == 0 { + continue + } + + msg := messages[0] + field := msg.Messages[0] + pendingMessageID = field.ID + + payloadIDStr, ok := field.Values["payload_id"].(string) + executionPayloadStr, okPayload := field.Values["execution_payload"].(string) + senderInstanceID, okSenderID := field.Values["sender_instance_id"].(string) + if !ok || !okPayload || !okSenderID { + s.logger.Error("Invalid message format", "clientID", clientID) + // Acknowledge malformed messages to prevent reprocessing + err = s.ackMessage(ctx, field.ID, groupName) + if err != nil { + s.logger.Error("Failed to acknowledge malformed message", "clientID", clientID, "error", err) + } + pendingMessageID = "" + continue + } + + err = stream.Send(&pb.PayloadMessage{ + PayloadId: payloadIDStr, + ExecutionPayload: executionPayloadStr, + SenderInstanceId: senderInstanceID, + MessageId: field.ID, + }) + if err != nil { + s.logger.Error("Failed to send message to client", "clientID", clientID, "error", err) + return err + } + } + + clientMsg, err := stream.Recv() + if err != nil { + s.logger.Error("Failed to receive acknowledgment", "clientID", clientID, "error", err) + return err + } + + if ack := clientMsg.GetAckPayload(); ack != nil { + if ack.MessageId == pendingMessageID { + err := s.ackMessage(ctx, pendingMessageID, groupName) + if err != nil { + s.logger.Error("Failed to acknowledge message", "clientID", clientID, "error", err) + return err + } + s.logger.Info("Message acknowledged", "clientID", clientID, "messageID", pendingMessageID) + pendingMessageID = "" + } else { + s.logger.Error("Received acknowledgment for unknown message ID", "clientID", clientID, "messageID", ack.MessageId) + } + } else { + s.logger.Error("Expected AckPayloadRequest, got something else", "clientID", clientID) + } + } +} + +func (s *PayloadStreamer) readMessages(ctx context.Context, groupName, consumerName string) ([]redis.XStream, error) { + messages, err := s.readMessagesFromStream(ctx, RedisMsgTypePending, groupName, consumerName) + if err != nil { + return nil, err + } + + if len(messages) == 0 || len(messages[0].Messages) == 0 { + messages, err = s.readMessagesFromStream(ctx, RedisMsgTypeNew, groupName, consumerName) + if err != nil { + return nil, err + } + } + + return messages, nil +} + +func (s *PayloadStreamer) readMessagesFromStream(ctx context.Context, msgType RedisMsgType, groupName, consumerName string) ([]redis.XStream, error) { + args := &redis.XReadGroupArgs{ + Group: groupName, + Consumer: consumerName, + Streams: []string{blockStreamName, string(msgType)}, + Count: 1, + Block: time.Second, + } + + messages, err := s.redisClient.XReadGroup(ctx, args).Result() + if err != nil && !errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("error reading messages: %w", err) + } + + return messages, nil +} + +func (s *PayloadStreamer) ackMessage(ctx context.Context, messageID, groupName string) error { + return s.redisClient.XAck(ctx, blockStreamName, groupName, messageID).Err() +} diff --git a/cl/streamer/streamer_test.go b/cl/streamer/streamer_test.go new file mode 100644 index 000000000..eb7c6dbd9 --- /dev/null +++ b/cl/streamer/streamer_test.go @@ -0,0 +1,276 @@ +package streamer + +import ( + "context" + "errors" + "net" + "strings" + "sync/atomic" + "testing" + "time" + + "log/slog" + + "github.com/go-redis/redismock/v9" + "github.com/primev/mev-commit/cl/pb/pb" + "github.com/redis/go-redis/v9" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" +) + +func TestCreateConsumerGroup(t *testing.T) { + logger := slog.Default() + db, mock := redismock.NewClientMock() + + r := &PayloadStreamer{ + redisClient: db, + logger: logger, + server: grpc.NewServer(), + } + + groupName := "member_group:testClient" + mock.ExpectXGroupCreateMkStream(blockStreamName, groupName, "0").SetVal("OK") + err := r.createConsumerGroup(context.Background(), groupName) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + mock.ClearExpect() + mock.ExpectXGroupCreateMkStream(blockStreamName, groupName, "0").SetErr(errors.New("BUSYGROUP Consumer Group name already exists")) + err = r.createConsumerGroup(context.Background(), groupName) + if err != nil { + t.Fatalf("expected no error on BUSYGROUP, got %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet redis expectations: %v", err) + } +} + +func TestAckMessage(t *testing.T) { + logger := slog.Default() + db, mock := redismock.NewClientMock() + + r := &PayloadStreamer{ + redisClient: db, + logger: logger, + } + + groupName := "member_group:testClient" + messageID := "123-1" + + mock.ExpectXAck(blockStreamName, groupName, messageID).SetVal(1) + + err := r.ackMessage(context.Background(), messageID, groupName) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet redis expectations: %v", err) + } +} + +func TestReadMessages(t *testing.T) { + logger := slog.Default() + db, mock := redismock.NewClientMock() + + r := &PayloadStreamer{ + redisClient: db, + logger: logger, + } + + groupName := "member_group:testClient" + consumerName := "member_consumer:testClient" + + mock.ExpectXReadGroup(&redis.XReadGroupArgs{ + Group: groupName, + Consumer: consumerName, + Streams: []string{blockStreamName, string(RedisMsgTypePending)}, + Count: 1, + Block: time.Second, + }).SetErr(redis.Nil) // simulating no pending messages + + mock.ExpectXReadGroup(&redis.XReadGroupArgs{ + Group: groupName, + Consumer: consumerName, + Streams: []string{blockStreamName, string(RedisMsgTypeNew)}, + Count: 1, + Block: time.Second, + }).SetVal([]redis.XStream{ + { + Stream: blockStreamName, + Messages: []redis.XMessage{ + { + ID: "123-1", + Values: map[string]interface{}{ + "payload_id": "payload_123", + "execution_payload": "some_encoded_payload", + "sender_instance_id": "instance_abc", + }, + }, + }, + }, + }) + + messages, err := r.readMessages(context.Background(), groupName, consumerName) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if len(messages) != 1 || len(messages[0].Messages) != 1 { + t.Fatalf("expected 1 message, got %v", messages) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet redis expectations: %v", err) + } +} + +func TestSubscribe(t *testing.T) { + logger := slog.Default() + db, mock := redismock.NewClientMock() + + r := &PayloadStreamer{ + redisClient: db, + logger: logger, + server: grpc.NewServer(), + } + + pb.RegisterPayloadStreamerServer(r.server, r) + + lis := bufconn.Listen(1024 * 1024) + + serverDone := make(chan struct{}) + var serverErr atomic.Value + + mock.ExpectXGroupCreateMkStream(blockStreamName, "member_group:testClient", "0").SetVal("OK") + + mock.ExpectXReadGroup(&redis.XReadGroupArgs{ + Group: "member_group:testClient", + Consumer: "member_consumer:testClient", + Streams: []string{blockStreamName, "0"}, + Count: 1, + Block: time.Second, + }).SetErr(redis.Nil) + + mock.ExpectXReadGroup(&redis.XReadGroupArgs{ + Group: "member_group:testClient", + Consumer: "member_consumer:testClient", + Streams: []string{blockStreamName, ">"}, + Count: 1, + Block: time.Second, + }).SetVal([]redis.XStream{ + { + Stream: blockStreamName, + Messages: []redis.XMessage{ + { + ID: "123-1", + Values: map[string]interface{}{ + "payload_id": "payload_123", + "execution_payload": "some_encoded_payload", + "sender_instance_id": "instance_abc", + }, + }, + }, + }, + }) + + ackCalled := make(chan struct{}) + + customMatch := func(expected, actual []interface{}) error { + if len(actual) >= 1 { + cmdName, ok := actual[0].(string) + if ok && strings.ToUpper(cmdName) == "XACK" { + select { + case <-ackCalled: + default: + close(ackCalled) + } + } + } + return nil + } + + mock.CustomMatch(customMatch).ExpectXAck(blockStreamName, "member_group:testClient", "123-1").SetVal(int64(1)) + + go func() { + err := r.server.Serve(lis) + if err != nil && err != grpc.ErrServerStopped { + serverErr.Store(err) + } + close(serverDone) + }() + + defer func() { + r.server.GracefulStop() + <-serverDone + if err, ok := serverErr.Load().(error); ok { + t.Errorf("Server error: %v", err) + } + }() + + // Create a gRPC client + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + conn, err := grpc.NewClient( + "passthrough:///", + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { + return lis.Dial() + }), + ) + if err != nil { + t.Fatalf("failed to dial bufconn: %v", err) + } + defer conn.Close() + + client := pb.NewPayloadStreamerClient(conn) + + // Call Subscribe + stream, err := client.Subscribe(ctx) + if err != nil { + t.Fatalf("failed to call Subscribe: %v", err) + } + + err = stream.Send(&pb.ClientMessage{ + Message: &pb.ClientMessage_SubscribeRequest{ + SubscribeRequest: &pb.SubscribeRequest{ + ClientId: "testClient", + }, + }, + }) + if err != nil { + t.Fatalf("failed to send subscribe request: %v", err) + } + + recvMsg, err := stream.Recv() + if err != nil { + t.Fatalf("failed to receive message from server: %v", err) + } + if recvMsg.GetPayloadId() != "payload_123" { + t.Errorf("expected payload_123, got %s", recvMsg.GetPayloadId()) + } + + err = stream.Send(&pb.ClientMessage{ + Message: &pb.ClientMessage_AckPayload{ + AckPayload: &pb.AckPayloadRequest{ + ClientId: "testClient", + PayloadId: "payload_123", + MessageId: "123-1", + }, + }, + }) + if err != nil { + t.Fatalf("failed to send ack: %v", err) + } + + select { + case <-ackCalled: + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for XAck to be called") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet redis expectations: %v", err) + } +}