Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@ type FunctionValue struct {
func (f *FunctionValue) Eval(ctx *Ctx) Result {
fn, ok := StdlibFuncs[f.fn]
if !ok {
if ctx.Macros != nil {
if macro, ok := ctx.Macros[f.fn]; ok {
if len(f.args.vals) > 0 {
return Result{
Error: fmt.Errorf("macro %q expects 0 arguments, got %d", f.fn, len(f.args.vals)),
EvaluatedRule: f,
}
}
return macro.Eval(ctx)
}
}

return Result{
Error: fmt.Errorf("unknown function %q", f.fn),
EvaluatedRule: f,
Expand Down
3 changes: 2 additions & 1 deletion rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ func MustParse(str string) Rule {
type KV = map[string]any

type Ctx struct {
KV KV
KV KV
Macros map[string]Rule
}

func (c *Ctx) Eval(r Rule) Result {
Expand Down
48 changes: 46 additions & 2 deletions rule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,12 @@ func (k kv) Ctx() *Ctx {
return &Ctx{KV: k}
}

type ctx Ctx

func (c *ctx) Ctx() *Ctx {
return (*Ctx)(c)
}

func assertParseEval(t *testing.T, rule string, input Ctxer, pass bool) {
t.Helper()
r, err := Parse(rule)
Expand All @@ -716,7 +722,7 @@ func assertParseEval(t *testing.T, rule string, input Ctxer, pass bool) {
// assertEval is a helper function to assert the result of a rule evaluation.
// It enforces strict evaluation.
func assertEval(t *testing.T, r Rule, input Ctxer, value any) {
var ctx *Ctx
ctx := &Ctx{}
if input != nil {
ctx = input.Ctx()
}
Expand Down Expand Up @@ -943,6 +949,44 @@ func TestFunctionParsing(t *testing.T) {
assert.IsType(t, &FunctionValue{}, fn.args.vals[3])
}

func TestMacros(t *testing.T) {
r := MustParse(`dst_k8s_svc() && user != "root"`)
macros := map[string]Rule{
"dst_k8s_svc": MustParse(`ip in 172.16.0.0/16 or host matches /svc.cluster.local$/`),
}

assertRule(t, r, &ctx{
Macros: macros,
KV: KV{
"ip": net.ParseIP("172.16.0.1"),
"user": "nouser",
},
}).
Pass().
EvaluatedRule(`ip == 172.16.0.0/16 and user != "root"`)

assertRule(t, r, &ctx{
Macros: macros,
KV: KV{
"ip": net.ParseIP("1.1.1.1"),
"host": "1.1.1.1",
"user": "nouser",
},
}).
Fail().
EvaluatedRule(`ip == 172.16.0.0/16 or host =~ /svc.cluster.local$/`)

assertRule(t, r, &ctx{
Macros: macros,
KV: KV{
"host": "test.svc.cluster.local",
"user": "nouser",
},
}).
Pass().
EvaluatedRule(`host =~ /svc.cluster.local$/ and user != "root"`)
}

type ruleAssertion struct {
t *testing.T
rule Rule
Expand All @@ -958,7 +1002,7 @@ func assertRulep(t *testing.T, rule string, input Ctxer) *ruleAssertion {
}

func assertRule(t *testing.T, rule Rule, input Ctxer) *ruleAssertion {
var ctx *Ctx
ctx := &Ctx{}
if input != nil {
ctx = input.Ctx()
}
Expand Down