Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion v1/instancetype.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,33 @@ import (
"errors"
"fmt"
"reflect"
"strings"
"time"

"github.com/alecthomas/units"
"github.com/bojanz/currency"
"github.com/google/go-cmp/cmp"
)

type Manufacturer string

const (
ManufacturerNVIDIA Manufacturer = "NVIDIA"
ManufacturerIntel Manufacturer = "Intel"
ManufacturerUnknown Manufacturer = "unknown"
)

func GetManufacturer(manufacturer string) Manufacturer {
switch strings.ToLower(manufacturer) {
case "nvidia":
return ManufacturerNVIDIA
case "intel":
return ManufacturerIntel
default:
return ManufacturerUnknown
}
}

type InstanceTypeID string

type InstanceType struct {
Expand Down Expand Up @@ -76,7 +96,7 @@ type GPU struct {
Memory units.Base2Bytes
MemoryDetails string // "", "HBM", "GDDR", "DDR", etc.
NetworkDetails string // "PCIe", "SXM4", "SXM5", etc.
Manufacturer string
Manufacturer Manufacturer
Name string
Type string
}
Expand All @@ -97,6 +117,7 @@ type GetInstanceTypeArgs struct {
Locations LocationsFilter
SupportedArchitectures []string
InstanceTypes []string
GPUManufacterers []Manufacturer
}

// ValidateGetInstanceTypes validates that the GetInstanceTypes functionality works correctly
Expand Down
8 changes: 4 additions & 4 deletions v1/providers/lambdalabs/instancetype_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestLambdaLabsClient_GetInstanceTypes_Success(t *testing.T) {
assert.True(t, a10Type.IsAvailable)
assert.Len(t, a10Type.SupportedGPUs, 1)
assert.Equal(t, int32(1), a10Type.SupportedGPUs[0].Count)
assert.Equal(t, "NVIDIA", a10Type.SupportedGPUs[0].Manufacturer)
assert.Equal(t, v1.ManufacturerNVIDIA, a10Type.SupportedGPUs[0].Manufacturer)
assert.Equal(t, "A10", a10Type.SupportedGPUs[0].Name)
}

Expand Down Expand Up @@ -141,7 +141,7 @@ func TestConvertLambdaLabsInstanceTypeToV1InstanceType(t *testing.T) {

gpu := v1InstanceType.SupportedGPUs[0]
assert.Equal(t, int32(1), gpu.Count)
assert.Equal(t, "NVIDIA", gpu.Manufacturer)
assert.Equal(t, v1.ManufacturerNVIDIA, gpu.Manufacturer)
assert.Equal(t, "NVIDIA A10", gpu.Name)
assert.Equal(t, "NVIDIA A10", gpu.Type)
assert.Equal(t, units.Base2Bytes(24*1024*1024*1024), gpu.Memory)
Expand Down Expand Up @@ -172,7 +172,7 @@ func TestParseGPUFromDescription(t *testing.T) {
description: "1x H100 (80 GB SXM5)",
expected: v1.GPU{
Count: 1,
Manufacturer: "NVIDIA",
Manufacturer: v1.ManufacturerNVIDIA,
Name: "H100",
Type: "H100.SXM5",
Memory: 80 * 1024 * 1024 * 1024,
Expand All @@ -184,7 +184,7 @@ func TestParseGPUFromDescription(t *testing.T) {
description: "8x Tesla V100 (16 GB)",
expected: v1.GPU{
Count: 8,
Manufacturer: "NVIDIA",
Manufacturer: v1.ManufacturerNVIDIA,
Name: "V100",
Type: "V100",
Memory: 16 * 1024 * 1024 * 1024,
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 26 additions & 3 deletions v1/providers/shadeform/instancetype.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"slices"
"strings"
"time"

Expand Down Expand Up @@ -44,8 +45,11 @@ func (c *ShadeformClient) GetInstanceTypes(ctx context.Context, args v1.GetInsta
if err != nil {
return nil, err
}
// Filter the list down to the instance types that are allowed by the configuration filter
// Filter the list down to the instance types that are allowed by the configuration filter and the args
for _, singleInstanceType := range instanceTypesFromShadeformInstanceType {
if !isSelectedByArgs(singleInstanceType, args) {
continue
}
if c.isInstanceTypeAllowed(singleInstanceType.Type) {
instanceTypes = append(instanceTypes, singleInstanceType)
}
Expand All @@ -55,6 +59,24 @@ func (c *ShadeformClient) GetInstanceTypes(ctx context.Context, args v1.GetInsta
return instanceTypes, nil
}

func isSelectedByArgs(instanceType v1.InstanceType, args v1.GetInstanceTypeArgs) bool {
if len(args.GPUManufacterers) > 0 {
if len(instanceType.SupportedGPUs) == 0 {
return false
}

// For each supported GPU, check to see if the manufacture matches the args. The supported GPUs
// must be a full subset of the args value.
for _, supportedGPU := range instanceType.SupportedGPUs {
if !slices.Contains(args.GPUManufacterers, supportedGPU.Manufacturer) {
return false
}
}
}

return true
}

func (c *ShadeformClient) GetInstanceTypePollTime() time.Duration {
return 5 * time.Minute
}
Expand Down Expand Up @@ -153,6 +175,7 @@ func (c *ShadeformClient) convertShadeformInstanceTypeToV1InstanceType(shadeform
}

gpuName := shadeformGPUTypeToBrevGPUName(shadeformInstanceType.Configuration.GpuType)
gpuManufacturer := v1.GetManufacturer(shadeformInstanceType.Configuration.GpuManufacturer)

for _, region := range shadeformInstanceType.Availability {
instanceTypes = append(instanceTypes, v1.InstanceType{
Expand All @@ -164,9 +187,9 @@ func (c *ShadeformClient) convertShadeformInstanceTypeToV1InstanceType(shadeform
{
Count: shadeformInstanceType.Configuration.NumGpus,
Memory: units.Base2Bytes(shadeformInstanceType.Configuration.VramPerGpuInGb) * units.GiB,
MemoryDetails: "", // TODO: add memory details
MemoryDetails: "",
NetworkDetails: shadeformInstanceType.Configuration.Interconnect,
Manufacturer: shadeformInstanceType.Configuration.GpuManufacturer,
Manufacturer: gpuManufacturer,
Name: gpuName,
Type: shadeformInstanceType.Configuration.GpuType,
},
Expand Down
Loading