Refactor comprehension planning in flat expr builder to use comprehen… · google/cel-cpp@fe53e95
@@ -75,6 +75,7 @@ using ::cel::TypeManager;
7575using ::cel::Value;
7676using ::cel::ValueFactory;
7777using ::cel::ast_internal::AstImpl;
78+using ::cel::ast_internal::AstTraverse;
78797980constexpr int64_t kExprIdNotFromAst = -1;
8081@@ -169,7 +170,7 @@ class ExhaustiveTernaryCondVisitor : public CondVisitor {
169170};
170171171172// Visitor Comprehension expression.
172-class ComprehensionVisitor : public CondVisitor {
173+class ComprehensionVisitor {
173174public:
174175explicit ComprehensionVisitor(FlatExprVisitor* visitor, bool short_circuiting,
175176bool enable_vulnerability_check)
@@ -179,9 +180,10 @@ class ComprehensionVisitor : public CondVisitor {
179180 short_circuiting_(short_circuiting),
180181 enable_vulnerability_check_(enable_vulnerability_check) {}
181182182-void PreVisit(const cel::ast_internal::Expr* expr) override;
183-void PostVisitArg(int arg_num, const cel::ast_internal::Expr* expr) override;
184-void PostVisit(const cel::ast_internal::Expr* expr) override;
183+void PreVisit(const cel::ast_internal::Expr* expr);
184+void PostVisitArg(cel::ast_internal::ComprehensionArg arg_num,
185+const cel::ast_internal::Expr* comprehension_expr);
186+void PostVisit(const cel::ast_internal::Expr* expr);
185187186188private:
187189 FlatExprVisitor* visitor_;
@@ -585,15 +587,13 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor {
585587ValidateOrError(comprehension->has_result(),
586588"Invalid comprehension: 'result' must be set");
587589 comprehension_stack_.push(
588- {comprehension,
590+ {expr, comprehension,
589591IsOptimizableListAppend(comprehension,
590- options_.enable_comprehension_list_append)});
591- cond_visitor_stack_.push(
592- {expr, std::make_unique<ComprehensionVisitor>(
593-this, options_.short_circuiting,
594- enable_comprehension_vulnerability_check_)});
595-auto cond_visitor = FindCondVisitor(expr);
596- cond_visitor->PreVisit(expr);
592+ options_.enable_comprehension_list_append),
593+ std::make_unique<ComprehensionVisitor>(
594+this, options_.short_circuiting,
595+ enable_comprehension_vulnerability_check_)});
596+ comprehension_stack_.top().visitor->PreVisit(expr);
597597 }
598598599599// Invoked after all child nodes are processed.
@@ -604,11 +604,32 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor {
604604if (!progress_status_.ok()) {
605605return;
606606 }
607+608+if (comprehension_stack_.empty() ||
609+ comprehension_stack_.top().comprehension != comprehension_expr) {
610+return;
611+ }
612+613+ comprehension_stack_.top().visitor->PostVisit(expr);
607614 comprehension_stack_.pop();
615+ }
608616609-auto cond_visitor = FindCondVisitor(expr);
610- cond_visitor->PostVisit(expr);
611- cond_visitor_stack_.pop();
617+void PostVisitComprehensionSubexpression(
618+const cel::ast_internal::Expr* subexpr,
619+const cel::ast_internal::Comprehension* compr,
620+ cel::ast_internal::ComprehensionArg comprehension_arg,
621+const cel::ast_internal::SourcePosition*) override {
622+if (!progress_status_.ok()) {
623+return;
624+ }
625+626+if (comprehension_stack_.empty() ||
627+ comprehension_stack_.top().comprehension != compr) {
628+return;
629+ }
630+631+ comprehension_stack_.top().visitor->PostVisitArg(
632+ comprehension_arg, comprehension_stack_.top().expr);
612633 }
613634614635// Invoked after each argument node processed.
@@ -739,8 +760,10 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor {
739760740761private:
741762struct ComprehensionStackRecord {
763+const cel::ast_internal::Expr* expr;
742764const cel::ast_internal::Comprehension* comprehension;
743765bool is_optimizable_list_append;
766+ std::unique_ptr<ComprehensionVisitor> visitor;
744767 };
745768746769const Resolver& resolver_;
@@ -1089,8 +1112,9 @@ void ComprehensionVisitor::PreVisit(const cel::ast_internal::Expr*) {
10891112kExprIdNotFromAst, false));
10901113}
109111141092-void ComprehensionVisitor::PostVisitArg(int arg_num,
1093-const cel::ast_internal::Expr* expr) {
1115+void ComprehensionVisitor::PostVisitArg(
1116+ cel::ast_internal::ComprehensionArg arg_num,
1117+const cel::ast_internal::Expr* expr) {
10941118const auto* comprehension = &expr->comprehension_expr();
10951119const auto& accu_var = comprehension->accu_var();
10961120const auto& iter_var = comprehension->iter_var();
@@ -1204,7 +1228,9 @@ absl::StatusOr<FlatExpression> FlatExprBuilder::CreateExpressionImpl(
12041228 ast_impl.reference_map(), execution_path, value_factory, warnings_builder,
12051229 program_tree, extension_context);
120612301207-AstTraverse(&ast_impl.root_expr(), &ast_impl.source_info(), &visitor);
1231+ cel::ast_internal::TraversalOptions opts;
1232+ opts.use_comprehension_callbacks = true;
1233+AstTraverse(&ast_impl.root_expr(), &ast_impl.source_info(), &visitor, opts);
1208123412091235if (!visitor.progress_status().ok()) {
12101236return visitor.progress_status();