Skip to content

Commit d2ef27c

Browse files
TristonianJoneskyessenov
authored andcommitted
Ensure the lifecycle of rewritten expressions is preserved in the output CelExpression
PiperOrigin-RevId: 347652258
1 parent 1f35b77 commit d2ef27c

3 files changed

Lines changed: 74 additions & 13 deletions

File tree

eval/compiler/flat_expr_builder.cc

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "eval/compiler/flat_expr_builder.h"
22

3+
#include <memory>
4+
35
#include "google/api/expr/v1alpha1/checked.pb.h"
46
#include "stack"
57
#include "absl/container/node_hash_map.h"
@@ -658,8 +660,8 @@ void ComprehensionVisitor::PreVisit(const Expr*) {
658660

659661
void ComprehensionVisitor::PostVisitArg(int arg_num, const Expr* expr) {
660662
const Comprehension* comprehension = &expr->comprehension_expr();
661-
const auto accu_var = comprehension->accu_var();
662-
const auto iter_var = comprehension->iter_var();
663+
const auto& accu_var = comprehension->accu_var();
664+
const auto& iter_var = comprehension->iter_var();
663665
// TODO(issues/20): Consider refactoring the comprehension prologue step.
664666
switch (arg_num) {
665667
case ITER_RANGE: {
@@ -732,7 +734,7 @@ FlatExprBuilder::CreateExpressionImpl(
732734

733735
const Expr* effective_expr = expr;
734736
// transformed expression preserving expression IDs
735-
Expr rewrite_buffer;
737+
std::unique_ptr<Expr> rewrite_buffer = nullptr;
736738
// TODO(issues/98): A type checker may perform these rewrites, but there
737739
// currently isn't a signal to expose that in an expression. If that becomes
738740
// available, we can skip the reference resolve step here if it's already
@@ -745,19 +747,19 @@ FlatExprBuilder::CreateExpressionImpl(
745747
return rewritten.status();
746748
}
747749
if (rewritten.value().has_value()) {
748-
rewrite_buffer = std::move(rewritten)->value();
749-
effective_expr = &rewrite_buffer;
750+
rewrite_buffer =
751+
std::make_unique<Expr>(std::move(rewritten).value().value());
752+
effective_expr = rewrite_buffer.get();
750753
}
751754
// TODO(issues/99): we could setup a check step here that confirms all of
752755
// references are defined before actually evaluating.
753756
}
754757

758+
Expr const_fold_buffer;
755759
if (constant_folding_) {
756-
Expr buffer;
757760
FoldConstants(*effective_expr, *this->GetRegistry(), constant_arena_,
758-
idents, &buffer);
759-
rewrite_buffer = std::move(buffer);
760-
effective_expr = &rewrite_buffer;
761+
idents, &const_fold_buffer);
762+
effective_expr = &const_fold_buffer;
761763
}
762764

763765
std::set<std::string> iter_variable_names;
@@ -776,7 +778,8 @@ FlatExprBuilder::CreateExpressionImpl(
776778
absl::make_unique<CelExpressionFlatImpl>(
777779
expr, std::move(execution_path), comprehension_max_iterations_,
778780
std::move(iter_variable_names), enable_unknowns_,
779-
enable_unknown_function_results_, enable_missing_attribute_errors_);
781+
enable_unknown_function_results_, enable_missing_attribute_errors_,
782+
std::move(rewrite_buffer));
780783

781784
if (warnings != nullptr) {
782785
*warnings = std::move(warnings_builder).warnings();

eval/compiler/flat_expr_builder_test.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,60 @@ TEST(FlatExprBuilderTest, CheckedExprActivationMissesReferences) {
614614
EXPECT_FALSE(result.BoolOrDie());
615615
}
616616

617+
TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapAndConstantFolding) {
618+
CheckedExpr expr;
619+
// {`var1`: 'hello'}
620+
google::protobuf::TextFormat::ParseFromString(R"(
621+
reference_map {
622+
key: 3
623+
value {
624+
name: "var1"
625+
value {
626+
int64_value: 1
627+
}
628+
}
629+
}
630+
expr {
631+
id: 1
632+
struct_expr {
633+
entries {
634+
id: 2
635+
map_key {
636+
id: 3
637+
ident_expr {
638+
name: "var1"
639+
}
640+
}
641+
value {
642+
id: 4
643+
const_expr {
644+
string_value: "hello"
645+
}
646+
}
647+
}
648+
}
649+
})",
650+
&expr);
651+
652+
FlatExprBuilder builder;
653+
google::protobuf::Arena arena;
654+
builder.set_constant_folding(true, &arena);
655+
ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry()));
656+
auto build_status = builder.CreateExpression(&expr);
657+
ASSERT_OK(build_status);
658+
659+
auto cel_expr = std::move(build_status.value());
660+
661+
Activation activation;
662+
auto result_or = cel_expr->Evaluate(activation, &arena);
663+
ASSERT_OK(result_or);
664+
CelValue result = result_or.value();
665+
ASSERT_TRUE(result.IsMap());
666+
auto m = result.MapOrDie();
667+
auto v = (*m)[CelValue::CreateInt64(1L)];
668+
EXPECT_THAT(v.value().StringOrDie().value(), Eq("hello"));
669+
}
670+
617671
TEST(FlatExprBuilderTest, ComprehensionWorksForError) {
618672
Expr expr;
619673
// {}[0].all(x, x) should evaluate OK but return an error value

eval/eval/evaluator_core.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ using ExecutionPath = std::vector<std::unique_ptr<const ExpressionStep>>;
6767
// stack as Span<>.
6868
class ValueStack {
6969
public:
70-
ValueStack(size_t max_size) : current_size_(0) {
70+
explicit ValueStack(size_t max_size) : current_size_(0) {
7171
stack_.resize(max_size);
7272
attribute_stack_.resize(max_size);
7373
}
@@ -336,8 +336,10 @@ class CelExpressionFlatImpl : public CelExpression {
336336
std::set<std::string> iter_variable_names,
337337
bool enable_unknowns = false,
338338
bool enable_unknown_function_results = false,
339-
bool enable_missing_attribute_errors = false)
340-
: path_(std::move(path)),
339+
bool enable_missing_attribute_errors = false,
340+
std::unique_ptr<Expr> rewritten_expr = nullptr)
341+
: rewritten_expr_(std::move(rewritten_expr)),
342+
path_(std::move(path)),
341343
max_iterations_(max_iterations),
342344
iter_variable_names_(std::move(iter_variable_names)),
343345
enable_unknowns_(enable_unknowns),
@@ -372,6 +374,8 @@ class CelExpressionFlatImpl : public CelExpression {
372374
CelEvaluationListener callback) const override;
373375

374376
private:
377+
// Maintain lifecycle of a modified expression.
378+
std::unique_ptr<Expr> rewritten_expr_;
375379
const ExecutionPath path_;
376380
const int max_iterations_;
377381
const std::set<std::string> iter_variable_names_;

0 commit comments

Comments
 (0)