diff --git a/async.go b/async.go index 972f6dd..c4d2fab 100644 --- a/async.go +++ b/async.go @@ -15,6 +15,7 @@ package riak import ( + "context" "sync" "time" @@ -23,6 +24,7 @@ import ( // Async object is used to pass required arguments to execute a Command asynchronously type Async struct { + Context context.Context Command Command Done chan Command Wait *sync.WaitGroup diff --git a/client.go b/client.go index 0cf9b3c..9bf424a 100644 --- a/client.go +++ b/client.go @@ -15,6 +15,7 @@ package riak import ( + "context" "fmt" "strconv" "strings" @@ -62,6 +63,11 @@ func (c *Client) Execute(cmd Command) error { return c.cluster.Execute(cmd) } +// ExecuteContext (synchronously) the provided Command against the cluster +func (c *Client) ExecuteContext(ctx context.Context, cmd Command) error { + return c.cluster.ExecuteContext(ctx, cmd) +} + // Execute (asynchronously) the provided Command against the cluster func (c *Client) ExecuteAsync(a *Async) error { return c.cluster.ExecuteAsync(a) diff --git a/cluster.go b/cluster.go index 867986d..eebe2a3 100644 --- a/cluster.go +++ b/cluster.go @@ -15,6 +15,7 @@ package riak import ( + "context" "fmt" "sync" "time" @@ -264,7 +265,7 @@ func (c *Cluster) RemoveNode(n *Node) error { return nil } -// Execute (asynchronously) the provided Command against the active pooled Nodes using the NodeManager +// ExecuteAsync (asynchronously) the provided Command against the active pooled Nodes using the NodeManager func (c *Cluster) ExecuteAsync(async *Async) error { if async.Command == nil { return ErrClusterCommandRequired @@ -281,11 +282,17 @@ func (c *Cluster) ExecuteAsync(async *Async) error { // Execute (synchronously) the provided Command against the active pooled Nodes using the NodeManager func (c *Cluster) Execute(command Command) error { + return c.ExecuteContext(context.Background(), command) +} + +// ExecuteContext (synchronously) with a context the provided Command against the active pooled Nodes using the NodeManager +func (c *Cluster) ExecuteContext(ctx context.Context, command Command) error { if command == nil { return ErrClusterCommandRequired } async := &Async{ Command: command, + Context: ctx, } c.execute(async) if async.Error != nil { @@ -322,7 +329,7 @@ func (c *Cluster) execute(async *Async) { if err = c.stateCheck(clusterRunning); err != nil { break } - executed, err = c.nodeManager.ExecuteOnNode(c.nodes, cmd, lastExeNode) + executed, err = c.nodeManager.ExecuteOnNodeContext(async.Context, c.nodes, cmd, lastExeNode) // NB: do *not* call cmd.onError here as it will have been called in connection if executed { // NB: "executed" means that a node sent the data to Riak and received a response diff --git a/connection.go b/connection.go index 0098cd7..15fd11c 100644 --- a/connection.go +++ b/connection.go @@ -15,6 +15,7 @@ package riak import ( + "context" "crypto/tls" "encoding/binary" "errors" @@ -132,7 +133,7 @@ func (c *connection) startTls() error { } c.setState(connTlsStarting) startTlsCmd := &startTlsCommand{} - if err := c.execute(startTlsCmd); err != nil { + if err := c.execute(context.Background(), startTlsCmd); err != nil { return err } var tlsConn *tls.Conn @@ -147,7 +148,7 @@ func (c *connection) startTls() error { user: c.authOptions.User, password: c.authOptions.Password, } - return c.execute(authCmd) + return c.execute(context.Background(), authCmd) } func (c *connection) available() bool { @@ -167,7 +168,7 @@ func (c *connection) setInFlight(inFlightVal bool) { c.inFlight = inFlightVal } -func (c *connection) execute(cmd Command) (err error) { +func (c *connection) execute(ctx context.Context, cmd Command) (err error) { if c.inFlight == true { err = fmt.Errorf("[Connection] attempted to run '%s' command on in-use connection", cmd.Name()) return @@ -202,14 +203,17 @@ func (c *connection) execute(cmd Command) (err error) { } } - if err = c.write(message, timeout); err != nil { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + if err = c.write(ctx, message); err != nil { return } var response []byte var decoded proto.Message for { - response, err = c.read(timeout) // NB: response *will* have entire pb message + response, err = c.read(ctx) // NB: response *will* have entire pb message if err != nil { cmd.onError(err) return @@ -249,7 +253,7 @@ func (c *connection) setReadDeadline(t time.Duration) { } // NB: This will read one full pb message from Riak, or error in doing so -func (c *connection) read(timeout time.Duration) ([]byte, error) { +func (c *connection) read(ctx context.Context) ([]byte, error) { if !c.available() { return nil, ErrCannotRead } @@ -257,7 +261,10 @@ func (c *connection) read(timeout time.Duration) ([]byte, error) { var err error var count int var messageLength uint32 - var rt time.Duration = timeout // rt = 'read timeout' + var rt time.Duration + if deadline, ok := ctx.Deadline(); ok { + rt = deadline.Sub(time.Now()) + } b := &backoff.Backoff{ Min: rt, Jitter: true, @@ -265,6 +272,12 @@ func (c *connection) read(timeout time.Duration) ([]byte, error) { try := uint16(0) for { + select { + case <-ctx.Done(): + return nil, context.Canceled + default: + } + c.setReadDeadline(rt) if count, err = io.ReadFull(c.conn, c.sizeBuf); err == nil && count == 4 { messageLength = binary.BigEndian.Uint32(c.sizeBuf) @@ -304,11 +317,20 @@ func (c *connection) read(timeout time.Duration) ([]byte, error) { } } -func (c *connection) write(data []byte, timeout time.Duration) error { +func (c *connection) write(ctx context.Context, data []byte) error { if !c.available() { return ErrCannotWrite } - c.conn.SetWriteDeadline(time.Now().Add(timeout)) + + select { + case <-ctx.Done(): + return context.Canceled + default: + } + + if deadline, ok := ctx.Deadline(); ok { + c.conn.SetWriteDeadline(deadline) + } count, err := c.conn.Write(data) if err != nil { c.setState(connInactive) diff --git a/node.go b/node.go index d7c2f07..f3250df 100644 --- a/node.go +++ b/node.go @@ -15,6 +15,7 @@ package riak import ( + "context" "fmt" "net" "time" @@ -180,7 +181,7 @@ func (n *Node) stop() error { // Execute retrieves an available connection from the pool and executes the Command operation against // Riak -func (n *Node) execute(cmd Command) (bool, error) { +func (n *Node) execute(ctx context.Context, cmd Command) (bool, error) { if err := n.stateCheck(nodeRunning, nodeHealthChecking); err != nil { return false, err } @@ -202,7 +203,7 @@ func (n *Node) execute(cmd Command) (bool, error) { } logDebug("[Node]", "(%v) - executing command '%v'", n, cmd.Name()) - err = conn.execute(cmd) + err = conn.execute(ctx, cmd) if err == nil { // NB: basically the success path of _responseReceived in Node.js client if cmErr := n.cm.put(conn); cmErr != nil { @@ -305,7 +306,7 @@ func (n *Node) healthCheck() { } hcmd := n.getHealthCheckCommand() logDebug("[Node]", "(%v) healthcheck executing %v", n, hcmd.Name()) - if hcerr := conn.execute(hcmd); hcerr != nil || !hcmd.Success() { + if hcerr := conn.execute(context.Background(), hcmd); hcerr != nil || !hcmd.Success() { conn.close() logError("[Node]", "(%v) failed healthcheck, err: %v", n, hcerr) } else { diff --git a/node_manager.go b/node_manager.go index 0945900..80ed11a 100644 --- a/node_manager.go +++ b/node_manager.go @@ -15,12 +15,14 @@ package riak import ( + "context" "sync" ) // NodeManager enforces the structure needed to if going to implement your own NodeManager type NodeManager interface { ExecuteOnNode(nodes []*Node, command Command, previousNode *Node) (bool, error) + ExecuteOnNodeContext(ctx context.Context, nodes []*Node, command Command, previousNode *Node) (bool, error) } var ErrDefaultNodeManagerRequiresNode = newClientError("Must pass at least one node to default node manager", nil) @@ -33,6 +35,12 @@ type defaultNodeManager struct { // ExecuteOnNode selects a Node from the pool and executes the provided Command on that Node. The // defaultNodeManager uses a simple round robin approach to distributing load func (nm *defaultNodeManager) ExecuteOnNode(nodes []*Node, command Command, previous *Node) (bool, error) { + return nm.ExecuteOnNodeContext(context.Background(), nodes, command, previous) +} + +// ExecuteOnNodeContext selects a Node from the pool and executes the provided Command on that Node. The +// defaultNodeManager uses a simple round robin approach to distributing load +func (nm *defaultNodeManager) ExecuteOnNodeContext(ctx context.Context, nodes []*Node, command Command, previous *Node) (bool, error) { if nodes == nil { panic("[defaultNodeManager] nil nodes argument") } @@ -61,7 +69,7 @@ func (nm *defaultNodeManager) ExecuteOnNode(nodes []*Node, command Command, prev continue } - executed, err = node.execute(command) + executed, err = node.execute(ctx, command) if executed == true { logDebug("[DefaultNodeManager]", "executed '%s' on node '%s', err '%v'", command.Name(), node, err) break