From 8e19d4518c36d08b1979ef8f9e3c2f56292a78b3 Mon Sep 17 00:00:00 2001 From: Steven Miller Date: Thu, 22 Jan 2026 13:28:30 -0500 Subject: [PATCH 1/2] GPU load balancing --- cmd/api/config/config.go | 6 ++ lib/devices/mdev.go | 199 +++++++++++++++++++++++++++++++++------ 2 files changed, 175 insertions(+), 30 deletions(-) diff --git a/cmd/api/config/config.go b/cmd/api/config/config.go index 3fcb23d..dd923fd 100644 --- a/cmd/api/config/config.go +++ b/cmd/api/config/config.go @@ -118,6 +118,9 @@ type Config struct { // Hypervisor configuration DefaultHypervisor string // Default hypervisor type: "cloud-hypervisor" or "qemu" + // GPU configuration + GPUProfileCacheTTL string // TTL for GPU profile metadata cache (e.g., "30m") + // Oversubscription ratios (1.0 = no oversubscription, 2.0 = 2x oversubscription) OversubCPU float64 // CPU oversubscription ratio OversubMemory float64 // Memory oversubscription ratio @@ -212,6 +215,9 @@ func Load() *Config { // Hypervisor configuration DefaultHypervisor: getEnv("DEFAULT_HYPERVISOR", "cloud-hypervisor"), + // GPU configuration + GPUProfileCacheTTL: getEnv("GPU_PROFILE_CACHE_TTL", "30m"), + // Oversubscription ratios (1.0 = no oversubscription) OversubCPU: getEnvFloat("OVERSUB_CPU", 4.0), OversubMemory: getEnvFloat("OVERSUB_MEMORY", 1.0), diff --git a/lib/devices/mdev.go b/lib/devices/mdev.go index 364bcfd..c5b33fb 100644 --- a/lib/devices/mdev.go +++ b/lib/devices/mdev.go @@ -8,9 +8,11 @@ import ( "os/exec" "path/filepath" "regexp" + "sort" "strconv" "strings" "sync" + "time" "github.com/google/uuid" "github.com/kernel/hypeman/lib/logger" @@ -32,12 +34,52 @@ type profileMetadata struct { FramebufferMB int } -// cachedProfiles holds static profile metadata, loaded once on first access +// cachedProfiles holds profile metadata with TTL-based expiry. +// The cache TTL is configurable via GPU_PROFILE_CACHE_TTL environment variable. var ( cachedProfiles []profileMetadata - cachedProfilesOnce sync.Once + cachedProfilesMu sync.RWMutex + cachedProfilesTime time.Time ) +// getProfileCacheTTL returns the TTL for profile metadata cache. +// Reads from GPU_PROFILE_CACHE_TTL env var, defaults to 30 minutes. +func getProfileCacheTTL() time.Duration { + if ttl := os.Getenv("GPU_PROFILE_CACHE_TTL"); ttl != "" { + if d, err := time.ParseDuration(ttl); err == nil { + return d + } + } + return 30 * time.Minute +} + +// getCachedProfiles returns cached profile metadata, refreshing if TTL has expired. +func getCachedProfiles(firstVF string) []profileMetadata { + ttl := getProfileCacheTTL() + + // Fast path: check with read lock + cachedProfilesMu.RLock() + if len(cachedProfiles) > 0 && time.Since(cachedProfilesTime) < ttl { + profiles := cachedProfiles + cachedProfilesMu.RUnlock() + return profiles + } + cachedProfilesMu.RUnlock() + + // Slow path: refresh cache with write lock + cachedProfilesMu.Lock() + defer cachedProfilesMu.Unlock() + + // Double-check after acquiring write lock + if len(cachedProfiles) > 0 && time.Since(cachedProfilesTime) < ttl { + return cachedProfiles + } + + cachedProfiles = loadProfileMetadata(firstVF) + cachedProfilesTime = time.Now() + return cachedProfiles +} + // DiscoverVFs returns all SR-IOV Virtual Functions available for vGPU. // These are discovered by scanning /sys/class/mdev_bus/ which contains // VFs that can host mdev devices. @@ -100,17 +142,15 @@ func ListGPUProfilesWithVFs(vfs []VirtualFunction) ([]GPUProfile, error) { return nil, nil } - // Load static profile metadata once (cached indefinitely) - cachedProfilesOnce.Do(func() { - cachedProfiles = loadProfileMetadata(vfs[0].PCIAddress) - }) + // Load profile metadata with TTL-based caching + cachedMeta := getCachedProfiles(vfs[0].PCIAddress) // Count availability for all profiles in parallel - availability := countAvailableVFsForProfilesParallel(vfs, cachedProfiles) + availability := countAvailableVFsForProfilesParallel(vfs, cachedMeta) // Build result with dynamic availability counts - profiles := make([]GPUProfile, 0, len(cachedProfiles)) - for _, meta := range cachedProfiles { + profiles := make([]GPUProfile, 0, len(cachedMeta)) + for _, meta := range cachedMeta { profiles = append(profiles, GPUProfile{ Name: meta.Name, FramebufferMB: meta.FramebufferMB, @@ -194,8 +234,8 @@ func parseFramebufferFromDescription(typeDir string) int { } // countAvailableVFsForProfilesParallel counts available instances for all profiles in parallel. -// Optimized: all VFs on the same parent GPU have identical profile support, -// so we only sample one VF per parent instead of reading from every VF. +// Groups VFs by parent GPU, then sums available_instances across all free VFs. +// For SR-IOV vGPU, each VF typically has available_instances of 0 or 1. func countAvailableVFsForProfilesParallel(vfs []VirtualFunction, profiles []profileMetadata) map[string]int { if len(vfs) == 0 || len(profiles) == 0 { return make(map[string]int) @@ -352,6 +392,118 @@ func getProfileNameFromType(profileType, vfAddress string) string { return strings.TrimSpace(string(data)) } +// getProfileFramebufferMB returns the framebuffer size in MB for a profile type. +// Uses cached profile metadata for fast lookup. +func getProfileFramebufferMB(profileType string) int { + cachedProfilesMu.RLock() + defer cachedProfilesMu.RUnlock() + + for _, p := range cachedProfiles { + if p.TypeName == profileType { + return p.FramebufferMB + } + } + return 0 +} + +// calculateGPUVRAMUsage calculates VRAM usage per GPU from active mdevs. +// Returns a map of parentGPU -> usedVRAMMB. +func calculateGPUVRAMUsage(vfs []VirtualFunction, mdevs []MdevDevice) map[string]int { + // Build VF -> parentGPU lookup + vfToParent := make(map[string]string, len(vfs)) + for _, vf := range vfs { + vfToParent[vf.PCIAddress] = vf.ParentGPU + } + + // Sum framebuffer usage per GPU + usageByGPU := make(map[string]int) + for _, mdev := range mdevs { + parentGPU := vfToParent[mdev.VFAddress] + if parentGPU == "" { + continue + } + usageByGPU[parentGPU] += getProfileFramebufferMB(mdev.ProfileType) + } + + return usageByGPU +} + +// selectLeastLoadedVF selects a VF from the GPU with the most available VRAM +// that can create the requested profile. Returns empty string if none available. +func selectLeastLoadedVF(ctx context.Context, vfs []VirtualFunction, profileType string) string { + log := logger.FromContext(ctx) + + // Get active mdevs to calculate VRAM usage + mdevs, _ := ListMdevDevices() + + // Calculate VRAM usage per GPU + vramUsage := calculateGPUVRAMUsage(vfs, mdevs) + + // Group free VFs by parent GPU + freeVFsByGPU := make(map[string][]VirtualFunction) + allGPUs := make(map[string]bool) + for _, vf := range vfs { + allGPUs[vf.ParentGPU] = true + if !vf.HasMdev { + freeVFsByGPU[vf.ParentGPU] = append(freeVFsByGPU[vf.ParentGPU], vf) + } + } + + // Build list of GPUs sorted by VRAM usage (ascending = least loaded first) + type gpuLoad struct { + gpu string + usedMB int + } + var gpuLoads []gpuLoad + for gpu := range allGPUs { + gpuLoads = append(gpuLoads, gpuLoad{gpu: gpu, usedMB: vramUsage[gpu]}) + } + sort.Slice(gpuLoads, func(i, j int) bool { + return gpuLoads[i].usedMB < gpuLoads[j].usedMB + }) + + log.DebugContext(ctx, "GPU VRAM usage for load balancing", + "gpu_count", len(gpuLoads), + "profile_type", profileType) + + // Try each GPU in order of least loaded + for _, gl := range gpuLoads { + freeVFs := freeVFsByGPU[gl.gpu] + if len(freeVFs) == 0 { + log.DebugContext(ctx, "skipping GPU: no free VFs", + "gpu", gl.gpu, + "used_mb", gl.usedMB) + continue + } + + // Check if any free VF on this GPU can create the profile + for _, vf := range freeVFs { + availPath := filepath.Join(mdevBusPath, vf.PCIAddress, "mdev_supported_types", profileType, "available_instances") + data, err := os.ReadFile(availPath) + if err != nil { + continue + } + instances, err := strconv.Atoi(strings.TrimSpace(string(data))) + if err != nil || instances < 1 { + continue + } + + log.DebugContext(ctx, "selected VF from least loaded GPU", + "vf", vf.PCIAddress, + "gpu", gl.gpu, + "gpu_used_mb", gl.usedMB) + return vf.PCIAddress + } + + log.DebugContext(ctx, "skipping GPU: no VF can create profile", + "gpu", gl.gpu, + "used_mb", gl.usedMB, + "profile_type", profileType) + } + + return "" +} + // CreateMdev creates an mdev device for the given profile and instance. // It finds an available VF and creates the mdev, returning the device info. // This function is thread-safe and uses a mutex to prevent race conditions @@ -369,32 +521,19 @@ func CreateMdev(ctx context.Context, profileName, instanceID string) (*MdevDevic return nil, err } - // Find an available VF + // Discover all VFs vfs, err := DiscoverVFs() if err != nil { return nil, fmt.Errorf("discover VFs: %w", err) } - var targetVF string - for _, vf := range vfs { - // Skip VFs that already have an mdev - if vf.HasMdev { - continue - } - // Check if this VF can create the profile - availPath := filepath.Join(mdevBusPath, vf.PCIAddress, "mdev_supported_types", profileType, "available_instances") - data, err := os.ReadFile(availPath) - if err != nil { - continue - } - instances, err := strconv.Atoi(strings.TrimSpace(string(data))) - if err != nil || instances < 1 { - continue - } - targetVF = vf.PCIAddress - break + // Ensure profile cache is populated (needed for VRAM calculation) + if len(vfs) > 0 { + _ = getCachedProfiles(vfs[0].PCIAddress) } + // Select VF from the least loaded GPU (by VRAM usage) + targetVF := selectLeastLoadedVF(ctx, vfs, profileType) if targetVF == "" { return nil, fmt.Errorf("no available VF for profile %q", profileName) } From d7e7aaaa4a405a0cf55a9e19b84a72e991078ea8 Mon Sep 17 00:00:00 2001 From: Steven Miller Date: Thu, 22 Jan 2026 14:24:30 -0500 Subject: [PATCH 2/2] Use config --- cmd/api/api/api_test.go | 8 ++++---- cmd/api/config/config.go | 4 ++-- cmd/api/main.go | 5 ++++- lib/devices/mdev.go | 23 ++++++++++++++--------- 4 files changed, 24 insertions(+), 16 deletions(-) diff --git a/cmd/api/api/api_test.go b/cmd/api/api/api_test.go index 8371f99..af71d36 100644 --- a/cmd/api/api/api_test.go +++ b/cmd/api/api/api_test.go @@ -25,10 +25,10 @@ import ( // newTestService creates an ApiService for testing with automatic cleanup func newTestService(t *testing.T) *ApiService { cfg := &config.Config{ - DataDir: t.TempDir(), - BridgeName: "vmbr0", - SubnetCIDR: "10.100.0.0/16", - DNSServer: "1.1.1.1", + DataDir: t.TempDir(), + BridgeName: "vmbr0", + SubnetCIDR: "10.100.0.0/16", + DNSServer: "1.1.1.1", } p := paths.New(cfg.DataDir) diff --git a/cmd/api/config/config.go b/cmd/api/config/config.go index dd923fd..b010135 100644 --- a/cmd/api/config/config.go +++ b/cmd/api/config/config.go @@ -201,8 +201,8 @@ func Load() *Config { CloudflareApiToken: getEnv("CLOUDFLARE_API_TOKEN", ""), // API ingress configuration - ApiHostname: getEnv("API_HOSTNAME", ""), // Empty = disabled - ApiTLS: getEnvBool("API_TLS", true), // Default to TLS enabled + ApiHostname: getEnv("API_HOSTNAME", ""), // Empty = disabled + ApiTLS: getEnvBool("API_TLS", true), // Default to TLS enabled ApiRedirectHTTP: getEnvBool("API_REDIRECT_HTTP", true), // Build system configuration diff --git a/cmd/api/main.go b/cmd/api/main.go index 39587e9..5c835a1 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -18,7 +18,6 @@ import ( "github.com/ghodss/yaml" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" - nethttpmiddleware "github.com/oapi-codegen/nethttp-middleware" "github.com/kernel/hypeman" "github.com/kernel/hypeman/cmd/api/api" "github.com/kernel/hypeman/cmd/api/config" @@ -30,6 +29,7 @@ import ( "github.com/kernel/hypeman/lib/oapi" "github.com/kernel/hypeman/lib/otel" "github.com/kernel/hypeman/lib/vmm" + nethttpmiddleware "github.com/oapi-codegen/nethttp-middleware" "github.com/riandyrn/otelchi" "golang.org/x/sync/errgroup" ) @@ -51,6 +51,9 @@ func run() error { return fmt.Errorf("invalid configuration: %w", err) } + // Configure GPU profile cache TTL + devices.SetGPUProfileCacheTTL(cfg.GPUProfileCacheTTL) + // Initialize OpenTelemetry (before wire initialization) otelCfg := otel.Config{ Enabled: cfg.OtelEnabled, diff --git a/lib/devices/mdev.go b/lib/devices/mdev.go index c5b33fb..de648e0 100644 --- a/lib/devices/mdev.go +++ b/lib/devices/mdev.go @@ -35,22 +35,27 @@ type profileMetadata struct { } // cachedProfiles holds profile metadata with TTL-based expiry. -// The cache TTL is configurable via GPU_PROFILE_CACHE_TTL environment variable. var ( cachedProfiles []profileMetadata cachedProfilesMu sync.RWMutex cachedProfilesTime time.Time + gpuProfileCacheTTL time.Duration = 30 * time.Minute // default ) -// getProfileCacheTTL returns the TTL for profile metadata cache. -// Reads from GPU_PROFILE_CACHE_TTL env var, defaults to 30 minutes. -func getProfileCacheTTL() time.Duration { - if ttl := os.Getenv("GPU_PROFILE_CACHE_TTL"); ttl != "" { - if d, err := time.ParseDuration(ttl); err == nil { - return d - } +// SetGPUProfileCacheTTL sets the TTL for GPU profile metadata cache. +// Should be called during application startup with the config value. +func SetGPUProfileCacheTTL(ttl string) { + if ttl == "" { + return + } + if d, err := time.ParseDuration(ttl); err == nil { + gpuProfileCacheTTL = d } - return 30 * time.Minute +} + +// getProfileCacheTTL returns the configured TTL for profile metadata cache. +func getProfileCacheTTL() time.Duration { + return gpuProfileCacheTTL } // getCachedProfiles returns cached profile metadata, refreshing if TTL has expired.