Refactor comprehension planning in flat expr builder to use comprehen… · google/cel-cpp@fe53e95

@@ -75,6 +75,7 @@ using ::cel::TypeManager;

7575

using ::cel::Value;

7676

using ::cel::ValueFactory;

7777

using ::cel::ast_internal::AstImpl;

78+

using ::cel::ast_internal::AstTraverse;

78797980

constexpr int64_t kExprIdNotFromAst = -1;

8081

@@ -169,7 +170,7 @@ class ExhaustiveTernaryCondVisitor : public CondVisitor {

169170

};

170171171172

// Visitor Comprehension expression.

172-

class ComprehensionVisitor : public CondVisitor {

173+

class ComprehensionVisitor {

173174

public:

174175

explicit ComprehensionVisitor(FlatExprVisitor* visitor, bool short_circuiting,

175176

bool enable_vulnerability_check)

@@ -179,9 +180,10 @@ class ComprehensionVisitor : public CondVisitor {

179180

short_circuiting_(short_circuiting),

180181

enable_vulnerability_check_(enable_vulnerability_check) {}

181182182-

void PreVisit(const cel::ast_internal::Expr* expr) override;

183-

void PostVisitArg(int arg_num, const cel::ast_internal::Expr* expr) override;

184-

void PostVisit(const cel::ast_internal::Expr* expr) override;

183+

void PreVisit(const cel::ast_internal::Expr* expr);

184+

void PostVisitArg(cel::ast_internal::ComprehensionArg arg_num,

185+

const cel::ast_internal::Expr* comprehension_expr);

186+

void PostVisit(const cel::ast_internal::Expr* expr);

185187186188

private:

187189

FlatExprVisitor* visitor_;

@@ -585,15 +587,13 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor {

585587

ValidateOrError(comprehension->has_result(),

586588

"Invalid comprehension: 'result' must be set");

587589

comprehension_stack_.push(

588-

{comprehension,

590+

{expr, comprehension,

589591

IsOptimizableListAppend(comprehension,

590-

options_.enable_comprehension_list_append)});

591-

cond_visitor_stack_.push(

592-

{expr, std::make_unique<ComprehensionVisitor>(

593-

this, options_.short_circuiting,

594-

enable_comprehension_vulnerability_check_)});

595-

auto cond_visitor = FindCondVisitor(expr);

596-

cond_visitor->PreVisit(expr);

592+

options_.enable_comprehension_list_append),

593+

std::make_unique<ComprehensionVisitor>(

594+

this, options_.short_circuiting,

595+

enable_comprehension_vulnerability_check_)});

596+

comprehension_stack_.top().visitor->PreVisit(expr);

597597

}

598598599599

// Invoked after all child nodes are processed.

@@ -604,11 +604,32 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor {

604604

if (!progress_status_.ok()) {

605605

return;

606606

}

607+608+

if (comprehension_stack_.empty() ||

609+

comprehension_stack_.top().comprehension != comprehension_expr) {

610+

return;

611+

}

612+613+

comprehension_stack_.top().visitor->PostVisit(expr);

607614

comprehension_stack_.pop();

615+

}

608616609-

auto cond_visitor = FindCondVisitor(expr);

610-

cond_visitor->PostVisit(expr);

611-

cond_visitor_stack_.pop();

617+

void PostVisitComprehensionSubexpression(

618+

const cel::ast_internal::Expr* subexpr,

619+

const cel::ast_internal::Comprehension* compr,

620+

cel::ast_internal::ComprehensionArg comprehension_arg,

621+

const cel::ast_internal::SourcePosition*) override {

622+

if (!progress_status_.ok()) {

623+

return;

624+

}

625+626+

if (comprehension_stack_.empty() ||

627+

comprehension_stack_.top().comprehension != compr) {

628+

return;

629+

}

630+631+

comprehension_stack_.top().visitor->PostVisitArg(

632+

comprehension_arg, comprehension_stack_.top().expr);

612633

}

613634614635

// Invoked after each argument node processed.

@@ -739,8 +760,10 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor {

739760740761

private:

741762

struct ComprehensionStackRecord {

763+

const cel::ast_internal::Expr* expr;

742764

const cel::ast_internal::Comprehension* comprehension;

743765

bool is_optimizable_list_append;

766+

std::unique_ptr<ComprehensionVisitor> visitor;

744767

};

745768746769

const Resolver& resolver_;

@@ -1089,8 +1112,9 @@ void ComprehensionVisitor::PreVisit(const cel::ast_internal::Expr*) {

10891112

kExprIdNotFromAst, false));

10901113

}

109111141092-

void ComprehensionVisitor::PostVisitArg(int arg_num,

1093-

const cel::ast_internal::Expr* expr) {

1115+

void ComprehensionVisitor::PostVisitArg(

1116+

cel::ast_internal::ComprehensionArg arg_num,

1117+

const cel::ast_internal::Expr* expr) {

10941118

const auto* comprehension = &expr->comprehension_expr();

10951119

const auto& accu_var = comprehension->accu_var();

10961120

const auto& iter_var = comprehension->iter_var();

@@ -1204,7 +1228,9 @@ absl::StatusOr<FlatExpression> FlatExprBuilder::CreateExpressionImpl(

12041228

ast_impl.reference_map(), execution_path, value_factory, warnings_builder,

12051229

program_tree, extension_context);

120612301207-

AstTraverse(&ast_impl.root_expr(), &ast_impl.source_info(), &visitor);

1231+

cel::ast_internal::TraversalOptions opts;

1232+

opts.use_comprehension_callbacks = true;

1233+

AstTraverse(&ast_impl.root_expr(), &ast_impl.source_info(), &visitor, opts);

1208123412091235

if (!visitor.progress_status().ok()) {

12101236

return visitor.progress_status();