Skip to content

Commit b317f18

Browse files
CEL Dev Teamcopybara-github
authored andcommitted
Support custom variable bindings (Activations) via cel test context
PiperOrigin-RevId: 802840715
1 parent 0c92b65 commit b317f18

4 files changed

Lines changed: 154 additions & 15 deletions

File tree

testing/testrunner/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@ cc_library(
1717
"//eval/public:cel_expression",
1818
"//runtime",
1919
"@com_google_absl//absl/base:nullability",
20+
"@com_google_absl//absl/container:flat_hash_map",
2021
"@com_google_absl//absl/memory",
2122
"@com_google_cel_spec//proto/cel/expr:checked_cc_proto",
23+
"@com_google_cel_spec//proto/cel/expr:value_cc_proto",
2224
],
2325
)
2426

@@ -96,6 +98,7 @@ cc_test(
9698
"//runtime",
9799
"//runtime:runtime_builder",
98100
"//runtime:standard_runtime_builder_factory",
101+
"@com_google_absl//absl/container:flat_hash_map",
99102
"@com_google_absl//absl/flags:flag",
100103
"@com_google_absl//absl/log:absl_check",
101104
"@com_google_absl//absl/status:status_matchers",

testing/testrunner/cel_test_context.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717

1818
#include <memory>
1919
#include <optional>
20+
#include <string>
2021
#include <utility>
2122

2223
#include "cel/expr/checked.pb.h"
24+
#include "cel/expr/value.pb.h"
2325
#include "absl/base/nullability.h"
26+
#include "absl/container/flat_hash_map.h"
2427
#include "absl/memory/memory.h"
2528
#include "compiler/compiler.h"
2629
#include "eval/public/cel_expression.h"
@@ -37,6 +40,16 @@ struct CelTestContextOptions {
3740
// input or output values are themselves CEL expressions that need to be
3841
// resolved at runtime or cel expression source is raw string or cel file.
3942
std::unique_ptr<const cel::Compiler> compiler = nullptr;
43+
44+
// A map of variable names to values that provides default bindings for the
45+
// evaluation.
46+
//
47+
// These bindings can be considered context-wide defaults. If a variable name
48+
// exists in both these custom bindings and in a specific TestCase's input,
49+
// the value from the TestCase will take precedence and override this one.
50+
// This logic is handled by the test runner when it constructs the final
51+
// activation.
52+
absl::flat_hash_map<std::string, cel::expr::Value> custom_bindings;
4053
};
4154

4255
// The context class for a CEL test, holding configurations needed to evaluate
@@ -97,6 +110,11 @@ class CelTestContext {
97110
: nullptr;
98111
}
99112

113+
const absl::flat_hash_map<std::string, cel::expr::Value>&
114+
custom_bindings() const {
115+
return cel_test_context_options_.custom_bindings;
116+
}
117+
100118
private:
101119
// Delete copy and move constructors.
102120
CelTestContext(const CelTestContext&) = delete;

testing/testrunner/runner_lib.cc

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,12 @@ using ::cel::expr::conformance::test::TestCase;
5656
using ::cel::expr::conformance::test::TestOutput;
5757
using ::cel::expr::CheckedExpr;
5858
using ::google::api::expr::runtime::CelExpression;
59-
using ::google::api::expr::runtime::CelValue;
6059
using ::google::api::expr::runtime::ValueToCelValue;
61-
using ValueProto = ::cel::expr::Value;
6260
using ::google::api::expr::runtime::Activation;
6361

62+
using LegacyCelValue = ::google::api::expr::runtime::CelValue;
63+
using ValueProto = ::cel::expr::Value;
64+
6465
absl::StatusOr<std::string> ReadFileToString(absl::string_view file_path) {
6566
std::ifstream file_stream{std::string(file_path)};
6667
if (!file_stream.is_open()) {
@@ -131,7 +132,7 @@ absl::StatusOr<cel::Value> EvalWithLegacyBindings(
131132
CEL_ASSIGN_OR_RETURN(std::unique_ptr<CelExpression> sub_expression,
132133
builder->CreateExpression(&checked_expr));
133134

134-
CEL_ASSIGN_OR_RETURN(CelValue legacy_result,
135+
CEL_ASSIGN_OR_RETURN(LegacyCelValue legacy_result,
135136
sub_expression->Evaluate(activation, arena));
136137

137138
ValueProto result_proto;
@@ -177,23 +178,60 @@ absl::StatusOr<cel::Value> ResolveInputValue(const InputValue& input_value,
177178
}
178179
}
179180

180-
absl::StatusOr<cel::Activation> CreateModernActivationFromBindings(
181+
absl::Status AddCustomBindingsToModernActivation(const CelTestContext& context,
182+
cel::Activation& activation,
183+
google::protobuf::Arena* arena) {
184+
for (const auto& binding : context.custom_bindings()) {
185+
CEL_ASSIGN_OR_RETURN(cel::Value value,
186+
FromExprValue(/*value_proto=*/binding.second,
187+
GetDescriptorPool(context),
188+
GetMessageFactory(context), arena));
189+
activation.InsertOrAssignValue(/*name=*/binding.first, value);
190+
}
191+
return absl::OkStatus();
192+
}
193+
194+
absl::Status AddTestCaseBindingsToModernActivation(
181195
const TestCase& test_case, const CelTestContext& context,
182-
google::protobuf::Arena* arena) {
183-
cel::Activation activation;
196+
cel::Activation& activation, google::protobuf::Arena* arena) {
184197
for (const auto& binding : test_case.input()) {
185198
CEL_ASSIGN_OR_RETURN(
186-
Value value,
199+
cel::Value value,
187200
ResolveInputValue(/*input_value=*/binding.second, context, arena));
188201
activation.InsertOrAssignValue(/*name=*/binding.first, std::move(value));
189202
}
190-
return activation;
203+
return absl::OkStatus();
191204
}
192205

193-
absl::StatusOr<Activation> CreateLegacyActivationFromBindings(
206+
absl::StatusOr<cel::Activation> CreateModernActivationFromBindings(
194207
const TestCase& test_case, const CelTestContext& context,
195208
google::protobuf::Arena* arena) {
196-
Activation activation;
209+
cel::Activation activation;
210+
211+
CEL_RETURN_IF_ERROR(
212+
AddCustomBindingsToModernActivation(context, activation, arena));
213+
214+
CEL_RETURN_IF_ERROR(AddTestCaseBindingsToModernActivation(test_case, context,
215+
activation, arena));
216+
217+
return activation;
218+
}
219+
220+
absl::Status AddCustomBindingsToLegacyActivation(const CelTestContext& context,
221+
Activation& activation,
222+
google::protobuf::Arena* arena) {
223+
for (const auto& binding : context.custom_bindings()) {
224+
CEL_ASSIGN_OR_RETURN(
225+
LegacyCelValue value,
226+
ValueToCelValue(/*value_proto=*/binding.second, arena));
227+
activation.InsertValue(/*name=*/binding.first, value);
228+
}
229+
return absl::OkStatus();
230+
}
231+
232+
absl::Status AddTestCaseBindingsToLegacyActivation(
233+
const TestCase& test_case, const CelTestContext& context,
234+
Activation& activation, google::protobuf::Arena* arena) {
197235
auto* message_factory = GetMessageFactory(context);
198236
auto* descriptor_pool = GetDescriptorPool(context);
199237
for (const auto& binding : test_case.input()) {
@@ -203,9 +241,24 @@ absl::StatusOr<Activation> CreateLegacyActivationFromBindings(
203241
CEL_ASSIGN_OR_RETURN(ValueProto value_proto,
204242
ToExprValue(resolved_cel_value, descriptor_pool,
205243
message_factory, arena));
206-
CEL_ASSIGN_OR_RETURN(CelValue value, ValueToCelValue(value_proto, arena));
244+
CEL_ASSIGN_OR_RETURN(LegacyCelValue value,
245+
ValueToCelValue(value_proto, arena));
207246
activation.InsertValue(/*name=*/binding.first, value);
208247
}
248+
return absl::OkStatus();
249+
}
250+
251+
absl::StatusOr<Activation> CreateLegacyActivationFromBindings(
252+
const TestCase& test_case, const CelTestContext& context,
253+
google::protobuf::Arena* arena) {
254+
Activation activation;
255+
256+
CEL_RETURN_IF_ERROR(
257+
AddCustomBindingsToLegacyActivation(context, activation, arena));
258+
259+
CEL_RETURN_IF_ERROR(AddTestCaseBindingsToLegacyActivation(test_case, context,
260+
activation, arena));
261+
209262
return activation;
210263
}
211264

testing/testrunner/runner_lib_test.cc

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <utility>
1919

2020
#include "gtest/gtest-spi.h"
21+
#include "absl/container/flat_hash_map.h"
2122
#include "absl/flags/flag.h"
2223
#include "absl/log/absl_check.h"
2324
#include "absl/status/status_matchers.h"
@@ -57,6 +58,7 @@ using ::cel::expr::conformance::proto3::TestAllTypes;
5758
using ::cel::expr::conformance::test::TestCase;
5859
using ::cel::expr::CheckedExpr;
5960
using ::google::api::expr::runtime::CelExpressionBuilder;
61+
using ValueProto = ::cel::expr::Value;
6062

6163
template <typename T>
6264
T ParseTextProtoOrDie(absl::string_view text_proto) {
@@ -190,7 +192,8 @@ TEST_P(TestRunnerParamTest, BasicTestReportsFailure) {
190192
CelExpressionSource::FromCheckedExpr(
191193
std::move(checked_expr))}));
192194
TestRunner test_runner(std::move(context));
193-
EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), "bool_value: true");
195+
EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case),
196+
"bool_value: true"); // expected true got false
194197
}
195198

196199
TEST_P(TestRunnerParamTest, DynamicInputAndOutputReportsSuccess) {
@@ -248,7 +251,8 @@ TEST_P(TestRunnerParamTest, DynamicInputAndOutputReportsFailure) {
248251
std::move(checked_expr)),
249252
.compiler = std::move(compiler)}));
250253
TestRunner test_runner(std::move(context));
251-
EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), "int64_value: 5");
254+
EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case),
255+
"int64_value: 5"); // expected 5 got 10
252256
}
253257

254258
TEST_P(TestRunnerParamTest, RawExpressionWithCompilerReportsSuccess) {
@@ -296,7 +300,8 @@ TEST_P(TestRunnerParamTest, RawExpressionWithCompilerReportsFailure) {
296300
CelExpressionSource::FromRawExpression("x - y"),
297301
.compiler = std::move(compiler)}));
298302
TestRunner test_runner(std::move(context));
299-
EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), "int64_value: 7");
303+
EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case),
304+
"int64_value: 7"); // expected 7 got 100
300305
}
301306

302307
TEST_P(TestRunnerParamTest, CelFileWithCompilerReportsSuccess) {
@@ -350,7 +355,67 @@ TEST_P(TestRunnerParamTest, CelFileWithCompilerReportsFailure) {
350355
CelExpressionSource::FromCelFile(cel_file_path),
351356
.compiler = std::move(compiler)}));
352357
TestRunner test_runner(std::move(context));
353-
EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), "int64_value: 7");
358+
EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case),
359+
"int64_value: 7"); // expected 7 got 123
360+
}
361+
362+
TEST_P(TestRunnerParamTest, BasicTestWithCustomBindingsSucceeds) {
363+
ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result,
364+
DefaultCompiler().Compile("x + y"));
365+
CheckedExpr checked_expr;
366+
ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr),
367+
absl_testing::IsOk());
368+
369+
TestCase test_case = ParseTextProtoOrDie<TestCase>(R"pb(
370+
input {
371+
key: "x"
372+
value { value { int64_value: 10 } }
373+
}
374+
output { result_value { int64_value: 15 } }
375+
)pb");
376+
377+
absl::flat_hash_map<std::string, ValueProto> bindings;
378+
bindings["y"] = ParseTextProtoOrDie<ValueProto>(R"pb(int64_value: 5)pb");
379+
380+
ASSERT_OK_AND_ASSIGN(
381+
auto context, CreateTestContext(
382+
/*options=*/{.expression_source =
383+
CelExpressionSource::FromCheckedExpr(
384+
std::move(checked_expr)),
385+
.custom_bindings = std::move(bindings)}));
386+
TestRunner test_runner(std::move(context));
387+
388+
EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case));
389+
}
390+
391+
TEST_P(TestRunnerParamTest, BasicTestWithCustomBindingsReportsFailure) {
392+
ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result,
393+
DefaultCompiler().Compile("x + y"));
394+
CheckedExpr checked_expr;
395+
ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr),
396+
absl_testing::IsOk());
397+
398+
TestCase test_case = ParseTextProtoOrDie<TestCase>(R"pb(
399+
input {
400+
key: "x"
401+
value { value { int64_value: 10 } }
402+
}
403+
output { result_value { int64_value: 999 } }
404+
)pb");
405+
406+
absl::flat_hash_map<std::string, ValueProto> bindings;
407+
bindings["y"] = ParseTextProtoOrDie<ValueProto>(R"pb(int64_value: 5)pb");
408+
409+
ASSERT_OK_AND_ASSIGN(
410+
auto context, CreateTestContext(
411+
/*options=*/{.expression_source =
412+
CelExpressionSource::FromCheckedExpr(
413+
std::move(checked_expr)),
414+
.custom_bindings = std::move(bindings)}));
415+
TestRunner test_runner(std::move(context));
416+
417+
EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case),
418+
"int64_value: 15"); // expected 15 got 999.
354419
}
355420

356421
INSTANTIATE_TEST_SUITE_P(TestRunnerTests, TestRunnerParamTest,

0 commit comments

Comments
 (0)