From b48999042f0949546b647c9c17e5a51ce175c8b7 Mon Sep 17 00:00:00 2001 From: Drew Malin Date: Wed, 27 Aug 2025 09:53:59 -0700 Subject: [PATCH 1/2] fix(BREV-1636): allow backoff to be altered via options pattern --- internal/lambdalabs/v1/client.go | 85 +++++++++++++++++++++----- internal/lambdalabs/v1/common_test.go | 6 +- internal/lambdalabs/v1/credential.go | 2 +- internal/lambdalabs/v1/instance.go | 12 ++-- internal/lambdalabs/v1/instancetype.go | 2 +- 5 files changed, 82 insertions(+), 25 deletions(-) diff --git a/internal/lambdalabs/v1/client.go b/internal/lambdalabs/v1/client.go index b05a1b62..ec5fb693 100644 --- a/internal/lambdalabs/v1/client.go +++ b/internal/lambdalabs/v1/client.go @@ -19,22 +19,82 @@ type LambdaLabsClient struct { baseURL string client *openapi.APIClient location string + backoff backoff.BackOff } var _ v1.CloudClient = &LambdaLabsClient{} +type options struct { + baseURL string + client *openapi.APIClient + location string + backoff backoff.BackOff +} + +type Option func(options *options) error + +func WithBaseURL(baseURL string) Option { + return func(options *options) error { + options.baseURL = baseURL + return nil + } +} + +func WithClient(client *openapi.APIClient) Option { + return func(options *options) error { + options.client = client + return nil + } +} + +func WithLocation(location string) Option { + return func(options *options) error { + options.location = location + return nil + } +} + +func WithBackoff(backoff backoff.BackOff) Option { + return func(options *options) error { + options.backoff = backoff + return nil + } +} + // NewLambdaLabsClient creates a new Lambda Labs client -func NewLambdaLabsClient(refID, apiKey string) *LambdaLabsClient { - config := openapi.NewConfiguration() - config.HTTPClient = http.DefaultClient - client := openapi.NewAPIClient(config) +func NewLambdaLabsClient(refID, apiKey string, opts ...Option) (*LambdaLabsClient, error) { + var options options + for _, opt := range opts { + if err := opt(&options); err != nil { + return nil, err + } + } - return &LambdaLabsClient{ - refID: refID, - apiKey: apiKey, - baseURL: "https://cloud.lambda.ai/api/v1", - client: client, + if options.baseURL == "" { + options.baseURL = "https://cloud.lambda.ai/api/v1" + } + + if options.client == nil { + config := openapi.NewConfiguration() + config.HTTPClient = http.DefaultClient + options.client = openapi.NewAPIClient(config) + } + + if options.backoff == nil { + bo := backoff.NewExponentialBackOff() + bo.InitialInterval = 1000 * time.Millisecond + bo.MaxElapsedTime = 120 * time.Second + options.backoff = bo } + + return &LambdaLabsClient{ + refID: refID, + apiKey: apiKey, + baseURL: options.baseURL, + client: options.client, + location: options.location, + backoff: options.backoff, + }, nil } // GetAPIType returns the API type for Lambda Labs @@ -73,10 +133,3 @@ func (c *LambdaLabsClient) makeAuthContext(ctx context.Context) context.Context UserName: c.apiKey, }) } - -func getBackoff() backoff.BackOff { - bo := backoff.NewExponentialBackOff() - bo.InitialInterval = 1000 * time.Millisecond - bo.MaxElapsedTime = 120 * time.Second - return bo -} diff --git a/internal/lambdalabs/v1/common_test.go b/internal/lambdalabs/v1/common_test.go index 8c8be2fc..df34d3fa 100644 --- a/internal/lambdalabs/v1/common_test.go +++ b/internal/lambdalabs/v1/common_test.go @@ -2,6 +2,7 @@ package v1 import ( openapi "github.com/brevdev/cloud/internal/lambdalabs/gen/lambdalabs" + "github.com/cenkalti/backoff/v4" "github.com/jarcoal/httpmock" ) @@ -12,7 +13,10 @@ const ( func setupMockClient() (*LambdaLabsClient, func()) { httpmock.Activate() - client := NewLambdaLabsClient("test-ref-id", "test-api-key") + client, err := NewLambdaLabsClient("test-ref-id", "test-api-key", WithBackoff(&backoff.StopBackOff{})) + if err != nil { + panic(err) + } return client, httpmock.DeactivateAndReset } diff --git a/internal/lambdalabs/v1/credential.go b/internal/lambdalabs/v1/credential.go index cd87d41e..e0de8bac 100644 --- a/internal/lambdalabs/v1/credential.go +++ b/internal/lambdalabs/v1/credential.go @@ -50,5 +50,5 @@ func (c *LambdaLabsCredential) GetTenantID() (string, error) { // MakeClient creates a new Lambda Labs client from this credential func (c *LambdaLabsCredential) MakeClient(_ context.Context, _ string) (v1.CloudClient, error) { - return NewLambdaLabsClient(c.RefID, c.APIKey), nil + return NewLambdaLabsClient(c.RefID, c.APIKey) } diff --git a/internal/lambdalabs/v1/instance.go b/internal/lambdalabs/v1/instance.go index c2c5439c..412841b9 100644 --- a/internal/lambdalabs/v1/instance.go +++ b/internal/lambdalabs/v1/instance.go @@ -238,7 +238,7 @@ func (c *LambdaLabsClient) addSSHKey(ctx context.Context, request openapi.AddSSH return &openapi.AddSSHKey200Response{}, handleAPIError(ctx, resp, err) } return res, nil - }, getBackoff()) + }, c.backoff) if err != nil { return nil, err } @@ -255,7 +255,7 @@ func (c *LambdaLabsClient) launchInstance(ctx context.Context, request openapi.L return &openapi.LaunchInstance200Response{}, handleAPIError(ctx, resp, err) } return res, nil - }, getBackoff()) + }, c.backoff) if err != nil { return nil, err } @@ -272,7 +272,7 @@ func (c *LambdaLabsClient) getInstance(ctx context.Context, instanceID string) ( return &openapi.GetInstance200Response{}, handleAPIError(ctx, resp, err) } return res, nil - }, getBackoff()) + }, c.backoff) if err != nil { return nil, err } @@ -289,7 +289,7 @@ func (c *LambdaLabsClient) terminateInstance(ctx context.Context, request openap return &openapi.TerminateInstance200Response{}, handleAPIError(ctx, resp, err) } return res, nil - }, getBackoff()) + }, c.backoff) if err != nil { return nil, err } @@ -306,7 +306,7 @@ func (c *LambdaLabsClient) listInstances(ctx context.Context) (*openapi.ListInst return &openapi.ListInstances200Response{}, handleAPIError(ctx, resp, err) } return res, nil - }, getBackoff()) + }, c.backoff) if err != nil { return nil, err } @@ -323,7 +323,7 @@ func (c *LambdaLabsClient) restartInstance(ctx context.Context, request openapi. return &openapi.RestartInstance200Response{}, handleAPIError(ctx, resp, err) } return res, nil - }, getBackoff()) + }, c.backoff) if err != nil { return nil, err } diff --git a/internal/lambdalabs/v1/instancetype.go b/internal/lambdalabs/v1/instancetype.go index 4493f360..219f5b19 100644 --- a/internal/lambdalabs/v1/instancetype.go +++ b/internal/lambdalabs/v1/instancetype.go @@ -99,7 +99,7 @@ func (c *LambdaLabsClient) getInstanceTypes(ctx context.Context) (*openapi.Insta return &openapi.InstanceTypes200Response{}, handleAPIError(ctx, resp, err) } return res, nil - }, getBackoff()) + }, c.backoff) if err != nil { return nil, err } From ccd2d497f6535c69adaf179bd86254c32ae7fb2b Mon Sep 17 00:00:00 2001 From: Drew Malin Date: Wed, 27 Aug 2025 10:01:57 -0700 Subject: [PATCH 2/2] tests --- go.mod | 1 + go.sum | 2 ++ internal/lambdalabs/v1/client.go | 21 ++++++++++++--- internal/lambdalabs/v1/client_test.go | 39 +++++++++++++++++++++++++++ 4 files changed, 60 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 546e671a..c80bea05 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.25.0 require ( github.com/alecthomas/units v0.0.0-20240927000941-0f3dac36c52b github.com/bojanz/currency v1.3.1 + github.com/cenkalti/backoff v2.2.1+incompatible github.com/cenkalti/backoff/v4 v4.3.0 github.com/gliderlabs/ssh v0.3.8 github.com/google/go-cmp v0.7.0 diff --git a/go.sum b/go.sum index 2e2b9714..91dfffaa 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/bojanz/currency v1.3.1 h1:3BUAvy/5hU/Pzqg5nrQslVihV50QG+A2xKPoQw1RKH4= github.com/bojanz/currency v1.3.1/go.mod h1:jNoZiJyRTqoU5DFoa+n+9lputxPUDa8Fz8BdDrW06Go= +github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= +github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cockroachdb/apd/v3 v3.2.1 h1:U+8j7t0axsIgvQUqthuNm82HIrYXodOV2iWLWtEaIwg= diff --git a/internal/lambdalabs/v1/client.go b/internal/lambdalabs/v1/client.go index ec5fb693..7eab9999 100644 --- a/internal/lambdalabs/v1/client.go +++ b/internal/lambdalabs/v1/client.go @@ -2,6 +2,7 @@ package v1 import ( "context" + "fmt" "net/http" "time" @@ -10,6 +11,12 @@ import ( "github.com/cenkalti/backoff/v4" ) +const ( + defaultBaseURL = "https://cloud.lambda.ai/api/v1" + defaultBackoffInitialInterval = 1000 * time.Millisecond + defaultBackoffMaxElapsedTime = 120 * time.Second +) + // LambdaLabsClient implements the CloudClient interface for Lambda Labs // It embeds NotImplCloudClient to handle unsupported features type LambdaLabsClient struct { @@ -33,6 +40,7 @@ type options struct { type Option func(options *options) error +// WithBaseURL sets the base URL for the Lambda Labs client func WithBaseURL(baseURL string) Option { return func(options *options) error { options.baseURL = baseURL @@ -40,6 +48,7 @@ func WithBaseURL(baseURL string) Option { } } +// WithClient sets the OpenAPI HTTP client for the Lambda Labs client func WithClient(client *openapi.APIClient) Option { return func(options *options) error { options.client = client @@ -47,6 +56,7 @@ func WithClient(client *openapi.APIClient) Option { } } +// WithLocation sets the location for the Lambda Labs client func WithLocation(location string) Option { return func(options *options) error { options.location = location @@ -54,6 +64,7 @@ func WithLocation(location string) Option { } } +// WithBackoff sets the backoff settings used to retry API calls for the Lambda Labs client func WithBackoff(backoff backoff.BackOff) Option { return func(options *options) error { options.backoff = backoff @@ -70,8 +81,12 @@ func NewLambdaLabsClient(refID, apiKey string, opts ...Option) (*LambdaLabsClien } } + if refID == "" || apiKey == "" { + return nil, fmt.Errorf("refID and apiKey are required") + } + if options.baseURL == "" { - options.baseURL = "https://cloud.lambda.ai/api/v1" + options.baseURL = defaultBaseURL } if options.client == nil { @@ -82,8 +97,8 @@ func NewLambdaLabsClient(refID, apiKey string, opts ...Option) (*LambdaLabsClien if options.backoff == nil { bo := backoff.NewExponentialBackOff() - bo.InitialInterval = 1000 * time.Millisecond - bo.MaxElapsedTime = 120 * time.Second + bo.InitialInterval = defaultBackoffInitialInterval + bo.MaxElapsedTime = defaultBackoffMaxElapsedTime options.backoff = bo } diff --git a/internal/lambdalabs/v1/client_test.go b/internal/lambdalabs/v1/client_test.go index 16e1692e..0c026bcf 100644 --- a/internal/lambdalabs/v1/client_test.go +++ b/internal/lambdalabs/v1/client_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/cenkalti/backoff" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -51,3 +52,41 @@ func TestLambdaLabsClient_makeAuthContext(t *testing.T) { assert.Equal(t, "test-api-key", basicAuth.UserName) assert.Equal(t, "", basicAuth.Password) } + +func TestLambdaLabsClient_NewLambdaLabsClientRequiredFields(t *testing.T) { + _, err := NewLambdaLabsClient("", "") + require.Error(t, err) + assert.Equal(t, "refID and apiKey are required", err.Error()) +} + +func TestLambdaLabsClient_NewLambdaLabsClientWithBaseURL(t *testing.T) { + baseURL := "https://test.lambda.ai/api/v1" + + client, err := NewLambdaLabsClient("test-ref-id", "test-api-key", WithBaseURL(baseURL)) + require.NoError(t, err) + assert.Equal(t, baseURL, client.baseURL) +} + +func TestLambdaLabsClient_NewLambdaLabsClientWithClient(t *testing.T) { + apiClient := openapi.NewAPIClient(openapi.NewConfiguration()) + + client, err := NewLambdaLabsClient("test-ref-id", "test-api-key", WithClient(apiClient)) + require.NoError(t, err) + assert.Equal(t, apiClient, client.client) +} + +func TestLambdaLabsClient_NewLambdaLabsClientWithLocation(t *testing.T) { + location := "us-west-1" + + client, err := NewLambdaLabsClient("test-ref-id", "test-api-key", WithLocation(location)) + require.NoError(t, err) + assert.Equal(t, location, client.location) +} + +func TestLambdaLabsClient_NewLambdaLabsClientWithBackoff(t *testing.T) { + backoff := &backoff.ZeroBackOff{} + + client, err := NewLambdaLabsClient("test-ref-id", "test-api-key", WithBackoff(backoff)) + require.NoError(t, err) + assert.Equal(t, backoff, client.backoff) +}