diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index df5189e4..1133a346 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -221,7 +221,7 @@ func (e *Engine) Eval(ctx context.Context, options *EvalOptions) *models.Results allRuleResults := []models.RuleResults{} totalResults := 0 for _, p := range e.policySets { - err := withtimeout.Do(ctx, e.timeouts.Eval, ErrEvalTimedOut, func(ctx context.Context) error { + err := withtimeout.Do(ctx, e.timeouts.Eval, ErrEvalTimedOut{Timeout: e.timeouts.Eval}, func(ctx context.Context) error { ruleResults, err := p.eval(ctx, ¶llelEvalOptions{ resourcesResolver: options.ResourcesResolver, input: &input, diff --git a/pkg/engine/errors.go b/pkg/engine/errors.go index 4e7d8929..6ad3b900 100644 --- a/pkg/engine/errors.go +++ b/pkg/engine/errors.go @@ -16,6 +16,8 @@ package engine import ( "errors" + "fmt" + "time" ) // FailedToLoadRegoAPI indicates that an error occurred while initializing the snyk @@ -35,10 +37,28 @@ var FailedToCompile = errors.New("Failed to compile rules") var ErrFailedToReadBundle = errors.New("failed to load bundle") // ErrInitTimedOut indicates that initialization took too long and was cancelled. -var ErrInitTimedOut = errors.New("initialization timed out") +type ErrInitTimedOut struct { + Timeout time.Duration +} + +func (e ErrInitTimedOut) Error() string { + return fmt.Sprintf("initialization timed out after %s", e.Timeout.String()) +} // ErrEvalTimedOut indicates that evaluation took too long and was cancelled. -var ErrEvalTimedOut = errors.New("evaluation timed out") +type ErrEvalTimedOut struct { + Timeout time.Duration +} + +func (e ErrEvalTimedOut) Error() string { + return fmt.Sprintf("initialization timed out after %s", e.Timeout.String()) +} // ErrQueryTimedOut indicates that a query took too long and was cancelled. -var ErrQueryTimedOut = errors.New("query timed out") +type ErrQueryTimedOut struct { + Timeout time.Duration +} + +func (e ErrQueryTimedOut) Error() string { + return fmt.Sprintf("query timed out after %s", e.Timeout.String()) +} diff --git a/pkg/engine/policyset.go b/pkg/engine/policyset.go index a0e7846d..dd63e4ea 100644 --- a/pkg/engine/policyset.go +++ b/pkg/engine/policyset.go @@ -96,7 +96,7 @@ func newPolicySet(ctx context.Context, options policySetOptions) (*policySet, er s.instrumentation.startInitialization(ctx) defer s.instrumentation.finishInitialization(ctx, s) - err := withtimeout.Do(ctx, options.timeouts.Init, ErrInitTimedOut, func(ctx context.Context) error { + err := withtimeout.Do(ctx, options.timeouts.Init, ErrInitTimedOut{options.timeouts.Init}, func(ctx context.Context) error { if err := s.loadRegoAPI(ctx); err != nil { return fmt.Errorf("%w: %v", FailedToLoadRegoAPI, err) } @@ -211,7 +211,7 @@ type policyFilter func(ctx context.Context, pol policy.Policy) (bool, error) func (s *policySet) selectPolicies(ctx context.Context, filters []policyFilter) ([]policy.Policy, error) { s.instrumentation.startPolicySelection(ctx) var subset []policy.Policy - err := withtimeout.Do(ctx, s.timeouts.Query, ErrQueryTimedOut, func(ctx context.Context) error { + err := withtimeout.Do(ctx, s.timeouts.Query, ErrQueryTimedOut{s.timeouts.Query}, func(ctx context.Context) error { for _, pol := range s.policies { include := true for _, filter := range filters { @@ -408,7 +408,7 @@ func (s *policySet) metadata(ctx context.Context) ([]MetadataResult, error) { return policies[i].Package() < policies[j].Package() }) metadata := make([]MetadataResult, len(policies)) - err := withtimeout.Do(ctx, s.timeouts.Query, ErrQueryTimedOut, func(ctx context.Context) error { + err := withtimeout.Do(ctx, s.timeouts.Query, ErrQueryTimedOut{s.timeouts.Query}, func(ctx context.Context) error { for idx, p := range policies { m, err := p.Metadata(ctx, s.rego) result := MetadataResult{ diff --git a/pkg/interfacetricks/extract.go b/pkg/interfacetricks/extract.go index 7385ace8..e117dd2f 100644 --- a/pkg/interfacetricks/extract.go +++ b/pkg/interfacetricks/extract.go @@ -15,15 +15,24 @@ package interfacetricks import ( + "errors" "fmt" "reflect" "strings" ) +var SetError = errors.New("cannot set destination (hint: use pointer receiver?)") +var TypeError = errors.New("type error") + type ExtractError struct { - SrcPath []interface{} - SrcType reflect.Type - DstType reflect.Type + underlying error + SrcPath []interface{} + SrcType reflect.Type + DstType reflect.Type +} + +func (e ExtractError) Unwrap() error { + return e.underlying } func (e ExtractError) Error() string { @@ -72,9 +81,27 @@ func Extract(src interface{}, dst interface{}) []error { // there as well. func extract(path []interface{}, src interface{}, dst reflect.Value) (errs []error) { ty := dst.Type() - switch ty.Kind() { - case reflect.Pointer: + + makeExtractError := func(err error) ExtractError { + pcopy := make([]interface{}, len(path)) + copy(pcopy, path) + return ExtractError{ + underlying: err, + SrcPath: pcopy, + SrcType: reflect.TypeOf(src), + DstType: ty, + } + } + + if ty.Kind() == reflect.Pointer { return extract(path, src, dst.Elem()) + } + + if !dst.CanSet() { + return []error{makeExtractError(SetError)} + } + + switch ty.Kind() { case reflect.Struct: if srcObject, ok := src.(map[string]interface{}); ok { for i := 0; i < ty.NumField(); i++ { @@ -92,13 +119,12 @@ func extract(path []interface{}, src interface{}, dst reflect.Value) (errs []err errs = append(errs, extract(path, srcFieldVal, goFieldVal)...) path = path[:len(path)-1] } - } else { } } return } case reflect.Slice: - if srcArray, ok := src.([]interface{}); ok && dst.CanSet() { + if srcArray, ok := src.([]interface{}); ok { dst.Set(reflect.MakeSlice(ty, len(srcArray), len(srcArray))) for i := 0; i < len(srcArray); i++ { path = append(path, i) @@ -108,7 +134,7 @@ func extract(path []interface{}, src interface{}, dst reflect.Value) (errs []err return } case reflect.Map: - if srcObject, ok := src.(map[string]interface{}); ok && dst.CanSet() { + if srcObject, ok := src.(map[string]interface{}); ok { dst.Set(reflect.MakeMap(ty)) for k, v := range srcObject { path = append(path, k) @@ -125,57 +151,37 @@ func extract(path []interface{}, src interface{}, dst reflect.Value) (errs []err return } case reflect.Interface: - if dst.CanSet() { - dst.Set(reflect.ValueOf(src)) - return - } + dst.Set(reflect.ValueOf(src)) + return case reflect.Bool: - if boolean, ok := src.(bool); ok && dst.CanSet() { + if boolean, ok := src.(bool); ok { dst.SetBool(boolean) return } case reflect.Int: if number, ok := src.(int64); ok { - if dst.CanSet() { - dst.SetInt(number) - return - } + dst.SetInt(number) + return } else if number, ok := src.(int); ok { - if dst.CanSet() { - dst.SetInt(int64(number)) - return - } + dst.SetInt(int64(number)) + return } else if number, ok := src.(float64); ok { - if dst.CanSet() { - dst.SetInt(int64(number)) - return - } + dst.SetInt(int64(number)) + return } case reflect.Float64: if number, ok := src.(float64); ok { - if dst.CanSet() { - dst.SetFloat(number) - return - } + dst.SetFloat(number) + return } case reflect.String: - if str, ok := src.(string); ok && dst.CanSet() { + if str, ok := src.(string); ok { dst.SetString(str) return } } - return []error{newExtractError(path, src, ty)} -} - -func newExtractError(path []interface{}, src interface{}, dst reflect.Type) ExtractError { - pcopy := make([]interface{}, len(path)) - copy(pcopy, path) - return ExtractError{ - SrcPath: pcopy, - SrcType: reflect.TypeOf(src), - DstType: dst, - } + return []error{makeExtractError(TypeError)} } func getJsonFieldName(field reflect.StructField) (string, bool) { diff --git a/pkg/interfacetricks/extract_test.go b/pkg/interfacetricks/extract_test.go index 524d3ff8..783b5833 100644 --- a/pkg/interfacetricks/extract_test.go +++ b/pkg/interfacetricks/extract_test.go @@ -34,6 +34,23 @@ func TestPrimitives(t *testing.T) { }, &p1, &p2) } +func TestPrimitiveErrors(t *testing.T) { + p1 := Primitives{} + errs := Extract(map[string]interface{}{ + "int": int(1), + }, p1) + require.Equal(t, []error{ + ExtractError{ + underlying: SetError, + SrcPath: []interface{}{}, + SrcType: reflect.TypeOf(map[string]interface{}{}), + DstType: reflect.TypeOf(p1), + }, + }, errs) + require.ErrorAs(t, errs[0], &ExtractError{}) + require.ErrorIs(t, errs[0], SetError) +} + type Collections struct { Slice []Primitives `json:"slice"` Map map[string]int `json:"map"` @@ -70,16 +87,20 @@ func TestCollectionErrors(t *testing.T) { }, &dst) require.Equal(t, []error{ ExtractError{ - SrcPath: []interface{}{"slice", 0}, - SrcType: intType, - DstType: reflect.TypeOf(Primitives{}), + underlying: TypeError, + SrcPath: []interface{}{"slice", 0}, + SrcType: intType, + DstType: reflect.TypeOf(Primitives{}), }, ExtractError{ - SrcPath: []interface{}{"slice", 1, "bool"}, - SrcType: stringType, - DstType: boolType, + underlying: TypeError, + SrcPath: []interface{}{"slice", 1, "bool"}, + SrcType: stringType, + DstType: boolType, }, }, errs) + require.ErrorAs(t, errs[0], &ExtractError{}) + require.ErrorIs(t, errs[0], TypeError) require.Equal(t, Collections{ Slice: []Primitives{ {},