Refactor comprehension planning in flat expr builder to use comprehension specific visitor callbacks. by copybara-service[bot] · Pull Request #293 · google/cel-cpp

Expand Up @@ -75,6 +75,7 @@ using ::cel::TypeManager; using ::cel::Value; using ::cel::ValueFactory; using ::cel::ast_internal::AstImpl; using ::cel::ast_internal::AstTraverse;
constexpr int64_t kExprIdNotFromAst = -1;
Expand Down Expand Up @@ -169,7 +170,7 @@ class ExhaustiveTernaryCondVisitor : public CondVisitor { };
// Visitor Comprehension expression. class ComprehensionVisitor : public CondVisitor { class ComprehensionVisitor { public: explicit ComprehensionVisitor(FlatExprVisitor* visitor, bool short_circuiting, bool enable_vulnerability_check) Expand All @@ -179,9 +180,10 @@ class ComprehensionVisitor : public CondVisitor { short_circuiting_(short_circuiting), enable_vulnerability_check_(enable_vulnerability_check) {}
void PreVisit(const cel::ast_internal::Expr* expr) override; void PostVisitArg(int arg_num, const cel::ast_internal::Expr* expr) override; void PostVisit(const cel::ast_internal::Expr* expr) override; void PreVisit(const cel::ast_internal::Expr* expr); void PostVisitArg(cel::ast_internal::ComprehensionArg arg_num, const cel::ast_internal::Expr* comprehension_expr); void PostVisit(const cel::ast_internal::Expr* expr);
private: FlatExprVisitor* visitor_; Expand Down Expand Up @@ -585,15 +587,13 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor { ValidateOrError(comprehension->has_result(), "Invalid comprehension: 'result' must be set"); comprehension_stack_.push( {comprehension, {expr, comprehension, IsOptimizableListAppend(comprehension, options_.enable_comprehension_list_append)}); cond_visitor_stack_.push( {expr, std::make_unique<ComprehensionVisitor>( this, options_.short_circuiting, enable_comprehension_vulnerability_check_)}); auto cond_visitor = FindCondVisitor(expr); cond_visitor->PreVisit(expr); options_.enable_comprehension_list_append), std::make_unique<ComprehensionVisitor>( this, options_.short_circuiting, enable_comprehension_vulnerability_check_)}); comprehension_stack_.top().visitor->PreVisit(expr); }
// Invoked after all child nodes are processed. Expand All @@ -604,11 +604,32 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor { if (!progress_status_.ok()) { return; }
if (comprehension_stack_.empty() || comprehension_stack_.top().comprehension != comprehension_expr) { return; }
comprehension_stack_.top().visitor->PostVisit(expr); comprehension_stack_.pop(); }
auto cond_visitor = FindCondVisitor(expr); cond_visitor->PostVisit(expr); cond_visitor_stack_.pop(); void PostVisitComprehensionSubexpression( const cel::ast_internal::Expr* subexpr, const cel::ast_internal::Comprehension* compr, cel::ast_internal::ComprehensionArg comprehension_arg, const cel::ast_internal::SourcePosition*) override { if (!progress_status_.ok()) { return; }
if (comprehension_stack_.empty() || comprehension_stack_.top().comprehension != compr) { return; }
comprehension_stack_.top().visitor->PostVisitArg( comprehension_arg, comprehension_stack_.top().expr); }
// Invoked after each argument node processed. Expand Down Expand Up @@ -739,8 +760,10 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor {
private: struct ComprehensionStackRecord { const cel::ast_internal::Expr* expr; const cel::ast_internal::Comprehension* comprehension; bool is_optimizable_list_append; std::unique_ptr<ComprehensionVisitor> visitor; };
const Resolver& resolver_; Expand Down Expand Up @@ -1089,8 +1112,9 @@ void ComprehensionVisitor::PreVisit(const cel::ast_internal::Expr*) { kExprIdNotFromAst, false)); }
void ComprehensionVisitor::PostVisitArg(int arg_num, const cel::ast_internal::Expr* expr) { void ComprehensionVisitor::PostVisitArg( cel::ast_internal::ComprehensionArg arg_num, const cel::ast_internal::Expr* expr) { const auto* comprehension = &expr->comprehension_expr(); const auto& accu_var = comprehension->accu_var(); const auto& iter_var = comprehension->iter_var(); Expand Down Expand Up @@ -1204,7 +1228,9 @@ absl::StatusOr<FlatExpression> FlatExprBuilder::CreateExpressionImpl( ast_impl.reference_map(), execution_path, value_factory, warnings_builder, program_tree, extension_context);
AstTraverse(&ast_impl.root_expr(), &ast_impl.source_info(), &visitor); cel::ast_internal::TraversalOptions opts; opts.use_comprehension_callbacks = true; AstTraverse(&ast_impl.root_expr(), &ast_impl.source_info(), &visitor, opts);
if (!visitor.progress_status().ok()) { return visitor.progress_status(); Expand Down