Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cmd/api/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ import (
// newTestService creates an ApiService for testing with automatic cleanup
func newTestService(t *testing.T) *ApiService {
cfg := &config.Config{
DataDir: t.TempDir(),
BridgeName: "vmbr0",
SubnetCIDR: "10.100.0.0/16",
DNSServer: "1.1.1.1",
DataDir: t.TempDir(),
BridgeName: "vmbr0",
SubnetCIDR: "10.100.0.0/16",
DNSServer: "1.1.1.1",
}

p := paths.New(cfg.DataDir)
Expand Down
10 changes: 8 additions & 2 deletions cmd/api/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ type Config struct {
// Hypervisor configuration
DefaultHypervisor string // Default hypervisor type: "cloud-hypervisor" or "qemu"

// GPU configuration
GPUProfileCacheTTL string // TTL for GPU profile metadata cache (e.g., "30m")

// Oversubscription ratios (1.0 = no oversubscription, 2.0 = 2x oversubscription)
OversubCPU float64 // CPU oversubscription ratio
OversubMemory float64 // Memory oversubscription ratio
Expand Down Expand Up @@ -198,8 +201,8 @@ func Load() *Config {
CloudflareApiToken: getEnv("CLOUDFLARE_API_TOKEN", ""),

// API ingress configuration
ApiHostname: getEnv("API_HOSTNAME", ""), // Empty = disabled
ApiTLS: getEnvBool("API_TLS", true), // Default to TLS enabled
ApiHostname: getEnv("API_HOSTNAME", ""), // Empty = disabled
ApiTLS: getEnvBool("API_TLS", true), // Default to TLS enabled
ApiRedirectHTTP: getEnvBool("API_REDIRECT_HTTP", true),

// Build system configuration
Expand All @@ -212,6 +215,9 @@ func Load() *Config {
// Hypervisor configuration
DefaultHypervisor: getEnv("DEFAULT_HYPERVISOR", "cloud-hypervisor"),

// GPU configuration
GPUProfileCacheTTL: getEnv("GPU_PROFILE_CACHE_TTL", "30m"),

// Oversubscription ratios (1.0 = no oversubscription)
OversubCPU: getEnvFloat("OVERSUB_CPU", 4.0),
OversubMemory: getEnvFloat("OVERSUB_MEMORY", 1.0),
Expand Down
5 changes: 4 additions & 1 deletion cmd/api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"github.com/ghodss/yaml"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
nethttpmiddleware "github.com/oapi-codegen/nethttp-middleware"
"github.com/kernel/hypeman"
"github.com/kernel/hypeman/cmd/api/api"
"github.com/kernel/hypeman/cmd/api/config"
Expand All @@ -30,6 +29,7 @@ import (
"github.com/kernel/hypeman/lib/oapi"
"github.com/kernel/hypeman/lib/otel"
"github.com/kernel/hypeman/lib/vmm"
nethttpmiddleware "github.com/oapi-codegen/nethttp-middleware"
"github.com/riandyrn/otelchi"
"golang.org/x/sync/errgroup"
)
Expand All @@ -51,6 +51,9 @@ func run() error {
return fmt.Errorf("invalid configuration: %w", err)
}

// Configure GPU profile cache TTL
devices.SetGPUProfileCacheTTL(cfg.GPUProfileCacheTTL)

// Initialize OpenTelemetry (before wire initialization)
otelCfg := otel.Config{
Enabled: cfg.OtelEnabled,
Expand Down
204 changes: 174 additions & 30 deletions lib/devices/mdev.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ import (
"os/exec"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"sync"
"time"

"github.com/google/uuid"
"github.com/kernel/hypeman/lib/logger"
Expand All @@ -32,12 +34,57 @@ type profileMetadata struct {
FramebufferMB int
}

// cachedProfiles holds static profile metadata, loaded once on first access
// cachedProfiles holds profile metadata with TTL-based expiry.
var (
cachedProfiles []profileMetadata
cachedProfilesOnce sync.Once
cachedProfilesMu sync.RWMutex
cachedProfilesTime time.Time
gpuProfileCacheTTL time.Duration = 30 * time.Minute // default
)

// SetGPUProfileCacheTTL sets the TTL for GPU profile metadata cache.
// Should be called during application startup with the config value.
func SetGPUProfileCacheTTL(ttl string) {
if ttl == "" {
return
}
if d, err := time.ParseDuration(ttl); err == nil {
gpuProfileCacheTTL = d
}
}

// getProfileCacheTTL returns the configured TTL for profile metadata cache.
func getProfileCacheTTL() time.Duration {
return gpuProfileCacheTTL
}

// getCachedProfiles returns cached profile metadata, refreshing if TTL has expired.
func getCachedProfiles(firstVF string) []profileMetadata {
ttl := getProfileCacheTTL()

// Fast path: check with read lock
cachedProfilesMu.RLock()
if len(cachedProfiles) > 0 && time.Since(cachedProfilesTime) < ttl {
profiles := cachedProfiles
cachedProfilesMu.RUnlock()
return profiles
}
cachedProfilesMu.RUnlock()

// Slow path: refresh cache with write lock
cachedProfilesMu.Lock()
defer cachedProfilesMu.Unlock()

// Double-check after acquiring write lock
if len(cachedProfiles) > 0 && time.Since(cachedProfilesTime) < ttl {
return cachedProfiles
}

cachedProfiles = loadProfileMetadata(firstVF)
cachedProfilesTime = time.Now()
return cachedProfiles
}

// DiscoverVFs returns all SR-IOV Virtual Functions available for vGPU.
// These are discovered by scanning /sys/class/mdev_bus/ which contains
// VFs that can host mdev devices.
Expand Down Expand Up @@ -100,17 +147,15 @@ func ListGPUProfilesWithVFs(vfs []VirtualFunction) ([]GPUProfile, error) {
return nil, nil
}

// Load static profile metadata once (cached indefinitely)
cachedProfilesOnce.Do(func() {
cachedProfiles = loadProfileMetadata(vfs[0].PCIAddress)
})
// Load profile metadata with TTL-based caching
cachedMeta := getCachedProfiles(vfs[0].PCIAddress)

// Count availability for all profiles in parallel
availability := countAvailableVFsForProfilesParallel(vfs, cachedProfiles)
availability := countAvailableVFsForProfilesParallel(vfs, cachedMeta)

// Build result with dynamic availability counts
profiles := make([]GPUProfile, 0, len(cachedProfiles))
for _, meta := range cachedProfiles {
profiles := make([]GPUProfile, 0, len(cachedMeta))
for _, meta := range cachedMeta {
profiles = append(profiles, GPUProfile{
Name: meta.Name,
FramebufferMB: meta.FramebufferMB,
Expand Down Expand Up @@ -194,8 +239,8 @@ func parseFramebufferFromDescription(typeDir string) int {
}

// countAvailableVFsForProfilesParallel counts available instances for all profiles in parallel.
// Optimized: all VFs on the same parent GPU have identical profile support,
// so we only sample one VF per parent instead of reading from every VF.
// Groups VFs by parent GPU, then sums available_instances across all free VFs.
// For SR-IOV vGPU, each VF typically has available_instances of 0 or 1.
func countAvailableVFsForProfilesParallel(vfs []VirtualFunction, profiles []profileMetadata) map[string]int {
if len(vfs) == 0 || len(profiles) == 0 {
return make(map[string]int)
Expand Down Expand Up @@ -352,6 +397,118 @@ func getProfileNameFromType(profileType, vfAddress string) string {
return strings.TrimSpace(string(data))
}

// getProfileFramebufferMB returns the framebuffer size in MB for a profile type.
// Uses cached profile metadata for fast lookup.
func getProfileFramebufferMB(profileType string) int {
cachedProfilesMu.RLock()
defer cachedProfilesMu.RUnlock()

for _, p := range cachedProfiles {
if p.TypeName == profileType {
return p.FramebufferMB
}
}
return 0
}

// calculateGPUVRAMUsage calculates VRAM usage per GPU from active mdevs.
// Returns a map of parentGPU -> usedVRAMMB.
func calculateGPUVRAMUsage(vfs []VirtualFunction, mdevs []MdevDevice) map[string]int {
// Build VF -> parentGPU lookup
vfToParent := make(map[string]string, len(vfs))
for _, vf := range vfs {
vfToParent[vf.PCIAddress] = vf.ParentGPU
}

// Sum framebuffer usage per GPU
usageByGPU := make(map[string]int)
for _, mdev := range mdevs {
parentGPU := vfToParent[mdev.VFAddress]
if parentGPU == "" {
continue
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VFs without parent GPU have VRAM usage ignored

Medium Severity

In calculateGPUVRAMUsage, mdevs on VFs with empty ParentGPU are skipped (if parentGPU == "" { continue }), so their VRAM is never counted. However, in selectLeastLoadedVF, these same VFs ARE included in allGPUs and freeVFsByGPU for selection. This means VFs without a physfn symlink are grouped under an empty-string "GPU" that always appears to have 0 VRAM usage, making them preferentially selected even when they already have active mdevs. This could cause load imbalance.

Additional Locations (1)

Fix in Cursor Fix in Web

usageByGPU[parentGPU] += getProfileFramebufferMB(mdev.ProfileType)
}

return usageByGPU
}

// selectLeastLoadedVF selects a VF from the GPU with the most available VRAM
// that can create the requested profile. Returns empty string if none available.
func selectLeastLoadedVF(ctx context.Context, vfs []VirtualFunction, profileType string) string {
log := logger.FromContext(ctx)

// Get active mdevs to calculate VRAM usage
mdevs, _ := ListMdevDevices()

// Calculate VRAM usage per GPU
vramUsage := calculateGPUVRAMUsage(vfs, mdevs)

// Group free VFs by parent GPU
freeVFsByGPU := make(map[string][]VirtualFunction)
allGPUs := make(map[string]bool)
for _, vf := range vfs {
allGPUs[vf.ParentGPU] = true
if !vf.HasMdev {
freeVFsByGPU[vf.ParentGPU] = append(freeVFsByGPU[vf.ParentGPU], vf)
}
}

// Build list of GPUs sorted by VRAM usage (ascending = least loaded first)
type gpuLoad struct {
gpu string
usedMB int
}
var gpuLoads []gpuLoad
for gpu := range allGPUs {
gpuLoads = append(gpuLoads, gpuLoad{gpu: gpu, usedMB: vramUsage[gpu]})
}
sort.Slice(gpuLoads, func(i, j int) bool {
return gpuLoads[i].usedMB < gpuLoads[j].usedMB
})

log.DebugContext(ctx, "GPU VRAM usage for load balancing",
"gpu_count", len(gpuLoads),
"profile_type", profileType)

// Try each GPU in order of least loaded
for _, gl := range gpuLoads {
freeVFs := freeVFsByGPU[gl.gpu]
if len(freeVFs) == 0 {
log.DebugContext(ctx, "skipping GPU: no free VFs",
"gpu", gl.gpu,
"used_mb", gl.usedMB)
continue
}

// Check if any free VF on this GPU can create the profile
for _, vf := range freeVFs {
availPath := filepath.Join(mdevBusPath, vf.PCIAddress, "mdev_supported_types", profileType, "available_instances")
data, err := os.ReadFile(availPath)
if err != nil {
continue
}
instances, err := strconv.Atoi(strings.TrimSpace(string(data)))
if err != nil || instances < 1 {
continue
}

log.DebugContext(ctx, "selected VF from least loaded GPU",
"vf", vf.PCIAddress,
"gpu", gl.gpu,
"gpu_used_mb", gl.usedMB)
return vf.PCIAddress
}

log.DebugContext(ctx, "skipping GPU: no VF can create profile",
"gpu", gl.gpu,
"used_mb", gl.usedMB,
"profile_type", profileType)
}

return ""
}

// CreateMdev creates an mdev device for the given profile and instance.
// It finds an available VF and creates the mdev, returning the device info.
// This function is thread-safe and uses a mutex to prevent race conditions
Expand All @@ -369,32 +526,19 @@ func CreateMdev(ctx context.Context, profileName, instanceID string) (*MdevDevic
return nil, err
}

// Find an available VF
// Discover all VFs
vfs, err := DiscoverVFs()
if err != nil {
return nil, fmt.Errorf("discover VFs: %w", err)
}

var targetVF string
for _, vf := range vfs {
// Skip VFs that already have an mdev
if vf.HasMdev {
continue
}
// Check if this VF can create the profile
availPath := filepath.Join(mdevBusPath, vf.PCIAddress, "mdev_supported_types", profileType, "available_instances")
data, err := os.ReadFile(availPath)
if err != nil {
continue
}
instances, err := strconv.Atoi(strings.TrimSpace(string(data)))
if err != nil || instances < 1 {
continue
}
targetVF = vf.PCIAddress
break
// Ensure profile cache is populated (needed for VRAM calculation)
if len(vfs) > 0 {
_ = getCachedProfiles(vfs[0].PCIAddress)
}

// Select VF from the least loaded GPU (by VRAM usage)
targetVF := selectLeastLoadedVF(ctx, vfs, profileType)
if targetVF == "" {
return nil, fmt.Errorf("no available VF for profile %q", profileName)
}
Expand Down