Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions cmd/api/src/api/v2/opengraphschema.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type GraphExtensionPayload struct {
GraphSchemaRelationshipKinds []GraphSchemaRelationshipKindsPayload `json:"relationship_kinds"`
GraphSchemaNodeKinds []GraphSchemaNodeKindsPayload `json:"node_kinds"`
GraphEnvironments []EnvironmentPayload `json:"environments"`
GraphRelationshipFindings []RelationshipFindingsPayload `json:"relationship_findings"`
GraphRelationshipFindings []FindingsPayload `json:"relationship_findings"`
}

type GraphSchemaExtensionPayload struct {
Expand Down Expand Up @@ -88,13 +88,13 @@ type EnvironmentPayload struct {
PrincipalKinds []string `json:"principal_kinds"`
}

type RelationshipFindingsPayload struct {
Name string `json:"name"`
DisplayName string `json:"display_name"`
SourceKind string `json:"source_kind"`
RelationshipKind string `json:"relationship_kind"`
EnvironmentKind string `json:"environment_kind"`
Remediation RemediationPayload `json:"remediation"`
type FindingsPayload struct {
Name string `json:"name"`
DisplayName string `json:"display_name"`
SourceKind string `json:"source_kind"`
Kind string `json:"kind"`
EnvironmentKind string `json:"environment_kind"`
Remediation RemediationPayload `json:"remediation"`
}

type RemediationPayload struct {
Expand Down Expand Up @@ -240,12 +240,12 @@ func convertGraphExtensionPayloadToGraphExtension(payload GraphExtensionPayload)
})
}
for _, findingPayload := range payload.GraphRelationshipFindings {
graphExtension.RelationshipFindingsInput = append(graphExtension.RelationshipFindingsInput, model.RelationshipFindingInput{
Name: findingPayload.Name,
DisplayName: findingPayload.DisplayName,
SourceKindName: findingPayload.SourceKind,
RelationshipKindName: findingPayload.RelationshipKind,
EnvironmentKindName: findingPayload.EnvironmentKind,
graphExtension.FindingsInput = append(graphExtension.FindingsInput, model.FindingInput{
Name: findingPayload.Name,
DisplayName: findingPayload.DisplayName,
SourceKindName: findingPayload.SourceKind,
KindName: findingPayload.Kind,
EnvironmentKindName: findingPayload.EnvironmentKind,
RemediationInput: model.RemediationInput{
ShortDescription: findingPayload.Remediation.ShortDescription,
LongDescription: findingPayload.Remediation.LongDescription,
Expand Down
24 changes: 12 additions & 12 deletions cmd/api/src/api/v2/opengraphschema_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ func Test_convertGraphExtensionPayloadToGraphExtension(t *testing.T) {
PrincipalKinds: []string{"User"},
},
},
GraphRelationshipFindings: []RelationshipFindingsPayload{
GraphRelationshipFindings: []FindingsPayload{
{
Name: "Finding_1",
DisplayName: "Finding 1",
SourceKind: "Source_Kind_1",
RelationshipKind: "GraphSchemaEdgeKind_1",
EnvironmentKind: "EnvironmentInput",
Name: "Finding_1",
DisplayName: "Finding 1",
SourceKind: "Source_Kind_1",
Kind: "GraphSchemaEdgeKind_1",
EnvironmentKind: "EnvironmentInput",
Remediation: RemediationPayload{
ShortDescription: "remediation for Finding_1",
LongDescription: "a remediation for Finding 1",
Expand Down Expand Up @@ -130,13 +130,13 @@ func Test_convertGraphExtensionPayloadToGraphExtension(t *testing.T) {
PrincipalKinds: []string{"User"},
},
},
RelationshipFindingsInput: []model.RelationshipFindingInput{
FindingsInput: []model.FindingInput{
{
Name: "Finding_1",
DisplayName: "Finding 1",
SourceKindName: "Source_Kind_1",
RelationshipKindName: "GraphSchemaEdgeKind_1",
EnvironmentKindName: "EnvironmentInput",
Name: "Finding_1",
DisplayName: "Finding 1",
SourceKindName: "Source_Kind_1",
KindName: "GraphSchemaEdgeKind_1",
EnvironmentKindName: "EnvironmentInput",
RemediationInput: model.RemediationInput{
ShortDescription: "remediation for Finding_1",
LongDescription: "a remediation for Finding 1",
Expand Down
24 changes: 12 additions & 12 deletions cmd/api/src/api/v2/opengraphschema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,13 @@ func TestResources_OpenGraphSchemaIngest(t *testing.T) {
PrincipalKinds: []string{"User"},
},
},
GraphRelationshipFindings: []v2.RelationshipFindingsPayload{
GraphRelationshipFindings: []v2.FindingsPayload{
{
Name: "TEST_Finding_1",
DisplayName: "Finding 1",
SourceKind: "Source_Kind_1",
RelationshipKind: "TEST_GraphSchemaEdgeKind_1",
EnvironmentKind: "TEST_EnvironmentInput",
Name: "TEST_Finding_1",
DisplayName: "Finding 1",
SourceKind: "Source_Kind_1",
Kind: "TEST_GraphSchemaEdgeKind_1",
EnvironmentKind: "TEST_EnvironmentInput",
Remediation: v2.RemediationPayload{
ShortDescription: "remediation for Finding_1",
LongDescription: "a remediation for Finding 1",
Expand Down Expand Up @@ -145,13 +145,13 @@ func TestResources_OpenGraphSchemaIngest(t *testing.T) {
PrincipalKinds: []string{"User"},
},
},
RelationshipFindingsInput: []model.RelationshipFindingInput{
FindingsInput: []model.FindingInput{
{
Name: "TEST_Finding_1",
DisplayName: "Finding 1",
SourceKindName: "Source_Kind_1",
RelationshipKindName: "TEST_GraphSchemaEdgeKind_1",
EnvironmentKindName: "TEST_EnvironmentInput",
Name: "TEST_Finding_1",
DisplayName: "Finding 1",
SourceKindName: "Source_Kind_1",
KindName: "TEST_GraphSchemaEdgeKind_1",
EnvironmentKindName: "TEST_EnvironmentInput",
RemediationInput: model.RemediationInput{
ShortDescription: "remediation for Finding_1",
LongDescription: "a remediation for Finding 1",
Expand Down
8 changes: 2 additions & 6 deletions cmd/api/src/database/assetgrouptags.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ import (
"gorm.io/gorm"
)

const (
kindTable = "kind"
)

// AssetGroupTagData defines the methods required to interact with the asset_group_tags table
type AssetGroupTagData interface {
CreateAssetGroupTag(ctx context.Context, tagType model.AssetGroupTagType, user model.User, name string, description string, position null.Int32, requireCertify null.Bool, glyph null.String) (model.AssetGroupTag, error)
Expand Down Expand Up @@ -361,7 +357,7 @@ func (s *BloodhoundDB) CreateAssetGroupTag(ctx context.Context, tagType model.As
INSERT INTO %s (type, kind_id, name, description, created_at, created_by, updated_at, updated_by, position, require_certify, analysis_enabled, glyph)
VALUES (?, (SELECT id FROM inserted_kind), ?, ?, NOW(), ?, NOW(), ?, ?, ?, ?, ?)
RETURNING id, type, kind_id, name, description, created_at, created_by, updated_at, updated_by, position, require_certify, analysis_enabled, glyph
`, kindTable, tag.TableName())
`, model.Kind{}.TableName(), tag.TableName())

if result := tx.Raw(query,
tag.KindName(),
Expand Down Expand Up @@ -481,7 +477,7 @@ func (s *BloodhoundDB) UpdateAssetGroupTag(ctx context.Context, user model.User,
} else {
if origTag.Name != tag.Name {
if result := tx.Exec(
fmt.Sprintf(`UPDATE %s SET name = ? WHERE id = ?`, kindTable),
fmt.Sprintf(`UPDATE %s SET name = ? WHERE id = ?`, model.Kind{}.TableName()),
tag.KindName(),
tag.KindId,
); result.Error != nil {
Expand Down
34 changes: 17 additions & 17 deletions cmd/api/src/database/graphschema.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,13 +323,13 @@ func (s *BloodhoundDB) GetGraphSchemaNodeKinds(ctx context.Context, filters mode
FROM %s nk
JOIN %s k ON nk.kind_id = k.id
%s %s %s`,
model.GraphSchemaNodeKind{}.TableName(), kindTable, filterAndPagination.WhereClause, filterAndPagination.OrderSql, filterAndPagination.SkipLimit)
model.GraphSchemaNodeKind{}.TableName(), model.Kind{}.TableName(), filterAndPagination.WhereClause, filterAndPagination.OrderSql, filterAndPagination.SkipLimit)
if result := s.db.WithContext(ctx).Raw(sqlStr, filterAndPagination.Filter.params...).Scan(&schemaNodeKinds); result.Error != nil {
return nil, 0, CheckError(result)
} else {
if limit > 0 || skip > 0 {
countSqlStr := fmt.Sprintf(`SELECT COUNT(*) FROM %s nk JOIN %s k ON nk.kind_id = k.id %s`,
model.GraphSchemaNodeKind{}.TableName(), kindTable, filterAndPagination.WhereClause)
model.GraphSchemaNodeKind{}.TableName(), model.Kind{}.TableName(), filterAndPagination.WhereClause)
if countResult := s.db.WithContext(ctx).Raw(countSqlStr, filterAndPagination.Filter.params...).Scan(&totalRowCount); countResult.Error != nil {
return model.GraphSchemaNodeKinds{}, 0, CheckError(countResult)
}
Expand All @@ -346,8 +346,8 @@ func (s *BloodhoundDB) GetGraphSchemaNodeKindById(ctx context.Context, schemaNod
var schemaNodeKind model.GraphSchemaNodeKind
if result := s.db.WithContext(ctx).Raw(fmt.Sprintf(`
SELECT %s.id, name, schema_extension_id, display_name, description, is_display_kind, icon, icon_color, created_at, updated_at, deleted_at
FROM %s JOIN %s ON %s.kind_id = %s.id WHERE %s.id = ?`, schemaNodeKind.TableName(), schemaNodeKind.TableName(), kindTable,
schemaNodeKind.TableName(), kindTable, schemaNodeKind.TableName()), schemaNodeKindId).First(&schemaNodeKind); result.Error != nil {
FROM %s JOIN %s ON %s.kind_id = %s.id WHERE %s.id = ?`, schemaNodeKind.TableName(), schemaNodeKind.TableName(), model.Kind{}.TableName(),
schemaNodeKind.TableName(), model.Kind{}.TableName(), schemaNodeKind.TableName()), schemaNodeKindId).First(&schemaNodeKind); result.Error != nil {
return model.GraphSchemaNodeKind{}, CheckError(result)
}
return schemaNodeKind, nil
Expand All @@ -366,10 +366,10 @@ func (s *BloodhoundDB) UpdateGraphSchemaNodeKind(ctx context.Context, schemaNode
WHERE id = ?
RETURNING id, kind_id, schema_extension_id, display_name, description, is_display_kind, icon, icon_color, created_at, updated_at, deleted_at
)
SELECT updated_row.id, %s.name, schema_extension_id, display_name, description, is_display_kind, icon, icon_color, created_at, updated_at, deleted_at
SELECT updated_row.id, k.name, schema_extension_id, display_name, description, is_display_kind, icon, icon_color, created_at, updated_at, deleted_at
FROM updated_row
JOIN %s ON %s.id = updated_row.kind_id`,
schemaNodeKind.TableName(), kindTable, kindTable, kindTable), schemaNodeKind.SchemaExtensionId,
JOIN %s k ON k.id = updated_row.kind_id`,
schemaNodeKind.TableName(), model.Kind{}.TableName()), schemaNodeKind.SchemaExtensionId,
schemaNodeKind.DisplayName, schemaNodeKind.Description, schemaNodeKind.IsDisplayKind, schemaNodeKind.Icon,
schemaNodeKind.IconColor, schemaNodeKind.ID).Scan(&schemaNodeKind); result.Error != nil {
if strings.Contains(result.Error.Error(), DuplicateKeyValueErrorString) {
Expand Down Expand Up @@ -541,14 +541,14 @@ func (s *BloodhoundDB) GetGraphSchemaRelationshipKinds(ctx context.Context, rela
FROM %s ek
JOIN %s k ON ek.kind_id = k.id
%s %s %s`,
model.GraphSchemaRelationshipKind{}.TableName(), kindTable, filterAndPagination.WhereClause,
model.GraphSchemaRelationshipKind{}.TableName(), model.Kind{}.TableName(), filterAndPagination.WhereClause,
filterAndPagination.OrderSql, filterAndPagination.SkipLimit)
if result := s.db.WithContext(ctx).Raw(sqlStr, filterAndPagination.Filter.params...).Scan(&schemaRelationshipKinds); result.Error != nil {
return nil, 0, CheckError(result)
} else {
if limit > 0 || skip > 0 {
countSqlStr := fmt.Sprintf(`SELECT COUNT(*) FROM %s ek JOIN %s k on ek.kind_id = k.id %s`,
model.GraphSchemaRelationshipKind{}.TableName(), kindTable, filterAndPagination.WhereClause)
model.GraphSchemaRelationshipKind{}.TableName(), model.Kind{}.TableName(), filterAndPagination.WhereClause)
if countResult := s.db.WithContext(ctx).Raw(countSqlStr, filterAndPagination.Filter.params...).Scan(&totalRowCount); countResult.Error != nil {
return model.GraphSchemaRelationshipKinds{}, 0, CheckError(countResult)
}
Expand All @@ -573,7 +573,7 @@ func (s *BloodhoundDB) GetGraphSchemaRelationshipKindsWithSchemaName(ctx context
FROM %s edge JOIN %s schema ON edge.schema_extension_id = schema.id JOIN %s k ON edge.kind_id = k.id %s %s %s`,
model.GraphSchemaRelationshipKind{}.TableName(),
model.GraphSchemaExtension{}.TableName(),
kindTable,
model.Kind{}.TableName(),
filterAndPagination.WhereClause,
filterAndPagination.OrderSql,
filterAndPagination.SkipLimit)
Expand All @@ -583,7 +583,7 @@ func (s *BloodhoundDB) GetGraphSchemaRelationshipKindsWithSchemaName(ctx context
} else {
if limit > 0 || skip > 0 {
countSqlStr := fmt.Sprintf(`SELECT COUNT(*) FROM %s edge JOIN %s schema ON edge.schema_extension_id = schema.id JOIN %s k ON edge.kind_id = k.id %s`,
model.GraphSchemaRelationshipKind{}.TableName(), model.GraphSchemaExtension{}.TableName(), kindTable,
model.GraphSchemaRelationshipKind{}.TableName(), model.GraphSchemaExtension{}.TableName(), model.Kind{}.TableName(),
filterAndPagination.WhereClause)
if countResult := s.db.WithContext(ctx).Raw(countSqlStr, filterAndPagination.Filter.params...).Scan(&totalRowCount); countResult.Error != nil {
return model.GraphSchemaRelationshipKindsWithNamedSchema{}, 0, CheckError(countResult)
Expand All @@ -602,8 +602,8 @@ func (s *BloodhoundDB) GetGraphSchemaRelationshipKindById(ctx context.Context, s
var schemaRelationshipKind model.GraphSchemaRelationshipKind
if result := s.db.WithContext(ctx).Raw(fmt.Sprintf(`
SELECT %s.id, name, schema_extension_id, description, is_traversable, created_at, updated_at, deleted_at
FROM %s JOIN %s ON %s.kind_id = %s.id WHERE %s.id = ?`, schemaRelationshipKind.TableName(), schemaRelationshipKind.TableName(), kindTable,
schemaRelationshipKind.TableName(), kindTable, schemaRelationshipKind.TableName()), schemaRelationshipKindId).First(&schemaRelationshipKind); result.Error != nil {
FROM %s JOIN %s ON %s.kind_id = %s.id WHERE %s.id = ?`, schemaRelationshipKind.TableName(), schemaRelationshipKind.TableName(), model.Kind{}.TableName(),
schemaRelationshipKind.TableName(), model.Kind{}.TableName(), schemaRelationshipKind.TableName()), schemaRelationshipKindId).First(&schemaRelationshipKind); result.Error != nil {
return schemaRelationshipKind, CheckError(result)
}
return schemaRelationshipKind, nil
Expand All @@ -622,10 +622,10 @@ func (s *BloodhoundDB) UpdateGraphSchemaRelationshipKind(ctx context.Context, sc
WHERE id = ?
RETURNING id, kind_id, schema_extension_id, description, is_traversable, created_at, updated_at, deleted_at
)
SELECT updated_row.id, %s.name, schema_extension_id, description, is_traversable, created_at, updated_at, deleted_at
SELECT updated_row.id, k.name, schema_extension_id, description, is_traversable, created_at, updated_at, deleted_at
FROM updated_row
JOIN %s ON %s.id = updated_row.kind_id`,
schemaRelationshipKind.TableName(), kindTable, kindTable, kindTable),
JOIN %s k ON k.id = updated_row.kind_id`,
schemaRelationshipKind.TableName(), model.Kind{}.TableName()),
schemaRelationshipKind.SchemaExtensionId, schemaRelationshipKind.Description, schemaRelationshipKind.IsTraversable,
schemaRelationshipKind.ID).Scan(&schemaRelationshipKind); result.Error != nil {
if strings.Contains(result.Error.Error(), DuplicateKeyValueErrorString) {
Expand Down Expand Up @@ -710,7 +710,7 @@ func (s *BloodhoundDB) GetEnvironmentsByExtensionId(ctx context.Context, extensi
JOIN %s k ON e.environment_kind_id = k.id
WHERE schema_extension_id = ?
ORDER BY id`,
model.SchemaEnvironment{}.TableName(), kindTable), extensionId).Scan(&environments); result.Error != nil {
model.SchemaEnvironment{}.TableName(), model.Kind{}.TableName()), extensionId).Scan(&environments); result.Error != nil {
return nil, CheckError(result)
}

Expand Down
4 changes: 2 additions & 2 deletions cmd/api/src/database/graphschema_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2851,7 +2851,7 @@ func TestDatabase_GraphSchemaRelationshipKind_CRUD(t *testing.T) {
rk.SchemaExtensionId == want.SchemaExtensionId {

// Additional validations for the found item
assert.Greater(t, rk.ID, int32(0), "RelationshipKind %v - ID is invalid", rk.Name)
assert.Greater(t, rk.ID, int32(0), "Kind %v - ID is invalid", rk.Name)

found = true
break
Expand Down Expand Up @@ -3612,7 +3612,7 @@ func TestDatabase_GetGraphSchemaRelationshipKindsWithSchemaName(t *testing.T) {
rk.SchemaName == want.SchemaName {

// Additional validations for the found item
assert.Greater(t, rk.ID, int32(0), "RelationshipKind %v - ID is invalid", rk.Name)
assert.Greater(t, rk.ID, int32(0), "Kind %v - ID is invalid", rk.Name)

found = true
break
Expand Down
6 changes: 3 additions & 3 deletions cmd/api/src/database/upsert_schema_extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (s *BloodhoundDB) UpsertOpenGraphExtension(ctx context.Context, graphExtens
graphExtensionInput.EnvironmentsInput); err != nil {
return schemaExists, err
} else if err = bloodhoundDBTransaction.upsertFindingsAndRemediations(ctx, createdExtension.ID,
graphExtensionInput.RelationshipFindingsInput); err != nil {
graphExtensionInput.FindingsInput); err != nil {
return schemaExists, err
} else if err = tx.Commit().Error; err != nil {
return schemaExists, err
Expand Down Expand Up @@ -155,10 +155,10 @@ func (s *BloodhoundDB) upsertGraphEnvironments(ctx context.Context, extensionID
}

// upsertFindingsAndRemediations - inserts a slice of new findings/remediations for the provided extension.
func (s *BloodhoundDB) upsertFindingsAndRemediations(ctx context.Context, extensionId int32, findings model.RelationshipFindingsInput) error {
func (s *BloodhoundDB) upsertFindingsAndRemediations(ctx context.Context, extensionId int32, findings model.FindingsInput) error {
for _, finding := range findings {
if schemaFinding, err := s.UpsertFinding(ctx, extensionId, finding.SourceKindName,
finding.RelationshipKindName, finding.EnvironmentKindName, finding.Name, finding.DisplayName); err != nil {
finding.KindName, finding.EnvironmentKindName, finding.Name, finding.DisplayName); err != nil {
return fmt.Errorf("failed to upsert finding: %w", err)
} else {
if err := s.UpsertRemediation(ctx, schemaFinding.ID, finding.RemediationInput.ShortDescription,
Expand Down
Loading