3232
3333#include " cel/expr/syntax.pb.h"
3434#include " absl/base/macros.h"
35+ #include " absl/base/nullability.h"
3536#include " absl/base/optimization.h"
3637#include " absl/container/btree_map.h"
3738#include " absl/container/flat_hash_map.h"
@@ -601,23 +602,151 @@ Expr ExpressionBalancer::BalancedTree(int lo, int hi) {
601602 return factory_.NewCall (ops_[mid], function_, std::move (arguments));
602603}
603604
605+ // Lightweight overlay for a registry.
606+ // Adds stateful macros that are relevant per Parse call.
607+ class AugmentedMacroRegistry {
608+ public:
609+ explicit AugmentedMacroRegistry (const cel::MacroRegistry& registry)
610+ : base_(registry) {}
611+
612+ cel::MacroRegistry& overlay () { return overlay_; }
613+
614+ absl::optional<Macro> FindMacro (absl::string_view name, size_t arg_count,
615+ bool receiver_style) const ;
616+
617+ private:
618+ const cel::MacroRegistry& base_;
619+ cel::MacroRegistry overlay_;
620+ };
621+
622+ absl::optional<Macro> AugmentedMacroRegistry::FindMacro (
623+ absl::string_view name, size_t arg_count, bool receiver_style) const {
624+ auto result = overlay_.FindMacro (name, arg_count, receiver_style);
625+ if (result.has_value ()) {
626+ return result;
627+ }
628+
629+ return base_.FindMacro (name, arg_count, receiver_style);
630+ }
631+
632+ bool IsSupportedAnnotation (const Expr& e) {
633+ if (e.has_const_expr () && e.const_expr ().has_string_value ()) {
634+ return true ;
635+ } else if (e.has_struct_expr () &&
636+ e.struct_expr ().name () == " cel.Annotation" ) {
637+ for (const auto & field : e.struct_expr ().fields ()) {
638+ if (field.name () != " name" && field.name () != " inspect_only" &&
639+ field.name () != " value" ) {
640+ return false ;
641+ }
642+ }
643+ return true ;
644+ }
645+ return false ;
646+ }
647+
648+ class AnnotationCollector {
649+ private:
650+ struct AnnotationRep {
651+ Expr expr;
652+ };
653+
654+ struct MacroImpl {
655+ absl::Nonnull<AnnotationCollector*> parent;
656+
657+ // Record a single annotation. Returns a non-empty optional if
658+ // an error is encountered.
659+ absl::optional<Expr> RecordAnnotation (cel::MacroExprFactory& mef,
660+ int64_t id, Expr e) const ;
661+
662+ // MacroExpander for "cel.annotate"
663+ absl::optional<Expr> operator ()(cel::MacroExprFactory& mef, Expr& target,
664+ absl::Span<Expr> args) const ;
665+ };
666+
667+ void Add (int64_t annotated_expr, Expr value);
668+
669+ public:
670+ const absl::btree_map<int64_t , std::vector<AnnotationRep>>& annotations () {
671+ return annotations_;
672+ }
673+
674+ absl::btree_map<int64_t , std::vector<AnnotationRep>> consume_annotations () {
675+ using std::swap;
676+ absl::btree_map<int64_t , std::vector<AnnotationRep>> result;
677+ swap (result, annotations_);
678+ return result;
679+ }
680+
681+ Macro MakeAnnotationImpl () {
682+ auto impl = Macro::Receiver (" annotate" , 2 , MacroImpl{this });
683+ ABSL_CHECK_OK (impl.status ());
684+ return std::move (impl).value ();
685+ }
686+
687+ private:
688+ absl::btree_map<int64_t , std::vector<AnnotationRep>> annotations_;
689+ };
690+
691+ absl::optional<Expr> AnnotationCollector::MacroImpl::RecordAnnotation (
692+ cel::MacroExprFactory& mef, int64_t id, Expr e) const {
693+ if (IsSupportedAnnotation (e)) {
694+ parent->Add (id, std::move (e));
695+ return absl::nullopt ;
696+ }
697+
698+ return mef.ReportErrorAt (
699+ e,
700+ " cel.annotate argument is not a cel.Annotation{} or string expression" );
701+ }
702+
703+ absl::optional<Expr> AnnotationCollector::MacroImpl::operator ()(
704+ cel::MacroExprFactory& mef, Expr& target, absl::Span<Expr> args) const {
705+ if (!target.has_ident_expr () || target.ident_expr ().name () != " cel" ) {
706+ return absl::nullopt ;
707+ }
708+
709+ if (args.size () != 2 ) {
710+ return mef.ReportErrorAt (
711+ target, " wrong number of arguments for cel.annotate macro" );
712+ }
713+
714+ // arg0 (the annotated expression) is the expansion result. The remainder are
715+ // annotations to record.
716+ int64_t id = args[0 ].id ();
717+
718+ absl::optional<Expr> result;
719+ if (args[1 ].has_list_expr ()) {
720+ auto list = args[1 ].release_list_expr ();
721+ for (auto & e : list.mutable_elements ()) {
722+ result = RecordAnnotation (mef, id, e.release_expr ());
723+ if (result) {
724+ break ;
725+ }
726+ }
727+ } else {
728+ result = RecordAnnotation (mef, id, std::move (args[1 ]));
729+ }
730+
731+ if (result) {
732+ return result;
733+ }
734+
735+ return std::move (args[0 ]);
736+ }
737+
738+ void AnnotationCollector::Add (int64_t annotated_expr, Expr value) {
739+ annotations_[annotated_expr].push_back ({std::move (value)});
740+ }
741+
604742class ParserVisitor final : public CelBaseVisitor,
605743 public antlr4::BaseErrorListener {
606744 public:
607745 ParserVisitor (const cel::Source& source, int max_recursion_depth,
608746 absl::string_view accu_var,
609- const cel::MacroRegistry& macro_registry,
610- bool add_macro_calls = false ,
611- bool enable_optional_syntax = false ,
612- bool enable_quoted_identifiers = false )
613- : source_(source),
614- factory_ (source_, accu_var),
615- macro_registry_(macro_registry),
616- recursion_depth_(0 ),
617- max_recursion_depth_(max_recursion_depth),
618- add_macro_calls_(add_macro_calls),
619- enable_optional_syntax_(enable_optional_syntax),
620- enable_quoted_identifiers_(enable_quoted_identifiers) {}
747+ const cel::MacroRegistry& macro_registry, bool add_macro_calls,
748+ bool enable_optional_syntax, bool enable_quoted_identifiers,
749+ bool enable_annotations);
621750
622751 ~ParserVisitor () override = default ;
623752
@@ -675,6 +804,8 @@ class ParserVisitor final : public CelBaseVisitor,
675804
676805 std::string ErrorMessage ();
677806
807+ Expr PackAnnotations (Expr ast);
808+
678809 private:
679810 template <typename ... Args>
680811 Expr GlobalCallOrMacro (int64_t expr_id, absl::string_view function,
@@ -702,14 +833,38 @@ class ParserVisitor final : public CelBaseVisitor,
702833 private:
703834 const cel::Source& source_;
704835 cel::ParserMacroExprFactory factory_;
705- const cel::MacroRegistry& macro_registry_;
836+ AugmentedMacroRegistry macro_registry_;
837+ AnnotationCollector annotations_;
706838 int recursion_depth_;
707839 const int max_recursion_depth_;
708840 const bool add_macro_calls_;
709841 const bool enable_optional_syntax_;
710842 const bool enable_quoted_identifiers_;
843+ const bool enable_annotations_;
711844};
712845
846+ ParserVisitor::ParserVisitor (const cel::Source& source, int max_recursion_depth,
847+ absl::string_view accu_var,
848+ const cel::MacroRegistry& macro_registry,
849+ bool add_macro_calls, bool enable_optional_syntax,
850+ bool enable_quoted_identifiers,
851+ bool enable_annotations)
852+ : source_(source),
853+ factory_(source_, accu_var),
854+ macro_registry_(macro_registry),
855+ recursion_depth_(0 ),
856+ max_recursion_depth_(max_recursion_depth),
857+ add_macro_calls_(add_macro_calls),
858+ enable_optional_syntax_(enable_optional_syntax),
859+ enable_quoted_identifiers_(enable_quoted_identifiers),
860+ enable_annotations_(enable_annotations) {
861+ if (enable_annotations_) {
862+ macro_registry_.overlay ()
863+ .RegisterMacro (annotations_.MakeAnnotationImpl ())
864+ .IgnoreError ();
865+ }
866+ }
867+
713868template <typename T, typename = std::enable_if_t <
714869 std::is_base_of<antlr4::tree::ParseTree, T>::value>>
715870T* tree_as (antlr4::tree::ParseTree* tree) {
@@ -1638,6 +1793,61 @@ struct ParseResult {
16381793 EnrichedSourceInfo enriched_source_info;
16391794};
16401795
1796+ Expr NormalizeAnnotation (cel::ParserMacroExprFactory& mef, Expr expr) {
1797+ if (expr.has_struct_expr ()) {
1798+ return expr;
1799+ }
1800+
1801+ if (expr.has_const_expr ()) {
1802+ std::vector<cel::StructExprField> fields;
1803+ fields.reserve (2 );
1804+ fields.push_back (
1805+ mef.NewStructField (mef.NextId ({}), " name" , std::move (expr)));
1806+ fields.push_back (
1807+ mef.NewStructField (mef.NextId ({}), " inspect_only" ,
1808+ mef.NewBoolConst (mef.NextId ({}), true )));
1809+ return mef.NewStruct (mef.NextId ({}), " cel.Annotation" , std::move (fields));
1810+ }
1811+
1812+ return mef.ReportError (" invalid annotation encountered finalizing AST" );
1813+ }
1814+
1815+ Expr ParserVisitor::PackAnnotations (Expr ast) {
1816+ if (annotations_.annotations ().empty ()) {
1817+ return ast;
1818+ }
1819+
1820+ auto annotations = annotations_.consume_annotations ();
1821+ std::vector<MapExprEntry> entries;
1822+ entries.reserve (annotations.size ());
1823+
1824+ for (auto & annotation : annotations) {
1825+ std::vector<cel::ListExprElement> annotation_values;
1826+ annotation_values.reserve (annotation.second .size ());
1827+
1828+ for (auto & annotation_value : annotation.second ) {
1829+ auto annotation =
1830+ NormalizeAnnotation (factory_, std::move (annotation_value.expr ));
1831+ annotation_values.push_back (
1832+ factory_.NewListElement (std::move (annotation)));
1833+ }
1834+
1835+ entries.push_back (factory_.NewMapEntry (
1836+ factory_.NextId ({}),
1837+ factory_.NewIntConst (factory_.NextId ({}), annotation.first ),
1838+ factory_.NewList (factory_.NextId ({}), std::move (annotation_values))));
1839+ }
1840+
1841+ std::vector<Expr> args;
1842+ args.push_back (std::move (ast));
1843+ args.push_back (factory_.NewMap (factory_.NextId ({}), std::move (entries)));
1844+
1845+ auto result =
1846+ factory_.NewCall (factory_.NextId ({}), " cel.@annotated" , std::move (args));
1847+
1848+ return result;
1849+ }
1850+
16411851absl::StatusOr<ParseResult> ParseImpl (const cel::Source& source,
16421852 const cel::MacroRegistry& registry,
16431853 const ParserOptions& options) {
@@ -1656,10 +1866,10 @@ absl::StatusOr<ParseResult> ParseImpl(const cel::Source& source,
16561866 if (options.enable_hidden_accumulator_var ) {
16571867 accu_var = cel::kHiddenAccumulatorVariableName ;
16581868 }
1659- ParserVisitor visitor (source, options. max_recursion_depth , accu_var,
1660- registry , options.add_macro_calls ,
1661- options.enable_optional_syntax ,
1662- options.enable_quoted_identifiers );
1869+ ParserVisitor visitor (
1870+ source , options.max_recursion_depth , accu_var, registry ,
1871+ options. add_macro_calls , options.enable_optional_syntax ,
1872+ options.enable_quoted_identifiers , options. enable_annotations );
16631873
16641874 lexer.removeErrorListeners ();
16651875 parser.removeErrorListeners ();
@@ -1686,7 +1896,9 @@ absl::StatusOr<ParseResult> ParseImpl(const cel::Source& source,
16861896 if (visitor.HasErrored ()) {
16871897 return absl::InvalidArgumentError (visitor.ErrorMessage ());
16881898 }
1689-
1899+ if (options.enable_annotations ) {
1900+ expr = visitor.PackAnnotations (std::move (expr));
1901+ }
16901902 return {
16911903 ParseResult{.expr = std::move (expr),
16921904 .source_info = visitor.GetSourceInfo (),
0 commit comments