From 6a3fedafc600ae464aefc3151eb2bf7e31e9879a Mon Sep 17 00:00:00 2001 From: Drew Malin Date: Fri, 5 Sep 2025 12:32:30 -0700 Subject: [PATCH 1/2] feat(BREV-1659): Manufacturer --- v1/instancetype.go | 23 ++++++++++++++- .../shadeform/model_instance_configuration.go | 2 +- .../model_instance_type_configuration.go | 2 +- v1/providers/shadeform/instancetype.go | 29 +++++++++++++++++-- 4 files changed, 50 insertions(+), 6 deletions(-) diff --git a/v1/instancetype.go b/v1/instancetype.go index 171907ff..9e42d4e4 100644 --- a/v1/instancetype.go +++ b/v1/instancetype.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "reflect" + "strings" "time" "github.com/alecthomas/units" @@ -12,6 +13,25 @@ import ( "github.com/google/go-cmp/cmp" ) +type Manufacturer string + +const ( + ManufacturerNVIDIA Manufacturer = "NVIDIA" + ManufacturerIntel Manufacturer = "Intel" + ManufacturerUnknown Manufacturer = "unknown" +) + +func GetManufacturer(manufacturer string) Manufacturer { + switch strings.ToLower(manufacturer) { + case "nvidia": + return ManufacturerNVIDIA + case "intel": + return ManufacturerIntel + default: + return ManufacturerUnknown + } +} + type InstanceTypeID string type InstanceType struct { @@ -76,7 +96,7 @@ type GPU struct { Memory units.Base2Bytes MemoryDetails string // "", "HBM", "GDDR", "DDR", etc. NetworkDetails string // "PCIe", "SXM4", "SXM5", etc. - Manufacturer string + Manufacturer Manufacturer Name string Type string } @@ -97,6 +117,7 @@ type GetInstanceTypeArgs struct { Locations LocationsFilter SupportedArchitectures []string InstanceTypes []string + GPUManufacterers []Manufacturer } // ValidateGetInstanceTypes validates that the GetInstanceTypes functionality works correctly diff --git a/v1/providers/shadeform/gen/shadeform/model_instance_configuration.go b/v1/providers/shadeform/gen/shadeform/model_instance_configuration.go index ceadf426..201084fc 100644 --- a/v1/providers/shadeform/gen/shadeform/model_instance_configuration.go +++ b/v1/providers/shadeform/gen/shadeform/model_instance_configuration.go @@ -443,4 +443,4 @@ func (v NullableInstanceConfiguration) MarshalJSON() ([]byte, error) { func (v *NullableInstanceConfiguration) UnmarshalJSON(src []byte) error { v.isSet = true return json.Unmarshal(src, &v.value) -} \ No newline at end of file +} diff --git a/v1/providers/shadeform/gen/shadeform/model_instance_type_configuration.go b/v1/providers/shadeform/gen/shadeform/model_instance_type_configuration.go index 410e4f2b..eebfad14 100644 --- a/v1/providers/shadeform/gen/shadeform/model_instance_type_configuration.go +++ b/v1/providers/shadeform/gen/shadeform/model_instance_type_configuration.go @@ -443,4 +443,4 @@ func (v NullableInstanceTypeConfiguration) MarshalJSON() ([]byte, error) { func (v *NullableInstanceTypeConfiguration) UnmarshalJSON(src []byte) error { v.isSet = true return json.Unmarshal(src, &v.value) -} \ No newline at end of file +} diff --git a/v1/providers/shadeform/instancetype.go b/v1/providers/shadeform/instancetype.go index ad2a76c8..0f17856d 100644 --- a/v1/providers/shadeform/instancetype.go +++ b/v1/providers/shadeform/instancetype.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "slices" "strings" "time" @@ -44,8 +45,11 @@ func (c *ShadeformClient) GetInstanceTypes(ctx context.Context, args v1.GetInsta if err != nil { return nil, err } - // Filter the list down to the instance types that are allowed by the configuration filter + // Filter the list down to the instance types that are allowed by the configuration filter and the args for _, singleInstanceType := range instanceTypesFromShadeformInstanceType { + if !isSelectedByArgs(singleInstanceType, args) { + continue + } if c.isInstanceTypeAllowed(singleInstanceType.Type) { instanceTypes = append(instanceTypes, singleInstanceType) } @@ -55,6 +59,24 @@ func (c *ShadeformClient) GetInstanceTypes(ctx context.Context, args v1.GetInsta return instanceTypes, nil } +func isSelectedByArgs(instanceType v1.InstanceType, args v1.GetInstanceTypeArgs) bool { + if len(args.GPUManufacterers) > 0 { + if len(instanceType.SupportedGPUs) == 0 { + return false + } + + // For each supported GPU, check to see if the manufacture matches the args. The supported GPUs + // must be a full subset of the args value. + for _, supportedGPU := range instanceType.SupportedGPUs { + if !slices.Contains(args.GPUManufacterers, supportedGPU.Manufacturer) { + return false + } + } + } + + return true +} + func (c *ShadeformClient) GetInstanceTypePollTime() time.Duration { return 5 * time.Minute } @@ -153,6 +175,7 @@ func (c *ShadeformClient) convertShadeformInstanceTypeToV1InstanceType(shadeform } gpuName := shadeformGPUTypeToBrevGPUName(shadeformInstanceType.Configuration.GpuType) + gpuManufacturer := v1.GetManufacturer(shadeformInstanceType.Configuration.GpuManufacturer) for _, region := range shadeformInstanceType.Availability { instanceTypes = append(instanceTypes, v1.InstanceType{ @@ -164,9 +187,9 @@ func (c *ShadeformClient) convertShadeformInstanceTypeToV1InstanceType(shadeform { Count: shadeformInstanceType.Configuration.NumGpus, Memory: units.Base2Bytes(shadeformInstanceType.Configuration.VramPerGpuInGb) * units.GiB, - MemoryDetails: "", // TODO: add memory details + MemoryDetails: "", NetworkDetails: shadeformInstanceType.Configuration.Interconnect, - Manufacturer: shadeformInstanceType.Configuration.GpuManufacturer, + Manufacturer: gpuManufacturer, Name: gpuName, Type: shadeformInstanceType.Configuration.GpuType, }, From 1df376a0dd1c3e3f26d643cb19dccfd70364bc9f Mon Sep 17 00:00:00 2001 From: Drew Malin Date: Fri, 5 Sep 2025 12:42:23 -0700 Subject: [PATCH 2/2] fix tests --- v1/providers/lambdalabs/instancetype_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/v1/providers/lambdalabs/instancetype_test.go b/v1/providers/lambdalabs/instancetype_test.go index 6d306ab9..9295d7f5 100644 --- a/v1/providers/lambdalabs/instancetype_test.go +++ b/v1/providers/lambdalabs/instancetype_test.go @@ -34,7 +34,7 @@ func TestLambdaLabsClient_GetInstanceTypes_Success(t *testing.T) { assert.True(t, a10Type.IsAvailable) assert.Len(t, a10Type.SupportedGPUs, 1) assert.Equal(t, int32(1), a10Type.SupportedGPUs[0].Count) - assert.Equal(t, "NVIDIA", a10Type.SupportedGPUs[0].Manufacturer) + assert.Equal(t, v1.ManufacturerNVIDIA, a10Type.SupportedGPUs[0].Manufacturer) assert.Equal(t, "A10", a10Type.SupportedGPUs[0].Name) } @@ -141,7 +141,7 @@ func TestConvertLambdaLabsInstanceTypeToV1InstanceType(t *testing.T) { gpu := v1InstanceType.SupportedGPUs[0] assert.Equal(t, int32(1), gpu.Count) - assert.Equal(t, "NVIDIA", gpu.Manufacturer) + assert.Equal(t, v1.ManufacturerNVIDIA, gpu.Manufacturer) assert.Equal(t, "NVIDIA A10", gpu.Name) assert.Equal(t, "NVIDIA A10", gpu.Type) assert.Equal(t, units.Base2Bytes(24*1024*1024*1024), gpu.Memory) @@ -172,7 +172,7 @@ func TestParseGPUFromDescription(t *testing.T) { description: "1x H100 (80 GB SXM5)", expected: v1.GPU{ Count: 1, - Manufacturer: "NVIDIA", + Manufacturer: v1.ManufacturerNVIDIA, Name: "H100", Type: "H100.SXM5", Memory: 80 * 1024 * 1024 * 1024, @@ -184,7 +184,7 @@ func TestParseGPUFromDescription(t *testing.T) { description: "8x Tesla V100 (16 GB)", expected: v1.GPU{ Count: 8, - Manufacturer: "NVIDIA", + Manufacturer: v1.ManufacturerNVIDIA, Name: "V100", Type: "V100", Memory: 16 * 1024 * 1024 * 1024,