diff --git a/go.mod b/go.mod index bd8a576a..fa938260 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( require ( github.com/cockroachdb/apd/v3 v3.2.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/jarcoal/httpmock v1.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 57bbe570..080cadf7 100644 --- a/go.sum +++ b/go.sum @@ -11,6 +11,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jarcoal/httpmock v1.4.0 h1:BvhqnH0JAYbNudL2GMJKgOHe2CtKlzJ/5rWKyp+hc2k= +github.com/jarcoal/httpmock v1.4.0/go.mod h1:ftW1xULwo+j0R0JJkJIIi7UKigZUXCLLanykgjwBXL0= github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/internal/lambdalabs/v1/capabilities_test.go b/internal/lambdalabs/v1/capabilities_test.go new file mode 100644 index 00000000..a2bebb63 --- /dev/null +++ b/internal/lambdalabs/v1/capabilities_test.go @@ -0,0 +1,33 @@ +package v1 + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + v1 "github.com/brevdev/compute/pkg/v1" +) + +func TestLambdaLabsClient_GetCapabilities(t *testing.T) { + client := &LambdaLabsClient{} + capabilities, err := client.GetCapabilities(context.Background()) + require.NoError(t, err) + + assert.Contains(t, capabilities, v1.CapabilityCreateInstance) + assert.Contains(t, capabilities, v1.CapabilityTerminateInstance) + assert.Contains(t, capabilities, v1.CapabilityRebootInstance) + assert.NotContains(t, capabilities, v1.CapabilityStopStartInstance) +} + +func TestLambdaLabsCredential_GetCapabilities(t *testing.T) { + cred := &LambdaLabsCredential{} + capabilities, err := cred.GetCapabilities(context.Background()) + require.NoError(t, err) + + assert.Contains(t, capabilities, v1.CapabilityCreateInstance) + assert.Contains(t, capabilities, v1.CapabilityTerminateInstance) + assert.Contains(t, capabilities, v1.CapabilityRebootInstance) + assert.NotContains(t, capabilities, v1.CapabilityStopStartInstance) +} diff --git a/internal/lambdalabs/v1/client_test.go b/internal/lambdalabs/v1/client_test.go new file mode 100644 index 00000000..105232d3 --- /dev/null +++ b/internal/lambdalabs/v1/client_test.go @@ -0,0 +1,53 @@ +package v1 + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + openapi "github.com/brevdev/cloud/internal/lambdalabs/gen/lambdalabs" + v1 "github.com/brevdev/compute/pkg/v1" +) + +func TestLambdaLabsClient_GetAPIType(t *testing.T) { + client := &LambdaLabsClient{} + assert.Equal(t, v1.APITypeGlobal, client.GetAPIType()) +} + +func TestLambdaLabsClient_GetCloudProviderID(t *testing.T) { + client := &LambdaLabsClient{} + assert.Equal(t, v1.CloudProviderID("lambdalabs"), client.GetCloudProviderID()) +} + +func TestLambdaLabsClient_MakeClient(t *testing.T) { + client := &LambdaLabsClient{ + refID: "test-ref-id", + apiKey: "test-api-key", + } + + newClient, err := client.MakeClient(context.Background(), "test-tenant") + require.NoError(t, err) + lambdaClient, ok := newClient.(*LambdaLabsClient) + require.True(t, ok) + assert.Equal(t, client, lambdaClient) +} + +func TestLambdaLabsClient_GetReferenceID(t *testing.T) { + client := &LambdaLabsClient{refID: "test-ref-id"} + assert.Equal(t, "test-ref-id", client.GetReferenceID()) +} + +func TestLambdaLabsClient_makeAuthContext(t *testing.T) { + client := &LambdaLabsClient{apiKey: "test-api-key"} + ctx := client.makeAuthContext(context.Background()) + + auth := ctx.Value(openapi.ContextBasicAuth) + require.NotNil(t, auth) + + basicAuth, ok := auth.(openapi.BasicAuth) + require.True(t, ok) + assert.Equal(t, "test-api-key", basicAuth.UserName) + assert.Equal(t, "", basicAuth.Password) +} diff --git a/internal/lambdalabs/v1/common_test.go b/internal/lambdalabs/v1/common_test.go new file mode 100644 index 00000000..8c8be2fc --- /dev/null +++ b/internal/lambdalabs/v1/common_test.go @@ -0,0 +1,49 @@ +package v1 + +import ( + openapi "github.com/brevdev/cloud/internal/lambdalabs/gen/lambdalabs" + "github.com/jarcoal/httpmock" +) + +const ( + testInstanceID = "test-instance-id" + nonexistentInstance = "nonexistent-instance" +) + +func setupMockClient() (*LambdaLabsClient, func()) { + httpmock.Activate() + client := NewLambdaLabsClient("test-ref-id", "test-api-key") + return client, httpmock.DeactivateAndReset +} + +func createMockInstance(instanceID string) openapi.Instance { + name := "test-instance" + ip := "192.168.1.100" + privateIP := "10.0.1.100" + hostname := "test-instance.lambda.ai" + + return openapi.Instance{ + Id: instanceID, + Name: *openapi.NewNullableString(&name), + Ip: *openapi.NewNullableString(&ip), + PrivateIp: *openapi.NewNullableString(&privateIP), + Status: "active", + SshKeyNames: []string{"test-key"}, + FileSystemNames: []string{}, + Region: &openapi.Region{ + Name: "us-west-1", + Description: "US West 1", + }, + InstanceType: &openapi.InstanceType{ + Name: "gpu_1x_a10", + Description: "1x NVIDIA A10 GPU", + GpuDescription: "NVIDIA A10", + PriceCentsPerHour: 100, + Specs: openapi.InstanceTypeSpecs{ + MemoryGib: 32, + StorageGib: 512, + }, + }, + Hostname: *openapi.NewNullableString(&hostname), + } +} diff --git a/internal/lambdalabs/v1/credential_test.go b/internal/lambdalabs/v1/credential_test.go new file mode 100644 index 00000000..abdb5b2d --- /dev/null +++ b/internal/lambdalabs/v1/credential_test.go @@ -0,0 +1,52 @@ +package v1 + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + v1 "github.com/brevdev/compute/pkg/v1" +) + +func TestLambdaLabsCredential_GetReferenceID(t *testing.T) { + cred := &LambdaLabsCredential{ + RefID: "test-ref-id", + APIKey: "test-api-key", + } + + assert.Equal(t, "test-ref-id", cred.GetReferenceID()) +} + +func TestLambdaLabsCredential_GetAPIType(t *testing.T) { + cred := &LambdaLabsCredential{} + assert.Equal(t, v1.APITypeGlobal, cred.GetAPIType()) +} + +func TestLambdaLabsCredential_GetCloudProviderID(t *testing.T) { + cred := &LambdaLabsCredential{} + assert.Equal(t, v1.CloudProviderID("lambdalabs"), cred.GetCloudProviderID()) +} + +func TestLambdaLabsCredential_GetTenantID(t *testing.T) { + cred := &LambdaLabsCredential{APIKey: "test-key"} + tenantID, err := cred.GetTenantID() + assert.NoError(t, err) + assert.Contains(t, tenantID, "lambdalabs-") +} + +func TestLambdaLabsCredential_MakeClient(t *testing.T) { + cred := &LambdaLabsCredential{ + RefID: "test-ref-id", + APIKey: "test-api-key", + } + + client, err := cred.MakeClient(context.Background(), "test-tenant") + require.NoError(t, err) + + lambdaClient, ok := client.(*LambdaLabsClient) + require.True(t, ok) + assert.Equal(t, "test-ref-id", lambdaClient.refID) + assert.Equal(t, "test-api-key", lambdaClient.apiKey) +} diff --git a/internal/lambdalabs/v1/helpers_test.go b/internal/lambdalabs/v1/helpers_test.go new file mode 100644 index 00000000..d6601225 --- /dev/null +++ b/internal/lambdalabs/v1/helpers_test.go @@ -0,0 +1,79 @@ +package v1 + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + v1 "github.com/brevdev/compute/pkg/v1" +) + +func TestConvertLambdaLabsInstanceToV1Instance(t *testing.T) { + lambdaInstance := createMockInstance("test-instance-id") + + v1Instance := convertLambdaLabsInstanceToV1Instance(lambdaInstance) + + assert.Equal(t, "test-instance-id", string(v1Instance.CloudID)) + assert.Equal(t, "test-instance", v1Instance.Name) + assert.Equal(t, v1.LifecycleStatusRunning, v1Instance.Status.LifecycleStatus) + assert.Equal(t, "192.168.1.100", v1Instance.PublicIP) + assert.Equal(t, "10.0.1.100", v1Instance.PrivateIP) + assert.Equal(t, "us-west-1", v1Instance.Location) + assert.Equal(t, "gpu_1x_a10", v1Instance.InstanceType) +} + +func TestConvertLambdaLabsStatusToV1Status(t *testing.T) { + tests := []struct { + lambdaStatus string + expected v1.LifecycleStatus + }{ + {"active", v1.LifecycleStatusRunning}, + {"booting", v1.LifecycleStatusPending}, + {"unhealthy", v1.LifecycleStatusRunning}, + {"terminating", v1.LifecycleStatusTerminating}, + {"terminated", v1.LifecycleStatusTerminated}, + {"error", v1.LifecycleStatusFailed}, + } + + for _, test := range tests { + t.Run(test.lambdaStatus, func(t *testing.T) { + result := convertLambdaLabsStatusToV1Status(test.lambdaStatus) + assert.Equal(t, test.expected, result) + }) + } +} + +func TestMergeInstanceForUpdate(t *testing.T) { + client := &LambdaLabsClient{} + original := v1.Instance{ + CloudID: "test-id", + Name: "original-name", + Status: v1.Status{LifecycleStatus: v1.LifecycleStatusRunning}, + } + + update := v1.Instance{ + Name: "updated-name", + Status: v1.Status{LifecycleStatus: v1.LifecycleStatusTerminated}, + } + + merged := client.MergeInstanceForUpdate(original, update) + + assert.Equal(t, "updated-name", merged.Name) + assert.Equal(t, v1.LifecycleStatusTerminated, merged.Status.LifecycleStatus) +} + +func TestMergeInstanceTypeForUpdate(t *testing.T) { + client := &LambdaLabsClient{} + original := v1.InstanceType{ + ID: "test-id", + Type: "original-type", + } + + update := v1.InstanceType{ + Type: "updated-type", + } + + merged := client.MergeInstanceTypeForUpdate(original, update) + + assert.Equal(t, "updated-type", merged.Type) +} diff --git a/internal/lambdalabs/v1/instance_test.go b/internal/lambdalabs/v1/instance_test.go new file mode 100644 index 00000000..aef19e76 --- /dev/null +++ b/internal/lambdalabs/v1/instance_test.go @@ -0,0 +1,266 @@ +package v1 + +import ( + "context" + "fmt" + "testing" + + "github.com/jarcoal/httpmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + openapi "github.com/brevdev/cloud/internal/lambdalabs/gen/lambdalabs" + v1 "github.com/brevdev/compute/pkg/v1" +) + +func TestLambdaLabsClient_CreateInstance_Success(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := testInstanceID + publicKey := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQ test@example.com" + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/ssh-keys", + httpmock.NewJsonResponderOrPanic(200, openapi.AddSSHKey200Response{ + Data: openapi.SshKey{ + Id: "ssh-key-id", + Name: "test-instance-id", + PublicKey: publicKey, + }, + })) + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/launch", + httpmock.NewJsonResponderOrPanic(200, openapi.LaunchInstance200Response{ + Data: openapi.LaunchInstance200ResponseData{ + InstanceIds: []string{instanceID}, + }, + })) + + mockInstance := createMockInstance(instanceID) + httpmock.RegisterResponder("GET", fmt.Sprintf("https://cloud.lambda.ai/api/v1/instances/%s", instanceID), + httpmock.NewJsonResponderOrPanic(200, openapi.GetInstance200Response{ + Data: mockInstance, + })) + + args := v1.CreateInstanceAttrs{ + InstanceType: "gpu_1x_a10", + Location: "us-west-1", + PublicKey: publicKey, + Name: "test-instance", + } + + instance, err := client.CreateInstance(context.Background(), args) + require.NoError(t, err) + assert.Equal(t, instanceID, string(instance.CloudID)) + assert.Contains(t, instance.Name, "test-instance") + assert.Equal(t, v1.LifecycleStatusRunning, instance.Status.LifecycleStatus) +} + +func TestLambdaLabsClient_CreateInstance_WithoutPublicKey(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := testInstanceID + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/launch", + httpmock.NewJsonResponderOrPanic(200, openapi.LaunchInstance200Response{ + Data: openapi.LaunchInstance200ResponseData{ + InstanceIds: []string{instanceID}, + }, + })) + + mockInstance := createMockInstance(instanceID) + httpmock.RegisterResponder("GET", fmt.Sprintf("https://cloud.lambda.ai/api/v1/instances/%s", instanceID), + httpmock.NewJsonResponderOrPanic(200, openapi.GetInstance200Response{ + Data: mockInstance, + })) + + args := v1.CreateInstanceAttrs{ + InstanceType: "gpu_1x_a10", + Location: "us-west-1", + Name: "test-instance", + } + + instance, err := client.CreateInstance(context.Background(), args) + require.NoError(t, err) + assert.Equal(t, instanceID, string(instance.CloudID)) +} + +func TestLambdaLabsClient_CreateInstance_SSHKeyError(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + publicKey := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQ test@example.com" + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/ssh-keys", + httpmock.NewStringResponder(400, `{"error": {"code": "INVALID_REQUEST", "message": "SSH key already exists"}}`)) + + args := v1.CreateInstanceAttrs{ + InstanceType: "gpu_1x_a10", + Location: "us-west-1", + PublicKey: publicKey, + Name: "test-instance", + } + + _, err := client.CreateInstance(context.Background(), args) + assert.Error(t, err) +} + +func TestLambdaLabsClient_CreateInstance_LaunchError(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/launch", + httpmock.NewStringResponder(400, `{"error": {"code": "INVALID_REQUEST", "message": "Instance type not available"}}`)) + + args := v1.CreateInstanceAttrs{ + InstanceType: "gpu_1x_a10", + Location: "us-west-1", + Name: "test-instance", + } + + _, err := client.CreateInstance(context.Background(), args) + assert.Error(t, err) +} + +func TestLambdaLabsClient_GetInstance_Success(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := testInstanceID + mockInstance := createMockInstance(instanceID) + + httpmock.RegisterResponder("GET", fmt.Sprintf("https://cloud.lambda.ai/api/v1/instances/%s", instanceID), + httpmock.NewJsonResponderOrPanic(200, openapi.GetInstance200Response{ + Data: mockInstance, + })) + + instance, err := client.GetInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) + require.NoError(t, err) + assert.Equal(t, instanceID, string(instance.CloudID)) + assert.Equal(t, "test-instance", instance.Name) + assert.Equal(t, v1.LifecycleStatusRunning, instance.Status.LifecycleStatus) +} + +func TestLambdaLabsClient_GetInstance_NotFound(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := nonexistentInstance + + httpmock.RegisterResponder("GET", fmt.Sprintf("https://cloud.lambda.ai/api/v1/instances/%s", instanceID), + httpmock.NewStringResponder(404, `{"error": {"code": "NOT_FOUND", "message": "Instance not found"}}`)) + + _, err := client.GetInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) + assert.Error(t, err) +} + +func TestLambdaLabsClient_ListInstances_Success(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + mockInstances := []openapi.Instance{ + createMockInstance("instance-1"), + createMockInstance("instance-2"), + } + + httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instances", + httpmock.NewJsonResponderOrPanic(200, openapi.ListInstances200Response{ + Data: mockInstances, + })) + + instances, err := client.ListInstances(context.Background(), v1.ListInstancesArgs{}) + require.NoError(t, err) + assert.Len(t, instances, 2) + assert.Equal(t, "instance-1", string(instances[0].CloudID)) + assert.Equal(t, "instance-2", string(instances[1].CloudID)) +} + +func TestLambdaLabsClient_ListInstances_Empty(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instances", + httpmock.NewJsonResponderOrPanic(200, openapi.ListInstances200Response{ + Data: []openapi.Instance{}, + })) + + instances, err := client.ListInstances(context.Background(), v1.ListInstancesArgs{}) + require.NoError(t, err) + assert.Len(t, instances, 0) +} + +func TestLambdaLabsClient_ListInstances_Error(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instances", + httpmock.NewStringResponder(500, `{"error": {"code": "INTERNAL_ERROR", "message": "Internal server error"}}`)) + + _, err := client.ListInstances(context.Background(), v1.ListInstancesArgs{}) + assert.Error(t, err) +} + +func TestLambdaLabsClient_TerminateInstance_Success(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := testInstanceID + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/terminate", + httpmock.NewJsonResponderOrPanic(200, openapi.TerminateInstance200Response{ + Data: openapi.TerminateInstance200ResponseData{ + TerminatedInstances: []openapi.Instance{ + createMockInstance(instanceID), + }, + }, + })) + + err := client.TerminateInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) + assert.NoError(t, err) +} + +func TestLambdaLabsClient_TerminateInstance_Error(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := nonexistentInstance + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/terminate", + httpmock.NewStringResponder(404, `{"error": {"code": "NOT_FOUND", "message": "Instance not found"}}`)) + + err := client.TerminateInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) + assert.Error(t, err) +} + +func TestLambdaLabsClient_RebootInstance_Success(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := testInstanceID + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/restart", + httpmock.NewJsonResponderOrPanic(200, openapi.RestartInstance200Response{ + Data: openapi.RestartInstance200ResponseData{ + RestartedInstances: []openapi.Instance{ + createMockInstance(instanceID), + }, + }, + })) + + err := client.RebootInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) + assert.NoError(t, err) +} + +func TestLambdaLabsClient_RebootInstance_Error(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := nonexistentInstance + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/restart", + httpmock.NewStringResponder(404, `{"error": {"code": "NOT_FOUND", "message": "Instance not found"}}`)) + + err := client.RebootInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) + assert.Error(t, err) +} diff --git a/internal/lambdalabs/v1/instancetype_test.go b/internal/lambdalabs/v1/instancetype_test.go new file mode 100644 index 00000000..13fc98fe --- /dev/null +++ b/internal/lambdalabs/v1/instancetype_test.go @@ -0,0 +1,268 @@ +package v1 + +import ( + "context" + "testing" + "time" + + "github.com/alecthomas/units" + "github.com/jarcoal/httpmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + openapi "github.com/brevdev/cloud/internal/lambdalabs/gen/lambdalabs" + v1 "github.com/brevdev/compute/pkg/v1" +) + +func TestLambdaLabsClient_GetInstanceTypes_Success(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + mockResponse := createMockInstanceTypeResponse() + httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instance-types", + httpmock.NewJsonResponderOrPanic(200, mockResponse)) + + instanceTypes, err := client.GetInstanceTypes(context.Background(), v1.GetInstanceTypeArgs{}) + require.NoError(t, err) + assert.Len(t, instanceTypes, 3) + + a10Type := findInstanceTypeByName(instanceTypes, "gpu_1x_a10") + require.NotNil(t, a10Type) + assert.Equal(t, "gpu_1x_a10", a10Type.Type) + assert.True(t, a10Type.IsAvailable) + assert.Len(t, a10Type.SupportedGPUs, 1) + assert.Equal(t, int32(1), a10Type.SupportedGPUs[0].Count) + assert.Equal(t, "NVIDIA", a10Type.SupportedGPUs[0].Manufacturer) + assert.Equal(t, "NVIDIA A10", a10Type.SupportedGPUs[0].Name) +} + +func TestLambdaLabsClient_GetInstanceTypes_FilterByLocation(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + mockResponse := createMockInstanceTypeResponse() + httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instance-types", + httpmock.NewJsonResponderOrPanic(200, mockResponse)) + + instanceTypes, err := client.GetInstanceTypes(context.Background(), v1.GetInstanceTypeArgs{ + Locations: v1.LocationsFilter{"us-west-1"}, + }) + require.NoError(t, err) + assert.Len(t, instanceTypes, 1) + + for _, instanceType := range instanceTypes { + assert.Equal(t, "us-west-1", instanceType.Location) + } +} + +func TestLambdaLabsClient_GetInstanceTypes_FilterByInstanceType(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + mockResponse := createMockInstanceTypeResponse() + httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instance-types", + httpmock.NewJsonResponderOrPanic(200, mockResponse)) + + instanceTypes, err := client.GetInstanceTypes(context.Background(), v1.GetInstanceTypeArgs{ + InstanceTypes: []string{"gpu_1x_a10"}, + }) + require.NoError(t, err) + assert.Len(t, instanceTypes, 2) + + for _, instanceType := range instanceTypes { + assert.Equal(t, "gpu_1x_a10", instanceType.Type) + } +} + +func TestLambdaLabsClient_GetInstanceTypes_FilterBoth(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + mockResponse := createMockInstanceTypeResponse() + httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instance-types", + httpmock.NewJsonResponderOrPanic(200, mockResponse)) + + instanceTypes, err := client.GetInstanceTypes(context.Background(), v1.GetInstanceTypeArgs{ + Locations: v1.LocationsFilter{"us-east-1"}, + InstanceTypes: []string{"gpu_8x_h100"}, + }) + require.NoError(t, err) + assert.Len(t, instanceTypes, 1) + assert.Equal(t, "gpu_8x_h100", instanceTypes[0].Type) + assert.Equal(t, "us-east-1", instanceTypes[0].Location) +} + +func TestLambdaLabsClient_GetInstanceTypes_Empty(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + emptyResponse := openapi.InstanceTypes200Response{ + Data: map[string]openapi.InstanceTypes200ResponseDataValue{}, + } + httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instance-types", + httpmock.NewJsonResponderOrPanic(200, emptyResponse)) + + instanceTypes, err := client.GetInstanceTypes(context.Background(), v1.GetInstanceTypeArgs{}) + require.NoError(t, err) + assert.Len(t, instanceTypes, 0) +} + +func TestLambdaLabsClient_GetInstanceTypes_Error(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instance-types", + httpmock.NewStringResponder(500, `{"error": {"code": "INTERNAL_ERROR", "message": "Internal server error"}}`)) + + _, err := client.GetInstanceTypes(context.Background(), v1.GetInstanceTypeArgs{}) + assert.Error(t, err) +} + +func TestLambdaLabsClient_GetInstanceTypePollTime(t *testing.T) { + client := &LambdaLabsClient{} + pollTime := client.GetInstanceTypePollTime() + assert.Equal(t, 5*time.Minute, pollTime) +} + +func TestConvertLambdaLabsInstanceTypeToV1InstanceType(t *testing.T) { + llInstanceType := createMockLambdaLabsInstanceType("gpu_1x_a10", "1x NVIDIA A10 (24 GB)", "NVIDIA A10", 100) + + v1InstanceType, err := convertLambdaLabsInstanceTypeToV1InstanceType("us-west-1", llInstanceType, true) + require.NoError(t, err) + + assert.Equal(t, "gpu_1x_a10", v1InstanceType.Type) + assert.Equal(t, "us-west-1", v1InstanceType.Location) + assert.True(t, v1InstanceType.IsAvailable) + assert.Len(t, v1InstanceType.SupportedGPUs, 1) + + gpu := v1InstanceType.SupportedGPUs[0] + assert.Equal(t, int32(1), gpu.Count) + assert.Equal(t, "NVIDIA", gpu.Manufacturer) + assert.Equal(t, "NVIDIA A10", gpu.Name) + assert.Equal(t, "NVIDIA A10", gpu.Type) + assert.Equal(t, units.Base2Bytes(24*1024*1024*1024), gpu.Memory) + + assert.NotNil(t, v1InstanceType.BasePrice) + assert.Equal(t, "USD", v1InstanceType.BasePrice.CurrencyCode()) + assert.Equal(t, "1.00", v1InstanceType.BasePrice.Number()) +} + +func TestConvertLambdaLabsInstanceTypeToV1InstanceType_CPUOnly(t *testing.T) { + llInstanceType := createMockLambdaLabsInstanceType("cpu_4x", "4x CPU cores", "", 50) + + v1InstanceType, err := convertLambdaLabsInstanceTypeToV1InstanceType("us-west-1", llInstanceType, true) + require.NoError(t, err) + + assert.Equal(t, "cpu_4x", v1InstanceType.Type) + assert.Equal(t, "us-west-1", v1InstanceType.Location) + assert.True(t, v1InstanceType.IsAvailable) + assert.Len(t, v1InstanceType.SupportedGPUs, 0) +} + +func TestParseGPUFromDescription(t *testing.T) { + tests := []struct { + description string + expected v1.GPU + }{ + { + description: "1x NVIDIA A10 (24 GB)", + expected: v1.GPU{ + Count: 1, + Manufacturer: "NVIDIA", + Name: "NVIDIA A10", + Type: "NVIDIA A10", + Memory: 24 * 1024 * 1024 * 1024, + }, + }, + { + description: "8x NVIDIA H100 (80 GB)", + expected: v1.GPU{ + Count: 8, + Manufacturer: "NVIDIA", + Name: "NVIDIA H100", + Type: "NVIDIA H100", + Memory: 80 * 1024 * 1024 * 1024, + }, + }, + { + description: "4x NVIDIA RTX 4090 (24 GB)", + expected: v1.GPU{ + Count: 4, + Manufacturer: "NVIDIA", + Name: "NVIDIA RTX 4090", + Type: "NVIDIA RTX 4090", + Memory: 24 * 1024 * 1024 * 1024, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + gpu := parseGPUFromDescription(tt.description) + assert.Equal(t, tt.expected.Count, gpu.Count) + assert.Equal(t, tt.expected.Manufacturer, gpu.Manufacturer) + assert.Equal(t, tt.expected.Name, gpu.Name) + assert.Equal(t, tt.expected.Type, gpu.Type) + assert.Equal(t, tt.expected.Memory, gpu.Memory) + }) + } +} + +func createMockInstanceTypeResponse() openapi.InstanceTypes200Response { + return openapi.InstanceTypes200Response{ + Data: map[string]openapi.InstanceTypes200ResponseDataValue{ + "gpu_1x_a10": { + InstanceType: createMockLambdaLabsInstanceType("gpu_1x_a10", "1x NVIDIA A10 (24 GB)", "NVIDIA A10", 100), + RegionsWithCapacityAvailable: []openapi.Region{ + createMockRegion("us-west-1", "US West 1"), + createMockRegion("us-east-1", "US East 1"), + }, + }, + "gpu_8x_h100": { + InstanceType: createMockLambdaLabsInstanceType("gpu_8x_h100", "8x NVIDIA H100 (80 GB)", "NVIDIA H100", 3200), + RegionsWithCapacityAvailable: []openapi.Region{ + createMockRegion("us-east-1", "US East 1"), + }, + }, + }, + } +} + +func createMockLambdaLabsInstanceType(name, description, gpuDescription string, priceCents int32) openapi.InstanceType { + gpuCount := int32(0) + if gpuDescription != "" { + gpuCount = 1 + if name == "gpu_8x_h100" { + gpuCount = 8 + } + } + + return openapi.InstanceType{ + Name: name, + Description: description, + GpuDescription: gpuDescription, + PriceCentsPerHour: priceCents, + Specs: openapi.InstanceTypeSpecs{ + Vcpus: 8, + MemoryGib: 32, + StorageGib: 512, + Gpus: gpuCount, + }, + } +} + +func createMockRegion(name, description string) openapi.Region { + return openapi.Region{ + Name: name, + Description: description, + } +} + +func findInstanceTypeByName(instanceTypes []v1.InstanceType, name string) *v1.InstanceType { + for _, instanceType := range instanceTypes { + if instanceType.Type == name { + return &instanceType + } + } + return nil +}