Skip to content

Commit bedb150

Browse files
jnthntatumcopybara-github
authored andcommitted
Memoize enum lookup table in FlatExprBuilder.
PiperOrigin-RevId: 739351773
1 parent ad75340 commit bedb150

13 files changed

+449
-201
lines changed

eval/compiler/BUILD

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,12 @@ cc_library(
143143
"@com_google_absl//absl/container:node_hash_map",
144144
"@com_google_absl//absl/functional:any_invocable",
145145
"@com_google_absl//absl/log:absl_check",
146+
"@com_google_absl//absl/log:absl_log",
146147
"@com_google_absl//absl/log:check",
147148
"@com_google_absl//absl/status",
148149
"@com_google_absl//absl/status:statusor",
149150
"@com_google_absl//absl/strings",
151+
"@com_google_absl//absl/synchronization",
150152
"@com_google_absl//absl/types:optional",
151153
"@com_google_absl//absl/types:span",
152154
"@com_google_absl//absl/types:variant",
@@ -368,7 +370,9 @@ cc_test(
368370
"//runtime/internal:issue_collector",
369371
"//runtime/internal:runtime_env",
370372
"//runtime/internal:runtime_env_testing",
373+
"@com_google_absl//absl/base:no_destructor",
371374
"@com_google_absl//absl/base:nullability",
375+
"@com_google_absl//absl/container:flat_hash_map",
372376
"@com_google_absl//absl/status",
373377
"@com_google_absl//absl/status:statusor",
374378
"@com_google_absl//absl/strings",
@@ -417,11 +421,12 @@ cc_library(
417421
"//internal:status_macros",
418422
"//runtime:function_overload_reference",
419423
"//runtime:function_registry",
420-
"//runtime:type_registry",
424+
"@com_google_absl//absl/base:no_destructor",
421425
"@com_google_absl//absl/container:flat_hash_map",
422426
"@com_google_absl//absl/status:statusor",
423427
"@com_google_absl//absl/strings",
424428
"@com_google_absl//absl/types:optional",
429+
"@com_google_absl//absl/types:span",
425430
],
426431
)
427432

@@ -436,19 +441,22 @@ cc_test(
436441
"//base:ast",
437442
"//base:builtins",
438443
"//common:expr",
444+
"//common:value",
439445
"//common/ast:ast_impl",
440446
"//common/ast:expr",
441447
"//common/ast:expr_proto",
442448
"//eval/public:builtin_func_registrar",
443449
"//eval/public:cel_function",
444450
"//eval/public:cel_function_registry",
451+
"//eval/public:cel_value",
445452
"//extensions/protobuf:ast_converters",
446453
"//internal:casts",
447454
"//internal:proto_matchers",
448455
"//internal:testing",
449456
"//runtime:runtime_issue",
450457
"//runtime:type_registry",
451458
"//runtime/internal:issue_collector",
459+
"@com_google_absl//absl/base:no_destructor",
452460
"@com_google_absl//absl/container:flat_hash_map",
453461
"@com_google_absl//absl/log:absl_check",
454462
"@com_google_absl//absl/memory",
@@ -496,7 +504,6 @@ cc_test(
496504
"//eval/testutil:test_message_cc_proto",
497505
"//internal:testing",
498506
"@com_google_absl//absl/status",
499-
"@com_google_absl//absl/types:optional",
500507
"@com_google_absl//absl/types:span",
501508
"@com_google_protobuf//:protobuf",
502509
],

eval/compiler/cel_expression_builder_flat_impl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class CelExpressionBuilderFlatImpl : public CelExpressionBuilder {
7474
FlatExprBuilder& flat_expr_builder() { return flat_expr_builder_; }
7575

7676
void set_container(std::string container) override {
77+
flat_expr_builder_.InvalidateResolverIndex();
7778
flat_expr_builder_.set_container(std::move(container));
7879
}
7980

@@ -87,6 +88,7 @@ class CelExpressionBuilderFlatImpl : public CelExpressionBuilder {
8788
// CelValue instances, and to extend the set of types and enums known to
8889
// expressions by registering them ahead of time.
8990
CelTypeRegistry* GetTypeRegistry() const override {
91+
flat_expr_builder_.InvalidateResolverIndex();
9092
return &env_->legacy_type_registry;
9193
}
9294

eval/compiler/constant_folding_test.cc

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@
1515
#include "eval/compiler/constant_folding.h"
1616

1717
#include <memory>
18+
#include <string>
1819
#include <utility>
20+
#include <vector>
1921

2022
#include "cel/expr/syntax.pb.h"
23+
#include "absl/base/no_destructor.h"
2124
#include "absl/base/nullability.h"
25+
#include "absl/container/flat_hash_map.h"
2226
#include "absl/status/status.h"
2327
#include "absl/status/statusor.h"
2428
#include "absl/strings/string_view.h"
@@ -68,16 +72,28 @@ using ::google::api::expr::runtime::ProgramOptimizerFactory;
6872
using ::google::api::expr::runtime::Resolver;
6973
using ::testing::SizeIs;
7074

75+
const std::vector<std::string>& EmptyNamespacePrefixes() {
76+
static const absl::NoDestructor<std::vector<std::string>> kEmptyPrefixes(
77+
{""});
78+
return *kEmptyPrefixes;
79+
}
80+
81+
const absl::flat_hash_map<std::string, cel::Value>& EmptyEnumValueMap() {
82+
static const absl::NoDestructor<absl::flat_hash_map<std::string, cel::Value>>
83+
kEmptyEnumValueMap({});
84+
return *kEmptyEnumValueMap;
85+
}
86+
7187
class UpdatedConstantFoldingTest : public testing::Test {
7288
public:
7389
UpdatedConstantFoldingTest()
7490
: env_(NewTestingRuntimeEnv()),
7591
function_registry_(env_->function_registry),
7692
type_registry_(env_->type_registry),
7793
issue_collector_(RuntimeIssue::Severity::kError),
78-
resolver_("", function_registry_, type_registry_,
79-
type_registry_.GetComposedTypeProvider(),
80-
type_registry_.resolveable_enums()) {}
94+
resolver_(EmptyNamespacePrefixes(), EmptyEnumValueMap(),
95+
function_registry_,
96+
type_registry_.GetComposedTypeProvider()) {}
8197

8298
protected:
8399
absl::Nonnull<std::shared_ptr<RuntimeEnv>> env_;

eval/compiler/flat_expr_builder.cc

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include "absl/strings/match.h"
4242
#include "absl/strings/numbers.h"
4343
#include "absl/strings/str_cat.h"
44+
#include "absl/strings/str_split.h"
4445
#include "absl/strings/string_view.h"
4546
#include "absl/strings/strip.h"
4647
#include "absl/types/optional.h"
@@ -2459,14 +2460,50 @@ std::vector<ExecutionPathView> FlattenExpressionTable(
24592460

24602461
} // namespace
24612462

2463+
std::shared_ptr<FlatExprBuilder::ResolverIndex>
2464+
FlatExprBuilder::BuildResolverIndex() const {
2465+
auto result = std::make_shared<ResolverIndex>();
2466+
auto& namespace_prefixes = result->namespace_prefixes;
2467+
auto& enum_value_map = result->enum_value_map;
2468+
const auto& resolveable_enums = type_registry_.resolveable_enums();
2469+
2470+
std::string prefix = "";
2471+
namespace_prefixes.push_back(prefix);
2472+
auto container_elements = absl::StrSplit(container_, '.');
2473+
for (const auto& elem : container_elements) {
2474+
// Tolerate trailing / leading '.'.
2475+
if (elem.empty()) {
2476+
continue;
2477+
}
2478+
absl::StrAppend(&prefix, elem, ".");
2479+
// longest prefix first.
2480+
namespace_prefixes.insert(namespace_prefixes.begin(), prefix);
2481+
}
2482+
2483+
for (auto iter = resolveable_enums.begin(); iter != resolveable_enums.end();
2484+
++iter) {
2485+
absl::string_view enum_name = iter->first;
2486+
const auto& enum_type = iter->second;
2487+
2488+
for (const auto& enumerator : enum_type.enumerators) {
2489+
auto key = absl::StrCat(enum_name, ".", enumerator.name);
2490+
enum_value_map[key] = cel::IntValue(enumerator.number);
2491+
}
2492+
}
2493+
2494+
return result;
2495+
}
2496+
24622497
absl::StatusOr<FlatExpression> FlatExprBuilder::CreateExpressionImpl(
24632498
std::unique_ptr<Ast> ast, std::vector<RuntimeIssue>* issues) const {
24642499
RuntimeIssue::Severity max_severity = options_.fail_on_warnings
24652500
? RuntimeIssue::Severity::kWarning
24662501
: RuntimeIssue::Severity::kError;
24672502
IssueCollector issue_collector(max_severity);
2468-
Resolver resolver(container_, function_registry_, type_registry_,
2469-
GetTypeProvider(), type_registry_.resolveable_enums(),
2503+
auto resolver_index = GetResolverIndex();
2504+
Resolver resolver(resolver_index->namespace_prefixes,
2505+
resolver_index->enum_value_map, function_registry_,
2506+
GetTypeProvider(),
24702507
options_.enable_qualified_type_identifiers);
24712508

24722509
std::shared_ptr<google::protobuf::Arena> arena;

eval/compiler/flat_expr_builder.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,15 @@
2323
#include <vector>
2424

2525
#include "absl/base/nullability.h"
26+
#include "absl/base/thread_annotations.h"
27+
#include "absl/container/flat_hash_map.h"
28+
#include "absl/log/absl_log.h"
2629
#include "absl/status/statusor.h"
2730
#include "absl/strings/string_view.h"
31+
#include "absl/synchronization/mutex.h"
2832
#include "base/ast.h"
2933
#include "base/type_provider.h"
34+
#include "common/value.h"
3035
#include "eval/compiler/flat_expr_builder_extensions.h"
3136
#include "eval/eval/evaluator_core.h"
3237
#include "runtime/function_registry.h"
@@ -93,11 +98,55 @@ class FlatExprBuilder {
9398
// `optional_type` handling is needed.
9499
void enable_optional_types() { enable_optional_types_ = true; }
95100

101+
// Note: this is a temporary solution to support the legacy expression
102+
// builder which remains mutable after building expressions. This is not
103+
// correct for thread safety (e.g. storing a reference to the type
104+
// registry), however the builder is otherwise thread hostile if used in this
105+
// way so only mitigates some cases.
106+
void InvalidateResolverIndex() const {
107+
absl::MutexLock lock(&resolver_index_mutex_);
108+
if (resolver_index_ != nullptr) {
109+
ABSL_LOG(WARNING)
110+
<< "attempted to update CEL expression builder after use";
111+
}
112+
resolver_index_.reset();
113+
}
114+
96115
private:
97116
const cel::TypeProvider& GetTypeProvider() const;
98117

99118
const absl::Nonnull<std::shared_ptr<const cel::runtime_internal::RuntimeEnv>>
100119
env_;
120+
121+
struct ResolverIndex {
122+
std::vector<std::string> namespace_prefixes;
123+
absl::flat_hash_map<std::string, cel::Value> enum_value_map;
124+
};
125+
126+
std::shared_ptr<ResolverIndex> BuildResolverIndex() const;
127+
128+
std::shared_ptr<ResolverIndex> GetResolverIndex() const {
129+
std::shared_ptr<ResolverIndex> result;
130+
{
131+
absl::ReaderMutexLock lock(&resolver_index_mutex_);
132+
result = resolver_index_;
133+
}
134+
135+
if (result != nullptr) {
136+
return result;
137+
}
138+
// Slow path: build the resolver index.
139+
absl::MutexLock lock(&resolver_index_mutex_);
140+
result = resolver_index_;
141+
if (result != nullptr) {
142+
return result;
143+
}
144+
145+
result = BuildResolverIndex();
146+
resolver_index_ = result;
147+
return result;
148+
}
149+
101150
cel::RuntimeOptions options_;
102151
std::string container_;
103152
bool enable_optional_types_ = false;
@@ -108,6 +157,11 @@ class FlatExprBuilder {
108157
bool use_legacy_type_provider_;
109158
std::vector<std::unique_ptr<AstTransform>> ast_transforms_;
110159
std::vector<ProgramOptimizerFactory> program_optimizers_;
160+
161+
// See note on accessor.
162+
mutable std::shared_ptr<ResolverIndex> resolver_index_
163+
ABSL_GUARDED_BY(resolver_index_mutex_);
164+
mutable absl::Mutex resolver_index_mutex_;
111165
};
112166

113167
} // namespace google::api::expr::runtime

eval/compiler/flat_expr_builder_extensions_test.cc

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,26 @@ using ::testing::Optional;
5555

5656
using Subexpression = ProgramBuilder::Subexpression;
5757

58+
const std::vector<std::string>& EmptyNamespacePrefixes() {
59+
static const absl::NoDestructor<std::vector<std::string>> kEmptyPrefixes(
60+
{""});
61+
return *kEmptyPrefixes;
62+
}
63+
64+
const absl::flat_hash_map<std::string, cel::Value>& EmptyEnumValueMap() {
65+
static const absl::NoDestructor<absl::flat_hash_map<std::string, cel::Value>>
66+
kEmptyEnumValueMap({});
67+
return *kEmptyEnumValueMap;
68+
}
69+
5870
class PlannerContextTest : public testing::Test {
5971
public:
6072
PlannerContextTest()
6173
: env_(NewTestingRuntimeEnv()),
6274
type_registry_(env_->type_registry),
6375
function_registry_(env_->function_registry),
64-
resolver_("", function_registry_, type_registry_,
65-
type_registry_.GetComposedTypeProvider(),
66-
type_registry_.resolveable_enums()),
76+
resolver_(EmptyNamespacePrefixes(), EmptyEnumValueMap(),
77+
function_registry_, type_registry_.GetComposedTypeProvider()),
6778
issue_collector_(RuntimeIssue::Severity::kError) {}
6879

6980
protected:

0 commit comments

Comments
 (0)