Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions async.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package riak

import (
"context"
"sync"
"time"

Expand All @@ -23,6 +24,7 @@ import (

// Async object is used to pass required arguments to execute a Command asynchronously
type Async struct {
Context context.Context
Command Command
Done chan Command
Wait *sync.WaitGroup
Expand Down
6 changes: 6 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package riak

import (
"context"
"fmt"
"strconv"
"strings"
Expand Down Expand Up @@ -62,6 +63,11 @@ func (c *Client) Execute(cmd Command) error {
return c.cluster.Execute(cmd)
}

// ExecuteContext (synchronously) the provided Command against the cluster
func (c *Client) ExecuteContext(ctx context.Context, cmd Command) error {
return c.cluster.ExecuteContext(ctx, cmd)
}

// Execute (asynchronously) the provided Command against the cluster
func (c *Client) ExecuteAsync(a *Async) error {
return c.cluster.ExecuteAsync(a)
Expand Down
11 changes: 9 additions & 2 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package riak

import (
"context"
"fmt"
"sync"
"time"
Expand Down Expand Up @@ -264,7 +265,7 @@ func (c *Cluster) RemoveNode(n *Node) error {
return nil
}

// Execute (asynchronously) the provided Command against the active pooled Nodes using the NodeManager
// ExecuteAsync (asynchronously) the provided Command against the active pooled Nodes using the NodeManager
func (c *Cluster) ExecuteAsync(async *Async) error {
if async.Command == nil {
return ErrClusterCommandRequired
Expand All @@ -281,11 +282,17 @@ func (c *Cluster) ExecuteAsync(async *Async) error {

// Execute (synchronously) the provided Command against the active pooled Nodes using the NodeManager
func (c *Cluster) Execute(command Command) error {
return c.ExecuteContext(context.Background(), command)
}

// ExecuteContext (synchronously) with a context the provided Command against the active pooled Nodes using the NodeManager
func (c *Cluster) ExecuteContext(ctx context.Context, command Command) error {
if command == nil {
return ErrClusterCommandRequired
}
async := &Async{
Command: command,
Context: ctx,
}
c.execute(async)
if async.Error != nil {
Expand Down Expand Up @@ -322,7 +329,7 @@ func (c *Cluster) execute(async *Async) {
if err = c.stateCheck(clusterRunning); err != nil {
break
}
executed, err = c.nodeManager.ExecuteOnNode(c.nodes, cmd, lastExeNode)
executed, err = c.nodeManager.ExecuteOnNodeContext(async.Context, c.nodes, cmd, lastExeNode)
// NB: do *not* call cmd.onError here as it will have been called in connection
if executed {
// NB: "executed" means that a node sent the data to Riak and received a response
Expand Down
40 changes: 31 additions & 9 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package riak

import (
"context"
"crypto/tls"
"encoding/binary"
"errors"
Expand Down Expand Up @@ -132,7 +133,7 @@ func (c *connection) startTls() error {
}
c.setState(connTlsStarting)
startTlsCmd := &startTlsCommand{}
if err := c.execute(startTlsCmd); err != nil {
if err := c.execute(context.Background(), startTlsCmd); err != nil {
return err
}
var tlsConn *tls.Conn
Expand All @@ -147,7 +148,7 @@ func (c *connection) startTls() error {
user: c.authOptions.User,
password: c.authOptions.Password,
}
return c.execute(authCmd)
return c.execute(context.Background(), authCmd)
}

func (c *connection) available() bool {
Expand All @@ -167,7 +168,7 @@ func (c *connection) setInFlight(inFlightVal bool) {
c.inFlight = inFlightVal
}

func (c *connection) execute(cmd Command) (err error) {
func (c *connection) execute(ctx context.Context, cmd Command) (err error) {
if c.inFlight == true {
err = fmt.Errorf("[Connection] attempted to run '%s' command on in-use connection", cmd.Name())
return
Expand Down Expand Up @@ -202,14 +203,17 @@ func (c *connection) execute(cmd Command) (err error) {
}
}

if err = c.write(message, timeout); err != nil {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

if err = c.write(ctx, message); err != nil {
return
}

var response []byte
var decoded proto.Message
for {
response, err = c.read(timeout) // NB: response *will* have entire pb message
response, err = c.read(ctx) // NB: response *will* have entire pb message
if err != nil {
cmd.onError(err)
return
Expand Down Expand Up @@ -249,22 +253,31 @@ func (c *connection) setReadDeadline(t time.Duration) {
}

// NB: This will read one full pb message from Riak, or error in doing so
func (c *connection) read(timeout time.Duration) ([]byte, error) {
func (c *connection) read(ctx context.Context) ([]byte, error) {
if !c.available() {
return nil, ErrCannotRead
}

var err error
var count int
var messageLength uint32
var rt time.Duration = timeout // rt = 'read timeout'
var rt time.Duration
if deadline, ok := ctx.Deadline(); ok {
rt = deadline.Sub(time.Now())
}
b := &backoff.Backoff{
Min: rt,
Jitter: true,
}
try := uint16(0)

for {
select {
case <-ctx.Done():
return nil, context.Canceled
default:
}

c.setReadDeadline(rt)
if count, err = io.ReadFull(c.conn, c.sizeBuf); err == nil && count == 4 {
messageLength = binary.BigEndian.Uint32(c.sizeBuf)
Expand Down Expand Up @@ -304,11 +317,20 @@ func (c *connection) read(timeout time.Duration) ([]byte, error) {
}
}

func (c *connection) write(data []byte, timeout time.Duration) error {
func (c *connection) write(ctx context.Context, data []byte) error {
if !c.available() {
return ErrCannotWrite
}
c.conn.SetWriteDeadline(time.Now().Add(timeout))

select {
case <-ctx.Done():
return context.Canceled
default:
}

if deadline, ok := ctx.Deadline(); ok {
c.conn.SetWriteDeadline(deadline)
}
count, err := c.conn.Write(data)
if err != nil {
c.setState(connInactive)
Expand Down
7 changes: 4 additions & 3 deletions node.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package riak

import (
"context"
"fmt"
"net"
"time"
Expand Down Expand Up @@ -180,7 +181,7 @@ func (n *Node) stop() error {

// Execute retrieves an available connection from the pool and executes the Command operation against
// Riak
func (n *Node) execute(cmd Command) (bool, error) {
func (n *Node) execute(ctx context.Context, cmd Command) (bool, error) {
if err := n.stateCheck(nodeRunning, nodeHealthChecking); err != nil {
return false, err
}
Expand All @@ -202,7 +203,7 @@ func (n *Node) execute(cmd Command) (bool, error) {
}

logDebug("[Node]", "(%v) - executing command '%v'", n, cmd.Name())
err = conn.execute(cmd)
err = conn.execute(ctx, cmd)
if err == nil {
// NB: basically the success path of _responseReceived in Node.js client
if cmErr := n.cm.put(conn); cmErr != nil {
Expand Down Expand Up @@ -305,7 +306,7 @@ func (n *Node) healthCheck() {
}
hcmd := n.getHealthCheckCommand()
logDebug("[Node]", "(%v) healthcheck executing %v", n, hcmd.Name())
if hcerr := conn.execute(hcmd); hcerr != nil || !hcmd.Success() {
if hcerr := conn.execute(context.Background(), hcmd); hcerr != nil || !hcmd.Success() {
conn.close()
logError("[Node]", "(%v) failed healthcheck, err: %v", n, hcerr)
} else {
Expand Down
10 changes: 9 additions & 1 deletion node_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
package riak

import (
"context"
"sync"
)

// NodeManager enforces the structure needed to if going to implement your own NodeManager
type NodeManager interface {
ExecuteOnNode(nodes []*Node, command Command, previousNode *Node) (bool, error)
ExecuteOnNodeContext(ctx context.Context, nodes []*Node, command Command, previousNode *Node) (bool, error)
}

var ErrDefaultNodeManagerRequiresNode = newClientError("Must pass at least one node to default node manager", nil)
Expand All @@ -33,6 +35,12 @@ type defaultNodeManager struct {
// ExecuteOnNode selects a Node from the pool and executes the provided Command on that Node. The
// defaultNodeManager uses a simple round robin approach to distributing load
func (nm *defaultNodeManager) ExecuteOnNode(nodes []*Node, command Command, previous *Node) (bool, error) {
return nm.ExecuteOnNodeContext(context.Background(), nodes, command, previous)
}

// ExecuteOnNodeContext selects a Node from the pool and executes the provided Command on that Node. The
// defaultNodeManager uses a simple round robin approach to distributing load
func (nm *defaultNodeManager) ExecuteOnNodeContext(ctx context.Context, nodes []*Node, command Command, previous *Node) (bool, error) {
if nodes == nil {
panic("[defaultNodeManager] nil nodes argument")
}
Expand Down Expand Up @@ -61,7 +69,7 @@ func (nm *defaultNodeManager) ExecuteOnNode(nodes []*Node, command Command, prev
continue
}

executed, err = node.execute(command)
executed, err = node.execute(ctx, command)
if executed == true {
logDebug("[DefaultNodeManager]", "executed '%s' on node '%s', err '%v'", command.Name(), node, err)
break
Expand Down