Skip to content
6 changes: 5 additions & 1 deletion .github/workflows/validation-nebius.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version-file: 'go.mod'
go-version-file: "go.mod"

- name: Cache Go modules
uses: actions/cache@v4
Expand All @@ -47,6 +47,10 @@ jobs:
NEBIUS_SERVICE_ACCOUNT_ID: ${{ secrets.NEBIUS_SERVICE_ACCOUNT_ID }}
NEBIUS_PROJECT_ID: ${{ secrets.NEBIUS_PROJECT_ID }}
TEST_USER_PRIVATE_KEY_PEM_BASE64: ${{ secrets.TEST_USER_PRIVATE_KEY_PEM_BASE64 }}
NEBIUS_SERVICE_ACCOUNT_JSON: ${{ secrets.NEBIUS_SERVICE_ACCOUNT_JSON }}
NEBIUS_TENANT_ID: ${{ secrets.NEBIUS_TENANT_ID }}
TEST_PRIVATE_KEY_BASE64: ${{ secrets.TEST_PRIVATE_KEY_BASE64 }}
TEST_PUBLIC_KEY_BASE64: ${{ secrets.TEST_PUBLIC_KEY_BASE64 }}
VALIDATION_TEST: true
run: |
cd v1/providers/nebius
Expand Down
6 changes: 5 additions & 1 deletion internal/validation/suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ func RunInstanceLifecycleValidation(t *testing.T, config ProviderConfig) {
capabilities, err := client.GetCapabilities(ctx)
require.NoError(t, err)

types, err := client.GetInstanceTypes(ctx, v1.GetInstanceTypeArgs{})
types, err := client.GetInstanceTypes(ctx, v1.GetInstanceTypeArgs{
ArchitectureFilter: &v1.ArchitectureFilter{
IncludeArchitectures: []v1.Architecture{v1.ArchitectureX86_64},
},
})
require.NoError(t, err)
require.NotEmpty(t, types, "Should have instance types")

Expand Down
2 changes: 1 addition & 1 deletion v1/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func validateOSVersion(ctx context.Context, sshClient *ssh.Client) (string, erro
}

osVersion := strings.Trim(parts[1], "\"")
ubuntuRegex := regexp.MustCompile(`Ubuntu 20\.04|22\.04`)
ubuntuRegex := regexp.MustCompile(`Ubuntu 20\.04|22\.04|24\.04`)
if !ubuntuRegex.MatchString(osVersion) {
return "", fmt.Errorf("expected Ubuntu 20.04 or 22.04, got: %s", osVersion)
}
Expand Down
2 changes: 2 additions & 0 deletions v1/providers/aws/validation_kubernetes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
)

func TestAWSKubernetesValidation(t *testing.T) {
t.Skip("Skipping AWS Kubernetes validation tests")

if isValidationTest == "" {
t.Skip("VALIDATION_TEST is not set, skipping AWS Kubernetes validation tests")
}
Expand Down
2 changes: 2 additions & 0 deletions v1/providers/aws/validation_network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ var (
)

func TestAWSNetworkValidation(t *testing.T) {
t.Skip("Skipping AWS Network validation tests")

if isValidationTest == "" {
t.Skip("VALIDATION_TEST is not set, skipping AWS Network validation tests")
}
Expand Down
5 changes: 5 additions & 0 deletions v1/providers/nebius/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ func findProjectForRegion(ctx context.Context, sdk *gosdk.SDK, tenantID, region
return "", fmt.Errorf("no projects found in tenant %s", tenantID)
}

// TODO: I don't think the following code is correct, as the use of monikers like "default" or "default-project"
// or even the nebius convention of "default-project-{region}" will work with the nebius SDK. The SDK expects
// the project *ID* to be used, not the name. If we get to this part of the code, it likely implies that we will
// not be able to proceed.

// Sort projects by ID for deterministic selection
// This ensures CreateInstance and ListInstances always use the same project!
sort.Slice(projects, func(i, j int) bool {
Expand Down
18 changes: 12 additions & 6 deletions v1/providers/nebius/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,13 @@ func getImageDescription(image *compute.Image) string {
return ""
}

const (
ArchitectureX86_64 = "x86_64"
ArchitectureArm64 = "arm64"
ArchitectureAMD64 = "amd64"
ArchitectureAArch64 = "aarch64"
)

// extractArchitecture extracts architecture information from image metadata
func extractArchitecture(image *compute.Image) string {
// Check labels for architecture info
Expand All @@ -217,16 +224,15 @@ func extractArchitecture(image *compute.Image) string {
// Infer from image name
if image.Metadata != nil {
name := strings.ToLower(image.Metadata.Name)
if strings.Contains(name, "arm64") || strings.Contains(name, "aarch64") {
return "arm64"
if strings.Contains(name, ArchitectureArm64) || strings.Contains(name, ArchitectureAArch64) {
return ArchitectureArm64
}
if strings.Contains(name, "x86_64") || strings.Contains(name, "amd64") {
//nolint:goconst // Architecture string used in detection and returned as default
return "x86_64"
if strings.Contains(name, ArchitectureX86_64) || strings.Contains(name, ArchitectureAMD64) {
return ArchitectureX86_64
}
}

return "x86_64"
return ArchitectureX86_64
}

// filterImagesByArchitectures filters images by multiple architectures
Expand Down
9 changes: 5 additions & 4 deletions v1/providers/nebius/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ func (c *NebiusClient) convertNebiusInstanceToV1(ctx context.Context, instance *
InstanceType: instanceTypeID, // Full instance type ID (e.g., "gpu-h100-sxm.8gpu-128vcpu-1600gb")
InstanceTypeID: v1.InstanceTypeID(instanceTypeID), // Same as InstanceType - required for dev-plane lookup
ImageID: imageFamily,
DiskSize: units.Base2Bytes(diskSize),
DiskSizeBytes: v1.NewBytes(v1.BytesValue(diskSize), v1.Byte), // diskSize is already in bytes from getBootDiskSize
Tags: tags,
Status: v1.Status{LifecycleStatus: lifecycleStatus},
Expand Down Expand Up @@ -1150,6 +1151,10 @@ func (c *NebiusClient) createBootDisk(ctx context.Context, attrs v1.CreateInstan

// buildDiskCreateRequest builds a disk creation request, trying image family first, then image ID
func (c *NebiusClient) buildDiskCreateRequest(ctx context.Context, diskName string, attrs v1.CreateInstanceAttrs) (*compute.CreateDiskRequest, error) {
if attrs.DiskSize == 0 {
attrs.DiskSize = 1280 * units.Gibibyte // Defaulted by the Nebius Console
}
Comment on lines +1154 to +1156
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cannot be 0


baseReq := &compute.CreateDiskRequest{
Metadata: &common.ResourceMetadata{
ParentId: c.projectID,
Expand Down Expand Up @@ -1553,7 +1558,6 @@ func (c *NebiusClient) resolveImageFamily(ctx context.Context, imageID string) (
"mk8s-worker-node-v-1-31-ubuntu24.04-cuda12",
"ubuntu22.04",
"ubuntu20.04",
"ubuntu18.04",
}

// Check if ImageID is already a known family name
Expand Down Expand Up @@ -1600,9 +1604,6 @@ func (c *NebiusClient) resolveImageFamily(ctx context.Context, imageID string) (
if strings.Contains(name, "ubuntu20") || strings.Contains(name, "ubuntu-20") {
return "ubuntu20.04", nil
}
if strings.Contains(name, "ubuntu18") || strings.Contains(name, "ubuntu-18") {
return "ubuntu18.04", nil
}
}

// Default fallback - use the original ImageID as family
Expand Down
34 changes: 21 additions & 13 deletions v1/providers/nebius/instancetype.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@ func (c *NebiusClient) GetInstanceTypes(ctx context.Context, args v1.GetInstance
// Default behavior: check ALL regions to show all available quota
var locations []v1.Location

if len(args.Locations) > 0 && !args.Locations.IsAll() {
if args.Locations.IsAll() { //nolint:gocritic // prefer if statement over switch statement
allLocations, err := c.GetLocations(ctx, v1.GetLocationsArgs{})
if err != nil {
return nil, errors.WrapAndTrace(err)
}
locations = allLocations
} else if len(args.Locations) > 0 {
// User requested specific locations - filter to those
allLocations, err := c.GetLocations(ctx, v1.GetLocationsArgs{})
if err == nil {
Expand All @@ -48,15 +54,8 @@ func (c *NebiusClient) GetInstanceTypes(ctx context.Context, args v1.GetInstance
locations = []v1.Location{{Name: c.location}}
}
} else {
// Default behavior: enumerate ALL regions for quota-aware discovery
// This shows users all instance types they have quota for, regardless of region
allLocations, err := c.GetLocations(ctx, v1.GetLocationsArgs{})
if err == nil {
locations = allLocations
} else {
// Fallback to client's configured location if we can't get all locations
locations = []v1.Location{{Name: c.location}}
}
// Fallback to client's configured location if we can't get all locations
locations = []v1.Location{{Name: c.location}}
}

// Get quota information for all regions
Expand Down Expand Up @@ -176,10 +175,10 @@ func (c *NebiusClient) getInstanceTypesForLocation(ctx context.Context, platform

// Convert Nebius platform preset to our InstanceType format
instanceType := v1.InstanceType{
ID: v1.InstanceTypeID(instanceTypeID), // Dot-separated format (e.g., "gpu-h100-sxm.8gpu-128vcpu-1600gb")
Location: location.Name,
Type: instanceTypeID, // Same as ID - both use dot-separated format
VCPU: preset.Resources.VcpuCount,
Memory: units.Base2Bytes(preset.Resources.MemoryGibibytes) * units.Gibibyte,
MemoryBytes: v1.NewBytes(v1.BytesValue(preset.Resources.MemoryGibibytes), v1.Gibibyte), // Memory in GiB
NetworkPerformance: "standard", // Default network performance
IsAvailable: isAvailable,
Expand All @@ -191,12 +190,14 @@ func (c *NebiusClient) getInstanceTypesForLocation(ctx context.Context, platform

// Add GPU information if available
if preset.Resources.GpuCount > 0 && !isCPUOnly {
memory := getGPUMemory(gpuType)
gpu := v1.GPU{
Count: preset.Resources.GpuCount,
Type: gpuType,
Name: gpuName,
Manufacturer: v1.ManufacturerNVIDIA, // Nebius currently only supports NVIDIA GPUs
Memory: getGPUMemory(gpuType), // Populate VRAM based on GPU type
Memory: memory, // Populate VRAM based on GPU type
MemoryBytes: v1.NewBytes(v1.BytesValue(int64(memory)/int64(units.Gibibyte)), v1.Gibibyte),
}
instanceType.SupportedGPUs = []v1.GPU{gpu}
}
Expand All @@ -207,6 +208,9 @@ func (c *NebiusClient) getInstanceTypesForLocation(ctx context.Context, platform
instanceType.BasePrice = pricing
}

// Make the instance type ID
instanceType.ID = v1.MakeGenericInstanceTypeID(instanceType)
Comment on lines +211 to +212
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should now resolve IDs properly


instanceTypes = append(instanceTypes, instanceType)
}
}
Expand Down Expand Up @@ -368,7 +372,9 @@ func (c *NebiusClient) buildSupportedStorage() []v1.Storage {
// Nebius supports dynamically allocatable network SSD disks
// Minimum: 50GB, Maximum: 2560GB
minSize := 50 * units.GiB
minSizeBytes := v1.NewBytes(50, v1.Gibibyte)
maxSize := 2560 * units.GiB
maxSizeBytes := v1.NewBytes(2560, v1.Gibibyte)

// Pricing is roughly $0.10 per GB-month, which is ~$0.00014 per GB-hour
pricePerGBHr, _ := currency.NewAmount("0.00014", "USD")
Expand All @@ -379,6 +385,8 @@ func (c *NebiusClient) buildSupportedStorage() []v1.Storage {
Count: 1,
MinSize: &minSize,
MaxSize: &maxSize,
MinSizeBytes: &minSizeBytes,
MaxSizeBytes: &maxSizeBytes,
IsElastic: true,
PricePerGBHr: &pricePerGBHr,
},
Expand All @@ -396,7 +404,7 @@ func (c *NebiusClient) applyInstanceTypeFilters(instanceTypes []v1.InstanceType,
if len(args.InstanceTypes) > 0 {
found := false
for _, requestedType := range args.InstanceTypes {
if string(instanceType.ID) == requestedType {
if instanceType.Type == requestedType {
Comment on lines -399 to +407
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We filter on Type, not Type ID

found = true
break
}
Expand Down
29 changes: 16 additions & 13 deletions v1/providers/nebius/scripts/images_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ import (

// Test_EnumerateImages enumerates all available images in Nebius
// Usage:
// export NEBIUS_SERVICE_ACCOUNT_JSON='/path/to/service-account.json'
// export NEBIUS_TENANT_ID='tenant-e00xxx'
// export NEBIUS_LOCATION='eu-north1'
// go test -tags scripts -v -run Test_EnumerateImages
//
// export NEBIUS_SERVICE_ACCOUNT_JSON='/path/to/service-account.json'
// export NEBIUS_TENANT_ID='tenant-e00xxx'
// export NEBIUS_LOCATION='eu-north1'
// go test -tags scripts -v -run Test_EnumerateImages
func Test_EnumerateImages(t *testing.T) {
serviceAccountJSON := os.Getenv("NEBIUS_SERVICE_ACCOUNT_JSON")
tenantID := os.Getenv("NEBIUS_TENANT_ID")
Expand Down Expand Up @@ -93,7 +94,7 @@ func Test_EnumerateImages(t *testing.T) {
t.Fatalf("Error marshaling JSON: %v", err)
}

err = os.WriteFile(outputFile, output, 0644)
err = os.WriteFile(outputFile, output, 0o644)
if err != nil {
t.Fatalf("Error writing to file: %v", err)
}
Expand All @@ -103,9 +104,10 @@ func Test_EnumerateImages(t *testing.T) {

// Test_EnumerateImagesAllRegions enumerates images across all Nebius regions
// Usage:
// export NEBIUS_SERVICE_ACCOUNT_JSON='/path/to/service-account.json'
// export NEBIUS_TENANT_ID='tenant-e00xxx'
// go test -tags scripts -v -run Test_EnumerateImagesAllRegions
//
// export NEBIUS_SERVICE_ACCOUNT_JSON='/path/to/service-account.json'
// export NEBIUS_TENANT_ID='tenant-e00xxx'
// go test -tags scripts -v -run Test_EnumerateImagesAllRegions
func Test_EnumerateImagesAllRegions(t *testing.T) {
serviceAccountJSON := os.Getenv("NEBIUS_SERVICE_ACCOUNT_JSON")
tenantID := os.Getenv("NEBIUS_TENANT_ID")
Expand Down Expand Up @@ -173,7 +175,7 @@ func Test_EnumerateImagesAllRegions(t *testing.T) {
t.Fatalf("Error marshaling JSON: %v", err)
}

err = os.WriteFile(outputFile, output, 0644)
err = os.WriteFile(outputFile, output, 0o644)
if err != nil {
t.Fatalf("Error writing to file: %v", err)
}
Expand All @@ -183,10 +185,11 @@ func Test_EnumerateImagesAllRegions(t *testing.T) {

// Test_FilterGPUImages filters images suitable for GPU instances
// Usage:
// export NEBIUS_SERVICE_ACCOUNT_JSON='/path/to/service-account.json'
// export NEBIUS_TENANT_ID='tenant-e00xxx'
// export NEBIUS_LOCATION='eu-north1'
// go test -tags scripts -v -run Test_FilterGPUImages
//
// export NEBIUS_SERVICE_ACCOUNT_JSON='/path/to/service-account.json'
// export NEBIUS_TENANT_ID='tenant-e00xxx'
// export NEBIUS_LOCATION='eu-north1'
// go test -tags scripts -v -run Test_FilterGPUImages
func Test_FilterGPUImages(t *testing.T) {
serviceAccountJSON := os.Getenv("NEBIUS_SERVICE_ACCOUNT_JSON")
tenantID := os.Getenv("NEBIUS_TENANT_ID")
Expand Down
29 changes: 16 additions & 13 deletions v1/providers/nebius/scripts/instancetypes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ import (

// Test_EnumerateInstanceTypes enumerates all instance types across all Nebius regions
// Usage:
// export NEBIUS_SERVICE_ACCOUNT_JSON='/path/to/service-account.json'
// export NEBIUS_TENANT_ID='tenant-e00xxx'
// go test -tags scripts -v -run Test_EnumerateInstanceTypes
//
// export NEBIUS_SERVICE_ACCOUNT_JSON='/path/to/service-account.json'
// export NEBIUS_TENANT_ID='tenant-e00xxx'
// go test -tags scripts -v -run Test_EnumerateInstanceTypes
func Test_EnumerateInstanceTypes(t *testing.T) {
serviceAccountJSON := os.Getenv("NEBIUS_SERVICE_ACCOUNT_JSON")
tenantID := os.Getenv("NEBIUS_TENANT_ID")
Expand Down Expand Up @@ -120,7 +121,7 @@ func Test_EnumerateInstanceTypes(t *testing.T) {
t.Fatalf("Error marshaling JSON: %v", err)
}

err = os.WriteFile(outputFile, output, 0644)
err = os.WriteFile(outputFile, output, 0o644)
if err != nil {
t.Fatalf("Error writing to file: %v", err)
}
Expand All @@ -130,10 +131,11 @@ func Test_EnumerateInstanceTypes(t *testing.T) {

// Test_EnumerateInstanceTypesSingleRegion enumerates instance types for a specific region
// Usage:
// export NEBIUS_SERVICE_ACCOUNT_JSON='/path/to/service-account.json'
// export NEBIUS_TENANT_ID='tenant-e00xxx'
// export NEBIUS_LOCATION='eu-north1'
// go test -tags scripts -v -run Test_EnumerateInstanceTypesSingleRegion
//
// export NEBIUS_SERVICE_ACCOUNT_JSON='/path/to/service-account.json'
// export NEBIUS_TENANT_ID='tenant-e00xxx'
// export NEBIUS_LOCATION='eu-north1'
// go test -tags scripts -v -run Test_EnumerateInstanceTypesSingleRegion
func Test_EnumerateInstanceTypesSingleRegion(t *testing.T) {
serviceAccountJSON := os.Getenv("NEBIUS_SERVICE_ACCOUNT_JSON")
tenantID := os.Getenv("NEBIUS_TENANT_ID")
Expand Down Expand Up @@ -217,7 +219,7 @@ func Test_EnumerateInstanceTypesSingleRegion(t *testing.T) {
t.Fatalf("Error marshaling JSON: %v", err)
}

err = os.WriteFile(outputFile, output, 0644)
err = os.WriteFile(outputFile, output, 0o644)
if err != nil {
t.Fatalf("Error writing to file: %v", err)
}
Expand All @@ -227,10 +229,11 @@ func Test_EnumerateInstanceTypesSingleRegion(t *testing.T) {

// Test_EnumerateGPUTypes filters and displays only GPU instance types with detailed specs
// Usage:
// export NEBIUS_SERVICE_ACCOUNT_JSON='/path/to/service-account.json'
// export NEBIUS_TENANT_ID='tenant-e00xxx'
// export NEBIUS_LOCATION='eu-north1'
// go test -tags scripts -v -run Test_EnumerateGPUTypes
//
// export NEBIUS_SERVICE_ACCOUNT_JSON='/path/to/service-account.json'
// export NEBIUS_TENANT_ID='tenant-e00xxx'
// export NEBIUS_LOCATION='eu-north1'
// go test -tags scripts -v -run Test_EnumerateGPUTypes
func Test_EnumerateGPUTypes(t *testing.T) {
serviceAccountJSON := os.Getenv("NEBIUS_SERVICE_ACCOUNT_JSON")
tenantID := os.Getenv("NEBIUS_TENANT_ID")
Expand Down
2 changes: 2 additions & 0 deletions v1/providers/nebius/validation_kubernetes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
)

func TestKubernetesValidation(t *testing.T) {
t.Skip("Skipping Nebius Kubernetes validation tests")

isValidationTest := os.Getenv("VALIDATION_TEST")
if isValidationTest == "" {
t.Skip("VALIDATION_TEST is not set, skipping Nebius Kubernetes validation tests")
Expand Down
2 changes: 2 additions & 0 deletions v1/providers/nebius/validation_network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ var (
)

func TestNetworkValidation(t *testing.T) {
t.Skip("Skipping Nebius Network validation tests")

if isValidationTest == "" {
t.Skip("VALIDATION_TEST is not set, skipping Nebius Network validation tests")
}
Expand Down
Loading
Loading