diff --git a/internal/common/slices/slices.go b/internal/common/slices/slices.go index 0eb84145472..a68169e58af 100644 --- a/internal/common/slices/slices.go +++ b/internal/common/slices/slices.go @@ -87,6 +87,24 @@ func Unique[S ~[]E, E comparable](s S) S { return rv } +// UniqueBy returns a copy of s with duplicate elements removed based on the result of keyFunc(e), +// keeping only the first occurrence of each unique key. +func UniqueBy[S ~[]E, E any, K comparable](s S, keyFunc func(E) K) S { + if s == nil { + return nil + } + rv := make(S, 0) + seen := make(map[K]bool) + for _, v := range s { + key := keyFunc(v) + if !seen[key] { + rv = append(rv, v) + seen[key] = true + } + } + return rv +} + // GroupByFunc groups the elements e_1, ..., e_n of s into separate slices by keyFunc(e). func GroupByFunc[S ~[]E, E any, K comparable](s S, keyFunc func(E) K) map[K]S { rv := make(map[K]S) diff --git a/internal/common/slices/slices_test.go b/internal/common/slices/slices_test.go index dd733ad98bf..ef05180c41c 100644 --- a/internal/common/slices/slices_test.go +++ b/internal/common/slices/slices_test.go @@ -322,6 +322,87 @@ func TestUnique(t *testing.T) { } } +func TestUniqueBy(t *testing.T) { + type item struct { + Id string + Name string + } + + tests := map[string]struct { + input []item + keyFunc func(i item) string + expected []item + }{ + "nil": { + input: nil, + keyFunc: func(i item) string { return i.Id }, + expected: nil, + }, + "empty": { + input: []item{}, + keyFunc: func(i item) string { return i.Name }, + expected: []item{}, + }, + "no duplicates": { + input: []item{ + {Id: "1", Name: "Alice"}, + {Id: "2", Name: "Bob"}, + {Id: "3", Name: "Charlie"}, + }, + keyFunc: func(i item) string { return i.Id }, + expected: []item{ + {Id: "1", Name: "Alice"}, + {Id: "2", Name: "Bob"}, + {Id: "3", Name: "Charlie"}, + }, + }, + "consecutive duplicates": { + input: []item{ + {Id: "1", Name: "Alice"}, + {Id: "2", Name: "Bob"}, + {Id: "2", Name: "Bobby"}, + }, + keyFunc: func(i item) string { return i.Id }, + expected: []item{ + {Id: "1", Name: "Alice"}, + {Id: "2", Name: "Bob"}, + }, + }, + "non-consecutive duplicates": { + input: []item{ + {Id: "2", Name: "Bob"}, + {Id: "1", Name: "Alice"}, + {Id: "2", Name: "Bobby"}, + }, + keyFunc: func(i item) string { return i.Id }, + expected: []item{ + {Id: "2", Name: "Bob"}, + {Id: "1", Name: "Alice"}, + }, + }, + "duplicate based on custom key function Name": { + input: []item{ + {Id: "1", Name: "Alice"}, + {Id: "2", Name: "Bob"}, + {Id: "2", Name: "Bobby"}, + }, + keyFunc: func(i item) string { return i.Name }, + expected: []item{ + {Id: "1", Name: "Alice"}, + {Id: "2", Name: "Bob"}, + {Id: "2", Name: "Bobby"}, + }, + }, + } + + // Run all test cases + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + assert.Equal(t, tc.expected, UniqueBy(tc.input, tc.keyFunc)) + }) + } +} + func TestFilter(t *testing.T) { includeOver5 := func(val int) bool { return val > 5 } input := []int{1, 3, 5, 7, 9} diff --git a/internal/scheduler/jobdb/job.go b/internal/scheduler/jobdb/job.go index ebfbe170519..9145b433613 100644 --- a/internal/scheduler/jobdb/job.go +++ b/internal/scheduler/jobdb/job.go @@ -893,6 +893,11 @@ func (job *Job) WithValidated(validated bool) *Job { return j } +// Leased returns true if the job is currently leased +func (job *Job) Leased() bool { + return !job.queued && !job.InTerminalState() && job.LatestRun() != nil +} + // Validated returns true if the job has been validated func (job *Job) Validated() bool { return job.validated diff --git a/internal/scheduler/jobdb/job_test.go b/internal/scheduler/jobdb/job_test.go index 4db9dc11e96..147c3ae862c 100644 --- a/internal/scheduler/jobdb/job_test.go +++ b/internal/scheduler/jobdb/job_test.go @@ -288,6 +288,22 @@ func TestJob_TestWithNewRun(t *testing.T) { ) } +func TestJob_Leased(t *testing.T) { + leasedJob := baseJob.WithQueued(false).WithUpdatedRun(baseRun) + assert.True(t, leasedJob.Leased()) + + queuedJob := leasedJob.WithQueued(true) + assert.False(t, queuedJob.Leased()) + + // Terminal jobs + assert.False(t, leasedJob.WithSucceeded(true).Leased()) + assert.False(t, leasedJob.WithFailed(true).Leased()) + assert.False(t, leasedJob.WithCancelled(true).Leased()) + + jobWithoutRun := baseJob.WithQueued(false) + assert.False(t, jobWithoutRun.Leased()) +} + func TestJob_TestWithUpdatedRun_NewRun(t *testing.T) { jobWithRun := baseJob.WithUpdatedRun(baseRun) assert.Equal(t, true, jobWithRun.HasRuns()) diff --git a/internal/scheduler/jobdb/jobdb.go b/internal/scheduler/jobdb/jobdb.go index 6f8b93148a9..6ed2716391c 100644 --- a/internal/scheduler/jobdb/jobdb.go +++ b/internal/scheduler/jobdb/jobdb.go @@ -69,6 +69,8 @@ type JobDb struct { jobsByGangKey map[gangKey]immutable.Set[string] jobsByQueue map[string]immutable.SortedSet[*Job] jobsByPoolAndQueue map[string]map[string]immutable.SortedSet[*Job] + leasedJobs *immutable.Set[*Job] + terminalJobs *immutable.Set[*Job] unvalidatedJobs *immutable.Set[*Job] // Configured priority classes. priorityClasses map[string]types.PriorityClass @@ -128,12 +130,16 @@ func NewJobDbWithSchedulingKeyGenerator( panic(fmt.Sprintf("unknown default priority class %s", defaultPriorityClassName)) } unvalidatedJobs := immutable.NewSet[*Job](JobHasher{}) + leasedJobs := immutable.NewSet[*Job](JobHasher{}) + terminalJobs := immutable.NewSet[*Job](JobHasher{}) return &JobDb{ jobsById: immutable.NewMap[string, *Job](nil), jobsByRunId: immutable.NewMap[string, string](nil), jobsByGangKey: map[gangKey]immutable.Set[string]{}, jobsByQueue: map[string]immutable.SortedSet[*Job]{}, jobsByPoolAndQueue: map[string]map[string]immutable.SortedSet[*Job]{}, + leasedJobs: &leasedJobs, + terminalJobs: &terminalJobs, unvalidatedJobs: &unvalidatedJobs, priorityClasses: priorityClasses, defaultPriorityClass: defaultPriorityClass, @@ -161,6 +167,8 @@ func (jobDb *JobDb) Clone() *JobDb { jobsByGangKey: maps.Clone(jobDb.jobsByGangKey), jobsByQueue: maps.Clone(jobDb.jobsByQueue), jobsByPoolAndQueue: deepClone(jobDb.jobsByPoolAndQueue), + leasedJobs: jobDb.leasedJobs, + terminalJobs: jobDb.terminalJobs, unvalidatedJobs: jobDb.unvalidatedJobs, priorityClasses: jobDb.priorityClasses, defaultPriorityClass: jobDb.defaultPriorityClass, @@ -318,6 +326,8 @@ func (jobDb *JobDb) ReadTxn() *Txn { jobsByGangKey: jobDb.jobsByGangKey, jobsByQueue: jobDb.jobsByQueue, jobsByPoolAndQueue: jobDb.jobsByPoolAndQueue, + leasedJobs: jobDb.leasedJobs, + terminalJobs: jobDb.terminalJobs, unvalidatedJobs: jobDb.unvalidatedJobs, active: true, jobDb: jobDb, @@ -338,6 +348,8 @@ func (jobDb *JobDb) WriteTxn() *Txn { jobsByGangKey: maps.Clone(jobDb.jobsByGangKey), jobsByQueue: maps.Clone(jobDb.jobsByQueue), jobsByPoolAndQueue: deepClone(jobDb.jobsByPoolAndQueue), + leasedJobs: jobDb.leasedJobs, + terminalJobs: jobDb.terminalJobs, unvalidatedJobs: jobDb.unvalidatedJobs, active: true, jobDb: jobDb, @@ -376,6 +388,10 @@ type Txn struct { // Queued jobs for each queue and pool. // Stored as a set and needs sorting to determine the order they should be scheduled in. jobsByPoolAndQueue map[string]map[string]immutable.SortedSet[*Job] + // Jobs that are currently leased + leasedJobs *immutable.Set[*Job] + // Jobs that are currently in a terminal state + terminalJobs *immutable.Set[*Job] // Jobs that require submit checking unvalidatedJobs *immutable.Set[*Job] // The jobDb from which this transaction was created. @@ -396,6 +412,8 @@ func (txn *Txn) Commit() { txn.jobDb.jobsByGangKey = txn.jobsByGangKey txn.jobDb.jobsByQueue = txn.jobsByQueue txn.jobDb.jobsByPoolAndQueue = txn.jobsByPoolAndQueue + txn.jobDb.leasedJobs = txn.leasedJobs + txn.jobDb.terminalJobs = txn.terminalJobs txn.jobDb.unvalidatedJobs = txn.unvalidatedJobs txn.active = false @@ -526,6 +544,16 @@ func (txn *Txn) Upsert(jobs []*Job) error { txn.jobsByPoolAndQueue[pool][job.queue] = existingJobs.Delete(existingJob) } + if existingJob.Leased() { + newLeasedJobs := txn.leasedJobs.Delete(existingJob) + txn.leasedJobs = &newLeasedJobs + } + + if existingJob.InTerminalState() { + newTerminalJobs := txn.terminalJobs.Delete(existingJob) + txn.terminalJobs = &newTerminalJobs + } + if !existingJob.Validated() { newUnvalidatedJobs := txn.unvalidatedJobs.Delete(existingJob) txn.unvalidatedJobs = &newUnvalidatedJobs @@ -536,7 +564,7 @@ func (txn *Txn) Upsert(jobs []*Job) error { // Now need to insert jobs, runs and queuedJobs. This can be done in parallel. wg := sync.WaitGroup{} - wg.Add(5) + wg.Add(7) // jobs go func() { @@ -675,6 +703,54 @@ func (txn *Txn) Upsert(jobs []*Job) error { } }() + // Leased jobs + go func() { + defer wg.Done() + if hasJobs { + for _, job := range jobs { + if job.Leased() { + leasedJobs := txn.leasedJobs.Add(job) + txn.leasedJobs = &leasedJobs + } + } + } else { + leasedJobs := map[*Job]bool{} + + for _, job := range jobs { + if job.Leased() { + leasedJobs[job] = true + } + } + + leasedJobsImmutable := immutable.NewSet[*Job](JobHasher{}, maps.Keys(leasedJobs)...) + txn.leasedJobs = &leasedJobsImmutable + } + }() + + // Terminal jobs + go func() { + defer wg.Done() + if hasJobs { + for _, job := range jobs { + if job.InTerminalState() { + terminalJobs := txn.terminalJobs.Add(job) + txn.terminalJobs = &terminalJobs + } + } + } else { + terminalJobs := map[*Job]bool{} + + for _, job := range jobs { + if job.InTerminalState() { + terminalJobs[job] = true + } + } + + terminalJobsImmutable := immutable.NewSet[*Job](JobHasher{}, maps.Keys(terminalJobs)...) + txn.terminalJobs = &terminalJobsImmutable + } + }() + // Unvalidated jobs go func() { defer wg.Done() @@ -811,6 +887,16 @@ func (txn *Txn) UnvalidatedJobs() *immutable.SetIterator[*Job] { return txn.unvalidatedJobs.Iterator() } +// GetAllLeasedJobs returns all leased jobs in the database +func (txn *Txn) GetAllLeasedJobs() []*Job { + return txn.leasedJobs.Items() +} + +// GetAllTerminalJobs returns all terminal jobs in the database +func (txn *Txn) GetAllTerminalJobs() []*Job { + return txn.terminalJobs.Items() +} + // GetAll returns all jobs in the database. func (txn *Txn) GetAll() []*Job { allJobs := make([]*Job, 0, txn.jobsById.Len()) @@ -822,6 +908,15 @@ func (txn *Txn) GetAll() []*Job { return allJobs } +// GetQueuedJobsByPool returns all queued jobs against a given pool +func (txn *Txn) GetQueuedJobsByPool(pool string) []*Job { + allJobs := make([]*Job, 0) + for _, jobs := range txn.jobsByPoolAndQueue[pool] { + allJobs = append(allJobs, jobs.Items()...) + } + return allJobs +} + // BatchDelete deletes the jobs with the given ids from the database. // Any ids not in the database are ignored. func (txn *Txn) BatchDelete(jobIds []string) error { @@ -871,6 +966,12 @@ func (txn *Txn) delete(jobId string) { } } } + newLeasedJobs := txn.leasedJobs.Delete(job) + txn.leasedJobs = &newLeasedJobs + + newTerminalJobs := txn.terminalJobs.Delete(job) + txn.terminalJobs = &newTerminalJobs + newUnvalidatedJobs := txn.unvalidatedJobs.Delete(job) txn.unvalidatedJobs = &newUnvalidatedJobs } diff --git a/internal/scheduler/jobdb/jobdb_test.go b/internal/scheduler/jobdb/jobdb_test.go index 971095e0374..9add27bd512 100644 --- a/internal/scheduler/jobdb/jobdb_test.go +++ b/internal/scheduler/jobdb/jobdb_test.go @@ -74,6 +74,138 @@ func TestJobDb_TestGetById(t *testing.T) { assert.Nil(t, txn.GetById(util.NewULID())) } +func TestJobDb_TestGetLeased(t *testing.T) { + jobDb := NewTestJobDb() + job1 := newJob().WithQueued(false).WithNewRun("executor", "nodeId", "nodeName", "pool", 5) + job2 := newJob().WithQueued(true) + job3 := newJob().WithQueued(false).WithSucceeded(true) + job4 := newJob().WithQueued(false).WithNewRun("executor", "nodeId", "nodeName", "pool", 5) + txn := jobDb.WriteTxn() + + err := txn.Upsert([]*Job{job1, job2, job3, job4}) + require.NoError(t, err) + + expected := []*Job{job1, job4} + actual := txn.GetAllLeasedJobs() + sort.SliceStable(actual, func(i, j int) bool { return actual[i].id < actual[j].id }) + sort.SliceStable(expected, func(i, j int) bool { return expected[i].id < expected[j].id }) + assert.Equal(t, expected, actual) +} + +func TestJobDb_LeasedJobs_Lifecycle(t *testing.T) { + jobDb := NewTestJobDb() + + upsert := func(jobDb *JobDb, job *Job) { + txn := jobDb.WriteTxn() + err := txn.Upsert([]*Job{job}) + require.NoError(t, err) + txn.Commit() + } + + job1 := newJob().WithQueued(true) + upsert(jobDb, job1) + assert.Empty(t, jobDb.ReadTxn().GetAllLeasedJobs()) + + // leased + job1 = job1.WithQueued(false).WithNewRun("executor", "nodeId", "nodeName", "pool", 5) + upsert(jobDb, job1) + assert.NotEmpty(t, jobDb.ReadTxn().GetAllLeasedJobs()) + + // requeued + job1 = job1.WithQueued(true) + upsert(jobDb, job1) + assert.Empty(t, jobDb.ReadTxn().GetAllLeasedJobs()) + + // leased + job1 = job1.WithQueued(false).WithNewRun("executor", "nodeId", "nodeName", "pool", 5) + upsert(jobDb, job1) + assert.NotEmpty(t, jobDb.ReadTxn().GetAllLeasedJobs()) + + // finished + job1 = job1.WithSucceeded(true) + upsert(jobDb, job1) + assert.Empty(t, jobDb.ReadTxn().GetAllLeasedJobs()) +} + +func TestJobDb_LeasedJobs_Deleted(t *testing.T) { + jobDb := NewTestJobDb() + job1 := newJob().WithQueued(false).WithNewRun("executor", "nodeId", "nodeName", "pool", 5) + txn := jobDb.WriteTxn() + + err := txn.Upsert([]*Job{job1}) + require.NoError(t, err) + + expected := []*Job{job1} + actual := txn.GetAllLeasedJobs() + assert.Equal(t, expected, actual) + + err = txn.BatchDelete([]string{job1.Id()}) + require.NoError(t, err) + assert.Empty(t, txn.GetAllLeasedJobs()) +} + +func TestJobDb_TestGetTerminalJobs(t *testing.T) { + jobDb := NewTestJobDb() + job1 := newJob().WithQueued(false).WithNewRun("executor", "nodeId", "nodeName", "pool", 5) + job2 := newJob().WithQueued(true) + job3 := newJob().WithQueued(false).WithSucceeded(true) + job4 := newJob().WithQueued(false).WithCancelled(true) + job5 := newJob().WithQueued(false).WithFailed(true) + job6 := newJob().WithQueued(true).WithFailed(true) + txn := jobDb.WriteTxn() + + err := txn.Upsert([]*Job{job1, job2, job3, job4, job5, job6}) + require.NoError(t, err) + + expected := []*Job{job3, job4, job5, job6} + actual := txn.GetAllTerminalJobs() + sort.SliceStable(actual, func(i, j int) bool { return actual[i].id < actual[j].id }) + sort.SliceStable(expected, func(i, j int) bool { return expected[i].id < expected[j].id }) + assert.Equal(t, expected, actual) +} + +func TestJobDb_TerminalJobs_Lifecycle(t *testing.T) { + jobDb := NewTestJobDb() + + upsert := func(jobDb *JobDb, job *Job) { + txn := jobDb.WriteTxn() + err := txn.Upsert([]*Job{job}) + require.NoError(t, err) + txn.Commit() + } + + job1 := newJob().WithQueued(true) + upsert(jobDb, job1) + assert.Empty(t, jobDb.ReadTxn().GetAllTerminalJobs()) + + // leased + job1 = job1.WithQueued(false).WithNewRun("executor", "nodeId", "nodeName", "pool", 5) + upsert(jobDb, job1) + assert.Empty(t, jobDb.ReadTxn().GetAllTerminalJobs()) + + // finished + job1 = job1.WithSucceeded(true) + upsert(jobDb, job1) + assert.NotEmpty(t, jobDb.ReadTxn().GetAllTerminalJobs()) +} + +func TestJobDb_TerminalJobs_Deleted(t *testing.T) { + jobDb := NewTestJobDb() + job1 := newJob().WithFailed(true) + txn := jobDb.WriteTxn() + + err := txn.Upsert([]*Job{job1}) + require.NoError(t, err) + + expected := []*Job{job1} + actual := txn.GetAllTerminalJobs() + assert.Equal(t, expected, actual) + + err = txn.BatchDelete([]string{job1.Id()}) + require.NoError(t, err) + assert.Empty(t, txn.GetAllTerminalJobs()) +} + func TestJobDb_TestGetUnvalidated(t *testing.T) { jobDb := NewTestJobDb() job1 := newJob().WithValidated(false) @@ -163,6 +295,28 @@ func TestJobDb_TestGetByRunId(t *testing.T) { assert.Nil(t, txn.GetByRunId(job1.LatestRun().id)) } +func TestJobDb_TestGetQueuedJobsByPool(t *testing.T) { + jobDb := NewTestJobDb() + job1 := newJob().WithQueued(true).WithPools([]string{"pool-1", "pool-2", "pool-3"}) + job2 := newJob().WithQueued(true).WithPools([]string{"pool-1", "pool-2"}) + job3 := newJob().WithQueued(true).WithPools([]string{"pool-1"}) + txn := jobDb.WriteTxn() + + err := txn.Upsert([]*Job{job1, job2, job3}) + require.NoError(t, err) + + assertEqual := func(expected []*Job, actual []*Job) { + sort.SliceStable(actual, func(i, j int) bool { return actual[i].id < actual[j].id }) + sort.SliceStable(expected, func(i, j int) bool { return expected[i].id < expected[j].id }) + assert.Equal(t, expected, actual) + } + + assertEqual([]*Job{job1, job2, job3}, txn.GetQueuedJobsByPool("pool-1")) + assertEqual([]*Job{job1, job2}, txn.GetQueuedJobsByPool("pool-2")) + assertEqual([]*Job{job1}, txn.GetQueuedJobsByPool("pool-3")) + assertEqual([]*Job{}, txn.GetQueuedJobsByPool("pool-4")) +} + func TestJobDb_TestHasQueuedJobs(t *testing.T) { jobDb := NewTestJobDb() job1 := newJob().WithNewRun("executor", "nodeId", "nodeName", "pool", 5) diff --git a/internal/scheduler/scheduling/scheduling_algo.go b/internal/scheduler/scheduling/scheduling_algo.go index 3ca82c858a1..b19dacf5872 100644 --- a/internal/scheduler/scheduling/scheduling_algo.go +++ b/internal/scheduler/scheduling/scheduling_algo.go @@ -370,13 +370,29 @@ func (l *FairSchedulingAlgo) newFairSchedulingAlgoContext(ctx *armadacontext.Con allPools := []string{currentPool.Name} allPools = append(allPools, currentPool.AwayPools...) allPools = append(allPools, awayAllocationPools...) + allPools = armadaslices.Unique(allPools) + + // We must include jobs in the following states: + // - Jobs active on the nodes of this pool + // - These are used to populate the jobdb, calculate demand/fairshare + // - This may include nodes from other pools, especially if the nodes pool has changed + // - Terminal jobs of this pool + // - For calculating short job penalty + // - Jobs queued against home/away pools relevant to the pool being computed + // - This is to calculate demand on both home and away pools + allJobs := txn.GetAllLeasedJobs() + allJobs = append(allJobs, txn.GetAllTerminalJobs()...) + for _, pool := range allPools { + allJobs = append(allJobs, txn.GetQueuedJobsByPool(pool)...) + } + allJobs = armadaslices.UniqueBy(allJobs, func(job *jobdb.Job) string { return job.Id() }) jobSchedulingInfo, err := l.calculateJobSchedulingInfo(ctx, armadamaps.FromSlice(executors, func(ex *schedulerobjects.Executor) string { return ex.Id }, func(_ *schedulerobjects.Executor) bool { return true }), queueByName, - txn.GetAll(), + allJobs, currentPool.Name, awayAllocationPools, allPools)