diff --git a/functions.go b/functions.go index d110062..54cb4aa 100644 --- a/functions.go +++ b/functions.go @@ -13,6 +13,18 @@ type FunctionValue struct { func (f *FunctionValue) Eval(ctx *Ctx) Result { fn, ok := StdlibFuncs[f.fn] if !ok { + if ctx.Macros != nil { + if macro, ok := ctx.Macros[f.fn]; ok { + if len(f.args.vals) > 0 { + return Result{ + Error: fmt.Errorf("macro %q expects 0 arguments, got %d", f.fn, len(f.args.vals)), + EvaluatedRule: f, + } + } + return macro.Eval(ctx) + } + } + return Result{ Error: fmt.Errorf("unknown function %q", f.fn), EvaluatedRule: f, diff --git a/rule.go b/rule.go index 5445267..76f531c 100644 --- a/rule.go +++ b/rule.go @@ -129,7 +129,8 @@ func MustParse(str string) Rule { type KV = map[string]any type Ctx struct { - KV KV + KV KV + Macros map[string]Rule } func (c *Ctx) Eval(r Rule) Result { diff --git a/rule_test.go b/rule_test.go index a6c47dc..b00fd60 100644 --- a/rule_test.go +++ b/rule_test.go @@ -706,6 +706,12 @@ func (k kv) Ctx() *Ctx { return &Ctx{KV: k} } +type ctx Ctx + +func (c *ctx) Ctx() *Ctx { + return (*Ctx)(c) +} + func assertParseEval(t *testing.T, rule string, input Ctxer, pass bool) { t.Helper() r, err := Parse(rule) @@ -716,7 +722,7 @@ func assertParseEval(t *testing.T, rule string, input Ctxer, pass bool) { // assertEval is a helper function to assert the result of a rule evaluation. // It enforces strict evaluation. func assertEval(t *testing.T, r Rule, input Ctxer, value any) { - var ctx *Ctx + ctx := &Ctx{} if input != nil { ctx = input.Ctx() } @@ -943,6 +949,44 @@ func TestFunctionParsing(t *testing.T) { assert.IsType(t, &FunctionValue{}, fn.args.vals[3]) } +func TestMacros(t *testing.T) { + r := MustParse(`dst_k8s_svc() && user != "root"`) + macros := map[string]Rule{ + "dst_k8s_svc": MustParse(`ip in 172.16.0.0/16 or host matches /svc.cluster.local$/`), + } + + assertRule(t, r, &ctx{ + Macros: macros, + KV: KV{ + "ip": net.ParseIP("172.16.0.1"), + "user": "nouser", + }, + }). + Pass(). + EvaluatedRule(`ip == 172.16.0.0/16 and user != "root"`) + + assertRule(t, r, &ctx{ + Macros: macros, + KV: KV{ + "ip": net.ParseIP("1.1.1.1"), + "host": "1.1.1.1", + "user": "nouser", + }, + }). + Fail(). + EvaluatedRule(`ip == 172.16.0.0/16 or host =~ /svc.cluster.local$/`) + + assertRule(t, r, &ctx{ + Macros: macros, + KV: KV{ + "host": "test.svc.cluster.local", + "user": "nouser", + }, + }). + Pass(). + EvaluatedRule(`host =~ /svc.cluster.local$/ and user != "root"`) +} + type ruleAssertion struct { t *testing.T rule Rule @@ -958,7 +1002,7 @@ func assertRulep(t *testing.T, rule string, input Ctxer) *ruleAssertion { } func assertRule(t *testing.T, rule Rule, input Ctxer) *ruleAssertion { - var ctx *Ctx + ctx := &Ctx{} if input != nil { ctx = input.Ctx() }