From 486232bc3a72abe1bab56fbe2c0f6c7fd29d42a3 Mon Sep 17 00:00:00 2001 From: Brandon Shearin Date: Thu, 19 Feb 2026 10:32:14 -0800 Subject: [PATCH 1/7] GetEnvironmentsByExtensionId added to db interface --- cmd/api/src/database/graphschema.go | 1 + cmd/api/src/database/mocks/db.go | 35 +++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/cmd/api/src/database/graphschema.go b/cmd/api/src/database/graphschema.go index 6a3e69d0b92..5e90a002e24 100644 --- a/cmd/api/src/database/graphschema.go +++ b/cmd/api/src/database/graphschema.go @@ -57,6 +57,7 @@ type OpenGraphSchema interface { GetEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) GetEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) GetEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) + GetEnvironmentsByExtensionId(ctx context.Context, extensionId int32) ([]model.SchemaEnvironment, error) DeleteEnvironment(ctx context.Context, environmentId int32) error CreateSchemaRelationshipFinding(ctx context.Context, extensionId int32, relationshipKindId int32, environmentId int32, name string, displayName string) (model.SchemaRelationshipFinding, error) diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index c1852742624..a7a9b8ffa04 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -1762,6 +1762,21 @@ func (mr *MockDatabaseMockRecorder) GetEnvironments(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnvironments", reflect.TypeOf((*MockDatabase)(nil).GetEnvironments), ctx) } +// GetEnvironmentsByExtensionId mocks base method. +func (m *MockDatabase) GetEnvironmentsByExtensionId(ctx context.Context, extensionId int32) ([]model.SchemaEnvironment, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetEnvironmentsByExtensionId", ctx, extensionId) + ret0, _ := ret[0].([]model.SchemaEnvironment) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetEnvironmentsByExtensionId indicates an expected call of GetEnvironmentsByExtensionId. +func (mr *MockDatabaseMockRecorder) GetEnvironmentsByExtensionId(ctx, extensionId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnvironmentsByExtensionId", reflect.TypeOf((*MockDatabase)(nil).GetEnvironmentsByExtensionId), ctx, extensionId) +} + // GetFlag mocks base method. func (m *MockDatabase) GetFlag(ctx context.Context, id int32) (appcfg.FeatureFlag, error) { m.ctrl.T.Helper() @@ -2423,6 +2438,26 @@ func (mr *MockDatabaseMockRecorder) GetSourceKinds(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSourceKinds", reflect.TypeOf((*MockDatabase)(nil).GetSourceKinds), ctx) } +// GetSourceKindsByIds mocks base method. +func (m *MockDatabase) GetSourceKindsByIds(ctx context.Context, ids ...int32) ([]database.SourceKind, error) { + m.ctrl.T.Helper() + varargs := []any{ctx} + for _, a := range ids { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetSourceKindsByIds", varargs...) + ret0, _ := ret[0].([]database.SourceKind) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSourceKindsByIds indicates an expected call of GetSourceKindsByIds. +func (mr *MockDatabaseMockRecorder) GetSourceKindsByIds(ctx any, ids ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx}, ids...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSourceKindsByIds", reflect.TypeOf((*MockDatabase)(nil).GetSourceKindsByIds), varargs...) +} + // GetTimeRangedAssetGroupCollections mocks base method. func (m *MockDatabase) GetTimeRangedAssetGroupCollections(ctx context.Context, assetGroupID int32, from, to int64, order string) (model.AssetGroupCollections, error) { m.ctrl.T.Helper() From abacc5061e4bdb670347d21c91a58b6e70229a07 Mon Sep 17 00:00:00 2001 From: Brandon Shearin Date: Thu, 19 Feb 2026 10:52:00 -0800 Subject: [PATCH 2/7] GetSourceKindsByIDs variadic method --- cmd/api/src/database/mocks/db.go | 12 ++-- cmd/api/src/database/sourcekinds.go | 42 ++++++++++++++ .../database/sourcekinds_integration_test.go | 56 +++++++++++++++++++ 3 files changed, 104 insertions(+), 6 deletions(-) diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index a7a9b8ffa04..c74060d1253 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -2438,24 +2438,24 @@ func (mr *MockDatabaseMockRecorder) GetSourceKinds(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSourceKinds", reflect.TypeOf((*MockDatabase)(nil).GetSourceKinds), ctx) } -// GetSourceKindsByIds mocks base method. -func (m *MockDatabase) GetSourceKindsByIds(ctx context.Context, ids ...int32) ([]database.SourceKind, error) { +// GetSourceKindsByIDs mocks base method. +func (m *MockDatabase) GetSourceKindsByIDs(ctx context.Context, ids ...int32) ([]database.SourceKind, error) { m.ctrl.T.Helper() varargs := []any{ctx} for _, a := range ids { varargs = append(varargs, a) } - ret := m.ctrl.Call(m, "GetSourceKindsByIds", varargs...) + ret := m.ctrl.Call(m, "GetSourceKindsByIDs", varargs...) ret0, _ := ret[0].([]database.SourceKind) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetSourceKindsByIds indicates an expected call of GetSourceKindsByIds. -func (mr *MockDatabaseMockRecorder) GetSourceKindsByIds(ctx any, ids ...any) *gomock.Call { +// GetSourceKindsByIDs indicates an expected call of GetSourceKindsByIDs. +func (mr *MockDatabaseMockRecorder) GetSourceKindsByIDs(ctx any, ids ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{ctx}, ids...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSourceKindsByIds", reflect.TypeOf((*MockDatabase)(nil).GetSourceKindsByIds), varargs...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSourceKindsByIDs", reflect.TypeOf((*MockDatabase)(nil).GetSourceKindsByIDs), varargs...) } // GetTimeRangedAssetGroupCollections mocks base method. diff --git a/cmd/api/src/database/sourcekinds.go b/cmd/api/src/database/sourcekinds.go index 5a892fbb3e6..440c7369840 100644 --- a/cmd/api/src/database/sourcekinds.go +++ b/cmd/api/src/database/sourcekinds.go @@ -29,6 +29,7 @@ type SourceKindsData interface { DeactivateSourceKindsByName(ctx context.Context, kinds graph.Kinds) error RegisterSourceKind(ctx context.Context) func(sourceKind graph.Kind) error GetSourceKindByName(ctx context.Context, name string) (SourceKind, error) + GetSourceKindsByIDs(ctx context.Context, ids ...int32) ([]SourceKind, error) } // RegisterSourceKind returns a function that inserts a source kind by name, @@ -168,6 +169,47 @@ func (s *BloodhoundDB) GetSourceKindByID(ctx context.Context, id int) (SourceKin return kind, nil } +func (s *BloodhoundDB) GetSourceKindsByIDs(ctx context.Context, ids ...int32) ([]SourceKind, error) { + if len(ids) == 0 { + return []SourceKind{}, nil + } + + const query = ` + SELECT sk.id, k.name, sk.active + FROM source_kinds sk + JOIN kind k ON k.id = sk.kind_id + WHERE sk.id = ANY(?) + ORDER BY sk.id; + ` + + type rawSourceKind struct { + ID int + Name string + Active bool + } + + var rawKinds []rawSourceKind + result := s.db.WithContext(ctx).Raw(query, pq.Array(ids)).Scan(&rawKinds) + if err := result.Error; err != nil { + return nil, fmt.Errorf("failed to fetch source kinds by IDs: %w", err) + } + + if len(rawKinds) != len(ids) { + return nil, ErrNotFound + } + + sourceKinds := make([]SourceKind, len(rawKinds)) + for i, raw := range rawKinds { + sourceKinds[i] = SourceKind{ + ID: raw.ID, + Name: graph.StringKind(raw.Name), + Active: raw.Active, + } + } + + return sourceKinds, nil +} + func (s *BloodhoundDB) DeactivateSourceKindsByName(ctx context.Context, kinds graph.Kinds) error { if len(kinds) == 0 { return nil diff --git a/cmd/api/src/database/sourcekinds_integration_test.go b/cmd/api/src/database/sourcekinds_integration_test.go index cb768dbb2c7..ff1906a0442 100644 --- a/cmd/api/src/database/sourcekinds_integration_test.go +++ b/cmd/api/src/database/sourcekinds_integration_test.go @@ -495,3 +495,59 @@ func TestDeactivateSourceKindsByName(t *testing.T) { }) } } + +func TestBloodhoundDB_GetSourceKindByIDs(t *testing.T) { + var ( + testSuite = setupIntegrationTestSuite(t) + ) + defer teardownIntegrationTestSuite(t, &testSuite) + + tests := []struct { + name string + setup func(t *testing.T) []int32 + wantErr error + }{ + { + name: "empty input", + setup: func(t *testing.T) []int32 { + return []int32{} + }, + wantErr: nil, + }, + { + name: "fail - unknown source kind", + setup: func(t *testing.T) []int32 { + return []int32{ + 123, + } + }, + wantErr: database.ErrNotFound, + }, + { + name: "success", + setup: func(t *testing.T) []int32 { + return []int32{ + 1, 2, + } + }, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sourceKindIDs := tt.setup(t) + + if sourceKinds, err := testSuite.BHDatabase.GetSourceKindsByIDs(testSuite.Context, sourceKindIDs...); tt.wantErr != nil { + require.EqualErrorf(t, err, tt.wantErr.Error(), "error not equal") + } else { + require.NoError(t, err) + + actualIDs := make([]int32, len(sourceKinds)) + for i, sourceKind := range sourceKinds { + actualIDs[i] = int32(sourceKind.ID) + } + assert.Equal(t, sourceKindIDs, actualIDs) + } + }) + } +} From 7c18d623560af0cc3a2ed8a2aaa10f454809ed20 Mon Sep 17 00:00:00 2001 From: Brandon Shearin Date: Thu, 19 Feb 2026 10:56:17 -0800 Subject: [PATCH 3/7] GetKindsByIDs variadic func --- cmd/api/src/database/kind.go | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/cmd/api/src/database/kind.go b/cmd/api/src/database/kind.go index 17aaa2249c0..94e8fbcf9a6 100644 --- a/cmd/api/src/database/kind.go +++ b/cmd/api/src/database/kind.go @@ -17,13 +17,16 @@ package database import ( "context" + "fmt" + "github.com/lib/pq" "github.com/specterops/bloodhound/cmd/api/src/model" ) type Kind interface { GetKindByName(ctx context.Context, name string) (model.Kind, error) GetKindById(ctx context.Context, id int32) (model.Kind, error) + GetKindsByIDs(ctx context.Context, ids ...int32) ([]model.Kind, error) } func (s *BloodhoundDB) GetKindByName(ctx context.Context, name string) (model.Kind, error) { @@ -67,3 +70,29 @@ func (s *BloodhoundDB) GetKindById(ctx context.Context, id int32) (model.Kind, e return kind, nil } + +func (s *BloodhoundDB) GetKindsByIDs(ctx context.Context, ids ...int32) ([]model.Kind, error) { + if len(ids) == 0 { + return []model.Kind{}, nil + } + + const query = ` + SELECT id, name + FROM kind + WHERE id = ANY(?) + ORDER BY id; + ` + + var kinds []model.Kind + result := s.db.WithContext(ctx).Raw(query, pq.Array(ids)).Scan(&kinds) + + if err := result.Error; err != nil { + return nil, fmt.Errorf("failed to fetch kinds by IDs: %w", err) + } + + if len(kinds) != len(ids) { + return nil, ErrNotFound + } + + return kinds, nil +} From 5ca19d431b39d3dfe7f6a9cf0749f31f20dd4e8a Mon Sep 17 00:00:00 2001 From: Brandon Shearin Date: Thu, 19 Feb 2026 11:02:12 -0800 Subject: [PATCH 4/7] GetTraversableRelationshipKindsByExtensionID --- cmd/api/src/database/graphschema.go | 19 ++++++++++++++++ cmd/api/src/database/mocks/db.go | 35 +++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/cmd/api/src/database/graphschema.go b/cmd/api/src/database/graphschema.go index 5e90a002e24..527860f0dbf 100644 --- a/cmd/api/src/database/graphschema.go +++ b/cmd/api/src/database/graphschema.go @@ -51,6 +51,7 @@ type OpenGraphSchema interface { UpdateGraphSchemaRelationshipKind(ctx context.Context, schemaRelationshipKind model.GraphSchemaRelationshipKind) (model.GraphSchemaRelationshipKind, error) DeleteGraphSchemaRelationshipKind(ctx context.Context, schemaRelationshipKindId int32) error + GetTraversableRelationshipKindsByExtensionID(ctx context.Context, extensionID int32) (model.GraphSchemaRelationshipKinds, error) GetGraphSchemaRelationshipKindsWithSchemaName(ctx context.Context, filters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaRelationshipKindsWithNamedSchema, int, error) CreateEnvironment(ctx context.Context, extensionId int32, environmentKindId int32, sourceKindId int32) (model.SchemaEnvironment, error) @@ -560,6 +561,24 @@ func (s *BloodhoundDB) GetGraphSchemaRelationshipKinds(ctx context.Context, rela } } +// GetTraversableRelationshipKindsByExtensionID returns all traversable relationship kinds for a given schema extension. +// This is a purpose-built query for the analysis pipeline that needs traversable edges for graph traversal. +func (s *BloodhoundDB) GetTraversableRelationshipKindsByExtensionID(ctx context.Context, extensionID int32) (model.GraphSchemaRelationshipKinds, error) { + const query = ` + SELECT rk.id, k.name, rk.schema_extension_id, rk.description, rk.is_traversable, + rk.created_at, rk.updated_at, rk.deleted_at + FROM schema_relationship_kinds rk + JOIN kind k ON rk.kind_id = k.id + WHERE rk.schema_extension_id = $1 AND rk.is_traversable = true + ` + + var kinds model.GraphSchemaRelationshipKinds + if result := s.db.WithContext(ctx).Raw(query, extensionID).Scan(&kinds); result.Error != nil { + return nil, CheckError(result) + } + return kinds, nil +} + func (s *BloodhoundDB) GetGraphSchemaRelationshipKindsWithSchemaName(ctx context.Context, relationshipKindFilters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaRelationshipKindsWithNamedSchema, int, error) { var ( schemaRelationshipKinds = model.GraphSchemaRelationshipKindsWithNamedSchema{} diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index c74060d1253..63205d12506 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -2037,6 +2037,26 @@ func (mr *MockDatabaseMockRecorder) GetKindByName(ctx, name any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKindByName", reflect.TypeOf((*MockDatabase)(nil).GetKindByName), ctx, name) } +// GetKindsByIDs mocks base method. +func (m *MockDatabase) GetKindsByIDs(ctx context.Context, ids ...int32) ([]model.Kind, error) { + m.ctrl.T.Helper() + varargs := []any{ctx} + for _, a := range ids { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetKindsByIDs", varargs...) + ret0, _ := ret[0].([]model.Kind) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetKindsByIDs indicates an expected call of GetKindsByIDs. +func (mr *MockDatabaseMockRecorder) GetKindsByIDs(ctx any, ids ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx}, ids...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKindsByIDs", reflect.TypeOf((*MockDatabase)(nil).GetKindsByIDs), varargs...) +} + // GetLatestAssetGroupCollection mocks base method. func (m *MockDatabase) GetLatestAssetGroupCollection(ctx context.Context, assetGroupID int32) (model.AssetGroupCollection, error) { m.ctrl.T.Helper() @@ -2473,6 +2493,21 @@ func (mr *MockDatabaseMockRecorder) GetTimeRangedAssetGroupCollections(ctx, asse return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTimeRangedAssetGroupCollections", reflect.TypeOf((*MockDatabase)(nil).GetTimeRangedAssetGroupCollections), ctx, assetGroupID, from, to, order) } +// GetTraversableRelationshipKindsByExtensionID mocks base method. +func (m *MockDatabase) GetTraversableRelationshipKindsByExtensionID(ctx context.Context, extensionID int32) (model.GraphSchemaRelationshipKinds, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTraversableRelationshipKindsByExtensionID", ctx, extensionID) + ret0, _ := ret[0].(model.GraphSchemaRelationshipKinds) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTraversableRelationshipKindsByExtensionID indicates an expected call of GetTraversableRelationshipKindsByExtensionID. +func (mr *MockDatabaseMockRecorder) GetTraversableRelationshipKindsByExtensionID(ctx, extensionID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTraversableRelationshipKindsByExtensionID", reflect.TypeOf((*MockDatabase)(nil).GetTraversableRelationshipKindsByExtensionID), ctx, extensionID) +} + // GetUser mocks base method. func (m *MockDatabase) GetUser(ctx context.Context, id uuid.UUID) (model.User, error) { m.ctrl.T.Helper() From 5ccd76f692f3319c1567b775eac2c4ae00d5fca1 Mon Sep 17 00:00:00 2001 From: Brandon Shearin Date: Thu, 19 Feb 2026 11:42:47 -0800 Subject: [PATCH 5/7] Removed GetKindById from the Kind interface --- cmd/api/src/database/kind.go | 22 ----------- cmd/api/src/database/kind_integration_test.go | 18 ++++----- cmd/api/src/database/mocks/db.go | 15 ------- ...psert_schema_extension_integration_test.go | 39 ++++++++++--------- 4 files changed, 30 insertions(+), 64 deletions(-) diff --git a/cmd/api/src/database/kind.go b/cmd/api/src/database/kind.go index 94e8fbcf9a6..a537860f05a 100644 --- a/cmd/api/src/database/kind.go +++ b/cmd/api/src/database/kind.go @@ -25,7 +25,6 @@ import ( type Kind interface { GetKindByName(ctx context.Context, name string) (model.Kind, error) - GetKindById(ctx context.Context, id int32) (model.Kind, error) GetKindsByIDs(ctx context.Context, ids ...int32) ([]model.Kind, error) } @@ -50,27 +49,6 @@ func (s *BloodhoundDB) GetKindByName(ctx context.Context, name string) (model.Ki return kind, nil } -func (s *BloodhoundDB) GetKindById(ctx context.Context, id int32) (model.Kind, error) { - const query = ` - SELECT id, name - FROM kind - WHERE id = $1; - ` - - var kind model.Kind - result := s.db.WithContext(ctx).Raw(query, id).Scan(&kind) - - if result.Error != nil { - return model.Kind{}, result.Error - } - - if result.RowsAffected == 0 || kind.ID == 0 { - return model.Kind{}, ErrNotFound - } - - return kind, nil -} - func (s *BloodhoundDB) GetKindsByIDs(ctx context.Context, ids ...int32) ([]model.Kind, error) { if len(ids) == 0 { return []model.Kind{}, nil diff --git a/cmd/api/src/database/kind_integration_test.go b/cmd/api/src/database/kind_integration_test.go index dea63dada97..b9bd6c1f25e 100644 --- a/cmd/api/src/database/kind_integration_test.go +++ b/cmd/api/src/database/kind_integration_test.go @@ -76,7 +76,7 @@ func TestGetKindByName(t *testing.T) { } } -func TestGetKindByID(t *testing.T) { +func TestGetKindsByIDs(t *testing.T) { testSuite := setupIntegrationTestSuite(t) defer teardownIntegrationTestSuite(t, &testSuite) @@ -110,15 +110,14 @@ func TestGetKindByID(t *testing.T) { var kind model.Kind result := testSuite.DB.WithContext(testSuite.Context).Raw(` INSERT INTO kind (name) - VALUES ('Test_Get_Kind_By_Id') + VALUES ('Test_Get_Kinds_By_IDs') RETURNING id, name;`).Scan(&kind) require.NoError(t, result.Error) return kind }, want: want{ kind: model.Kind{ - // Don't know what the ID will be - Name: "Test_Get_Kind_By_Id", + Name: "Test_Get_Kinds_By_IDs", }, }, }, @@ -129,15 +128,16 @@ func TestGetKindByID(t *testing.T) { var ( err error createdKind model.Kind - got model.Kind ) createdKind = tt.setup(t) - if got, err = testSuite.BHDatabase.GetKindById(testSuite.Context, createdKind.ID); tt.want.err != nil { + if kinds, getErr := testSuite.BHDatabase.GetKindsByIDs(testSuite.Context, createdKind.ID); tt.want.err != nil { + err = getErr assert.EqualError(t, err, tt.want.err.Error()) } else { - assert.NoError(t, err) - assert.Equal(t, tt.want.kind.Name, got.Name) - assert.Greater(t, got.ID, int32(0)) + assert.NoError(t, getErr) + assert.Len(t, kinds, 1) + assert.Equal(t, tt.want.kind.Name, kinds[0].Name) + assert.Greater(t, kinds[0].ID, int32(0)) } }) } diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index 63205d12506..45198617d92 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -2007,21 +2007,6 @@ func (mr *MockDatabaseMockRecorder) GetInstallation(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstallation", reflect.TypeOf((*MockDatabase)(nil).GetInstallation), ctx) } -// GetKindById mocks base method. -func (m *MockDatabase) GetKindById(ctx context.Context, id int32) (model.Kind, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetKindById", ctx, id) - ret0, _ := ret[0].(model.Kind) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetKindById indicates an expected call of GetKindById. -func (mr *MockDatabaseMockRecorder) GetKindById(ctx, id any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKindById", reflect.TypeOf((*MockDatabase)(nil).GetKindById), ctx, id) -} - // GetKindByName mocks base method. func (m *MockDatabase) GetKindByName(ctx context.Context, name string) (model.Kind, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/database/upsert_schema_extension_integration_test.go b/cmd/api/src/database/upsert_schema_extension_integration_test.go index ca3cff540b3..b05debba52a 100644 --- a/cmd/api/src/database/upsert_schema_extension_integration_test.go +++ b/cmd/api/src/database/upsert_schema_extension_integration_test.go @@ -1066,18 +1066,18 @@ func getAndCompareGraphExtension(t *testing.T, testContext context.Context, db * SetOperator: model.FilterAnd, } - gotNodeKinds model.GraphSchemaNodeKinds - gotRelationshipKinds model.GraphSchemaRelationshipKinds - gotProperties model.GraphSchemaProperties - gotSchemaEnvironments []model.SchemaEnvironment - gotPrincipalKinds model.SchemaEnvironmentPrincipalKinds - sourceKind database.SourceKind - dawgsPrincipalKind model.Kind - dawgsFindingRelationshipKind model.Kind - dawgsFindingEnvironmentKind model.Kind - gotSchemaRelationshipFinding []model.SchemaRelationshipFinding - gotRemediation model.Remediation - findingEnvironment model.SchemaEnvironment + gotNodeKinds model.GraphSchemaNodeKinds + gotRelationshipKinds model.GraphSchemaRelationshipKinds + gotProperties model.GraphSchemaProperties + gotSchemaEnvironments []model.SchemaEnvironment + gotPrincipalKinds model.SchemaEnvironmentPrincipalKinds + sourceKind database.SourceKind + dawgsPrincipalKinds []model.Kind + dawgsFindingRelationshipKinds []model.Kind + dawgsFindingEnvironmentKinds []model.Kind + gotSchemaRelationshipFinding []model.SchemaRelationshipFinding + gotRemediation model.Remediation + findingEnvironment model.SchemaEnvironment ) // Test Node Kinds @@ -1147,9 +1147,10 @@ func getAndCompareGraphExtension(t *testing.T, testContext context.Context, db * require.Equalf(t, len(want.EnvironmentsInput[idx].PrincipalKinds), len(gotPrincipalKinds), "PrincipalKinds - count mismatch") for _, gotPrincipalKind := range gotPrincipalKinds { require.Equalf(t, gotEnvironment.ID, gotPrincipalKind.EnvironmentId, "PrincipalKind - EnvironmentId is invalid") - dawgsPrincipalKind, err = db.GetKindById(testContext, gotPrincipalKind.PrincipalKind) + dawgsPrincipalKinds, err = db.GetKindsByIDs(testContext, gotPrincipalKind.PrincipalKind) require.NoError(t, err) - require.Containsf(t, want.EnvironmentsInput[idx].PrincipalKinds, dawgsPrincipalKind.Name, "PrincipalKind - Name mismatch") + require.Len(t, dawgsPrincipalKinds, 1) + require.Containsf(t, want.EnvironmentsInput[idx].PrincipalKinds, dawgsPrincipalKinds[0].Name, "PrincipalKind - Name mismatch") } } @@ -1163,15 +1164,17 @@ func getAndCompareGraphExtension(t *testing.T, testContext context.Context, db * require.Greater(t, finding.ID, int32(0)) require.Equalf(t, gotGraphExtension.ID, finding.SchemaExtensionId, "RelationshipFindingInput - graph schema extension id should be greater than 0") - dawgsFindingRelationshipKind, err = db.GetKindById(testContext, finding.RelationshipKindId) + dawgsFindingRelationshipKinds, err = db.GetKindsByIDs(testContext, finding.RelationshipKindId) require.NoError(t, err) - require.Equalf(t, want.RelationshipFindingsInput[i].RelationshipKindName, dawgsFindingRelationshipKind.Name, "RelationshipFindingInput - relationship kind name mismatch") + require.Len(t, dawgsFindingRelationshipKinds, 1) + require.Equalf(t, want.RelationshipFindingsInput[i].RelationshipKindName, dawgsFindingRelationshipKinds[0].Name, "RelationshipFindingInput - relationship kind name mismatch") findingEnvironment, err = db.GetEnvironmentById(testContext, finding.EnvironmentId) require.NoError(t, err) - dawgsFindingEnvironmentKind, err = db.GetKindById(testContext, findingEnvironment.EnvironmentKindId) + dawgsFindingEnvironmentKinds, err = db.GetKindsByIDs(testContext, findingEnvironment.EnvironmentKindId) require.NoError(t, err) - require.Equalf(t, want.RelationshipFindingsInput[i].EnvironmentKindName, dawgsFindingEnvironmentKind.Name, "RelationshipFindingInput - environment kind name mismatch") + require.Len(t, dawgsFindingEnvironmentKinds, 1) + require.Equalf(t, want.RelationshipFindingsInput[i].EnvironmentKindName, dawgsFindingEnvironmentKinds[0].Name, "RelationshipFindingInput - environment kind name mismatch") require.Equalf(t, want.RelationshipFindingsInput[i].Name, finding.Name, "RelationshipFindingInput - name mismatch") require.Equalf(t, want.RelationshipFindingsInput[i].DisplayName, finding.DisplayName, "RelationshipFindingInput - display name mismatch") From a08cc0751c6874eceeb3cbd4259246e445e2fad3 Mon Sep 17 00:00:00 2001 From: Brandon Shearin Date: Mon, 23 Feb 2026 15:51:01 -0800 Subject: [PATCH 6/7] sql string: use IN instead of ANY --- cmd/api/src/database/sourcekinds.go | 6 +++--- cmd/api/src/database/sourcekinds_integration_test.go | 11 ++++++++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/cmd/api/src/database/sourcekinds.go b/cmd/api/src/database/sourcekinds.go index 440c7369840..f921f8856d0 100644 --- a/cmd/api/src/database/sourcekinds.go +++ b/cmd/api/src/database/sourcekinds.go @@ -174,11 +174,11 @@ func (s *BloodhoundDB) GetSourceKindsByIDs(ctx context.Context, ids ...int32) ([ return []SourceKind{}, nil } - const query = ` + query := ` SELECT sk.id, k.name, sk.active FROM source_kinds sk JOIN kind k ON k.id = sk.kind_id - WHERE sk.id = ANY(?) + WHERE sk.id IN (?) AND sk.active = true ORDER BY sk.id; ` @@ -189,7 +189,7 @@ func (s *BloodhoundDB) GetSourceKindsByIDs(ctx context.Context, ids ...int32) ([ } var rawKinds []rawSourceKind - result := s.db.WithContext(ctx).Raw(query, pq.Array(ids)).Scan(&rawKinds) + result := s.db.WithContext(ctx).Raw(query, ids).Scan(&rawKinds) if err := result.Error; err != nil { return nil, fmt.Errorf("failed to fetch source kinds by IDs: %w", err) } diff --git a/cmd/api/src/database/sourcekinds_integration_test.go b/cmd/api/src/database/sourcekinds_integration_test.go index ff1906a0442..fb9114af136 100644 --- a/cmd/api/src/database/sourcekinds_integration_test.go +++ b/cmd/api/src/database/sourcekinds_integration_test.go @@ -524,7 +524,16 @@ func TestBloodhoundDB_GetSourceKindByIDs(t *testing.T) { wantErr: database.ErrNotFound, }, { - name: "success", + name: "success - single", + setup: func(t *testing.T) []int32 { + return []int32{ + 1, + } + }, + wantErr: nil, + }, + { + name: "success - multiple", setup: func(t *testing.T) []int32 { return []int32{ 1, 2, From b85c988eb2873e5674c7afe4fac8cc49d97fbadb Mon Sep 17 00:00:00 2001 From: Brandon Shearin Date: Tue, 24 Feb 2026 12:51:14 -0800 Subject: [PATCH 7/7] add tierID to searchTierNodesCtx --- packages/go/analysis/tiering/helpers.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/go/analysis/tiering/helpers.go b/packages/go/analysis/tiering/helpers.go index 556c4ea7b74..d1275929ba3 100644 --- a/packages/go/analysis/tiering/helpers.go +++ b/packages/go/analysis/tiering/helpers.go @@ -24,6 +24,7 @@ import ( ) type SearchTierNodesCtx struct { + TierID int // Tier ID for findings processing IsTierZero bool PrimaryTierKind graph.Kind SearchTierNodes graph.Criteria @@ -32,8 +33,9 @@ type SearchTierNodesCtx struct { SearchPrimaryTierNodesRel graph.Criteria } -func NewSearchTierNodesCtx(tieringEnabled bool, isTierZero bool, primaryKind graph.Kind, tierKinds ...graph.Kind) SearchTierNodesCtx { +func NewSearchTierNodesCtx(tieringEnabled bool, isTierZero bool, tierID int, primaryKind graph.Kind, tierKinds ...graph.Kind) SearchTierNodesCtx { return SearchTierNodesCtx{ + TierID: tierID, IsTierZero: isTierZero, PrimaryTierKind: primaryKind, SearchTierNodesRel: searchTierNodesRel(tieringEnabled, tierKinds...),