Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.25.0
require (
github.com/alecthomas/units v0.0.0-20240927000941-0f3dac36c52b
github.com/bojanz/currency v1.3.1
github.com/cenkalti/backoff v2.2.1+incompatible
github.com/cenkalti/backoff/v4 v4.3.0
github.com/gliderlabs/ssh v0.3.8
github.com/google/go-cmp v0.7.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/bojanz/currency v1.3.1 h1:3BUAvy/5hU/Pzqg5nrQslVihV50QG+A2xKPoQw1RKH4=
github.com/bojanz/currency v1.3.1/go.mod h1:jNoZiJyRTqoU5DFoa+n+9lputxPUDa8Fz8BdDrW06Go=
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cockroachdb/apd/v3 v3.2.1 h1:U+8j7t0axsIgvQUqthuNm82HIrYXodOV2iWLWtEaIwg=
Expand Down
100 changes: 84 additions & 16 deletions internal/lambdalabs/v1/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package v1

import (
"context"
"fmt"
"net/http"
"time"

Expand All @@ -10,6 +11,12 @@ import (
"github.com/cenkalti/backoff/v4"
)

const (
defaultBaseURL = "https://cloud.lambda.ai/api/v1"
defaultBackoffInitialInterval = 1000 * time.Millisecond
defaultBackoffMaxElapsedTime = 120 * time.Second
)

// LambdaLabsClient implements the CloudClient interface for Lambda Labs
// It embeds NotImplCloudClient to handle unsupported features
type LambdaLabsClient struct {
Expand All @@ -19,22 +26,90 @@ type LambdaLabsClient struct {
baseURL string
client *openapi.APIClient
location string
backoff backoff.BackOff
}

var _ v1.CloudClient = &LambdaLabsClient{}

type options struct {
baseURL string
client *openapi.APIClient
location string
backoff backoff.BackOff
}

type Option func(options *options) error

// WithBaseURL sets the base URL for the Lambda Labs client
func WithBaseURL(baseURL string) Option {
return func(options *options) error {
options.baseURL = baseURL
return nil
}
}

// WithClient sets the OpenAPI HTTP client for the Lambda Labs client
func WithClient(client *openapi.APIClient) Option {
return func(options *options) error {
options.client = client
return nil
}
}

// WithLocation sets the location for the Lambda Labs client
func WithLocation(location string) Option {
return func(options *options) error {
options.location = location
return nil
}
}

// WithBackoff sets the backoff settings used to retry API calls for the Lambda Labs client
func WithBackoff(backoff backoff.BackOff) Option {
return func(options *options) error {
options.backoff = backoff
return nil
}
}

// NewLambdaLabsClient creates a new Lambda Labs client
func NewLambdaLabsClient(refID, apiKey string) *LambdaLabsClient {
config := openapi.NewConfiguration()
config.HTTPClient = http.DefaultClient
client := openapi.NewAPIClient(config)
func NewLambdaLabsClient(refID, apiKey string, opts ...Option) (*LambdaLabsClient, error) {
var options options
for _, opt := range opts {
if err := opt(&options); err != nil {
return nil, err
}
}

return &LambdaLabsClient{
refID: refID,
apiKey: apiKey,
baseURL: "https://cloud.lambda.ai/api/v1",
client: client,
if refID == "" || apiKey == "" {
return nil, fmt.Errorf("refID and apiKey are required")
}

if options.baseURL == "" {
options.baseURL = defaultBaseURL
}

if options.client == nil {
config := openapi.NewConfiguration()
config.HTTPClient = http.DefaultClient
options.client = openapi.NewAPIClient(config)
}

if options.backoff == nil {
bo := backoff.NewExponentialBackOff()
bo.InitialInterval = defaultBackoffInitialInterval
bo.MaxElapsedTime = defaultBackoffMaxElapsedTime
options.backoff = bo
}

return &LambdaLabsClient{
refID: refID,
apiKey: apiKey,
baseURL: options.baseURL,
client: options.client,
location: options.location,
backoff: options.backoff,
}, nil
}

// GetAPIType returns the API type for Lambda Labs
Expand Down Expand Up @@ -73,10 +148,3 @@ func (c *LambdaLabsClient) makeAuthContext(ctx context.Context) context.Context
UserName: c.apiKey,
})
}

func getBackoff() backoff.BackOff {
bo := backoff.NewExponentialBackOff()
bo.InitialInterval = 1000 * time.Millisecond
bo.MaxElapsedTime = 120 * time.Second
return bo
}
39 changes: 39 additions & 0 deletions internal/lambdalabs/v1/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"testing"

"github.com/cenkalti/backoff"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -51,3 +52,41 @@ func TestLambdaLabsClient_makeAuthContext(t *testing.T) {
assert.Equal(t, "test-api-key", basicAuth.UserName)
assert.Equal(t, "", basicAuth.Password)
}

func TestLambdaLabsClient_NewLambdaLabsClientRequiredFields(t *testing.T) {
_, err := NewLambdaLabsClient("", "")
require.Error(t, err)
assert.Equal(t, "refID and apiKey are required", err.Error())
}

func TestLambdaLabsClient_NewLambdaLabsClientWithBaseURL(t *testing.T) {
baseURL := "https://test.lambda.ai/api/v1"

client, err := NewLambdaLabsClient("test-ref-id", "test-api-key", WithBaseURL(baseURL))
require.NoError(t, err)
assert.Equal(t, baseURL, client.baseURL)
}

func TestLambdaLabsClient_NewLambdaLabsClientWithClient(t *testing.T) {
apiClient := openapi.NewAPIClient(openapi.NewConfiguration())

client, err := NewLambdaLabsClient("test-ref-id", "test-api-key", WithClient(apiClient))
require.NoError(t, err)
assert.Equal(t, apiClient, client.client)
}

func TestLambdaLabsClient_NewLambdaLabsClientWithLocation(t *testing.T) {
location := "us-west-1"

client, err := NewLambdaLabsClient("test-ref-id", "test-api-key", WithLocation(location))
require.NoError(t, err)
assert.Equal(t, location, client.location)
}

func TestLambdaLabsClient_NewLambdaLabsClientWithBackoff(t *testing.T) {
backoff := &backoff.ZeroBackOff{}

client, err := NewLambdaLabsClient("test-ref-id", "test-api-key", WithBackoff(backoff))
require.NoError(t, err)
assert.Equal(t, backoff, client.backoff)
}
6 changes: 5 additions & 1 deletion internal/lambdalabs/v1/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package v1

import (
openapi "github.com/brevdev/cloud/internal/lambdalabs/gen/lambdalabs"
"github.com/cenkalti/backoff/v4"
"github.com/jarcoal/httpmock"
)

Expand All @@ -12,7 +13,10 @@ const (

func setupMockClient() (*LambdaLabsClient, func()) {
httpmock.Activate()
client := NewLambdaLabsClient("test-ref-id", "test-api-key")
client, err := NewLambdaLabsClient("test-ref-id", "test-api-key", WithBackoff(&backoff.StopBackOff{}))
if err != nil {
panic(err)
}
return client, httpmock.DeactivateAndReset
}

Expand Down
2 changes: 1 addition & 1 deletion internal/lambdalabs/v1/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,5 @@ func (c *LambdaLabsCredential) GetTenantID() (string, error) {

// MakeClient creates a new Lambda Labs client from this credential
func (c *LambdaLabsCredential) MakeClient(_ context.Context, _ string) (v1.CloudClient, error) {
return NewLambdaLabsClient(c.RefID, c.APIKey), nil
return NewLambdaLabsClient(c.RefID, c.APIKey)
}
12 changes: 6 additions & 6 deletions internal/lambdalabs/v1/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func (c *LambdaLabsClient) addSSHKey(ctx context.Context, request openapi.AddSSH
return &openapi.AddSSHKey200Response{}, handleAPIError(ctx, resp, err)
}
return res, nil
}, getBackoff())
}, c.backoff)
if err != nil {
return nil, err
}
Expand All @@ -255,7 +255,7 @@ func (c *LambdaLabsClient) launchInstance(ctx context.Context, request openapi.L
return &openapi.LaunchInstance200Response{}, handleAPIError(ctx, resp, err)
}
return res, nil
}, getBackoff())
}, c.backoff)
if err != nil {
return nil, err
}
Expand All @@ -272,7 +272,7 @@ func (c *LambdaLabsClient) getInstance(ctx context.Context, instanceID string) (
return &openapi.GetInstance200Response{}, handleAPIError(ctx, resp, err)
}
return res, nil
}, getBackoff())
}, c.backoff)
if err != nil {
return nil, err
}
Expand All @@ -289,7 +289,7 @@ func (c *LambdaLabsClient) terminateInstance(ctx context.Context, request openap
return &openapi.TerminateInstance200Response{}, handleAPIError(ctx, resp, err)
}
return res, nil
}, getBackoff())
}, c.backoff)
if err != nil {
return nil, err
}
Expand All @@ -306,7 +306,7 @@ func (c *LambdaLabsClient) listInstances(ctx context.Context) (*openapi.ListInst
return &openapi.ListInstances200Response{}, handleAPIError(ctx, resp, err)
}
return res, nil
}, getBackoff())
}, c.backoff)
if err != nil {
return nil, err
}
Expand All @@ -323,7 +323,7 @@ func (c *LambdaLabsClient) restartInstance(ctx context.Context, request openapi.
return &openapi.RestartInstance200Response{}, handleAPIError(ctx, resp, err)
}
return res, nil
}, getBackoff())
}, c.backoff)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/lambdalabs/v1/instancetype.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (c *LambdaLabsClient) getInstanceTypes(ctx context.Context) (*openapi.Insta
return &openapi.InstanceTypes200Response{}, handleAPIError(ctx, resp, err)
}
return res, nil
}, getBackoff())
}, c.backoff)
if err != nil {
return nil, err
}
Expand Down
Loading