diff --git a/cmd/api/src/api/v2/assetgrouptags.go b/cmd/api/src/api/v2/assetgrouptags.go index ca264e4b46e..a4de0760745 100644 --- a/cmd/api/src/api/v2/assetgrouptags.go +++ b/cmd/api/src/api/v2/assetgrouptags.go @@ -988,7 +988,7 @@ func buildAssetGroupMembersByTagGraphDbFilters(ctx context.Context, db database. return filters, err } else { for _, kind := range sourceKinds { - sourceKindsMap[kind.Name.String()] = true + sourceKindsMap[kind.Name] = true } } } diff --git a/cmd/api/src/api/v2/database_wipe.go b/cmd/api/src/api/v2/database_wipe.go index 605f0a030f2..3347d3b3a4c 100644 --- a/cmd/api/src/api/v2/database_wipe.go +++ b/cmd/api/src/api/v2/database_wipe.go @@ -279,7 +279,7 @@ func (s Resources) BuildDeleteRequest(ctx context.Context, userID string, payloa found := false for _, sk := range sourceKinds { if sk.ID == id { - requestedKinds = append(requestedKinds, sk.Name) + requestedKinds = append(requestedKinds, sk.ToKind()) found = true break } diff --git a/cmd/api/src/api/v2/kinds.go b/cmd/api/src/api/v2/kinds.go index f60538268e7..7ed32945c3c 100644 --- a/cmd/api/src/api/v2/kinds.go +++ b/cmd/api/src/api/v2/kinds.go @@ -59,7 +59,7 @@ func (s Resources) ListSourceKinds(response http.ResponseWriter, request *http.R } else { // inject 0, Sourceless into the payload. We don't track this as an official kind // but it will facilitate delete requests for data that isn't associated with a kind. - kinds = append(kinds, database.SourceKind{ID: 0, Name: graph.StringKind("Sourceless")}) + kinds = append(kinds, database.SourceKind{ID: 0, Name: "Sourceless"}) api.WriteBasicResponse(request.Context(), ListSourceKindsResponse{Kinds: kinds}, http.StatusOK, response) } } diff --git a/cmd/api/src/daemons/datapipe/pipeline.go b/cmd/api/src/daemons/datapipe/pipeline.go index 835bedf4ccf..189d9fcdc93 100644 --- a/cmd/api/src/daemons/datapipe/pipeline.go +++ b/cmd/api/src/daemons/datapipe/pipeline.go @@ -127,7 +127,7 @@ func PurgeGraphData( func extractKindNames(sourceKinds []database.SourceKind) graph.Kinds { var kinds graph.Kinds for _, k := range sourceKinds { - kinds = append(kinds, k.Name) + kinds = append(kinds, k.ToKind()) } return kinds } diff --git a/cmd/api/src/database/graphschema.go b/cmd/api/src/database/graphschema.go index 9ea298cb7f4..fc787811d70 100644 --- a/cmd/api/src/database/graphschema.go +++ b/cmd/api/src/database/graphschema.go @@ -47,20 +47,23 @@ type OpenGraphSchema interface { CreateGraphSchemaRelationshipKind(ctx context.Context, name string, schemaExtensionId int32, description string, isTraversable bool) (model.GraphSchemaRelationshipKind, error) GetGraphSchemaRelationshipKinds(ctx context.Context, filters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaRelationshipKinds, int, error) GetGraphSchemaRelationshipKindById(ctx context.Context, schemaRelationshipKindId int32) (model.GraphSchemaRelationshipKind, error) + GetTraversableRelationshipKindsByExtensionID(ctx context.Context, extensionID int32) (model.GraphSchemaRelationshipKinds, error) UpdateGraphSchemaRelationshipKind(ctx context.Context, schemaRelationshipKind model.GraphSchemaRelationshipKind) (model.GraphSchemaRelationshipKind, error) DeleteGraphSchemaRelationshipKind(ctx context.Context, schemaRelationshipKindId int32) error GetGraphSchemaRelationshipKindsWithSchemaName(ctx context.Context, filters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaRelationshipKindsWithNamedSchema, int, error) CreateEnvironment(ctx context.Context, extensionId int32, environmentKindId int32, sourceKindId int32) (model.SchemaEnvironment, error) - GetEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) + GetEnvironmentByEnvironmentKindId(ctx context.Context, environmentKindId int32) (model.SchemaEnvironment, error) GetEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) GetEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) + GetEnvironmentsFiltered(ctx context.Context, filters model.Filters) ([]model.SchemaEnvironment, error) DeleteEnvironment(ctx context.Context, environmentId int32) error CreateSchemaRelationshipFinding(ctx context.Context, extensionId int32, relationshipKindId int32, environmentId int32, name string, displayName string) (model.SchemaRelationshipFinding, error) GetSchemaRelationshipFindingById(ctx context.Context, findingId int32) (model.SchemaRelationshipFinding, error) GetSchemaRelationshipFindingByName(ctx context.Context, name string) (model.SchemaRelationshipFinding, error) + GetSchemaRelationshipFindingsByEnvironmentId(ctx context.Context, environmentId int32) ([]model.SchemaRelationshipFinding, error) DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error CreateRemediation(ctx context.Context, findingId int32, shortDescription string, longDescription string, shortRemediation string, longRemediation string) (model.Remediation, error) @@ -494,6 +497,24 @@ func (s *BloodhoundDB) GetGraphSchemaRelationshipKinds(ctx context.Context, rela } } +// GetTraversableRelationshipKindsByExtensionID returns all traversable relationship kinds for a given schema extension. +// This is a purpose-built query for the analysis pipeline that needs traversable edges for graph traversal. +func (s *BloodhoundDB) GetTraversableRelationshipKindsByExtensionID(ctx context.Context, extensionID int32) (model.GraphSchemaRelationshipKinds, error) { + const query = ` + SELECT rk.id, k.name, rk.schema_extension_id, rk.description, rk.is_traversable, + rk.created_at, rk.updated_at, rk.deleted_at + FROM schema_relationship_kinds rk + JOIN kind k ON rk.kind_id = k.id + WHERE rk.schema_extension_id = $1 AND rk.is_traversable = true + ` + + var kinds model.GraphSchemaRelationshipKinds + if result := s.db.WithContext(ctx).Raw(query, extensionID).Scan(&kinds); result.Error != nil { + return nil, CheckError(result) + } + return kinds, nil +} + func (s *BloodhoundDB) GetGraphSchemaRelationshipKindsWithSchemaName(ctx context.Context, relationshipKindFilters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaRelationshipKindsWithNamedSchema, int, error) { var ( schemaRelationshipKinds = model.GraphSchemaRelationshipKindsWithNamedSchema{} @@ -601,11 +622,24 @@ func (s *BloodhoundDB) CreateEnvironment(ctx context.Context, extensionId int32, return schemaEnvironment, nil } -// GetEnvironments - retrieves list of schema environments. -func (s *BloodhoundDB) GetEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) { +// GetEnvironmentsFiltered - retrieves schema environments filtered by the given criteria. +// This is the core implementation that all other GetEnvironment* methods delegate to. +// Common use case: filter by schema_extension_id to get all environments for a specific extension. +// Example: filters := model.Filters{"se.schema_extension_id": []model.Filter{{Operator: model.Equals, Value: "1"}}} +func (s *BloodhoundDB) GetEnvironmentsFiltered(ctx context.Context, filters model.Filters) ([]model.SchemaEnvironment, error) { var result []model.SchemaEnvironment - query := ` + sqlFilter, err := buildSQLFilter(filters) + if err != nil { + return nil, err + } + + whereClause := "" + if sqlFilter.sqlString != "" { + whereClause = fmt.Sprintf("WHERE %s", sqlFilter.sqlString) + } + + query := fmt.Sprintf(` SELECT se.id, se.schema_extension_id, @@ -619,9 +653,10 @@ func (s *BloodhoundDB) GetEnvironments(ctx context.Context) ([]model.SchemaEnvir FROM schema_environments se INNER JOIN kind k ON se.environment_kind_id = k.id INNER JOIN schema_extensions ext ON se.schema_extension_id = ext.id - ORDER BY se.id` + %s + ORDER BY se.id`, whereClause) - if err := CheckError(s.db.WithContext(ctx).Raw(query).Scan(&result)); err != nil { + if err := CheckError(s.db.WithContext(ctx).Raw(query, sqlFilter.params...).Scan(&result)); err != nil { return nil, err } @@ -632,6 +667,12 @@ func (s *BloodhoundDB) GetEnvironments(ctx context.Context) ([]model.SchemaEnvir return result, nil } +// GetEnvironments - retrieves all schema environments. +func (s *BloodhoundDB) GetEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) { + return s.GetEnvironmentsFiltered(ctx, model.Filters{}) + +} + // GetEnvironmentsByExtensionId - retrieves a slice of model.SchemaEnvironment by extension id. func (s *BloodhoundDB) GetEnvironmentsByExtensionId(ctx context.Context, extensionId int32) ([]model.SchemaEnvironment, error) { var ( @@ -652,22 +693,6 @@ func (s *BloodhoundDB) GetEnvironmentsByExtensionId(ctx context.Context, extensi } -// GetEnvironmentByKinds - retrieves an environment by its environment kind and source kind. -func (s *BloodhoundDB) GetEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { - var env model.SchemaEnvironment - - if result := s.db.WithContext(ctx).Raw( - "SELECT * FROM schema_environments WHERE environment_kind_id = ? AND source_kind_id = ? AND deleted_at IS NULL", - environmentKindId, sourceKindId, - ).Scan(&env); result.Error != nil { - return model.SchemaEnvironment{}, CheckError(result) - } else if result.RowsAffected == 0 { - return model.SchemaEnvironment{}, ErrNotFound - } - - return env, nil -} - // GetEnvironmentById - retrieves a schema environment by id. func (s *BloodhoundDB) GetEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) { var schemaEnvironment model.SchemaEnvironment @@ -685,6 +710,23 @@ func (s *BloodhoundDB) GetEnvironmentById(ctx context.Context, environmentId int return schemaEnvironment, nil } +// GetEnvironmentByEnvironmentKindId - retrieves a schema environment by environment_kind_id. +func (s *BloodhoundDB) GetEnvironmentByEnvironmentKindId(ctx context.Context, environmentKindId int32) (model.SchemaEnvironment, error) { + filters := model.Filters{ + "se.environment_kind_id": []model.Filter{{Operator: model.Equals, Value: fmt.Sprintf("%d", environmentKindId)}}, + } + + envs, err := s.GetEnvironmentsFiltered(ctx, filters) + if err != nil { + return model.SchemaEnvironment{}, err + } + if len(envs) == 0 { + return model.SchemaEnvironment{}, ErrNotFound + } + + return envs[0], nil +} + // DeleteEnvironment - deletes a schema environment by id. func (s *BloodhoundDB) DeleteEnvironment(ctx context.Context, environmentId int32) error { var schemaEnvironment model.SchemaEnvironment @@ -716,38 +758,86 @@ func (s *BloodhoundDB) CreateSchemaRelationshipFinding(ctx context.Context, exte return finding, nil } +// getSchemaRelationshipFindingsFiltered - retrieves schema relationship findings filtered by the given criteria. +// This is the core implementation that all other GetSchemaRelationshipFinding* methods delegate to. +func (s *BloodhoundDB) getSchemaRelationshipFindingsFiltered(ctx context.Context, filters model.Filters) ([]model.SchemaRelationshipFinding, error) { + var result []model.SchemaRelationshipFinding + + sqlFilter, err := buildSQLFilter(filters) + if err != nil { + return nil, err + } + + whereClause := "" + if sqlFilter.sqlString != "" { + whereClause = fmt.Sprintf("WHERE %s", sqlFilter.sqlString) + } + + query := fmt.Sprintf(` + SELECT + srf.id, + srf.schema_extension_id, + srf.relationship_kind_id, + srf.environment_id, + srf.name, + srf.display_name, + srf.created_at + FROM schema_relationship_findings srf + %s + ORDER BY srf.id`, whereClause) + + if err := CheckError(s.db.WithContext(ctx).Raw(query, sqlFilter.params...).Scan(&result)); err != nil { + return nil, err + } + + if result == nil { + result = []model.SchemaRelationshipFinding{} + } + + return result, nil +} + // GetSchemaRelationshipFindingById - retrieves a schema relationship finding by id. func (s *BloodhoundDB) GetSchemaRelationshipFindingById(ctx context.Context, findingId int32) (model.SchemaRelationshipFinding, error) { - var finding model.SchemaRelationshipFinding + filters := model.Filters{ + "srf.id": []model.Filter{{Operator: model.Equals, Value: fmt.Sprintf("%d", findingId)}}, + } - if result := s.db.WithContext(ctx).Raw(fmt.Sprintf(` - SELECT id, schema_extension_id, relationship_kind_id, environment_id, name, display_name, created_at - FROM %s WHERE id = ?`, - finding.TableName()), - findingId).Scan(&finding); result.Error != nil { - return model.SchemaRelationshipFinding{}, CheckError(result) - } else if result.RowsAffected == 0 { + findings, err := s.getSchemaRelationshipFindingsFiltered(ctx, filters) + if err != nil { + return model.SchemaRelationshipFinding{}, err + } + if len(findings) == 0 { return model.SchemaRelationshipFinding{}, ErrNotFound } - return finding, nil + return findings[0], nil } // GetSchemaRelationshipFindingByName - retrieves a schema relationship finding by finding name. func (s *BloodhoundDB) GetSchemaRelationshipFindingByName(ctx context.Context, name string) (model.SchemaRelationshipFinding, error) { - var finding model.SchemaRelationshipFinding + filters := model.Filters{ + "srf.name": []model.Filter{{Operator: model.Equals, Value: name}}, + } - if result := s.db.WithContext(ctx).Raw(fmt.Sprintf(` - SELECT id, schema_extension_id, relationship_kind_id, environment_id, name, display_name, created_at - FROM %s WHERE name = ?`, - finding.TableName()), - name).Scan(&finding); result.Error != nil { - return model.SchemaRelationshipFinding{}, CheckError(result) - } else if result.RowsAffected == 0 { + findings, err := s.getSchemaRelationshipFindingsFiltered(ctx, filters) + if err != nil { + return model.SchemaRelationshipFinding{}, err + } + if len(findings) == 0 { return model.SchemaRelationshipFinding{}, ErrNotFound } - return finding, nil + return findings[0], nil +} + +// GetSchemaRelationshipFindingsByEnvironmentId - retrieves all schema relationship findings for a given environment. +func (s *BloodhoundDB) GetSchemaRelationshipFindingsByEnvironmentId(ctx context.Context, environmentId int32) ([]model.SchemaRelationshipFinding, error) { + filters := model.Filters{ + "srf.environment_id": []model.Filter{{Operator: model.Equals, Value: fmt.Sprintf("%d", environmentId)}}, + } + + return s.getSchemaRelationshipFindingsFiltered(ctx, filters) } // DeleteSchemaRelationshipFinding - deletes a schema relationship finding by id. diff --git a/cmd/api/src/database/graphschema_integration_test.go b/cmd/api/src/database/graphschema_integration_test.go index 11b517ecae5..a85f4ec18de 100644 --- a/cmd/api/src/database/graphschema_integration_test.go +++ b/cmd/api/src/database/graphschema_integration_test.go @@ -4023,7 +4023,7 @@ func TestDatabase_Environments_CRUD(t *testing.T) { assert.NoError(t, err, "unexpected error occurred when creating environment") // Validate created environment is as expected - retrievedEnvironment, err := testSuite.BHDatabase.GetEnvironmentById(testSuite.Context, newEnvironment.ID) + retrievedEnvironment, err := testSuite.BHDatabase.GetEnvironmentByEnvironmentKindId(testSuite.Context, newEnvironment.EnvironmentKindId) assert.NoError(t, err, "unexpected error occurred when retrieving environment by id") assertContainsEnvironment(t, retrievedEnvironment, environment) @@ -4080,7 +4080,7 @@ func TestDatabase_Environments_CRUD(t *testing.T) { newEnvironment, err := testSuite.BHDatabase.CreateEnvironment(testSuite.Context, environment.SchemaExtensionId, environment.EnvironmentKindId, environment.SourceKindId) require.NoError(t, err, "unexpected error occurred when creating environment") - retrievedEnvironment, err := testSuite.BHDatabase.GetEnvironmentByKinds(testSuite.Context, newEnvironment.EnvironmentKindId, newEnvironment.SourceKindId) + retrievedEnvironment, err := testSuite.BHDatabase.GetEnvironmentByEnvironmentKindId(testSuite.Context, newEnvironment.EnvironmentKindId) assert.NoError(t, err, database.ErrNotFound) assertContainsEnvironment(t, retrievedEnvironment, environment) @@ -4097,7 +4097,7 @@ func TestDatabase_Environments_CRUD(t *testing.T) { SourceKindId: 257958, } - _, err := testSuite.BHDatabase.GetEnvironmentByKinds(testSuite.Context, environment.EnvironmentKindId, environment.SourceKindId) + _, err := testSuite.BHDatabase.GetEnvironmentByEnvironmentKindId(testSuite.Context, environment.EnvironmentKindId) assert.EqualError(t, err, database.ErrNotFound.Error(), "expected entity not found") }, }, @@ -4124,7 +4124,7 @@ func TestDatabase_Environments_CRUD(t *testing.T) { require.NoError(t, err, "unexpected error occurred when creating environment") // Validate environment - retrievedEnvironment, err := testSuite.BHDatabase.GetEnvironmentById(testSuite.Context, newEnvironment.ID) + retrievedEnvironment, err := testSuite.BHDatabase.GetEnvironmentByEnvironmentKindId(testSuite.Context, newEnvironment.EnvironmentKindId) assert.NoError(t, err, "failed to get environment by id") assertContainsEnvironment(t, retrievedEnvironment, environment) @@ -4136,7 +4136,7 @@ func TestDatabase_Environments_CRUD(t *testing.T) { assert: func(t *testing.T, testSuite IntegrationTestSuite) { t.Helper() - _, err := testSuite.BHDatabase.GetEnvironmentById(testSuite.Context, int32(5000)) + _, err := testSuite.BHDatabase.GetEnvironmentByEnvironmentKindId(testSuite.Context, int32(5000)) require.ErrorIs(t, err, database.ErrNotFound) }, }, @@ -4265,7 +4265,7 @@ func TestDatabase_Environments_CRUD(t *testing.T) { assert.NoError(t, err, "unexpected error occurred when deleting environment for extension") // Validate environment no longer exists - _, err = testSuite.BHDatabase.GetEnvironmentById(testSuite.Context, newEnvironment.ID) + _, err = testSuite.BHDatabase.GetEnvironmentByEnvironmentKindId(testSuite.Context, newEnvironment.EnvironmentKindId) require.EqualError(t, err, database.ErrNotFound.Error()) }, }, @@ -5351,7 +5351,7 @@ func TestDeleteSchemaExtension_CascadeDeletesAllDependents(t *testing.T) { _, err = testSuite.BHDatabase.GetGraphSchemaRelationshipKindById(testSuite.Context, edgeKind.ID) assert.ErrorIs(t, err, database.ErrNotFound) - _, err = testSuite.BHDatabase.GetEnvironmentById(testSuite.Context, environment.ID) + _, err = testSuite.BHDatabase.GetEnvironmentByEnvironmentKindId(testSuite.Context, environment.EnvironmentKindId) assert.ErrorIs(t, err, database.ErrNotFound) _, err = testSuite.BHDatabase.GetSchemaRelationshipFindingById(testSuite.Context, relationshipFinding.ID) diff --git a/cmd/api/src/database/kind.go b/cmd/api/src/database/kind.go index 17aaa2249c0..fa253765b6d 100644 --- a/cmd/api/src/database/kind.go +++ b/cmd/api/src/database/kind.go @@ -18,12 +18,13 @@ package database import ( "context" + "github.com/lib/pq" "github.com/specterops/bloodhound/cmd/api/src/model" ) type Kind interface { GetKindByName(ctx context.Context, name string) (model.Kind, error) - GetKindById(ctx context.Context, id int32) (model.Kind, error) + GetKindsByIds(ctx context.Context, ids ...int32) ([]model.Kind, error) } func (s *BloodhoundDB) GetKindByName(ctx context.Context, name string) (model.Kind, error) { @@ -47,23 +48,25 @@ func (s *BloodhoundDB) GetKindByName(ctx context.Context, name string) (model.Ki return kind, nil } -func (s *BloodhoundDB) GetKindById(ctx context.Context, id int32) (model.Kind, error) { +func (s *BloodhoundDB) GetKindsByIds(ctx context.Context, ids ...int32) ([]model.Kind, error) { + if len(ids) == 0 { + return []model.Kind{}, nil + } + const query = ` SELECT id, name FROM kind - WHERE id = $1; + WHERE id = ANY(?); ` - var kind model.Kind - result := s.db.WithContext(ctx).Raw(query, id).Scan(&kind) - - if result.Error != nil { - return model.Kind{}, result.Error + var kinds []model.Kind + if err := s.db.WithContext(ctx).Raw(query, pq.Array(ids)).Scan(&kinds).Error; err != nil { + return nil, err } - if result.RowsAffected == 0 || kind.ID == 0 { - return model.Kind{}, ErrNotFound + if len(kinds) != len(ids) { + return nil, ErrNotFound } - return kind, nil + return kinds, nil } diff --git a/cmd/api/src/database/kind_integration_test.go b/cmd/api/src/database/kind_integration_test.go index dea63dada97..17b117181d7 100644 --- a/cmd/api/src/database/kind_integration_test.go +++ b/cmd/api/src/database/kind_integration_test.go @@ -129,16 +129,44 @@ func TestGetKindByID(t *testing.T) { var ( err error createdKind model.Kind - got model.Kind + got []model.Kind ) createdKind = tt.setup(t) - if got, err = testSuite.BHDatabase.GetKindById(testSuite.Context, createdKind.ID); tt.want.err != nil { + if got, err = testSuite.BHDatabase.GetKindsByIds(testSuite.Context, createdKind.ID); tt.want.err != nil { assert.EqualError(t, err, tt.want.err.Error()) } else { assert.NoError(t, err) - assert.Equal(t, tt.want.kind.Name, got.Name) - assert.Greater(t, got.ID, int32(0)) + assert.Equal(t, tt.want.kind.Name, got[0].Name) + assert.Greater(t, got[0].ID, int32(0)) } }) } } + +func TestGetKindsByIds_MultipleKinds(t *testing.T) { + testSuite := setupIntegrationTestSuite(t) + defer teardownIntegrationTestSuite(t, &testSuite) + + // Create two kinds + var kind1, kind2 model.Kind + result := testSuite.DB.WithContext(testSuite.Context).Raw(` + INSERT INTO kind (name) + VALUES ('Test_Kind_One') + RETURNING id, name;`).Scan(&kind1) + require.NoError(t, result.Error) + + result = testSuite.DB.WithContext(testSuite.Context).Raw(` + INSERT INTO kind (name) + VALUES ('Test_Kind_Two') + RETURNING id, name;`).Scan(&kind2) + require.NoError(t, result.Error) + + // Fetch both kinds by their IDs + kinds, err := testSuite.BHDatabase.GetKindsByIds(testSuite.Context, kind1.ID, kind2.ID) + require.NoError(t, err) + assert.Len(t, kinds, 2) + + // Verify both kinds are returned (order not guaranteed) + names := []string{kinds[0].Name, kinds[1].Name} + require.ElementsMatch(t, []string{"Test_Kind_One", "Test_Kind_Two"}, names) +} diff --git a/cmd/api/src/database/migration/extensions/ad_graph_schema.sql b/cmd/api/src/database/migration/extensions/ad_graph_schema.sql index c930a9f0aa6..e92abafb206 100644 --- a/cmd/api/src/database/migration/extensions/ad_graph_schema.sql +++ b/cmd/api/src/database/migration/extensions/ad_graph_schema.sql @@ -71,14 +71,16 @@ DECLARE BEGIN SELECT id INTO retreived_environment_kind_id FROM kind WHERE name = v_environment_kind_name; IF retreived_environment_kind_id IS NULL THEN - RAISE EXCEPTION 'couldn''t find matching kind_id'; + PERFORM genscript_upsert_kind(v_environment_kind_name); + SELECT id INTO retreived_environment_kind_id FROM kind WHERE name = v_environment_kind_name; END IF; SELECT id INTO retreived_source_kind_id FROM source_kinds WHERE name = v_source_kind_name; IF retreived_source_kind_id IS NULL THEN - RAISE EXCEPTION 'couldn''t find matching kind_id'; + PERFORM genscript_upsert_source_kind(v_source_kind_name); + SELECT id INTO retreived_source_kind_id FROM source_kinds WHERE name = v_source_kind_name; END IF; - + IF NOT EXISTS (SELECT id FROM schema_environments se WHERE se.schema_extension_id = v_extension_id) THEN INSERT INTO schema_environments (schema_extension_id, environment_kind_id, source_kind_id) VALUES (v_extension_id, retreived_environment_kind_id, retreived_source_kind_id) RETURNING id INTO schema_environment_id; ELSE @@ -327,8 +329,6 @@ BEGIN PERFORM genscript_upsert_schema_relationship_kind(extension_id, 'HasTrustKeys', '', true); PERFORM genscript_upsert_schema_relationship_kind(extension_id, 'ProtectAdminGroups', '', false); - PERFORM genscript_upsert_source_kind('Base'); - PERFORM genscript_upsert_kind('Domain'); SELECT genscript_upsert_schema_environments(extension_id, 'Domain', 'Base') INTO environment_id; PERFORM genscript_upsert_schema_environments_principal_kinds(environment_id, 'User'); PERFORM genscript_upsert_schema_environments_principal_kinds(environment_id, 'Computer'); diff --git a/cmd/api/src/database/migration/extensions/az_graph_schema.sql b/cmd/api/src/database/migration/extensions/az_graph_schema.sql index 7625a81fa4e..2275e16c4f8 100644 --- a/cmd/api/src/database/migration/extensions/az_graph_schema.sql +++ b/cmd/api/src/database/migration/extensions/az_graph_schema.sql @@ -71,14 +71,16 @@ DECLARE BEGIN SELECT id INTO retreived_environment_kind_id FROM kind WHERE name = v_environment_kind_name; IF retreived_environment_kind_id IS NULL THEN - RAISE EXCEPTION 'couldn''t find matching kind_id'; + PERFORM genscript_upsert_kind(v_environment_kind_name); + SELECT id INTO retreived_environment_kind_id FROM kind WHERE name = v_environment_kind_name; END IF; SELECT id INTO retreived_source_kind_id FROM source_kinds WHERE name = v_source_kind_name; IF retreived_source_kind_id IS NULL THEN - RAISE EXCEPTION 'couldn''t find matching kind_id'; + PERFORM genscript_upsert_source_kind(v_source_kind_name); + SELECT id INTO retreived_source_kind_id FROM source_kinds WHERE name = v_source_kind_name; END IF; - + IF NOT EXISTS (SELECT id FROM schema_environments se WHERE se.schema_extension_id = v_extension_id) THEN INSERT INTO schema_environments (schema_extension_id, environment_kind_id, source_kind_id) VALUES (v_extension_id, retreived_environment_kind_id, retreived_source_kind_id) RETURNING id INTO schema_environment_id; ELSE @@ -261,9 +263,7 @@ BEGIN PERFORM genscript_upsert_schema_relationship_kind(extension_id, 'AZRoleEligible', '', true); PERFORM genscript_upsert_schema_relationship_kind(extension_id, 'AZRoleApprover', '', true); - PERFORM genscript_upsert_source_kind('AZBase'); - PERFORM genscript_upsert_kind('Tenant'); - SELECT genscript_upsert_schema_environments(extension_id, 'Tenant', 'AZBase') INTO environment_id; + SELECT genscript_upsert_schema_environments(extension_id, 'AZTenant', 'AZBase') INTO environment_id; PERFORM genscript_upsert_schema_environments_principal_kinds(environment_id, 'AZUser'); PERFORM genscript_upsert_schema_environments_principal_kinds(environment_id, 'AZVM'); PERFORM genscript_upsert_schema_environments_principal_kinds(environment_id, 'AZServicePrincipal'); diff --git a/cmd/api/src/database/migration/migrations/v8.7.0.sql b/cmd/api/src/database/migration/migrations/v8.7.0.sql index 9c88f5109c0..d1ab0b910d7 100644 --- a/cmd/api/src/database/migration/migrations/v8.7.0.sql +++ b/cmd/api/src/database/migration/migrations/v8.7.0.sql @@ -13,6 +13,22 @@ -- limitations under the License. -- -- SPDX-License-Identifier: Apache-2.0 +-- Drop the compound unique constraint on schema_environments (environment_kind_id, source_kind_id) +-- and add a unique constraint on just environment_kind_id +ALTER TABLE IF EXISTS schema_environments + DROP CONSTRAINT IF EXISTS schema_environments_environment_kind_id_source_kind_id_key; + +DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'schema_environments_environment_kind_id_key' + ) THEN + ALTER TABLE schema_environments + ADD CONSTRAINT schema_environments_environment_kind_id_key UNIQUE (environment_kind_id); + END IF; + END$$; -- OpenGraph Extension Management feature flag INSERT INTO feature_flags (created_at, updated_at, key, name, description, enabled, user_updatable) diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index c1852742624..42d6a3a01e7 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -1702,34 +1702,34 @@ func (mr *MockDatabaseMockRecorder) GetDatapipeStatus(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDatapipeStatus", reflect.TypeOf((*MockDatabase)(nil).GetDatapipeStatus), ctx) } -// GetEnvironmentById mocks base method. -func (m *MockDatabase) GetEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) { +// GetEnvironmentByEnvironmentKindId mocks base method. +func (m *MockDatabase) GetEnvironmentByEnvironmentKindId(ctx context.Context, environmentKindId int32) (model.SchemaEnvironment, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetEnvironmentById", ctx, environmentId) + ret := m.ctrl.Call(m, "GetEnvironmentByEnvironmentKindId", ctx, environmentKindId) ret0, _ := ret[0].(model.SchemaEnvironment) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetEnvironmentById indicates an expected call of GetEnvironmentById. -func (mr *MockDatabaseMockRecorder) GetEnvironmentById(ctx, environmentId any) *gomock.Call { +// GetEnvironmentByEnvironmentKindId indicates an expected call of GetEnvironmentByEnvironmentKindId. +func (mr *MockDatabaseMockRecorder) GetEnvironmentByEnvironmentKindId(ctx, environmentKindId any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnvironmentById", reflect.TypeOf((*MockDatabase)(nil).GetEnvironmentById), ctx, environmentId) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnvironmentByEnvironmentKindId", reflect.TypeOf((*MockDatabase)(nil).GetEnvironmentByEnvironmentKindId), ctx, environmentKindId) } -// GetEnvironmentByKinds mocks base method. -func (m *MockDatabase) GetEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { +// GetEnvironmentById mocks base method. +func (m *MockDatabase) GetEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetEnvironmentByKinds", ctx, environmentKindId, sourceKindId) + ret := m.ctrl.Call(m, "GetEnvironmentById", ctx, environmentId) ret0, _ := ret[0].(model.SchemaEnvironment) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetEnvironmentByKinds indicates an expected call of GetEnvironmentByKinds. -func (mr *MockDatabaseMockRecorder) GetEnvironmentByKinds(ctx, environmentKindId, sourceKindId any) *gomock.Call { +// GetEnvironmentById indicates an expected call of GetEnvironmentById. +func (mr *MockDatabaseMockRecorder) GetEnvironmentById(ctx, environmentId any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnvironmentByKinds", reflect.TypeOf((*MockDatabase)(nil).GetEnvironmentByKinds), ctx, environmentKindId, sourceKindId) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnvironmentById", reflect.TypeOf((*MockDatabase)(nil).GetEnvironmentById), ctx, environmentId) } // GetEnvironmentTargetedAccessControlForUser mocks base method. @@ -1762,6 +1762,21 @@ func (mr *MockDatabaseMockRecorder) GetEnvironments(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnvironments", reflect.TypeOf((*MockDatabase)(nil).GetEnvironments), ctx) } +// GetEnvironmentsFiltered mocks base method. +func (m *MockDatabase) GetEnvironmentsFiltered(ctx context.Context, filters model.Filters) ([]model.SchemaEnvironment, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetEnvironmentsFiltered", ctx, filters) + ret0, _ := ret[0].([]model.SchemaEnvironment) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetEnvironmentsFiltered indicates an expected call of GetEnvironmentsFiltered. +func (mr *MockDatabaseMockRecorder) GetEnvironmentsFiltered(ctx, filters any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnvironmentsFiltered", reflect.TypeOf((*MockDatabase)(nil).GetEnvironmentsFiltered), ctx, filters) +} + // GetFlag mocks base method. func (m *MockDatabase) GetFlag(ctx context.Context, id int32) (appcfg.FeatureFlag, error) { m.ctrl.T.Helper() @@ -1992,34 +2007,39 @@ func (mr *MockDatabaseMockRecorder) GetInstallation(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstallation", reflect.TypeOf((*MockDatabase)(nil).GetInstallation), ctx) } -// GetKindById mocks base method. -func (m *MockDatabase) GetKindById(ctx context.Context, id int32) (model.Kind, error) { +// GetKindByName mocks base method. +func (m *MockDatabase) GetKindByName(ctx context.Context, name string) (model.Kind, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetKindById", ctx, id) + ret := m.ctrl.Call(m, "GetKindByName", ctx, name) ret0, _ := ret[0].(model.Kind) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetKindById indicates an expected call of GetKindById. -func (mr *MockDatabaseMockRecorder) GetKindById(ctx, id any) *gomock.Call { +// GetKindByName indicates an expected call of GetKindByName. +func (mr *MockDatabaseMockRecorder) GetKindByName(ctx, name any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKindById", reflect.TypeOf((*MockDatabase)(nil).GetKindById), ctx, id) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKindByName", reflect.TypeOf((*MockDatabase)(nil).GetKindByName), ctx, name) } -// GetKindByName mocks base method. -func (m *MockDatabase) GetKindByName(ctx context.Context, name string) (model.Kind, error) { +// GetKindsByIds mocks base method. +func (m *MockDatabase) GetKindsByIds(ctx context.Context, ids ...int32) ([]model.Kind, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetKindByName", ctx, name) - ret0, _ := ret[0].(model.Kind) + varargs := []any{ctx} + for _, a := range ids { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetKindsByIds", varargs...) + ret0, _ := ret[0].([]model.Kind) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetKindByName indicates an expected call of GetKindByName. -func (mr *MockDatabaseMockRecorder) GetKindByName(ctx, name any) *gomock.Call { +// GetKindsByIds indicates an expected call of GetKindsByIds. +func (mr *MockDatabaseMockRecorder) GetKindsByIds(ctx any, ids ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKindByName", reflect.TypeOf((*MockDatabase)(nil).GetKindByName), ctx, name) + varargs := append([]any{ctx}, ids...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKindsByIds", reflect.TypeOf((*MockDatabase)(nil).GetKindsByIds), varargs...) } // GetLatestAssetGroupCollection mocks base method. @@ -2307,6 +2327,21 @@ func (mr *MockDatabaseMockRecorder) GetSchemaRelationshipFindingByName(ctx, name return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaRelationshipFindingByName", reflect.TypeOf((*MockDatabase)(nil).GetSchemaRelationshipFindingByName), ctx, name) } +// GetSchemaRelationshipFindingsByEnvironmentId mocks base method. +func (m *MockDatabase) GetSchemaRelationshipFindingsByEnvironmentId(ctx context.Context, environmentId int32) ([]model.SchemaRelationshipFinding, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSchemaRelationshipFindingsByEnvironmentId", ctx, environmentId) + ret0, _ := ret[0].([]model.SchemaRelationshipFinding) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSchemaRelationshipFindingsByEnvironmentId indicates an expected call of GetSchemaRelationshipFindingsByEnvironmentId. +func (mr *MockDatabaseMockRecorder) GetSchemaRelationshipFindingsByEnvironmentId(ctx, environmentId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaRelationshipFindingsByEnvironmentId", reflect.TypeOf((*MockDatabase)(nil).GetSchemaRelationshipFindingsByEnvironmentId), ctx, environmentId) +} + // GetScopeForSavedQuery mocks base method. func (m *MockDatabase) GetScopeForSavedQuery(ctx context.Context, queryID int64, userID uuid.UUID) (database.SavedQueryScopeMap, error) { m.ctrl.T.Helper() @@ -2423,6 +2458,26 @@ func (mr *MockDatabaseMockRecorder) GetSourceKinds(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSourceKinds", reflect.TypeOf((*MockDatabase)(nil).GetSourceKinds), ctx) } +// GetSourceKindsByIds mocks base method. +func (m *MockDatabase) GetSourceKindsByIds(ctx context.Context, ids ...int32) ([]database.SourceKind, error) { + m.ctrl.T.Helper() + varargs := []any{ctx} + for _, a := range ids { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetSourceKindsByIds", varargs...) + ret0, _ := ret[0].([]database.SourceKind) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSourceKindsByIds indicates an expected call of GetSourceKindsByIds. +func (mr *MockDatabaseMockRecorder) GetSourceKindsByIds(ctx any, ids ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx}, ids...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSourceKindsByIds", reflect.TypeOf((*MockDatabase)(nil).GetSourceKindsByIds), varargs...) +} + // GetTimeRangedAssetGroupCollections mocks base method. func (m *MockDatabase) GetTimeRangedAssetGroupCollections(ctx context.Context, assetGroupID int32, from, to int64, order string) (model.AssetGroupCollections, error) { m.ctrl.T.Helper() @@ -2438,6 +2493,21 @@ func (mr *MockDatabaseMockRecorder) GetTimeRangedAssetGroupCollections(ctx, asse return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTimeRangedAssetGroupCollections", reflect.TypeOf((*MockDatabase)(nil).GetTimeRangedAssetGroupCollections), ctx, assetGroupID, from, to, order) } +// GetTraversableRelationshipKindsByExtensionID mocks base method. +func (m *MockDatabase) GetTraversableRelationshipKindsByExtensionID(ctx context.Context, extensionID int32) (model.GraphSchemaRelationshipKinds, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTraversableRelationshipKindsByExtensionID", ctx, extensionID) + ret0, _ := ret[0].(model.GraphSchemaRelationshipKinds) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTraversableRelationshipKindsByExtensionID indicates an expected call of GetTraversableRelationshipKindsByExtensionID. +func (mr *MockDatabaseMockRecorder) GetTraversableRelationshipKindsByExtensionID(ctx, extensionID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTraversableRelationshipKindsByExtensionID", reflect.TypeOf((*MockDatabase)(nil).GetTraversableRelationshipKindsByExtensionID), ctx, extensionID) +} + // GetUser mocks base method. func (m *MockDatabase) GetUser(ctx context.Context, id uuid.UUID) (model.User, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/database/sourcekinds.go b/cmd/api/src/database/sourcekinds.go index b005759d1d4..978ddb5799f 100644 --- a/cmd/api/src/database/sourcekinds.go +++ b/cmd/api/src/database/sourcekinds.go @@ -29,6 +29,7 @@ type SourceKindsData interface { DeactivateSourceKindsByName(ctx context.Context, kinds graph.Kinds) error RegisterSourceKind(ctx context.Context) func(sourceKind graph.Kind) error GetSourceKindByName(ctx context.Context, name string) (SourceKind, error) + GetSourceKindsByIds(ctx context.Context, ids ...int32) ([]SourceKind, error) } // RegisterSourceKind returns a function that inserts a source kind by name, @@ -57,9 +58,13 @@ func (s *BloodhoundDB) RegisterSourceKind(ctx context.Context) func(sourceKind g } type SourceKind struct { - ID int `json:"id"` - Name graph.Kind `json:"name"` - Active bool `json:"active"` + ID int `json:"id"` + Name string `json:"name"` + Active bool `json:"active"` +} + +func (s *SourceKind) ToKind() graph.Kind { + return graph.StringKind(s.Name) } func (s *BloodhoundDB) GetSourceKinds(ctx context.Context) ([]SourceKind, error) { @@ -70,28 +75,12 @@ func (s *BloodhoundDB) GetSourceKinds(ctx context.Context) ([]SourceKind, error) ORDER BY name ASC; ` - type rawSourceKind struct { - ID int - Name string - Active bool - } - - var kinds []rawSourceKind - result := s.db.WithContext(ctx).Raw(query).Scan(&kinds) - if err := result.Error; err != nil { + var kinds []SourceKind + if err := s.db.WithContext(ctx).Raw(query).Scan(&kinds).Error; err != nil { return nil, fmt.Errorf("failed to fetch source kinds: %w", err) } - out := make([]SourceKind, len(kinds)) - for i, k := range kinds { - out[i] = SourceKind{ - ID: k.ID, - Name: graph.StringKind(k.Name), - Active: k.Active, - } - } - - return out, nil + return kinds, nil } func (s *BloodhoundDB) GetSourceKindByName(ctx context.Context, name string) (SourceKind, error) { @@ -101,30 +90,41 @@ func (s *BloodhoundDB) GetSourceKindByName(ctx context.Context, name string) (So WHERE name = $1 AND active = true; ` - type rawSourceKind struct { - ID int - Name string - Active bool - } - - var raw rawSourceKind - result := s.db.WithContext(ctx).Raw(query, name).Scan(&raw) + var sourceKind SourceKind + result := s.db.WithContext(ctx).Raw(query, name).Scan(&sourceKind) if result.Error != nil { return SourceKind{}, result.Error } - if result.RowsAffected == 0 || raw.ID == 0 { + if result.RowsAffected == 0 || sourceKind.ID == 0 { return SourceKind{}, ErrNotFound } - kind := SourceKind{ - ID: raw.ID, - Name: graph.StringKind(raw.Name), - Active: raw.Active, + return sourceKind, nil +} + +func (s *BloodhoundDB) GetSourceKindsByIds(ctx context.Context, ids ...int32) ([]SourceKind, error) { + if len(ids) == 0 { + return []SourceKind{}, nil + } + + const query = ` + SELECT id, name, active + FROM source_kinds + WHERE id = ANY(?) AND active = true; + ` + + var sourceKinds []SourceKind + if err := s.db.WithContext(ctx).Raw(query, pq.Array(ids)).Scan(&sourceKinds).Error; err != nil { + return nil, err + } + + if len(sourceKinds) != len(ids) { + return nil, ErrNotFound } - return kind, nil + return sourceKinds, nil } func (s *BloodhoundDB) GetSourceKindByID(ctx context.Context, id int) (SourceKind, error) { @@ -133,30 +133,19 @@ func (s *BloodhoundDB) GetSourceKindByID(ctx context.Context, id int) (SourceKin FROM source_kinds WHERE id = $1 AND active = true; ` - type rawSourceKind struct { - ID int - Name string - Active bool - } - var raw rawSourceKind - result := s.db.WithContext(ctx).Raw(query, id).Scan(&raw) + var sourceKind SourceKind + result := s.db.WithContext(ctx).Raw(query, id).Scan(&sourceKind) if result.Error != nil { return SourceKind{}, result.Error } - if result.RowsAffected == 0 || raw.ID == 0 { + if result.RowsAffected == 0 || sourceKind.ID == 0 { return SourceKind{}, ErrNotFound } - kind := SourceKind{ - ID: raw.ID, - Name: graph.StringKind(raw.Name), - Active: raw.Active, - } - - return kind, nil + return sourceKind, nil } func (s *BloodhoundDB) DeactivateSourceKindsByName(ctx context.Context, kinds graph.Kinds) error { diff --git a/cmd/api/src/database/sourcekinds_integration_test.go b/cmd/api/src/database/sourcekinds_integration_test.go index cb768dbb2c7..d328f298c05 100644 --- a/cmd/api/src/database/sourcekinds_integration_test.go +++ b/cmd/api/src/database/sourcekinds_integration_test.go @@ -56,12 +56,12 @@ func TestRegisterSourceKind(t *testing.T) { sourceKinds: []database.SourceKind{ { ID: 2, - Name: graph.StringKind("AZBase"), + Name: "AZBase", Active: true, }, { ID: 1, - Name: graph.StringKind("Base"), + Name: "Base", Active: true, }, }, @@ -80,17 +80,17 @@ func TestRegisterSourceKind(t *testing.T) { sourceKinds: []database.SourceKind{ { ID: 2, - Name: graph.StringKind("AZBase"), + Name: "AZBase", Active: true, }, { ID: 1, - Name: graph.StringKind("Base"), + Name: "Base", Active: true, }, { ID: 3, - Name: graph.StringKind("harnessEdge.Kind"), + Name: "harnessEdge.Kind", Active: true, }, }, @@ -118,17 +118,17 @@ func TestRegisterSourceKind(t *testing.T) { sourceKinds: []database.SourceKind{ { ID: 2, - Name: graph.StringKind("AZBase"), + Name: "AZBase", Active: true, }, { ID: 1, - Name: graph.StringKind("Base"), + Name: "Base", Active: true, }, { ID: 3, - Name: graph.StringKind("Kind"), + Name: "Kind", Active: true, }, }, @@ -178,12 +178,12 @@ func TestGetSourceKinds(t *testing.T) { sourceKinds: []database.SourceKind{ { ID: 2, - Name: graph.StringKind("AZBase"), + Name: "AZBase", Active: true, }, { ID: 1, - Name: graph.StringKind("Base"), + Name: "Base", Active: true, }, }, @@ -234,7 +234,7 @@ func TestGetSourceKindByName(t *testing.T) { // simply testing the default returned source_kinds sourceKind: database.SourceKind{ ID: 2, - Name: graph.StringKind("AZBase"), + Name: "AZBase", Active: true, }, }, @@ -342,12 +342,12 @@ func TestDeactivateSourceKindsByName(t *testing.T) { sourceKinds: []database.SourceKind{ { ID: 2, - Name: graph.StringKind("AZBase"), + Name: "AZBase", Active: true, }, { ID: 1, - Name: graph.StringKind("Base"), + Name: "Base", Active: true, }, }, @@ -366,12 +366,12 @@ func TestDeactivateSourceKindsByName(t *testing.T) { sourceKinds: []database.SourceKind{ { ID: 2, - Name: graph.StringKind("AZBase"), + Name: "AZBase", Active: true, }, { ID: 1, - Name: graph.StringKind("Base"), + Name: "Base", Active: true, }, }, @@ -390,12 +390,12 @@ func TestDeactivateSourceKindsByName(t *testing.T) { sourceKinds: []database.SourceKind{ { ID: 2, - Name: graph.StringKind("AZBase"), + Name: "AZBase", Active: true, }, { ID: 1, - Name: graph.StringKind("Base"), + Name: "Base", Active: true, }, }, @@ -424,17 +424,17 @@ func TestDeactivateSourceKindsByName(t *testing.T) { sourceKinds: []database.SourceKind{ { ID: 4, - Name: graph.StringKind("AnotherKind"), + Name: "AnotherKind", Active: true, }, { ID: 2, - Name: graph.StringKind("AZBase"), + Name: "AZBase", Active: true, }, { ID: 1, - Name: graph.StringKind("Base"), + Name: "Base", Active: true, }, }, @@ -463,12 +463,12 @@ func TestDeactivateSourceKindsByName(t *testing.T) { sourceKinds: []database.SourceKind{ { ID: 2, - Name: graph.StringKind("AZBase"), + Name: "AZBase", Active: true, }, { ID: 1, - Name: graph.StringKind("Base"), + Name: "Base", Active: true, }, }, @@ -495,3 +495,55 @@ func TestDeactivateSourceKindsByName(t *testing.T) { }) } } + +func TestGetSourceKindsByIds(t *testing.T) { + testSuite := setupIntegrationTestSuite(t) + defer teardownIntegrationTestSuite(t, &testSuite) + + t.Run("not found", func(t *testing.T) { + _, err := testSuite.BHDatabase.GetSourceKindsByIds(testSuite.Context, 9999) + require.ErrorIs(t, err, database.ErrNotFound) + }) + + t.Run("single kind", func(t *testing.T) { + // Create a source kind + var sourceKind database.SourceKind + result := testSuite.DB.WithContext(testSuite.Context).Raw(` + INSERT INTO source_kinds (name, active) + VALUES ('TestSourceKindOne', true) + RETURNING id, name, active;`).Scan(&sourceKind) + require.NoError(t, result.Error) + + // Fetch it by ID + kinds, err := testSuite.BHDatabase.GetSourceKindsByIds(testSuite.Context, int32(sourceKind.ID)) + require.NoError(t, err) + require.Len(t, kinds, 1) + assert.Equal(t, "TestSourceKindOne", kinds[0].Name) + assert.True(t, kinds[0].Active) + }) + + t.Run("multiple kinds", func(t *testing.T) { + // Create two source kinds + var kind1, kind2 database.SourceKind + result := testSuite.DB.WithContext(testSuite.Context).Raw(` + INSERT INTO source_kinds (name, active) + VALUES ('TestSourceKindTwo', true) + RETURNING id, name, active;`).Scan(&kind1) + require.NoError(t, result.Error) + + result = testSuite.DB.WithContext(testSuite.Context).Raw(` + INSERT INTO source_kinds (name, active) + VALUES ('TestSourceKindThree', true) + RETURNING id, name, active;`).Scan(&kind2) + require.NoError(t, result.Error) + + // Fetch both by their IDs + kinds, err := testSuite.BHDatabase.GetSourceKindsByIds(testSuite.Context, int32(kind1.ID), int32(kind2.ID)) + require.NoError(t, err) + require.Len(t, kinds, 2) + + // Verify both kinds are returned (order not guaranteed) + names := []string{kinds[0].Name, kinds[1].Name} + assert.ElementsMatch(t, []string{"TestSourceKindTwo", "TestSourceKindThree"}, names) + }) +} diff --git a/cmd/api/src/database/upsert_schema_environment.go b/cmd/api/src/database/upsert_schema_environment.go index f266623d9df..42d2293863e 100644 --- a/cmd/api/src/database/upsert_schema_environment.go +++ b/cmd/api/src/database/upsert_schema_environment.go @@ -115,11 +115,11 @@ func (s *BloodhoundDB) validateAndTranslatePrincipalKinds(ctx context.Context, p } // replaceSchemaEnvironment creates or updates a schema environment. -// If an environment with the given kinds exists, it deletes it first before creating the new one. -// The unique constraint on (environment_kind_id, source_kind_id) of the Schema Environment table ensures no -// duplicate pairs exist, enabling this upsert logic. +// If an environment with the given environment_kind_id exists, it deletes it first before creating the new one. +// The unique constraint on environment_kind_id of the Schema Environment table ensures no +// duplicates exist, enabling this upsert logic. func (s *BloodhoundDB) replaceSchemaEnvironment(ctx context.Context, graphSchema model.SchemaEnvironment) (int32, error) { - if existing, err := s.GetEnvironmentByKinds(ctx, graphSchema.EnvironmentKindId, graphSchema.SourceKindId); err != nil && !errors.Is(err, ErrNotFound) { + if existing, err := s.GetEnvironmentByEnvironmentKindId(ctx, graphSchema.EnvironmentKindId); err != nil && !errors.Is(err, ErrNotFound) { return 0, fmt.Errorf("error retrieving schema environment: %w", err) } else if !errors.Is(err, ErrNotFound) { // Environment exists - delete it first diff --git a/cmd/api/src/database/upsert_schema_environment_integration_test.go b/cmd/api/src/database/upsert_schema_environment_integration_test.go index 54b03d9a823..330c8cec9b9 100644 --- a/cmd/api/src/database/upsert_schema_environment_integration_test.go +++ b/cmd/api/src/database/upsert_schema_environment_integration_test.go @@ -23,7 +23,6 @@ import ( "testing" "github.com/specterops/bloodhound/cmd/api/src/database" - "github.com/specterops/dawgs/graph" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -58,13 +57,10 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { assert: func(t *testing.T, db *database.BloodhoundDB, args args) { t.Helper() - sourceKind, err := db.GetSourceKindByName(context.Background(), args.sourceKind) - require.NoError(t, err, "unexpected error occurred when retrieving source kind by name") - environmentKind, err := db.GetKindByName(context.Background(), args.environmentKind) require.NoError(t, err, "unexpected error occurred when retrieving kind by name") - environment, err := db.GetEnvironmentByKinds(context.Background(), environmentKind.ID, int32(sourceKind.ID)) + environment, err := db.GetEnvironmentByEnvironmentKindId(context.Background(), environmentKind.ID) require.NoError(t, err, "unexpected error occurred when getting environment by kinds") principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environment.ID) @@ -116,13 +112,10 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { assert: func(t *testing.T, db *database.BloodhoundDB, args args) { t.Helper() - sourceKind, err := db.GetSourceKindByName(context.Background(), args.sourceKind) - require.NoError(t, err, "unexpected error occurred when retrieving source kind by name") - environmentKind, err := db.GetKindByName(context.Background(), args.environmentKind) require.NoError(t, err, "unexpected error occurred when retrieving kind by name") - environment, err := db.GetEnvironmentByKinds(context.Background(), environmentKind.ID, int32(sourceKind.ID)) + environment, err := db.GetEnvironmentByEnvironmentKindId(context.Background(), environmentKind.ID) require.NoError(t, err, "unexpected error occurred when getting environment by kinds") principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environment.ID) @@ -155,7 +148,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { sourceKind, err := db.GetSourceKindByName(context.Background(), args.sourceKind) require.NoError(t, err, "unexpected error occurred when retrieving source kind by name") - assert.Equal(t, graph.StringKind(args.sourceKind), sourceKind.Name) + assert.Equal(t, args.sourceKind, sourceKind.Name) assert.True(t, sourceKind.Active) }, }, @@ -199,14 +192,10 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { assert: func(t *testing.T, db *database.BloodhoundDB, args args) { t.Helper() - // Verify no environment was created for this extension - sourceKind, err := db.GetSourceKindByName(context.Background(), args.sourceKind) - require.NoError(t, err, "unexpected error occurred when retrieving source kind by name") - environmentKind, err := db.GetKindByName(context.Background(), args.environmentKind) require.NoError(t, err, "unexpected error occurred when retrieving kind by name") - _, err = db.GetEnvironmentByKinds(context.Background(), environmentKind.ID, int32(sourceKind.ID)) + _, err = db.GetEnvironmentByEnvironmentKindId(context.Background(), environmentKind.ID) assert.Error(t, err, "Environment should not exist after rollback") }, }, @@ -228,14 +217,10 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { assert: func(t *testing.T, db *database.BloodhoundDB, args args) { t.Helper() - // Verify no environment was created for this extension - sourceKind, err := db.GetSourceKindByName(context.Background(), args.sourceKind) - require.NoError(t, err, "unexpected error occurred when retrieving source kind by name") - environmentKind, err := db.GetKindByName(context.Background(), args.environmentKind) require.NoError(t, err, "unexpected error occurred when retrieving kind by name") - _, err = db.GetEnvironmentByKinds(context.Background(), environmentKind.ID, int32(sourceKind.ID)) + _, err = db.GetEnvironmentByEnvironmentKindId(context.Background(), environmentKind.ID) assert.Error(t, err, "Environment should not exist after rollback") }, }, @@ -274,10 +259,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { envKind1, err := db.GetKindByName(context.Background(), "Tag_Tier_Zero") require.NoError(t, err, "unexpected error occurred when retrieving kind 1 by name") - sourceKind, err := db.GetSourceKindByName(context.Background(), "Base") - require.NoError(t, err, "unexpected error occurred when retrieving source kind by name") - - env1, err := db.GetEnvironmentByKinds(context.Background(), envKind1.ID, int32(sourceKind.ID)) + env1, err := db.GetEnvironmentByEnvironmentKindId(context.Background(), envKind1.ID) require.NoError(t, err, "unexpected error occurred when retrieving environment by kinds") principalKinds1, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), env1.ID) @@ -288,7 +270,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { envKind2, err := db.GetKindByName(context.Background(), args.environmentKind) require.NoError(t, err, "unexpected error occurred when retrieving kind 2 by name") - env2, err := db.GetEnvironmentByKinds(context.Background(), envKind2.ID, int32(sourceKind.ID)) + env2, err := db.GetEnvironmentByEnvironmentKindId(context.Background(), envKind2.ID) require.NoError(t, err, "unexpected error occurred when getting environment by kinds") principalKinds2, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), env2.ID) diff --git a/cmd/api/src/database/upsert_schema_extension_integration_test.go b/cmd/api/src/database/upsert_schema_extension_integration_test.go index ef432480c09..662aec9e1a4 100644 --- a/cmd/api/src/database/upsert_schema_extension_integration_test.go +++ b/cmd/api/src/database/upsert_schema_extension_integration_test.go @@ -1068,18 +1068,18 @@ func getAndCompareGraphExtension(t *testing.T, testContext context.Context, db * SetOperator: model.FilterAnd, } - gotNodeKinds model.GraphSchemaNodeKinds - gotRelationshipKinds model.GraphSchemaRelationshipKinds - gotProperties model.GraphSchemaProperties - gotSchemaEnvironments []model.SchemaEnvironment - gotPrincipalKinds model.SchemaEnvironmentPrincipalKinds - sourceKind database.SourceKind - dawgsPrincipalKind model.Kind - dawgsFindingRelationshipKind model.Kind - dawgsFindingEnvironmentKind model.Kind - gotSchemaRelationshipFinding []model.SchemaRelationshipFinding - gotRemediation model.Remediation - findingEnvironment model.SchemaEnvironment + gotNodeKinds model.GraphSchemaNodeKinds + gotRelationshipKinds model.GraphSchemaRelationshipKinds + gotProperties model.GraphSchemaProperties + gotSchemaEnvironments []model.SchemaEnvironment + gotPrincipalKinds model.SchemaEnvironmentPrincipalKinds + sourceKind database.SourceKind + dawgsPrincipalKinds []model.Kind + dawgsPrincipalKind model.Kind + dawgsFindingRelationshipKinds []model.Kind + dawgsFindingRelationshipKind model.Kind + gotSchemaRelationshipFinding []model.SchemaRelationshipFinding + gotRemediation model.Remediation ) // Test Node Kinds @@ -1143,13 +1143,14 @@ func getAndCompareGraphExtension(t *testing.T, testContext context.Context, db * require.Equalf(t, want.EnvironmentsInput[idx].EnvironmentKindName, gotEnvironment.EnvironmentKindName, "EnvironmentInput - EnvironmentKindName mismatch") sourceKind, err = db.GetSourceKindByID(testContext, int(gotEnvironment.SourceKindId)) require.NoError(t, err) - require.Equalf(t, want.EnvironmentsInput[idx].SourceKindName, sourceKind.Name.String(), "EnvironmentInput - EnvironmentKindName mismatch") + require.Equalf(t, want.EnvironmentsInput[idx].SourceKindName, sourceKind.Name, "EnvironmentInput - EnvironmentKindName mismatch") gotPrincipalKinds, err = db.GetPrincipalKindsByEnvironmentId(testContext, gotEnvironment.ID) require.NoError(t, err) require.Equalf(t, len(want.EnvironmentsInput[idx].PrincipalKinds), len(gotPrincipalKinds), "PrincipalKinds - count mismatch") for _, gotPrincipalKind := range gotPrincipalKinds { require.Equalf(t, gotEnvironment.ID, gotPrincipalKind.EnvironmentId, "PrincipalKind - EnvironmentId is invalid") - dawgsPrincipalKind, err = db.GetKindById(testContext, gotPrincipalKind.PrincipalKind) + dawgsPrincipalKinds, err = db.GetKindsByIds(testContext, gotPrincipalKind.PrincipalKind) + dawgsPrincipalKind = dawgsPrincipalKinds[0] require.NoError(t, err) require.Containsf(t, want.EnvironmentsInput[idx].PrincipalKinds, dawgsPrincipalKind.Name, "PrincipalKind - Name mismatch") } @@ -1165,14 +1166,17 @@ func getAndCompareGraphExtension(t *testing.T, testContext context.Context, db * require.Greater(t, finding.ID, int32(0)) require.Equalf(t, gotGraphExtension.ID, finding.SchemaExtensionId, "RelationshipFindingInput - graph schema extension id should be greater than 0") - dawgsFindingRelationshipKind, err = db.GetKindById(testContext, finding.RelationshipKindId) + dawgsFindingRelationshipKinds, err = db.GetKindsByIds(testContext, finding.RelationshipKindId) + dawgsFindingRelationshipKind = dawgsFindingRelationshipKinds[0] require.NoError(t, err) require.Equalf(t, want.RelationshipFindingsInput[i].RelationshipKindName, dawgsFindingRelationshipKind.Name, "RelationshipFindingInput - relationship kind name mismatch") - findingEnvironment, err = db.GetEnvironmentById(testContext, finding.EnvironmentId) + findingEnvironment, err := db.GetEnvironmentById(testContext, finding.EnvironmentId) require.NoError(t, err) - dawgsFindingEnvironmentKind, err = db.GetKindById(testContext, findingEnvironment.EnvironmentKindId) + + dawgsFindingEnvironmentKinds, err := db.GetKindsByIds(testContext, findingEnvironment.EnvironmentKindId) require.NoError(t, err) + dawgsFindingEnvironmentKind := dawgsFindingEnvironmentKinds[0] require.Equalf(t, want.RelationshipFindingsInput[i].EnvironmentKindName, dawgsFindingEnvironmentKind.Name, "RelationshipFindingInput - environment kind name mismatch") require.Equalf(t, want.RelationshipFindingsInput[i].Name, finding.Name, "RelationshipFindingInput - name mismatch") diff --git a/cmd/api/src/database/upsert_schema_finding.go b/cmd/api/src/database/upsert_schema_finding.go index 9bca51a1271..2f15e6c93ad 100644 --- a/cmd/api/src/database/upsert_schema_finding.go +++ b/cmd/api/src/database/upsert_schema_finding.go @@ -36,14 +36,14 @@ func (s *BloodhoundDB) UpsertFinding(ctx context.Context, extensionId int32, sou return model.SchemaRelationshipFinding{}, err } - sourceKindId, err := s.validateAndTranslateSourceKind(ctx, sourceKindName) + _, err = s.validateAndTranslateSourceKind(ctx, sourceKindName) if err != nil { return model.SchemaRelationshipFinding{}, err } - // The unique constraint on (environment_kind_id, source_kind_id) of the Schema Environment table ensures no - // duplicate pairs exist, enabling this logic. - environment, err := s.GetEnvironmentByKinds(ctx, environmentKindId, sourceKindId) + // The unique constraint on environment_kind_id of the Schema Environment table ensures no + // duplicates exist, enabling this logic. + environment, err := s.GetEnvironmentByEnvironmentKindId(ctx, environmentKindId) if err != nil { return model.SchemaRelationshipFinding{}, err } diff --git a/cmd/api/src/services/graphify/convertors.go b/cmd/api/src/services/graphify/convertors.go index 1acab6af23a..b6c5e7830c3 100644 --- a/cmd/api/src/services/graphify/convertors.go +++ b/cmd/api/src/services/graphify/convertors.go @@ -52,6 +52,13 @@ func ConvertGenericNode(entity ein.GenericNode, converted *ConvertedData) error } } + // Uppercase environment_id if present + if rawEnvID, ok := node.PropertyMap["environment_id"]; ok { + if envID, ok := rawEnvID.(string); ok { + node.PropertyMap["environment_id"] = strings.ToUpper(envID) + } + } + // the first element in node.Labels determines which icon the UI renders for the node. // it is critical to specify this information because a node can have up to 3 kinds. if len(node.Labels) > 0 { diff --git a/packages/go/analysis/tiering/helpers.go b/packages/go/analysis/tiering/helpers.go index 556c4ea7b74..d1275929ba3 100644 --- a/packages/go/analysis/tiering/helpers.go +++ b/packages/go/analysis/tiering/helpers.go @@ -24,6 +24,7 @@ import ( ) type SearchTierNodesCtx struct { + TierID int // Tier ID for findings processing IsTierZero bool PrimaryTierKind graph.Kind SearchTierNodes graph.Criteria @@ -32,8 +33,9 @@ type SearchTierNodesCtx struct { SearchPrimaryTierNodesRel graph.Criteria } -func NewSearchTierNodesCtx(tieringEnabled bool, isTierZero bool, primaryKind graph.Kind, tierKinds ...graph.Kind) SearchTierNodesCtx { +func NewSearchTierNodesCtx(tieringEnabled bool, isTierZero bool, tierID int, primaryKind graph.Kind, tierKinds ...graph.Kind) SearchTierNodesCtx { return SearchTierNodesCtx{ + TierID: tierID, IsTierZero: isTierZero, PrimaryTierKind: primaryKind, SearchTierNodesRel: searchTierNodesRel(tieringEnabled, tierKinds...), diff --git a/packages/go/schemagen/generator/sql.go b/packages/go/schemagen/generator/sql.go index 29e7e770553..d5e5871b810 100644 --- a/packages/go/schemagen/generator/sql.go +++ b/packages/go/schemagen/generator/sql.go @@ -241,14 +241,16 @@ DECLARE BEGIN SELECT id INTO retreived_environment_kind_id FROM kind WHERE name = v_environment_kind_name; IF retreived_environment_kind_id IS NULL THEN - RAISE EXCEPTION 'couldn''t find matching kind_id'; + PERFORM genscript_upsert_kind(v_environment_kind_name); + SELECT id INTO retreived_environment_kind_id FROM kind WHERE name = v_environment_kind_name; END IF; SELECT id INTO retreived_source_kind_id FROM source_kinds WHERE name = v_source_kind_name; IF retreived_source_kind_id IS NULL THEN - RAISE EXCEPTION 'couldn''t find matching kind_id'; + PERFORM genscript_upsert_source_kind(v_source_kind_name); + SELECT id INTO retreived_source_kind_id FROM source_kinds WHERE name = v_source_kind_name; END IF; - + IF NOT EXISTS (SELECT id FROM schema_environments se WHERE se.schema_extension_id = v_extension_id) THEN INSERT INTO schema_environments (schema_extension_id, environment_kind_id, source_kind_id) VALUES (v_extension_id, retreived_environment_kind_id, retreived_source_kind_id) RETURNING id INTO schema_environment_id; ELSE @@ -356,17 +358,13 @@ DROP FUNCTION IF EXISTS genscript_upsert_schema_environments_principal_kinds;`) } func GenerateADSpecifics(sb io.StringWriter) { - sb.WriteString("\tPERFORM genscript_upsert_source_kind('Base');\n") - sb.WriteString("\tPERFORM genscript_upsert_kind('Domain');\n") sb.WriteString("\tSELECT genscript_upsert_schema_environments(extension_id, 'Domain', 'Base') INTO environment_id;\n") sb.WriteString("\tPERFORM genscript_upsert_schema_environments_principal_kinds(environment_id, 'User');\n") sb.WriteString("\tPERFORM genscript_upsert_schema_environments_principal_kinds(environment_id, 'Computer');\n") } func GenerateAZSpecifics(sb io.StringWriter) { - sb.WriteString("\tPERFORM genscript_upsert_source_kind('AZBase');\n") - sb.WriteString("\tPERFORM genscript_upsert_kind('Tenant');\n") - sb.WriteString("\tSELECT genscript_upsert_schema_environments(extension_id, 'Tenant', 'AZBase') INTO environment_id;\n") + sb.WriteString("\tSELECT genscript_upsert_schema_environments(extension_id, 'AZTenant', 'AZBase') INTO environment_id;\n") sb.WriteString("\tPERFORM genscript_upsert_schema_environments_principal_kinds(environment_id, 'AZUser');\n") sb.WriteString("\tPERFORM genscript_upsert_schema_environments_principal_kinds(environment_id, 'AZVM');\n") sb.WriteString("\tPERFORM genscript_upsert_schema_environments_principal_kinds(environment_id, 'AZServicePrincipal');\n")