Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions internal/common/armadacontext/armada_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"golang.org/x/sync/errgroup"

"github.com/armadaproject/armada/internal/common/ctxkeys"
"github.com/armadaproject/armada/internal/common/logging"
)

Expand Down Expand Up @@ -94,8 +95,8 @@ func FromGrpcCtx(ctx context.Context) *Context {
return armadaCtx
}
logger := logging.StdLogger().
WithField("user", ctx.Value("user")).
WithField("requestId", ctx.Value("requestId"))
WithField("user", ctx.Value(ctxkeys.UserKey)).
WithField("requestId", ctx.Value(ctxkeys.RequestIDKey))
return New(ctx, logger)
}

Expand Down
7 changes: 4 additions & 3 deletions internal/common/auth/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ import (
"github.com/grpc-ecosystem/go-grpc-middleware/v2/metadata"
"golang.org/x/exp/slices"

"github.com/armadaproject/armada/internal/common/ctxkeys"
"github.com/armadaproject/armada/internal/common/util"
)

// Name of the key used to store principals in contexts.
const principalKey = "principal"
// principalKey is the typed context key for storing the authenticated Principal.
const principalKey = ctxkeys.PrincipalKey

// All users are implicitly part of this group.
const EveryoneGroup = "everyone"
Expand Down Expand Up @@ -124,7 +125,7 @@ func CreateGrpcMiddlewareAuthFunction(authService AuthService) func(ctx context.
return nil, err
}
// record username for request logging
ctx = context.WithValue(ctx, "user", principal.GetName())
ctx = context.WithValue(ctx, ctxkeys.UserKey, principal.GetName())
return WithPrincipal(ctx, principal), nil
}
}
Expand Down
16 changes: 16 additions & 0 deletions internal/common/ctxkeys/ctxkeys.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Package ctxkeys provides typed context keys shared across packages to prevent
// collisions with plain string keys.
package ctxkeys

// ContextKey is a typed key for context values, preventing collisions with
// string keys from other packages.
type ContextKey string

const (
// PrincipalKey stores the authenticated Principal in a context.
PrincipalKey ContextKey = "principal"
// UserKey stores the authenticated username in a context for logging.
UserKey ContextKey = "user"
// RequestIDKey stores the request ID in a context for logging.
RequestIDKey ContextKey = "requestId"
)
5 changes: 3 additions & 2 deletions internal/common/grpc/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/armadaproject/armada/internal/common/armadaerrors"
"github.com/armadaproject/armada/internal/common/auth"
"github.com/armadaproject/armada/internal/common/certs"
"github.com/armadaproject/armada/internal/common/ctxkeys"
"github.com/armadaproject/armada/internal/common/grpc/configuration"
log "github.com/armadaproject/armada/internal/common/logging"
"github.com/armadaproject/armada/internal/common/requestid"
Expand Down Expand Up @@ -156,8 +157,8 @@ func panicRecoveryHandler(p interface{}) (err error) {
func InterceptorLogger() grpc_logging.Logger {
return grpc_logging.LoggerFunc(func(ctx context.Context, lvl grpc_logging.Level, msg string, fields ...any) {
logFields := make(map[string]any, len(fields)/2+2)
logFields["user"] = ctx.Value("user")
logFields["requestId"] = ctx.Value("requestId")
logFields["user"] = ctx.Value(ctxkeys.UserKey)
logFields["requestId"] = ctx.Value(ctxkeys.RequestIDKey)
i := grpc_logging.Fields(fields).Iterator()
for i.Next() {
k, v := i.At()
Expand Down
4 changes: 3 additions & 1 deletion internal/common/requestid/interceptors.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"github.com/renstrom/shortuuid"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"

"github.com/armadaproject/armada/internal/common/ctxkeys"
)

// MetadataKey is the HTTP header key using this key we use to store request ids.
Expand Down Expand Up @@ -43,7 +45,7 @@ func FromContextOrMissing(ctx context.Context) string {
// The second return value is true if the operation was successful.
func AddToIncomingContext(ctx context.Context, id string) (context.Context, bool) {
if md, ok := metadata.FromIncomingContext(ctx); ok {
ctx = context.WithValue(ctx, "requestId", id)
ctx = context.WithValue(ctx, ctxkeys.RequestIDKey, id)
md.Set(MetadataKey, id)
return metadata.NewIncomingContext(ctx, md), true
}
Expand Down
1 change: 0 additions & 1 deletion internal/scheduler/internaltypes/node_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ func NewNodeType(taints []v1.Taint, labels map[string]string, indexedTaints map[
// https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
// https://man.archlinux.org/man/community/kubectl/kubectl-taint.1.en
func nodeTypeIdFromTaintsAndLabels(taints []v1.Taint, labels, unsetIndexedLabels map[string]string) uint64 {
// TODO: We should test this function to ensure there are no collisions. And that the string is never empty.
h := fnv1a.Init64
for _, taint := range taints {
h = fnv1a.AddString64(h, taint.Key)
Expand Down
102 changes: 102 additions & 0 deletions internal/scheduler/internaltypes/node_type_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,108 @@ func TestNodeTypeLabels(t *testing.T) {
assert.False(t, ok4)
}

func TestNodeTypeIdFromTaintsAndLabels_NoCollisions(t *testing.T) {
taintKeys := []string{
"node.kubernetes.io/not-ready",
"node.kubernetes.io/unreachable",
"node.kubernetes.io/disk-pressure",
"node.kubernetes.io/memory-pressure",
"node.kubernetes.io/pid-pressure",
"node.kubernetes.io/unschedulable",
"node.cloudprovider.kubernetes.io/shutdown",
}
taintEffects := []v1.TaintEffect{
v1.TaintEffectNoSchedule,
v1.TaintEffectPreferNoSchedule,
v1.TaintEffectNoExecute,
}
labelKeys := []string{
"kubernetes.io/arch",
"kubernetes.io/os",
"topology.kubernetes.io/zone",
"topology.kubernetes.io/region",
"node.kubernetes.io/instance-type",
}
labelValues := []string{"amd64", "arm64", "linux", "us-east-1a", "us-west-2b", "m5.xlarge", "c5.2xlarge"}

type hashInput struct {
taints []v1.Taint
labels map[string]string
unsetIndexedLabels map[string]string
}

seen := make(map[uint64]hashInput)
collisions := 0

// Generate combinations: each taint key x effect as a single-taint node type
for _, key := range taintKeys {
for _, effect := range taintEffects {
taints := []v1.Taint{{Key: key, Value: "true", Effect: effect}}
h := nodeTypeIdFromTaintsAndLabels(taints, nil, nil)
assert.NotEqual(t, uint64(0), h, "hash should not be zero for taints=%v", taints)
input := hashInput{taints: taints}
if prev, exists := seen[h]; exists {
t.Errorf("hash collision: %v and %v both produce %d", prev, input, h)
collisions++
}
seen[h] = input
}
}

// Generate combinations: each label key x value as a single-label node type
for _, key := range labelKeys {
for _, value := range labelValues {
labels := map[string]string{key: value}
h := nodeTypeIdFromTaintsAndLabels(nil, labels, nil)
assert.NotEqual(t, uint64(0), h, "hash should not be zero for labels=%v", labels)
input := hashInput{labels: labels}
if prev, exists := seen[h]; exists {
t.Errorf("hash collision: %v and %v both produce %d", prev, input, h)
collisions++
}
seen[h] = input
}
}

// Generate combinations: unset indexed labels
for _, key := range labelKeys {
unset := map[string]string{key: ""}
h := nodeTypeIdFromTaintsAndLabels(nil, nil, unset)
assert.NotEqual(t, uint64(0), h, "hash should not be zero for unsetLabels=%v", unset)
input := hashInput{unsetIndexedLabels: unset}
if prev, exists := seen[h]; exists {
t.Errorf("hash collision: %v and %v both produce %d", prev, input, h)
collisions++
}
seen[h] = input
}

// Mixed: taint + label combinations
for _, tKey := range taintKeys[:3] {
for _, lKey := range labelKeys[:3] {
for _, lVal := range labelValues[:3] {
taints := []v1.Taint{{Key: tKey, Value: "true", Effect: v1.TaintEffectNoSchedule}}
labels := map[string]string{lKey: lVal}
h := nodeTypeIdFromTaintsAndLabels(taints, labels, nil)
input := hashInput{taints: taints, labels: labels}
if prev, exists := seen[h]; exists {
t.Errorf("hash collision: %v and %v both produce %d", prev, input, h)
collisions++
}
seen[h] = input
}
}
}

t.Logf("tested %d unique inputs with %d collisions", len(seen), collisions)
}

func TestNodeTypeIdFromTaintsAndLabels_NeverEmpty(t *testing.T) {
// Even with empty inputs, the hash should not be zero (FNV offset basis)
h := nodeTypeIdFromTaintsAndLabels(nil, nil, nil)
assert.NotEqual(t, uint64(0), h, "hash of empty inputs should not be zero")
}

func makeSut() *NodeType {
taints := []v1.Taint{
{Key: "taint1", Value: "value1", Effect: v1.TaintEffectNoSchedule},
Expand Down
9 changes: 5 additions & 4 deletions internal/server/submit/submit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/armadaproject/armada/internal/common/armadacontext"
"github.com/armadaproject/armada/internal/common/auth/permission"
"github.com/armadaproject/armada/internal/common/ctxkeys"
commonMocks "github.com/armadaproject/armada/internal/common/mocks"
"github.com/armadaproject/armada/internal/common/util"
"github.com/armadaproject/armada/internal/server/mocks"
Expand Down Expand Up @@ -75,7 +76,7 @@ func TestSubmit_Success(t *testing.T) {
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second)
ctx = armadacontext.WithValue(ctx, "principal", testfixtures.DefaultPrincipal)
ctx = armadacontext.WithValue(ctx, ctxkeys.PrincipalKey, testfixtures.DefaultPrincipal)

server, mockedObjects := createTestServer(t)

Expand Down Expand Up @@ -221,7 +222,7 @@ func TestCancelJobs(t *testing.T) {
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second)
ctx = armadacontext.WithValue(ctx, "principal", testfixtures.DefaultPrincipal)
ctx = armadacontext.WithValue(ctx, ctxkeys.PrincipalKey, testfixtures.DefaultPrincipal)

server, mockedObjects := createTestServer(t)

Expand Down Expand Up @@ -305,7 +306,7 @@ func TestPreemptJobs(t *testing.T) {
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second)
ctx = armadacontext.WithValue(ctx, "principal", testfixtures.DefaultPrincipal)
ctx = armadacontext.WithValue(ctx, ctxkeys.PrincipalKey, testfixtures.DefaultPrincipal)

server, mockedObjects := createTestServer(t)

Expand Down Expand Up @@ -401,7 +402,7 @@ func TestReprioritizeJobs(t *testing.T) {
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second)
ctx = armadacontext.WithValue(ctx, "principal", testfixtures.DefaultPrincipal)
ctx = armadacontext.WithValue(ctx, ctxkeys.PrincipalKey, testfixtures.DefaultPrincipal)

server, mockedObjects := createTestServer(t)

Expand Down