Skip to content

Commit 3209fd2

Browse files
jnthntatumcopybara-github
authored andcommitted
Extract implementation details for NavigableAst to internal template base class.
PiperOrigin-RevId: 800176850
1 parent 158ac71 commit 3209fd2

3 files changed

Lines changed: 231 additions & 131 deletions

File tree

common/ast/navigable_ast_internal.h

Lines changed: 159 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -103,46 +103,47 @@ class NavigableAstRange {
103103

104104
explicit NavigableAstRange(SpanType span) : span_(span) {}
105105

106-
Iterator begin() { return Iterator(span_, 0); }
107-
Iterator end() { return Iterator(span_, span_.size()); }
106+
Iterator begin() const { return Iterator(span_, 0); }
107+
Iterator end() const { return Iterator(span_, span_.size()); }
108108

109109
explicit operator bool() const { return !span_.empty(); }
110110

111111
private:
112112
SpanType span_;
113113
};
114114

115-
template <typename AstNode>
115+
template <typename AstTraits>
116116
struct NavigableAstMetadata;
117117

118118
// Internal implementation for data-structures handling cross-referencing nodes.
119119
//
120120
// This is exposed separately to allow building up the AST relationships
121121
// without exposing too much mutable state on the client facing classes.
122-
template <typename AstNode>
122+
template <typename AstTraits>
123123
struct NavigableAstNodeData {
124-
AstNode* parent;
125-
const typename AstNode::ExprType* expr;
124+
typename AstTraits::NodeType* parent;
125+
const typename AstTraits::ExprType* expr;
126126
ChildKind parent_relation;
127127
NodeKind node_kind;
128-
const NavigableAstMetadata<AstNode>* absl_nonnull metadata;
128+
const NavigableAstMetadata<AstTraits>* absl_nonnull metadata;
129129
size_t index;
130130
size_t tree_size;
131131
size_t height;
132132
int child_index;
133-
std::vector<AstNode*> children;
133+
std::vector<typename AstTraits::NodeType* absl_nonnull> children;
134134
};
135135

136-
template <typename AstNode>
136+
template <typename AstTraits>
137137
struct NavigableAstMetadata {
138138
// The nodes in the AST in preorder.
139139
//
140140
// unique_ptr is used to guarantee pointer stability in the other tables.
141-
std::vector<std::unique_ptr<AstNode>> nodes;
142-
std::vector<const AstNode* absl_nonnull> postorder;
143-
absl::flat_hash_map<int64_t, const AstNode* absl_nonnull> id_to_node;
144-
absl::flat_hash_map<const typename AstNode::ExprType*,
145-
const AstNode* absl_nonnull>
141+
std::vector<std::unique_ptr<typename AstTraits::NodeType>> nodes;
142+
std::vector<const typename AstTraits::NodeType* absl_nonnull> postorder;
143+
absl::flat_hash_map<int64_t, const typename AstTraits::NodeType* absl_nonnull>
144+
id_to_node;
145+
absl::flat_hash_map<const typename AstTraits::ExprType*,
146+
const typename AstTraits::NodeType* absl_nonnull>
146147
expr_to_node;
147148
};
148149

@@ -161,6 +162,150 @@ struct PreorderTraits {
161162
}
162163
};
163164

165+
// Base class for NavigableAstNode and NavigableProtoAstNode.
166+
template <typename AstTraits>
167+
class NavigableAstNodeBase {
168+
private:
169+
using MetadataType = NavigableAstMetadata<AstTraits>;
170+
using NodeDataType = NavigableAstNodeData<AstTraits>;
171+
using Derived = typename AstTraits::NodeType;
172+
using ExprType = typename AstTraits::ExprType;
173+
174+
public:
175+
using PreorderRange = NavigableAstRange<PreorderTraits<Derived>>;
176+
using PostorderRange = NavigableAstRange<PostorderTraits<Derived>>;
177+
178+
// The parent of this node or nullptr if it is a root.
179+
const Derived* absl_nullable parent() const { return data_.parent; }
180+
181+
const ExprType* absl_nonnull expr() const { return data_.expr; }
182+
183+
// The index of this node in the parent's children. -1 if this is a root.
184+
int child_index() const { return data_.child_index; }
185+
186+
// The type of traversal from parent to this node.
187+
ChildKind parent_relation() const { return data_.parent_relation; }
188+
189+
// The type of this node, analogous to Expr::ExprKindCase.
190+
NodeKind node_kind() const { return data_.node_kind; }
191+
192+
// The number of nodes in the tree rooted at this node (including self).
193+
size_t tree_size() const { return data_.tree_size; }
194+
195+
// The height of this node in the tree (the number of descendants including
196+
// self on the longest path).
197+
size_t height() const { return data_.height; }
198+
199+
absl::Span<const Derived* const> children() const {
200+
return absl::MakeConstSpan(data_.children);
201+
}
202+
203+
// Range over the descendants of this node (including self) using preorder
204+
// semantics. Each node is visited immediately before all of its descendants.
205+
PreorderRange DescendantsPreorder() const {
206+
return PreorderRange(absl::MakeConstSpan(data_.metadata->nodes)
207+
.subspan(data_.index, data_.tree_size));
208+
}
209+
210+
// Range over the descendants of this node (including self) using postorder
211+
// semantics. Each node is visited immediately after all of its descendants.
212+
PostorderRange DescendantsPostorder() const {
213+
return PostorderRange(absl::MakeConstSpan(data_.metadata->postorder)
214+
.subspan(data_.index, data_.tree_size));
215+
}
216+
217+
private:
218+
friend Derived;
219+
220+
NavigableAstNodeBase() = default;
221+
NavigableAstNodeBase(const NavigableAstNodeBase&) = delete;
222+
NavigableAstNodeBase& operator=(const NavigableAstNodeBase&) = delete;
223+
224+
protected:
225+
NodeDataType data_;
226+
};
227+
228+
// Shared implementation for NavigableAst and NavigableProtoAst.
229+
//
230+
// AstTraits provides type info for the derived classes that implement building
231+
// the traversal metadata. It provides the following types:
232+
//
233+
// ExprType is the expression node type of the source AST.
234+
//
235+
// AstType is the subclass of NavigableAstBase for the implementation.
236+
//
237+
// NodeType is the subclass of NavigableAstNodeBase for the implementation.
238+
template <class AstTraits>
239+
class NavigableAstBase {
240+
private:
241+
using MetadataType = NavigableAstMetadata<AstTraits>;
242+
using Derived = typename AstTraits::AstType;
243+
using NodeType = typename AstTraits::NodeType;
244+
using ExprType = typename AstTraits::ExprType;
245+
246+
public:
247+
NavigableAstBase(const NavigableAstBase&) = delete;
248+
NavigableAstBase& operator=(const NavigableAstBase&) = delete;
249+
NavigableAstBase(NavigableAstBase&&) = default;
250+
NavigableAstBase& operator=(NavigableAstBase&&) = default;
251+
252+
// Return ptr to the AST node with id if present. Otherwise returns nullptr.
253+
//
254+
// If ids are non-unique, the first pre-order node encountered with id is
255+
// returned.
256+
const NodeType* absl_nullable FindId(int64_t id) const {
257+
auto it = metadata_->id_to_node.find(id);
258+
if (it == metadata_->id_to_node.end()) {
259+
return nullptr;
260+
}
261+
return it->second;
262+
}
263+
264+
// Return ptr to the AST node representing the given Expr protobuf node.
265+
const NodeType* absl_nullable FindExpr(
266+
const ExprType* absl_nonnull expr) const {
267+
auto it = metadata_->expr_to_node.find(expr);
268+
if (it == metadata_->expr_to_node.end()) {
269+
return nullptr;
270+
}
271+
return it->second;
272+
}
273+
274+
// The root of the AST.
275+
const NodeType& Root() const { return *metadata_->nodes[0]; }
276+
277+
// Check whether the source AST used unique IDs for each node.
278+
//
279+
// This is typically the case, but older versions of the parsers didn't
280+
// guarantee uniqueness for nodes generated by some macros and ASTs modified
281+
// outside of CEL's parse/type check may not have unique IDs.
282+
bool IdsAreUnique() const {
283+
return metadata_->id_to_node.size() == metadata_->nodes.size();
284+
}
285+
286+
// Equality operators test for identity. They are intended to distinguish
287+
// moved-from or uninitialized instances from initialized.
288+
bool operator==(const NavigableAstBase& other) const {
289+
return metadata_ == other.metadata_;
290+
}
291+
292+
bool operator!=(const NavigableAstBase& other) const {
293+
return metadata_ != other.metadata_;
294+
}
295+
296+
// Return true if this instance is initialized.
297+
explicit operator bool() const { return metadata_ != nullptr; }
298+
299+
private:
300+
friend Derived;
301+
302+
NavigableAstBase() = default;
303+
explicit NavigableAstBase(std::unique_ptr<MetadataType> metadata)
304+
: metadata_(std::move(metadata)) {}
305+
306+
std::unique_ptr<MetadataType> metadata_;
307+
};
308+
164309
} // namespace cel::common_internal
165310

166311
#endif // THIRD_PARTY_CEL_CPP_COMMON_AST_NAVIGABLE_AST_INTERNAL_H_

tools/navigable_ast.cc

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ using ::cel::expr::Expr;
3939
using ::google::api::expr::runtime::AstTraverse;
4040
using ::google::api::expr::runtime::SourcePosition;
4141

42+
using NavigableAstNodeData =
43+
common_internal::NavigableAstNodeData<common_internal::ProtoAstTraits>;
44+
using NavigableAstMetadata =
45+
common_internal::NavigableAstMetadata<common_internal::ProtoAstTraits>;
46+
4247
NodeKind GetNodeKind(const Expr& expr) {
4348
switch (expr.expr_kind_case()) {
4449
case Expr::kConstExpr:
@@ -67,8 +72,7 @@ NodeKind GetNodeKind(const Expr& expr) {
6772

6873
// Get the traversal relationship from parent to the given node.
6974
// Note: these depend on the ast_visitor utility's traversal ordering.
70-
ChildKind GetChildKind(const common_internal::NavigableAstNodeData<
71-
NavigableProtoAstNode>& parent_node,
75+
ChildKind GetChildKind(const NavigableAstNodeData& parent_node,
7276
size_t child_index) {
7377
constexpr size_t kComprehensionRangeArgIndex =
7478
google::api::expr::runtime::ITER_RANGE;
@@ -122,17 +126,13 @@ class NavigableExprBuilderVisitor
122126
: public google::api::expr::runtime::AstVisitorBase {
123127
public:
124128
NavigableExprBuilderVisitor(
125-
absl::AnyInvocable<std::unique_ptr<NavigableProtoAstNode>()> node_factory,
126-
absl::AnyInvocable<common_internal::NavigableAstNodeData<
127-
NavigableProtoAstNode>&(NavigableProtoAstNode&)>
128-
node_data_accessor)
129+
absl::AnyInvocable<std::unique_ptr<AstNode>()> node_factory,
130+
absl::AnyInvocable<NavigableAstNodeData&(AstNode&)> node_data_accessor)
129131
: node_factory_(std::move(node_factory)),
130132
node_data_accessor_(std::move(node_data_accessor)),
131-
metadata_(std::make_unique<common_internal::NavigableAstMetadata<
132-
NavigableProtoAstNode>>()) {}
133+
metadata_(std::make_unique<NavigableAstMetadata>()) {}
133134

134-
common_internal::NavigableAstNodeData<NavigableProtoAstNode>& NodeDataAt(
135-
size_t index) {
135+
NavigableAstNodeData& NodeDataAt(size_t index) {
136136
return node_data_accessor_(*metadata_->nodes[index]);
137137
}
138138

@@ -171,8 +171,7 @@ class NavigableExprBuilderVisitor
171171
size_t idx = parent_stack_.back();
172172
parent_stack_.pop_back();
173173
metadata_->postorder.push_back(metadata_->nodes[idx].get());
174-
common_internal::NavigableAstNodeData<NavigableProtoAstNode>& node =
175-
NodeDataAt(idx);
174+
NavigableAstNodeData& node = NodeDataAt(idx);
176175
if (!parent_stack_.empty()) {
177176
auto& parent_node_data = NodeDataAt(parent_stack_.back());
178177
parent_node_data.tree_size += node.tree_size;
@@ -181,30 +180,23 @@ class NavigableExprBuilderVisitor
181180
}
182181
}
183182

184-
std::unique_ptr<common_internal::NavigableAstMetadata<NavigableProtoAstNode>>
185-
Consume() && {
183+
std::unique_ptr<NavigableAstMetadata> Consume() && {
186184
return std::move(metadata_);
187185
}
188186

189187
private:
190-
absl::AnyInvocable<std::unique_ptr<NavigableProtoAstNode>()> node_factory_;
191-
absl::AnyInvocable<common_internal::NavigableAstNodeData<
192-
NavigableProtoAstNode>&(NavigableProtoAstNode&)>
193-
node_data_accessor_;
194-
std::unique_ptr<common_internal::NavigableAstMetadata<NavigableProtoAstNode>>
195-
metadata_;
188+
absl::AnyInvocable<std::unique_ptr<AstNode>()> node_factory_;
189+
absl::AnyInvocable<NavigableAstNodeData&(AstNode&)> node_data_accessor_;
190+
std::unique_ptr<NavigableAstMetadata> metadata_;
196191
std::vector<size_t> parent_stack_;
197192
};
198193

199194
} // namespace
200195

201196
NavigableProtoAst NavigableProtoAst::Build(const Expr& expr) {
202197
NavigableExprBuilderVisitor visitor(
203-
[]() { return absl::WrapUnique(new NavigableProtoAstNode()); },
204-
[](NavigableProtoAstNode& node)
205-
-> common_internal::NavigableAstNodeData<NavigableProtoAstNode>& {
206-
return node.data_;
207-
});
198+
[]() { return absl::WrapUnique(new AstNode()); },
199+
[](AstNode& node) -> NavigableAstNodeData& { return node.data_; });
208200
AstTraverse(&expr, /*source_info=*/nullptr, &visitor);
209201
return NavigableProtoAst(std::move(visitor).Consume());
210202
}

0 commit comments

Comments
 (0)