From 9a6285659472d62c61562c07065ea829866e2301 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Tue, 5 Aug 2025 06:35:58 +0000 Subject: [PATCH 1/3] Add comprehensive unit tests for Lambda Labs provider with HTTP mocking - Add httpmock dependency for HTTP request mocking - Create comprehensive test suite covering all Lambda Labs client methods: - Credential methods (GetReferenceID, GetAPIType, GetCloudProviderID, etc.) - Client methods (CreateInstance, GetInstance, ListInstances, TerminateInstance, RebootInstance) - Helper functions (convertLambdaLabsInstanceToV1Instance, convertLambdaLabsStatusToV1Status) - Capabilities and merge functions - Mock all Lambda Labs API endpoints with realistic OpenAPI response structures - Test both success and error scenarios for each method - Use proper OpenAPI models for mock responses to ensure realistic testing - Add constants for repeated test values to satisfy linting requirements - All 28 tests pass successfully with comprehensive coverage Co-Authored-By: Alec Fong --- go.mod | 1 + go.sum | 2 + internal/lambdalabs/v1/client_test.go | 482 ++++++++++++++++++++++++++ 3 files changed, 485 insertions(+) create mode 100644 internal/lambdalabs/v1/client_test.go diff --git a/go.mod b/go.mod index bd8a576a..fa938260 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( require ( github.com/cockroachdb/apd/v3 v3.2.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/jarcoal/httpmock v1.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 57bbe570..080cadf7 100644 --- a/go.sum +++ b/go.sum @@ -11,6 +11,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jarcoal/httpmock v1.4.0 h1:BvhqnH0JAYbNudL2GMJKgOHe2CtKlzJ/5rWKyp+hc2k= +github.com/jarcoal/httpmock v1.4.0/go.mod h1:ftW1xULwo+j0R0JJkJIIi7UKigZUXCLLanykgjwBXL0= github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/internal/lambdalabs/v1/client_test.go b/internal/lambdalabs/v1/client_test.go new file mode 100644 index 00000000..746a3937 --- /dev/null +++ b/internal/lambdalabs/v1/client_test.go @@ -0,0 +1,482 @@ +package v1 + +import ( + "context" + "fmt" + "testing" + + "github.com/jarcoal/httpmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + openapi "github.com/brevdev/cloud/internal/lambdalabs/gen/lambdalabs" + v1 "github.com/brevdev/compute/pkg/v1" +) + +const ( + testInstanceID = "test-instance-id" + nonexistentInstance = "nonexistent-instance" +) + +func setupMockClient() (*LambdaLabsClient, func()) { + httpmock.Activate() + client := NewLambdaLabsClient("test-ref-id", "test-api-key") + return client, httpmock.DeactivateAndReset +} + +func TestLambdaLabsCredential_GetReferenceID(t *testing.T) { + cred := &LambdaLabsCredential{ + RefID: "test-ref-id", + APIKey: "test-api-key", + } + + assert.Equal(t, "test-ref-id", cred.GetReferenceID()) +} + +func TestLambdaLabsCredential_GetAPIType(t *testing.T) { + cred := &LambdaLabsCredential{} + assert.Equal(t, v1.APITypeGlobal, cred.GetAPIType()) +} + +func TestLambdaLabsCredential_GetCloudProviderID(t *testing.T) { + cred := &LambdaLabsCredential{} + assert.Equal(t, v1.CloudProviderID("lambdalabs"), cred.GetCloudProviderID()) +} + +func TestLambdaLabsCredential_GetTenantID(t *testing.T) { + cred := &LambdaLabsCredential{APIKey: "test-key"} + tenantID, err := cred.GetTenantID() + assert.NoError(t, err) + assert.Contains(t, tenantID, "lambdalabs-") +} + +func TestLambdaLabsCredential_MakeClient(t *testing.T) { + cred := &LambdaLabsCredential{ + RefID: "test-ref-id", + APIKey: "test-api-key", + } + + client, err := cred.MakeClient(context.Background(), "test-tenant") + require.NoError(t, err) + lambdaClient, ok := client.(*LambdaLabsClient) + require.True(t, ok) + assert.Equal(t, "test-ref-id", lambdaClient.refID) + assert.Equal(t, "test-api-key", lambdaClient.apiKey) +} + +func TestLambdaLabsClient_GetAPIType(t *testing.T) { + client := &LambdaLabsClient{} + assert.Equal(t, v1.APITypeGlobal, client.GetAPIType()) +} + +func TestLambdaLabsClient_GetCloudProviderID(t *testing.T) { + client := &LambdaLabsClient{} + assert.Equal(t, v1.CloudProviderID("lambdalabs"), client.GetCloudProviderID()) +} + +func TestLambdaLabsClient_MakeClient(t *testing.T) { + client := &LambdaLabsClient{ + refID: "test-ref-id", + apiKey: "test-api-key", + } + + newClient, err := client.MakeClient(context.Background(), "test-tenant") + require.NoError(t, err) + lambdaClient, ok := newClient.(*LambdaLabsClient) + require.True(t, ok) + assert.Equal(t, client, lambdaClient) +} + +func TestLambdaLabsClient_GetReferenceID(t *testing.T) { + client := &LambdaLabsClient{refID: "test-ref-id"} + assert.Equal(t, "test-ref-id", client.GetReferenceID()) +} + +func TestLambdaLabsClient_makeAuthContext(t *testing.T) { + client := &LambdaLabsClient{apiKey: "test-api-key"} + ctx := client.makeAuthContext(context.Background()) + + auth := ctx.Value(openapi.ContextBasicAuth) + require.NotNil(t, auth) + + basicAuth, ok := auth.(openapi.BasicAuth) + require.True(t, ok) + assert.Equal(t, "test-api-key", basicAuth.UserName) + assert.Equal(t, "", basicAuth.Password) +} + +func TestLambdaLabsClient_CreateInstance_Success(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := testInstanceID + publicKey := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQ test@example.com" + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/ssh-keys", + httpmock.NewJsonResponderOrPanic(200, openapi.AddSSHKey200Response{ + Data: openapi.SshKey{ + Id: "ssh-key-id", + Name: "test-instance-id", + PublicKey: publicKey, + }, + })) + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/launch", + httpmock.NewJsonResponderOrPanic(200, openapi.LaunchInstance200Response{ + Data: openapi.LaunchInstance200ResponseData{ + InstanceIds: []string{instanceID}, + }, + })) + + mockInstance := createMockInstance(instanceID) + httpmock.RegisterResponder("GET", fmt.Sprintf("https://cloud.lambda.ai/api/v1/instances/%s", instanceID), + httpmock.NewJsonResponderOrPanic(200, openapi.GetInstance200Response{ + Data: mockInstance, + })) + + args := v1.CreateInstanceAttrs{ + InstanceType: "gpu_1x_a10", + Location: "us-west-1", + PublicKey: publicKey, + Name: "test-instance", + } + + instance, err := client.CreateInstance(context.Background(), args) + require.NoError(t, err) + assert.Equal(t, instanceID, string(instance.CloudID)) + assert.Contains(t, instance.Name, "test-instance") + assert.Equal(t, v1.LifecycleStatusRunning, instance.Status.LifecycleStatus) +} + +func TestLambdaLabsClient_CreateInstance_WithoutPublicKey(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := testInstanceID + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/launch", + httpmock.NewJsonResponderOrPanic(200, openapi.LaunchInstance200Response{ + Data: openapi.LaunchInstance200ResponseData{ + InstanceIds: []string{instanceID}, + }, + })) + + mockInstance := createMockInstance(instanceID) + httpmock.RegisterResponder("GET", fmt.Sprintf("https://cloud.lambda.ai/api/v1/instances/%s", instanceID), + httpmock.NewJsonResponderOrPanic(200, openapi.GetInstance200Response{ + Data: mockInstance, + })) + + args := v1.CreateInstanceAttrs{ + InstanceType: "gpu_1x_a10", + Location: "us-west-1", + Name: "test-instance", + } + + instance, err := client.CreateInstance(context.Background(), args) + require.NoError(t, err) + assert.Equal(t, instanceID, string(instance.CloudID)) +} + +func TestLambdaLabsClient_CreateInstance_SSHKeyError(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + publicKey := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQ test@example.com" + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/ssh-keys", + httpmock.NewStringResponder(400, `{"error": {"code": "INVALID_REQUEST", "message": "SSH key already exists"}}`)) + + args := v1.CreateInstanceAttrs{ + InstanceType: "gpu_1x_a10", + Location: "us-west-1", + PublicKey: publicKey, + Name: "test-instance", + } + + _, err := client.CreateInstance(context.Background(), args) + assert.Error(t, err) +} + +func TestLambdaLabsClient_CreateInstance_LaunchError(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/launch", + httpmock.NewStringResponder(400, `{"error": {"code": "INVALID_REQUEST", "message": "Instance type not available"}}`)) + + args := v1.CreateInstanceAttrs{ + InstanceType: "gpu_1x_a10", + Location: "us-west-1", + Name: "test-instance", + } + + _, err := client.CreateInstance(context.Background(), args) + assert.Error(t, err) +} + +func TestLambdaLabsClient_GetInstance_Success(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := testInstanceID + mockInstance := createMockInstance(instanceID) + + httpmock.RegisterResponder("GET", fmt.Sprintf("https://cloud.lambda.ai/api/v1/instances/%s", instanceID), + httpmock.NewJsonResponderOrPanic(200, openapi.GetInstance200Response{ + Data: mockInstance, + })) + + instance, err := client.GetInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) + require.NoError(t, err) + assert.Equal(t, instanceID, string(instance.CloudID)) + assert.Equal(t, "test-instance", instance.Name) + assert.Equal(t, v1.LifecycleStatusRunning, instance.Status.LifecycleStatus) +} + +func TestLambdaLabsClient_GetInstance_NotFound(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := nonexistentInstance + + httpmock.RegisterResponder("GET", fmt.Sprintf("https://cloud.lambda.ai/api/v1/instances/%s", instanceID), + httpmock.NewStringResponder(404, `{"error": {"code": "NOT_FOUND", "message": "Instance not found"}}`)) + + _, err := client.GetInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) + assert.Error(t, err) +} + +func TestLambdaLabsClient_ListInstances_Success(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + mockInstances := []openapi.Instance{ + createMockInstance("instance-1"), + createMockInstance("instance-2"), + } + + httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instances", + httpmock.NewJsonResponderOrPanic(200, openapi.ListInstances200Response{ + Data: mockInstances, + })) + + instances, err := client.ListInstances(context.Background(), v1.ListInstancesArgs{}) + require.NoError(t, err) + assert.Len(t, instances, 2) + assert.Equal(t, "instance-1", string(instances[0].CloudID)) + assert.Equal(t, "instance-2", string(instances[1].CloudID)) +} + +func TestLambdaLabsClient_ListInstances_Empty(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instances", + httpmock.NewJsonResponderOrPanic(200, openapi.ListInstances200Response{ + Data: []openapi.Instance{}, + })) + + instances, err := client.ListInstances(context.Background(), v1.ListInstancesArgs{}) + require.NoError(t, err) + assert.Len(t, instances, 0) +} + +func TestLambdaLabsClient_ListInstances_Error(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instances", + httpmock.NewStringResponder(500, `{"error": {"code": "INTERNAL_ERROR", "message": "Internal server error"}}`)) + + _, err := client.ListInstances(context.Background(), v1.ListInstancesArgs{}) + assert.Error(t, err) +} + +func TestLambdaLabsClient_TerminateInstance_Success(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := testInstanceID + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/terminate", + httpmock.NewJsonResponderOrPanic(200, openapi.TerminateInstance200Response{ + Data: openapi.TerminateInstance200ResponseData{ + TerminatedInstances: []openapi.Instance{ + createMockInstance(instanceID), + }, + }, + })) + + err := client.TerminateInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) + assert.NoError(t, err) +} + +func TestLambdaLabsClient_TerminateInstance_Error(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := nonexistentInstance + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/terminate", + httpmock.NewStringResponder(404, `{"error": {"code": "NOT_FOUND", "message": "Instance not found"}}`)) + + err := client.TerminateInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) + assert.Error(t, err) +} + +func TestLambdaLabsClient_RebootInstance_Success(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := testInstanceID + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/restart", + httpmock.NewJsonResponderOrPanic(200, openapi.RestartInstance200Response{ + Data: openapi.RestartInstance200ResponseData{ + RestartedInstances: []openapi.Instance{ + createMockInstance(instanceID), + }, + }, + })) + + err := client.RebootInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) + assert.NoError(t, err) +} + +func TestLambdaLabsClient_RebootInstance_Error(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := nonexistentInstance + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/restart", + httpmock.NewStringResponder(404, `{"error": {"code": "NOT_FOUND", "message": "Instance not found"}}`)) + + err := client.RebootInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) + assert.Error(t, err) +} + +func TestLambdaLabsClient_GetCapabilities(t *testing.T) { + client := &LambdaLabsClient{} + capabilities, err := client.GetCapabilities(context.Background()) + require.NoError(t, err) + + assert.Contains(t, capabilities, v1.CapabilityCreateInstance) + assert.Contains(t, capabilities, v1.CapabilityTerminateInstance) + assert.Contains(t, capabilities, v1.CapabilityRebootInstance) + assert.NotContains(t, capabilities, v1.CapabilityStopStartInstance) +} + +func TestLambdaLabsCredential_GetCapabilities(t *testing.T) { + cred := &LambdaLabsCredential{} + capabilities, err := cred.GetCapabilities(context.Background()) + require.NoError(t, err) + + assert.Contains(t, capabilities, v1.CapabilityCreateInstance) + assert.Contains(t, capabilities, v1.CapabilityTerminateInstance) + assert.Contains(t, capabilities, v1.CapabilityRebootInstance) + assert.NotContains(t, capabilities, v1.CapabilityStopStartInstance) +} + +func TestConvertLambdaLabsInstanceToV1Instance(t *testing.T) { + lambdaInstance := createMockInstance("test-instance-id") + + v1Instance := convertLambdaLabsInstanceToV1Instance(lambdaInstance) + + assert.Equal(t, "test-instance-id", string(v1Instance.CloudID)) + assert.Equal(t, "test-instance", v1Instance.Name) + assert.Equal(t, v1.LifecycleStatusRunning, v1Instance.Status.LifecycleStatus) + assert.Equal(t, "192.168.1.100", v1Instance.PublicIP) + assert.Equal(t, "10.0.1.100", v1Instance.PrivateIP) + assert.Equal(t, "us-west-1", v1Instance.Location) + assert.Equal(t, "gpu_1x_a10", v1Instance.InstanceType) +} + +func TestConvertLambdaLabsStatusToV1Status(t *testing.T) { + tests := []struct { + lambdaStatus string + expected v1.LifecycleStatus + }{ + {"active", v1.LifecycleStatusRunning}, + {"booting", v1.LifecycleStatusPending}, + {"unhealthy", v1.LifecycleStatusRunning}, + {"terminating", v1.LifecycleStatusTerminating}, + {"terminated", v1.LifecycleStatusTerminated}, + {"error", v1.LifecycleStatusFailed}, + } + + for _, test := range tests { + t.Run(test.lambdaStatus, func(t *testing.T) { + result := convertLambdaLabsStatusToV1Status(test.lambdaStatus) + assert.Equal(t, test.expected, result) + }) + } +} + +func TestMergeInstanceForUpdate(t *testing.T) { + client := &LambdaLabsClient{} + original := v1.Instance{ + CloudID: "test-id", + Name: "original-name", + Status: v1.Status{LifecycleStatus: v1.LifecycleStatusRunning}, + } + + update := v1.Instance{ + Name: "updated-name", + Status: v1.Status{LifecycleStatus: v1.LifecycleStatusTerminated}, + } + + merged := client.MergeInstanceForUpdate(original, update) + + assert.Equal(t, "updated-name", merged.Name) + assert.Equal(t, v1.LifecycleStatusTerminated, merged.Status.LifecycleStatus) +} + +func TestMergeInstanceTypeForUpdate(t *testing.T) { + client := &LambdaLabsClient{} + original := v1.InstanceType{ + ID: "test-id", + Type: "original-type", + } + + update := v1.InstanceType{ + Type: "updated-type", + } + + merged := client.MergeInstanceTypeForUpdate(original, update) + + assert.Equal(t, "updated-type", merged.Type) +} + +func createMockInstance(instanceID string) openapi.Instance { + name := "test-instance" + ip := "192.168.1.100" + privateIP := "10.0.1.100" + hostname := "test-instance.lambda.ai" + + return openapi.Instance{ + Id: instanceID, + Name: *openapi.NewNullableString(&name), + Ip: *openapi.NewNullableString(&ip), + PrivateIp: *openapi.NewNullableString(&privateIP), + Status: "active", + SshKeyNames: []string{"test-key"}, + FileSystemNames: []string{}, + Region: &openapi.Region{ + Name: "us-west-1", + Description: "US West 1", + }, + InstanceType: &openapi.InstanceType{ + Name: "gpu_1x_a10", + Description: "1x NVIDIA A10 GPU", + GpuDescription: "NVIDIA A10", + PriceCentsPerHour: 100, + Specs: openapi.InstanceTypeSpecs{ + MemoryGib: 32, + StorageGib: 512, + }, + }, + Hostname: *openapi.NewNullableString(&hostname), + } +} From b811b7659dd0bc59b3c6027071c5d00c8e7e297a Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Tue, 5 Aug 2025 06:47:24 +0000 Subject: [PATCH 2/3] Refactor Lambda Labs tests into smaller, logically organized files - Break up 483-line client_test.go into 6 focused test files - common_test.go: shared utilities and constants - credential_test.go: 5 credential tests - client_test.go: 5 basic client tests - instance_test.go: 13 instance operation tests - capabilities_test.go: 2 capabilities tests - helpers_test.go: 4 helper function tests - Maintain all 28 tests with proper imports and dependencies - Address GitHub feedback from @theFong on PR #7 Co-Authored-By: Alec Fong --- internal/lambdalabs/v1/capabilities_test.go | 33 ++ internal/lambdalabs/v1/client_test.go | 429 -------------------- internal/lambdalabs/v1/common_test.go | 49 +++ internal/lambdalabs/v1/credential_test.go | 52 +++ internal/lambdalabs/v1/helpers_test.go | 79 ++++ internal/lambdalabs/v1/instance_test.go | 266 ++++++++++++ 6 files changed, 479 insertions(+), 429 deletions(-) create mode 100644 internal/lambdalabs/v1/capabilities_test.go create mode 100644 internal/lambdalabs/v1/common_test.go create mode 100644 internal/lambdalabs/v1/credential_test.go create mode 100644 internal/lambdalabs/v1/helpers_test.go create mode 100644 internal/lambdalabs/v1/instance_test.go diff --git a/internal/lambdalabs/v1/capabilities_test.go b/internal/lambdalabs/v1/capabilities_test.go new file mode 100644 index 00000000..a2bebb63 --- /dev/null +++ b/internal/lambdalabs/v1/capabilities_test.go @@ -0,0 +1,33 @@ +package v1 + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + v1 "github.com/brevdev/compute/pkg/v1" +) + +func TestLambdaLabsClient_GetCapabilities(t *testing.T) { + client := &LambdaLabsClient{} + capabilities, err := client.GetCapabilities(context.Background()) + require.NoError(t, err) + + assert.Contains(t, capabilities, v1.CapabilityCreateInstance) + assert.Contains(t, capabilities, v1.CapabilityTerminateInstance) + assert.Contains(t, capabilities, v1.CapabilityRebootInstance) + assert.NotContains(t, capabilities, v1.CapabilityStopStartInstance) +} + +func TestLambdaLabsCredential_GetCapabilities(t *testing.T) { + cred := &LambdaLabsCredential{} + capabilities, err := cred.GetCapabilities(context.Background()) + require.NoError(t, err) + + assert.Contains(t, capabilities, v1.CapabilityCreateInstance) + assert.Contains(t, capabilities, v1.CapabilityTerminateInstance) + assert.Contains(t, capabilities, v1.CapabilityRebootInstance) + assert.NotContains(t, capabilities, v1.CapabilityStopStartInstance) +} diff --git a/internal/lambdalabs/v1/client_test.go b/internal/lambdalabs/v1/client_test.go index 746a3937..105232d3 100644 --- a/internal/lambdalabs/v1/client_test.go +++ b/internal/lambdalabs/v1/client_test.go @@ -2,10 +2,8 @@ package v1 import ( "context" - "fmt" "testing" - "github.com/jarcoal/httpmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -13,57 +11,6 @@ import ( v1 "github.com/brevdev/compute/pkg/v1" ) -const ( - testInstanceID = "test-instance-id" - nonexistentInstance = "nonexistent-instance" -) - -func setupMockClient() (*LambdaLabsClient, func()) { - httpmock.Activate() - client := NewLambdaLabsClient("test-ref-id", "test-api-key") - return client, httpmock.DeactivateAndReset -} - -func TestLambdaLabsCredential_GetReferenceID(t *testing.T) { - cred := &LambdaLabsCredential{ - RefID: "test-ref-id", - APIKey: "test-api-key", - } - - assert.Equal(t, "test-ref-id", cred.GetReferenceID()) -} - -func TestLambdaLabsCredential_GetAPIType(t *testing.T) { - cred := &LambdaLabsCredential{} - assert.Equal(t, v1.APITypeGlobal, cred.GetAPIType()) -} - -func TestLambdaLabsCredential_GetCloudProviderID(t *testing.T) { - cred := &LambdaLabsCredential{} - assert.Equal(t, v1.CloudProviderID("lambdalabs"), cred.GetCloudProviderID()) -} - -func TestLambdaLabsCredential_GetTenantID(t *testing.T) { - cred := &LambdaLabsCredential{APIKey: "test-key"} - tenantID, err := cred.GetTenantID() - assert.NoError(t, err) - assert.Contains(t, tenantID, "lambdalabs-") -} - -func TestLambdaLabsCredential_MakeClient(t *testing.T) { - cred := &LambdaLabsCredential{ - RefID: "test-ref-id", - APIKey: "test-api-key", - } - - client, err := cred.MakeClient(context.Background(), "test-tenant") - require.NoError(t, err) - lambdaClient, ok := client.(*LambdaLabsClient) - require.True(t, ok) - assert.Equal(t, "test-ref-id", lambdaClient.refID) - assert.Equal(t, "test-api-key", lambdaClient.apiKey) -} - func TestLambdaLabsClient_GetAPIType(t *testing.T) { client := &LambdaLabsClient{} assert.Equal(t, v1.APITypeGlobal, client.GetAPIType()) @@ -104,379 +51,3 @@ func TestLambdaLabsClient_makeAuthContext(t *testing.T) { assert.Equal(t, "test-api-key", basicAuth.UserName) assert.Equal(t, "", basicAuth.Password) } - -func TestLambdaLabsClient_CreateInstance_Success(t *testing.T) { - client, cleanup := setupMockClient() - defer cleanup() - - instanceID := testInstanceID - publicKey := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQ test@example.com" - - httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/ssh-keys", - httpmock.NewJsonResponderOrPanic(200, openapi.AddSSHKey200Response{ - Data: openapi.SshKey{ - Id: "ssh-key-id", - Name: "test-instance-id", - PublicKey: publicKey, - }, - })) - - httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/launch", - httpmock.NewJsonResponderOrPanic(200, openapi.LaunchInstance200Response{ - Data: openapi.LaunchInstance200ResponseData{ - InstanceIds: []string{instanceID}, - }, - })) - - mockInstance := createMockInstance(instanceID) - httpmock.RegisterResponder("GET", fmt.Sprintf("https://cloud.lambda.ai/api/v1/instances/%s", instanceID), - httpmock.NewJsonResponderOrPanic(200, openapi.GetInstance200Response{ - Data: mockInstance, - })) - - args := v1.CreateInstanceAttrs{ - InstanceType: "gpu_1x_a10", - Location: "us-west-1", - PublicKey: publicKey, - Name: "test-instance", - } - - instance, err := client.CreateInstance(context.Background(), args) - require.NoError(t, err) - assert.Equal(t, instanceID, string(instance.CloudID)) - assert.Contains(t, instance.Name, "test-instance") - assert.Equal(t, v1.LifecycleStatusRunning, instance.Status.LifecycleStatus) -} - -func TestLambdaLabsClient_CreateInstance_WithoutPublicKey(t *testing.T) { - client, cleanup := setupMockClient() - defer cleanup() - - instanceID := testInstanceID - - httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/launch", - httpmock.NewJsonResponderOrPanic(200, openapi.LaunchInstance200Response{ - Data: openapi.LaunchInstance200ResponseData{ - InstanceIds: []string{instanceID}, - }, - })) - - mockInstance := createMockInstance(instanceID) - httpmock.RegisterResponder("GET", fmt.Sprintf("https://cloud.lambda.ai/api/v1/instances/%s", instanceID), - httpmock.NewJsonResponderOrPanic(200, openapi.GetInstance200Response{ - Data: mockInstance, - })) - - args := v1.CreateInstanceAttrs{ - InstanceType: "gpu_1x_a10", - Location: "us-west-1", - Name: "test-instance", - } - - instance, err := client.CreateInstance(context.Background(), args) - require.NoError(t, err) - assert.Equal(t, instanceID, string(instance.CloudID)) -} - -func TestLambdaLabsClient_CreateInstance_SSHKeyError(t *testing.T) { - client, cleanup := setupMockClient() - defer cleanup() - - publicKey := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQ test@example.com" - - httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/ssh-keys", - httpmock.NewStringResponder(400, `{"error": {"code": "INVALID_REQUEST", "message": "SSH key already exists"}}`)) - - args := v1.CreateInstanceAttrs{ - InstanceType: "gpu_1x_a10", - Location: "us-west-1", - PublicKey: publicKey, - Name: "test-instance", - } - - _, err := client.CreateInstance(context.Background(), args) - assert.Error(t, err) -} - -func TestLambdaLabsClient_CreateInstance_LaunchError(t *testing.T) { - client, cleanup := setupMockClient() - defer cleanup() - - httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/launch", - httpmock.NewStringResponder(400, `{"error": {"code": "INVALID_REQUEST", "message": "Instance type not available"}}`)) - - args := v1.CreateInstanceAttrs{ - InstanceType: "gpu_1x_a10", - Location: "us-west-1", - Name: "test-instance", - } - - _, err := client.CreateInstance(context.Background(), args) - assert.Error(t, err) -} - -func TestLambdaLabsClient_GetInstance_Success(t *testing.T) { - client, cleanup := setupMockClient() - defer cleanup() - - instanceID := testInstanceID - mockInstance := createMockInstance(instanceID) - - httpmock.RegisterResponder("GET", fmt.Sprintf("https://cloud.lambda.ai/api/v1/instances/%s", instanceID), - httpmock.NewJsonResponderOrPanic(200, openapi.GetInstance200Response{ - Data: mockInstance, - })) - - instance, err := client.GetInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) - require.NoError(t, err) - assert.Equal(t, instanceID, string(instance.CloudID)) - assert.Equal(t, "test-instance", instance.Name) - assert.Equal(t, v1.LifecycleStatusRunning, instance.Status.LifecycleStatus) -} - -func TestLambdaLabsClient_GetInstance_NotFound(t *testing.T) { - client, cleanup := setupMockClient() - defer cleanup() - - instanceID := nonexistentInstance - - httpmock.RegisterResponder("GET", fmt.Sprintf("https://cloud.lambda.ai/api/v1/instances/%s", instanceID), - httpmock.NewStringResponder(404, `{"error": {"code": "NOT_FOUND", "message": "Instance not found"}}`)) - - _, err := client.GetInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) - assert.Error(t, err) -} - -func TestLambdaLabsClient_ListInstances_Success(t *testing.T) { - client, cleanup := setupMockClient() - defer cleanup() - - mockInstances := []openapi.Instance{ - createMockInstance("instance-1"), - createMockInstance("instance-2"), - } - - httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instances", - httpmock.NewJsonResponderOrPanic(200, openapi.ListInstances200Response{ - Data: mockInstances, - })) - - instances, err := client.ListInstances(context.Background(), v1.ListInstancesArgs{}) - require.NoError(t, err) - assert.Len(t, instances, 2) - assert.Equal(t, "instance-1", string(instances[0].CloudID)) - assert.Equal(t, "instance-2", string(instances[1].CloudID)) -} - -func TestLambdaLabsClient_ListInstances_Empty(t *testing.T) { - client, cleanup := setupMockClient() - defer cleanup() - - httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instances", - httpmock.NewJsonResponderOrPanic(200, openapi.ListInstances200Response{ - Data: []openapi.Instance{}, - })) - - instances, err := client.ListInstances(context.Background(), v1.ListInstancesArgs{}) - require.NoError(t, err) - assert.Len(t, instances, 0) -} - -func TestLambdaLabsClient_ListInstances_Error(t *testing.T) { - client, cleanup := setupMockClient() - defer cleanup() - - httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instances", - httpmock.NewStringResponder(500, `{"error": {"code": "INTERNAL_ERROR", "message": "Internal server error"}}`)) - - _, err := client.ListInstances(context.Background(), v1.ListInstancesArgs{}) - assert.Error(t, err) -} - -func TestLambdaLabsClient_TerminateInstance_Success(t *testing.T) { - client, cleanup := setupMockClient() - defer cleanup() - - instanceID := testInstanceID - - httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/terminate", - httpmock.NewJsonResponderOrPanic(200, openapi.TerminateInstance200Response{ - Data: openapi.TerminateInstance200ResponseData{ - TerminatedInstances: []openapi.Instance{ - createMockInstance(instanceID), - }, - }, - })) - - err := client.TerminateInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) - assert.NoError(t, err) -} - -func TestLambdaLabsClient_TerminateInstance_Error(t *testing.T) { - client, cleanup := setupMockClient() - defer cleanup() - - instanceID := nonexistentInstance - - httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/terminate", - httpmock.NewStringResponder(404, `{"error": {"code": "NOT_FOUND", "message": "Instance not found"}}`)) - - err := client.TerminateInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) - assert.Error(t, err) -} - -func TestLambdaLabsClient_RebootInstance_Success(t *testing.T) { - client, cleanup := setupMockClient() - defer cleanup() - - instanceID := testInstanceID - - httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/restart", - httpmock.NewJsonResponderOrPanic(200, openapi.RestartInstance200Response{ - Data: openapi.RestartInstance200ResponseData{ - RestartedInstances: []openapi.Instance{ - createMockInstance(instanceID), - }, - }, - })) - - err := client.RebootInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) - assert.NoError(t, err) -} - -func TestLambdaLabsClient_RebootInstance_Error(t *testing.T) { - client, cleanup := setupMockClient() - defer cleanup() - - instanceID := nonexistentInstance - - httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/restart", - httpmock.NewStringResponder(404, `{"error": {"code": "NOT_FOUND", "message": "Instance not found"}}`)) - - err := client.RebootInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) - assert.Error(t, err) -} - -func TestLambdaLabsClient_GetCapabilities(t *testing.T) { - client := &LambdaLabsClient{} - capabilities, err := client.GetCapabilities(context.Background()) - require.NoError(t, err) - - assert.Contains(t, capabilities, v1.CapabilityCreateInstance) - assert.Contains(t, capabilities, v1.CapabilityTerminateInstance) - assert.Contains(t, capabilities, v1.CapabilityRebootInstance) - assert.NotContains(t, capabilities, v1.CapabilityStopStartInstance) -} - -func TestLambdaLabsCredential_GetCapabilities(t *testing.T) { - cred := &LambdaLabsCredential{} - capabilities, err := cred.GetCapabilities(context.Background()) - require.NoError(t, err) - - assert.Contains(t, capabilities, v1.CapabilityCreateInstance) - assert.Contains(t, capabilities, v1.CapabilityTerminateInstance) - assert.Contains(t, capabilities, v1.CapabilityRebootInstance) - assert.NotContains(t, capabilities, v1.CapabilityStopStartInstance) -} - -func TestConvertLambdaLabsInstanceToV1Instance(t *testing.T) { - lambdaInstance := createMockInstance("test-instance-id") - - v1Instance := convertLambdaLabsInstanceToV1Instance(lambdaInstance) - - assert.Equal(t, "test-instance-id", string(v1Instance.CloudID)) - assert.Equal(t, "test-instance", v1Instance.Name) - assert.Equal(t, v1.LifecycleStatusRunning, v1Instance.Status.LifecycleStatus) - assert.Equal(t, "192.168.1.100", v1Instance.PublicIP) - assert.Equal(t, "10.0.1.100", v1Instance.PrivateIP) - assert.Equal(t, "us-west-1", v1Instance.Location) - assert.Equal(t, "gpu_1x_a10", v1Instance.InstanceType) -} - -func TestConvertLambdaLabsStatusToV1Status(t *testing.T) { - tests := []struct { - lambdaStatus string - expected v1.LifecycleStatus - }{ - {"active", v1.LifecycleStatusRunning}, - {"booting", v1.LifecycleStatusPending}, - {"unhealthy", v1.LifecycleStatusRunning}, - {"terminating", v1.LifecycleStatusTerminating}, - {"terminated", v1.LifecycleStatusTerminated}, - {"error", v1.LifecycleStatusFailed}, - } - - for _, test := range tests { - t.Run(test.lambdaStatus, func(t *testing.T) { - result := convertLambdaLabsStatusToV1Status(test.lambdaStatus) - assert.Equal(t, test.expected, result) - }) - } -} - -func TestMergeInstanceForUpdate(t *testing.T) { - client := &LambdaLabsClient{} - original := v1.Instance{ - CloudID: "test-id", - Name: "original-name", - Status: v1.Status{LifecycleStatus: v1.LifecycleStatusRunning}, - } - - update := v1.Instance{ - Name: "updated-name", - Status: v1.Status{LifecycleStatus: v1.LifecycleStatusTerminated}, - } - - merged := client.MergeInstanceForUpdate(original, update) - - assert.Equal(t, "updated-name", merged.Name) - assert.Equal(t, v1.LifecycleStatusTerminated, merged.Status.LifecycleStatus) -} - -func TestMergeInstanceTypeForUpdate(t *testing.T) { - client := &LambdaLabsClient{} - original := v1.InstanceType{ - ID: "test-id", - Type: "original-type", - } - - update := v1.InstanceType{ - Type: "updated-type", - } - - merged := client.MergeInstanceTypeForUpdate(original, update) - - assert.Equal(t, "updated-type", merged.Type) -} - -func createMockInstance(instanceID string) openapi.Instance { - name := "test-instance" - ip := "192.168.1.100" - privateIP := "10.0.1.100" - hostname := "test-instance.lambda.ai" - - return openapi.Instance{ - Id: instanceID, - Name: *openapi.NewNullableString(&name), - Ip: *openapi.NewNullableString(&ip), - PrivateIp: *openapi.NewNullableString(&privateIP), - Status: "active", - SshKeyNames: []string{"test-key"}, - FileSystemNames: []string{}, - Region: &openapi.Region{ - Name: "us-west-1", - Description: "US West 1", - }, - InstanceType: &openapi.InstanceType{ - Name: "gpu_1x_a10", - Description: "1x NVIDIA A10 GPU", - GpuDescription: "NVIDIA A10", - PriceCentsPerHour: 100, - Specs: openapi.InstanceTypeSpecs{ - MemoryGib: 32, - StorageGib: 512, - }, - }, - Hostname: *openapi.NewNullableString(&hostname), - } -} diff --git a/internal/lambdalabs/v1/common_test.go b/internal/lambdalabs/v1/common_test.go new file mode 100644 index 00000000..8c8be2fc --- /dev/null +++ b/internal/lambdalabs/v1/common_test.go @@ -0,0 +1,49 @@ +package v1 + +import ( + openapi "github.com/brevdev/cloud/internal/lambdalabs/gen/lambdalabs" + "github.com/jarcoal/httpmock" +) + +const ( + testInstanceID = "test-instance-id" + nonexistentInstance = "nonexistent-instance" +) + +func setupMockClient() (*LambdaLabsClient, func()) { + httpmock.Activate() + client := NewLambdaLabsClient("test-ref-id", "test-api-key") + return client, httpmock.DeactivateAndReset +} + +func createMockInstance(instanceID string) openapi.Instance { + name := "test-instance" + ip := "192.168.1.100" + privateIP := "10.0.1.100" + hostname := "test-instance.lambda.ai" + + return openapi.Instance{ + Id: instanceID, + Name: *openapi.NewNullableString(&name), + Ip: *openapi.NewNullableString(&ip), + PrivateIp: *openapi.NewNullableString(&privateIP), + Status: "active", + SshKeyNames: []string{"test-key"}, + FileSystemNames: []string{}, + Region: &openapi.Region{ + Name: "us-west-1", + Description: "US West 1", + }, + InstanceType: &openapi.InstanceType{ + Name: "gpu_1x_a10", + Description: "1x NVIDIA A10 GPU", + GpuDescription: "NVIDIA A10", + PriceCentsPerHour: 100, + Specs: openapi.InstanceTypeSpecs{ + MemoryGib: 32, + StorageGib: 512, + }, + }, + Hostname: *openapi.NewNullableString(&hostname), + } +} diff --git a/internal/lambdalabs/v1/credential_test.go b/internal/lambdalabs/v1/credential_test.go new file mode 100644 index 00000000..abdb5b2d --- /dev/null +++ b/internal/lambdalabs/v1/credential_test.go @@ -0,0 +1,52 @@ +package v1 + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + v1 "github.com/brevdev/compute/pkg/v1" +) + +func TestLambdaLabsCredential_GetReferenceID(t *testing.T) { + cred := &LambdaLabsCredential{ + RefID: "test-ref-id", + APIKey: "test-api-key", + } + + assert.Equal(t, "test-ref-id", cred.GetReferenceID()) +} + +func TestLambdaLabsCredential_GetAPIType(t *testing.T) { + cred := &LambdaLabsCredential{} + assert.Equal(t, v1.APITypeGlobal, cred.GetAPIType()) +} + +func TestLambdaLabsCredential_GetCloudProviderID(t *testing.T) { + cred := &LambdaLabsCredential{} + assert.Equal(t, v1.CloudProviderID("lambdalabs"), cred.GetCloudProviderID()) +} + +func TestLambdaLabsCredential_GetTenantID(t *testing.T) { + cred := &LambdaLabsCredential{APIKey: "test-key"} + tenantID, err := cred.GetTenantID() + assert.NoError(t, err) + assert.Contains(t, tenantID, "lambdalabs-") +} + +func TestLambdaLabsCredential_MakeClient(t *testing.T) { + cred := &LambdaLabsCredential{ + RefID: "test-ref-id", + APIKey: "test-api-key", + } + + client, err := cred.MakeClient(context.Background(), "test-tenant") + require.NoError(t, err) + + lambdaClient, ok := client.(*LambdaLabsClient) + require.True(t, ok) + assert.Equal(t, "test-ref-id", lambdaClient.refID) + assert.Equal(t, "test-api-key", lambdaClient.apiKey) +} diff --git a/internal/lambdalabs/v1/helpers_test.go b/internal/lambdalabs/v1/helpers_test.go new file mode 100644 index 00000000..d6601225 --- /dev/null +++ b/internal/lambdalabs/v1/helpers_test.go @@ -0,0 +1,79 @@ +package v1 + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + v1 "github.com/brevdev/compute/pkg/v1" +) + +func TestConvertLambdaLabsInstanceToV1Instance(t *testing.T) { + lambdaInstance := createMockInstance("test-instance-id") + + v1Instance := convertLambdaLabsInstanceToV1Instance(lambdaInstance) + + assert.Equal(t, "test-instance-id", string(v1Instance.CloudID)) + assert.Equal(t, "test-instance", v1Instance.Name) + assert.Equal(t, v1.LifecycleStatusRunning, v1Instance.Status.LifecycleStatus) + assert.Equal(t, "192.168.1.100", v1Instance.PublicIP) + assert.Equal(t, "10.0.1.100", v1Instance.PrivateIP) + assert.Equal(t, "us-west-1", v1Instance.Location) + assert.Equal(t, "gpu_1x_a10", v1Instance.InstanceType) +} + +func TestConvertLambdaLabsStatusToV1Status(t *testing.T) { + tests := []struct { + lambdaStatus string + expected v1.LifecycleStatus + }{ + {"active", v1.LifecycleStatusRunning}, + {"booting", v1.LifecycleStatusPending}, + {"unhealthy", v1.LifecycleStatusRunning}, + {"terminating", v1.LifecycleStatusTerminating}, + {"terminated", v1.LifecycleStatusTerminated}, + {"error", v1.LifecycleStatusFailed}, + } + + for _, test := range tests { + t.Run(test.lambdaStatus, func(t *testing.T) { + result := convertLambdaLabsStatusToV1Status(test.lambdaStatus) + assert.Equal(t, test.expected, result) + }) + } +} + +func TestMergeInstanceForUpdate(t *testing.T) { + client := &LambdaLabsClient{} + original := v1.Instance{ + CloudID: "test-id", + Name: "original-name", + Status: v1.Status{LifecycleStatus: v1.LifecycleStatusRunning}, + } + + update := v1.Instance{ + Name: "updated-name", + Status: v1.Status{LifecycleStatus: v1.LifecycleStatusTerminated}, + } + + merged := client.MergeInstanceForUpdate(original, update) + + assert.Equal(t, "updated-name", merged.Name) + assert.Equal(t, v1.LifecycleStatusTerminated, merged.Status.LifecycleStatus) +} + +func TestMergeInstanceTypeForUpdate(t *testing.T) { + client := &LambdaLabsClient{} + original := v1.InstanceType{ + ID: "test-id", + Type: "original-type", + } + + update := v1.InstanceType{ + Type: "updated-type", + } + + merged := client.MergeInstanceTypeForUpdate(original, update) + + assert.Equal(t, "updated-type", merged.Type) +} diff --git a/internal/lambdalabs/v1/instance_test.go b/internal/lambdalabs/v1/instance_test.go new file mode 100644 index 00000000..aef19e76 --- /dev/null +++ b/internal/lambdalabs/v1/instance_test.go @@ -0,0 +1,266 @@ +package v1 + +import ( + "context" + "fmt" + "testing" + + "github.com/jarcoal/httpmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + openapi "github.com/brevdev/cloud/internal/lambdalabs/gen/lambdalabs" + v1 "github.com/brevdev/compute/pkg/v1" +) + +func TestLambdaLabsClient_CreateInstance_Success(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := testInstanceID + publicKey := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQ test@example.com" + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/ssh-keys", + httpmock.NewJsonResponderOrPanic(200, openapi.AddSSHKey200Response{ + Data: openapi.SshKey{ + Id: "ssh-key-id", + Name: "test-instance-id", + PublicKey: publicKey, + }, + })) + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/launch", + httpmock.NewJsonResponderOrPanic(200, openapi.LaunchInstance200Response{ + Data: openapi.LaunchInstance200ResponseData{ + InstanceIds: []string{instanceID}, + }, + })) + + mockInstance := createMockInstance(instanceID) + httpmock.RegisterResponder("GET", fmt.Sprintf("https://cloud.lambda.ai/api/v1/instances/%s", instanceID), + httpmock.NewJsonResponderOrPanic(200, openapi.GetInstance200Response{ + Data: mockInstance, + })) + + args := v1.CreateInstanceAttrs{ + InstanceType: "gpu_1x_a10", + Location: "us-west-1", + PublicKey: publicKey, + Name: "test-instance", + } + + instance, err := client.CreateInstance(context.Background(), args) + require.NoError(t, err) + assert.Equal(t, instanceID, string(instance.CloudID)) + assert.Contains(t, instance.Name, "test-instance") + assert.Equal(t, v1.LifecycleStatusRunning, instance.Status.LifecycleStatus) +} + +func TestLambdaLabsClient_CreateInstance_WithoutPublicKey(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := testInstanceID + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/launch", + httpmock.NewJsonResponderOrPanic(200, openapi.LaunchInstance200Response{ + Data: openapi.LaunchInstance200ResponseData{ + InstanceIds: []string{instanceID}, + }, + })) + + mockInstance := createMockInstance(instanceID) + httpmock.RegisterResponder("GET", fmt.Sprintf("https://cloud.lambda.ai/api/v1/instances/%s", instanceID), + httpmock.NewJsonResponderOrPanic(200, openapi.GetInstance200Response{ + Data: mockInstance, + })) + + args := v1.CreateInstanceAttrs{ + InstanceType: "gpu_1x_a10", + Location: "us-west-1", + Name: "test-instance", + } + + instance, err := client.CreateInstance(context.Background(), args) + require.NoError(t, err) + assert.Equal(t, instanceID, string(instance.CloudID)) +} + +func TestLambdaLabsClient_CreateInstance_SSHKeyError(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + publicKey := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQ test@example.com" + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/ssh-keys", + httpmock.NewStringResponder(400, `{"error": {"code": "INVALID_REQUEST", "message": "SSH key already exists"}}`)) + + args := v1.CreateInstanceAttrs{ + InstanceType: "gpu_1x_a10", + Location: "us-west-1", + PublicKey: publicKey, + Name: "test-instance", + } + + _, err := client.CreateInstance(context.Background(), args) + assert.Error(t, err) +} + +func TestLambdaLabsClient_CreateInstance_LaunchError(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/launch", + httpmock.NewStringResponder(400, `{"error": {"code": "INVALID_REQUEST", "message": "Instance type not available"}}`)) + + args := v1.CreateInstanceAttrs{ + InstanceType: "gpu_1x_a10", + Location: "us-west-1", + Name: "test-instance", + } + + _, err := client.CreateInstance(context.Background(), args) + assert.Error(t, err) +} + +func TestLambdaLabsClient_GetInstance_Success(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := testInstanceID + mockInstance := createMockInstance(instanceID) + + httpmock.RegisterResponder("GET", fmt.Sprintf("https://cloud.lambda.ai/api/v1/instances/%s", instanceID), + httpmock.NewJsonResponderOrPanic(200, openapi.GetInstance200Response{ + Data: mockInstance, + })) + + instance, err := client.GetInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) + require.NoError(t, err) + assert.Equal(t, instanceID, string(instance.CloudID)) + assert.Equal(t, "test-instance", instance.Name) + assert.Equal(t, v1.LifecycleStatusRunning, instance.Status.LifecycleStatus) +} + +func TestLambdaLabsClient_GetInstance_NotFound(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := nonexistentInstance + + httpmock.RegisterResponder("GET", fmt.Sprintf("https://cloud.lambda.ai/api/v1/instances/%s", instanceID), + httpmock.NewStringResponder(404, `{"error": {"code": "NOT_FOUND", "message": "Instance not found"}}`)) + + _, err := client.GetInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) + assert.Error(t, err) +} + +func TestLambdaLabsClient_ListInstances_Success(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + mockInstances := []openapi.Instance{ + createMockInstance("instance-1"), + createMockInstance("instance-2"), + } + + httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instances", + httpmock.NewJsonResponderOrPanic(200, openapi.ListInstances200Response{ + Data: mockInstances, + })) + + instances, err := client.ListInstances(context.Background(), v1.ListInstancesArgs{}) + require.NoError(t, err) + assert.Len(t, instances, 2) + assert.Equal(t, "instance-1", string(instances[0].CloudID)) + assert.Equal(t, "instance-2", string(instances[1].CloudID)) +} + +func TestLambdaLabsClient_ListInstances_Empty(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instances", + httpmock.NewJsonResponderOrPanic(200, openapi.ListInstances200Response{ + Data: []openapi.Instance{}, + })) + + instances, err := client.ListInstances(context.Background(), v1.ListInstancesArgs{}) + require.NoError(t, err) + assert.Len(t, instances, 0) +} + +func TestLambdaLabsClient_ListInstances_Error(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instances", + httpmock.NewStringResponder(500, `{"error": {"code": "INTERNAL_ERROR", "message": "Internal server error"}}`)) + + _, err := client.ListInstances(context.Background(), v1.ListInstancesArgs{}) + assert.Error(t, err) +} + +func TestLambdaLabsClient_TerminateInstance_Success(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := testInstanceID + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/terminate", + httpmock.NewJsonResponderOrPanic(200, openapi.TerminateInstance200Response{ + Data: openapi.TerminateInstance200ResponseData{ + TerminatedInstances: []openapi.Instance{ + createMockInstance(instanceID), + }, + }, + })) + + err := client.TerminateInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) + assert.NoError(t, err) +} + +func TestLambdaLabsClient_TerminateInstance_Error(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := nonexistentInstance + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/terminate", + httpmock.NewStringResponder(404, `{"error": {"code": "NOT_FOUND", "message": "Instance not found"}}`)) + + err := client.TerminateInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) + assert.Error(t, err) +} + +func TestLambdaLabsClient_RebootInstance_Success(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := testInstanceID + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/restart", + httpmock.NewJsonResponderOrPanic(200, openapi.RestartInstance200Response{ + Data: openapi.RestartInstance200ResponseData{ + RestartedInstances: []openapi.Instance{ + createMockInstance(instanceID), + }, + }, + })) + + err := client.RebootInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) + assert.NoError(t, err) +} + +func TestLambdaLabsClient_RebootInstance_Error(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + instanceID := nonexistentInstance + + httpmock.RegisterResponder("POST", "https://cloud.lambda.ai/api/v1/instance-operations/restart", + httpmock.NewStringResponder(404, `{"error": {"code": "NOT_FOUND", "message": "Instance not found"}}`)) + + err := client.RebootInstance(context.Background(), v1.CloudProviderInstanceID(instanceID)) + assert.Error(t, err) +} From 7dc3c487c2a3f8225ca9cd745a49ff28d96b1e82 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Tue, 5 Aug 2025 06:58:55 +0000 Subject: [PATCH 3/3] Add comprehensive unit tests for GetInstanceTypes method - Add tests for GetInstanceTypes with success, filtering, and error cases - Add tests for GetInstanceTypePollTime method - Add tests for helper functions convertLambdaLabsInstanceTypeToV1InstanceType and parseGPUFromDescription - Mock /api/v1/instance-types endpoint with realistic response data - Test location and instance type filtering functionality - Address GitHub feedback from @theFong on PR #7 Co-Authored-By: Alec Fong --- internal/lambdalabs/v1/instancetype_test.go | 268 ++++++++++++++++++++ 1 file changed, 268 insertions(+) create mode 100644 internal/lambdalabs/v1/instancetype_test.go diff --git a/internal/lambdalabs/v1/instancetype_test.go b/internal/lambdalabs/v1/instancetype_test.go new file mode 100644 index 00000000..13fc98fe --- /dev/null +++ b/internal/lambdalabs/v1/instancetype_test.go @@ -0,0 +1,268 @@ +package v1 + +import ( + "context" + "testing" + "time" + + "github.com/alecthomas/units" + "github.com/jarcoal/httpmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + openapi "github.com/brevdev/cloud/internal/lambdalabs/gen/lambdalabs" + v1 "github.com/brevdev/compute/pkg/v1" +) + +func TestLambdaLabsClient_GetInstanceTypes_Success(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + mockResponse := createMockInstanceTypeResponse() + httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instance-types", + httpmock.NewJsonResponderOrPanic(200, mockResponse)) + + instanceTypes, err := client.GetInstanceTypes(context.Background(), v1.GetInstanceTypeArgs{}) + require.NoError(t, err) + assert.Len(t, instanceTypes, 3) + + a10Type := findInstanceTypeByName(instanceTypes, "gpu_1x_a10") + require.NotNil(t, a10Type) + assert.Equal(t, "gpu_1x_a10", a10Type.Type) + assert.True(t, a10Type.IsAvailable) + assert.Len(t, a10Type.SupportedGPUs, 1) + assert.Equal(t, int32(1), a10Type.SupportedGPUs[0].Count) + assert.Equal(t, "NVIDIA", a10Type.SupportedGPUs[0].Manufacturer) + assert.Equal(t, "NVIDIA A10", a10Type.SupportedGPUs[0].Name) +} + +func TestLambdaLabsClient_GetInstanceTypes_FilterByLocation(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + mockResponse := createMockInstanceTypeResponse() + httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instance-types", + httpmock.NewJsonResponderOrPanic(200, mockResponse)) + + instanceTypes, err := client.GetInstanceTypes(context.Background(), v1.GetInstanceTypeArgs{ + Locations: v1.LocationsFilter{"us-west-1"}, + }) + require.NoError(t, err) + assert.Len(t, instanceTypes, 1) + + for _, instanceType := range instanceTypes { + assert.Equal(t, "us-west-1", instanceType.Location) + } +} + +func TestLambdaLabsClient_GetInstanceTypes_FilterByInstanceType(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + mockResponse := createMockInstanceTypeResponse() + httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instance-types", + httpmock.NewJsonResponderOrPanic(200, mockResponse)) + + instanceTypes, err := client.GetInstanceTypes(context.Background(), v1.GetInstanceTypeArgs{ + InstanceTypes: []string{"gpu_1x_a10"}, + }) + require.NoError(t, err) + assert.Len(t, instanceTypes, 2) + + for _, instanceType := range instanceTypes { + assert.Equal(t, "gpu_1x_a10", instanceType.Type) + } +} + +func TestLambdaLabsClient_GetInstanceTypes_FilterBoth(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + mockResponse := createMockInstanceTypeResponse() + httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instance-types", + httpmock.NewJsonResponderOrPanic(200, mockResponse)) + + instanceTypes, err := client.GetInstanceTypes(context.Background(), v1.GetInstanceTypeArgs{ + Locations: v1.LocationsFilter{"us-east-1"}, + InstanceTypes: []string{"gpu_8x_h100"}, + }) + require.NoError(t, err) + assert.Len(t, instanceTypes, 1) + assert.Equal(t, "gpu_8x_h100", instanceTypes[0].Type) + assert.Equal(t, "us-east-1", instanceTypes[0].Location) +} + +func TestLambdaLabsClient_GetInstanceTypes_Empty(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + emptyResponse := openapi.InstanceTypes200Response{ + Data: map[string]openapi.InstanceTypes200ResponseDataValue{}, + } + httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instance-types", + httpmock.NewJsonResponderOrPanic(200, emptyResponse)) + + instanceTypes, err := client.GetInstanceTypes(context.Background(), v1.GetInstanceTypeArgs{}) + require.NoError(t, err) + assert.Len(t, instanceTypes, 0) +} + +func TestLambdaLabsClient_GetInstanceTypes_Error(t *testing.T) { + client, cleanup := setupMockClient() + defer cleanup() + + httpmock.RegisterResponder("GET", "https://cloud.lambda.ai/api/v1/instance-types", + httpmock.NewStringResponder(500, `{"error": {"code": "INTERNAL_ERROR", "message": "Internal server error"}}`)) + + _, err := client.GetInstanceTypes(context.Background(), v1.GetInstanceTypeArgs{}) + assert.Error(t, err) +} + +func TestLambdaLabsClient_GetInstanceTypePollTime(t *testing.T) { + client := &LambdaLabsClient{} + pollTime := client.GetInstanceTypePollTime() + assert.Equal(t, 5*time.Minute, pollTime) +} + +func TestConvertLambdaLabsInstanceTypeToV1InstanceType(t *testing.T) { + llInstanceType := createMockLambdaLabsInstanceType("gpu_1x_a10", "1x NVIDIA A10 (24 GB)", "NVIDIA A10", 100) + + v1InstanceType, err := convertLambdaLabsInstanceTypeToV1InstanceType("us-west-1", llInstanceType, true) + require.NoError(t, err) + + assert.Equal(t, "gpu_1x_a10", v1InstanceType.Type) + assert.Equal(t, "us-west-1", v1InstanceType.Location) + assert.True(t, v1InstanceType.IsAvailable) + assert.Len(t, v1InstanceType.SupportedGPUs, 1) + + gpu := v1InstanceType.SupportedGPUs[0] + assert.Equal(t, int32(1), gpu.Count) + assert.Equal(t, "NVIDIA", gpu.Manufacturer) + assert.Equal(t, "NVIDIA A10", gpu.Name) + assert.Equal(t, "NVIDIA A10", gpu.Type) + assert.Equal(t, units.Base2Bytes(24*1024*1024*1024), gpu.Memory) + + assert.NotNil(t, v1InstanceType.BasePrice) + assert.Equal(t, "USD", v1InstanceType.BasePrice.CurrencyCode()) + assert.Equal(t, "1.00", v1InstanceType.BasePrice.Number()) +} + +func TestConvertLambdaLabsInstanceTypeToV1InstanceType_CPUOnly(t *testing.T) { + llInstanceType := createMockLambdaLabsInstanceType("cpu_4x", "4x CPU cores", "", 50) + + v1InstanceType, err := convertLambdaLabsInstanceTypeToV1InstanceType("us-west-1", llInstanceType, true) + require.NoError(t, err) + + assert.Equal(t, "cpu_4x", v1InstanceType.Type) + assert.Equal(t, "us-west-1", v1InstanceType.Location) + assert.True(t, v1InstanceType.IsAvailable) + assert.Len(t, v1InstanceType.SupportedGPUs, 0) +} + +func TestParseGPUFromDescription(t *testing.T) { + tests := []struct { + description string + expected v1.GPU + }{ + { + description: "1x NVIDIA A10 (24 GB)", + expected: v1.GPU{ + Count: 1, + Manufacturer: "NVIDIA", + Name: "NVIDIA A10", + Type: "NVIDIA A10", + Memory: 24 * 1024 * 1024 * 1024, + }, + }, + { + description: "8x NVIDIA H100 (80 GB)", + expected: v1.GPU{ + Count: 8, + Manufacturer: "NVIDIA", + Name: "NVIDIA H100", + Type: "NVIDIA H100", + Memory: 80 * 1024 * 1024 * 1024, + }, + }, + { + description: "4x NVIDIA RTX 4090 (24 GB)", + expected: v1.GPU{ + Count: 4, + Manufacturer: "NVIDIA", + Name: "NVIDIA RTX 4090", + Type: "NVIDIA RTX 4090", + Memory: 24 * 1024 * 1024 * 1024, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + gpu := parseGPUFromDescription(tt.description) + assert.Equal(t, tt.expected.Count, gpu.Count) + assert.Equal(t, tt.expected.Manufacturer, gpu.Manufacturer) + assert.Equal(t, tt.expected.Name, gpu.Name) + assert.Equal(t, tt.expected.Type, gpu.Type) + assert.Equal(t, tt.expected.Memory, gpu.Memory) + }) + } +} + +func createMockInstanceTypeResponse() openapi.InstanceTypes200Response { + return openapi.InstanceTypes200Response{ + Data: map[string]openapi.InstanceTypes200ResponseDataValue{ + "gpu_1x_a10": { + InstanceType: createMockLambdaLabsInstanceType("gpu_1x_a10", "1x NVIDIA A10 (24 GB)", "NVIDIA A10", 100), + RegionsWithCapacityAvailable: []openapi.Region{ + createMockRegion("us-west-1", "US West 1"), + createMockRegion("us-east-1", "US East 1"), + }, + }, + "gpu_8x_h100": { + InstanceType: createMockLambdaLabsInstanceType("gpu_8x_h100", "8x NVIDIA H100 (80 GB)", "NVIDIA H100", 3200), + RegionsWithCapacityAvailable: []openapi.Region{ + createMockRegion("us-east-1", "US East 1"), + }, + }, + }, + } +} + +func createMockLambdaLabsInstanceType(name, description, gpuDescription string, priceCents int32) openapi.InstanceType { + gpuCount := int32(0) + if gpuDescription != "" { + gpuCount = 1 + if name == "gpu_8x_h100" { + gpuCount = 8 + } + } + + return openapi.InstanceType{ + Name: name, + Description: description, + GpuDescription: gpuDescription, + PriceCentsPerHour: priceCents, + Specs: openapi.InstanceTypeSpecs{ + Vcpus: 8, + MemoryGib: 32, + StorageGib: 512, + Gpus: gpuCount, + }, + } +} + +func createMockRegion(name, description string) openapi.Region { + return openapi.Region{ + Name: name, + Description: description, + } +} + +func findInstanceTypeByName(instanceTypes []v1.InstanceType, name string) *v1.InstanceType { + for _, instanceType := range instanceTypes { + if instanceType.Type == name { + return &instanceType + } + } + return nil +}