Skip to content

Commit e47ba71

Browse files
jnthntatumcopybara-github
authored andcommitted
Fix case where type checker would infer a recursively defined type.
PiperOrigin-RevId: 688181907
1 parent 7e46ae9 commit e47ba71

8 files changed

Lines changed: 157 additions & 36 deletions

File tree

checker/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,12 @@ cc_test(
143143
"//checker/internal:test_ast_helpers",
144144
"//common:ast",
145145
"//common:constant",
146+
"//common:decl",
147+
"//common:type",
146148
"//internal:testing",
147149
"@com_google_absl//absl/status",
148150
"@com_google_absl//absl/status:status_matchers",
151+
"@com_google_protobuf//:protobuf",
149152
],
150153
)
151154

checker/internal/type_checker_impl_test.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,53 @@ TEST(TypeCheckerImplTest, ComprehensionVarsFollowQualifiedIdentPriority) {
759759
Contains(Pair(_, IsVariableReference("x.y"))));
760760
}
761761

762+
TEST(TypeCheckerImplTest, ComprehensionVarsCyclicParamAssignability) {
763+
TypeCheckEnv env;
764+
google::protobuf::Arena arena;
765+
ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk());
766+
767+
TypeCheckerImpl impl(std::move(env));
768+
// This is valid because the list construction in the transform will resolve
769+
// to list(dyn) since candidates E1 -> E2 and list(E1) -> E2 don't agree.
770+
ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("[].map(c, [ c, [c] ])"));
771+
ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast)));
772+
773+
EXPECT_TRUE(result.IsValid());
774+
775+
EXPECT_THAT(result.GetIssues(), IsEmpty());
776+
777+
// Remainder are conceptually the same, but confirm generality.
778+
ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("[].map(c, [ c, [[c]] ])"));
779+
ASSERT_OK_AND_ASSIGN(result, impl.Check(std::move(ast)));
780+
781+
EXPECT_TRUE(result.IsValid());
782+
783+
ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("[].map(c, [ [c], [[c]] ])"));
784+
ASSERT_OK_AND_ASSIGN(result, impl.Check(std::move(ast)));
785+
786+
EXPECT_TRUE(result.IsValid());
787+
788+
ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("[].map(c, [ c, c ])"));
789+
ASSERT_OK_AND_ASSIGN(result, impl.Check(std::move(ast)));
790+
791+
EXPECT_TRUE(result.IsValid());
792+
793+
ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("[].map(c, [ [c], c ])"));
794+
ASSERT_OK_AND_ASSIGN(result, impl.Check(std::move(ast)));
795+
796+
EXPECT_TRUE(result.IsValid());
797+
798+
ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("[].map(c, [ [[c]], c ])"));
799+
ASSERT_OK_AND_ASSIGN(result, impl.Check(std::move(ast)));
800+
801+
EXPECT_TRUE(result.IsValid());
802+
803+
ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("[].map(c, [ c, type(c) ])"));
804+
ASSERT_OK_AND_ASSIGN(result, impl.Check(std::move(ast)));
805+
806+
EXPECT_TRUE(result.IsValid());
807+
}
808+
762809
struct PrimitiveLiteralsTestCase {
763810
std::string expr;
764811
ast_internal::PrimitiveType expected_type;

checker/internal/type_inference_context.cc

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -128,21 +128,6 @@ FunctionOverloadInstance InstantiateFunctionOverload(
128128
return result;
129129
}
130130

131-
bool OccursWithin(absl::string_view var_name, Type t) {
132-
// This is difficult to trigger without lambdas in CEL, but we still check
133-
// to guarantee that we don't introduce a recursive type definition (a cycle
134-
// in the substitution map).
135-
if (t.kind() == TypeKind::kTypeParam && t.AsTypeParam()->name() == var_name) {
136-
return true;
137-
}
138-
for (const auto& param : t.GetParameters()) {
139-
if (OccursWithin(var_name, param)) {
140-
return true;
141-
}
142-
}
143-
return false;
144-
}
145-
146131
// Converts a wrapper type to its corresponding primitive type.
147132
// Returns nullopt if the type is not a wrapper type.
148133
absl::optional<Type> WrapperToPrimitive(const Type& t) {
@@ -205,7 +190,7 @@ Type TypeInferenceContext::InstantiateTypeParams(
205190
if (auto it = substitutions.find(name); it != substitutions.end()) {
206191
return TypeParamType(it->second);
207192
}
208-
absl::string_view substitution = NewTypeVar();
193+
absl::string_view substitution = NewTypeVar(name);
209194
substitutions[type.AsTypeParam()->name()] = substitution;
210195
return TypeParamType(substitution);
211196
}
@@ -360,8 +345,8 @@ Type TypeInferenceContext::Substitute(
360345
}
361346
if (auto it = type_parameter_bindings_.find(t.name());
362347
it != type_parameter_bindings_.end()) {
363-
if (it->second.has_value()) {
364-
subs = *it->second;
348+
if (it->second.type.has_value()) {
349+
subs = *it->second.type;
365350
continue;
366351
}
367352
}
@@ -370,6 +355,33 @@ Type TypeInferenceContext::Substitute(
370355
return subs;
371356
}
372357

358+
bool TypeInferenceContext::OccursWithin(
359+
absl::string_view var_name, const Type& type,
360+
const SubstitutionMap& substitutions) const {
361+
// This is difficult to trigger in normal CEL expressions, but may
362+
// happen with comprehensions where we can potentially reference a variable
363+
// with a free type var in different ways.
364+
//
365+
// This check guarantees that we don't introduce a recursive type definition
366+
// (a cycle in the substitution map).
367+
if (type.kind() == TypeKind::kTypeParam) {
368+
if (type.AsTypeParam()->name() == var_name) {
369+
return true;
370+
}
371+
auto typeSubs = Substitute(type, substitutions);
372+
if (typeSubs != type && OccursWithin(var_name, typeSubs, substitutions)) {
373+
return true;
374+
}
375+
}
376+
377+
for (const auto& param : type.GetParameters()) {
378+
if (OccursWithin(var_name, param, substitutions)) {
379+
return true;
380+
}
381+
}
382+
return false;
383+
}
384+
373385
bool TypeInferenceContext::IsAssignableWithConstraints(
374386
const Type& from, const Type& to,
375387
SubstitutionMap& prospective_substitutions) {
@@ -384,16 +396,16 @@ bool TypeInferenceContext::IsAssignableWithConstraints(
384396

385397
if (to.kind() == TypeKind::kTypeParam) {
386398
absl::string_view name = to.AsTypeParam()->name();
387-
if (!OccursWithin(name, from)) {
388-
prospective_substitutions[to.AsTypeParam()->name()] = from;
399+
if (!OccursWithin(name, from, prospective_substitutions)) {
400+
prospective_substitutions[name] = from;
389401
return true;
390402
}
391403
}
392404

393405
if (from.kind() == TypeKind::kTypeParam) {
394406
absl::string_view name = from.AsTypeParam()->name();
395-
if (!OccursWithin(name, to)) {
396-
prospective_substitutions[from.AsTypeParam()->name()] = to;
407+
if (!OccursWithin(name, to, prospective_substitutions)) {
408+
prospective_substitutions[name] = to;
397409
return true;
398410
}
399411
}
@@ -465,7 +477,7 @@ void TypeInferenceContext::UpdateTypeParameterBindings(
465477
iter != prospective_substitutions.end(); ++iter) {
466478
if (auto binding_iter = type_parameter_bindings_.find(iter->first);
467479
binding_iter != type_parameter_bindings_.end()) {
468-
binding_iter->second = iter->second;
480+
binding_iter->second.type = iter->second;
469481
} else {
470482
ABSL_LOG(WARNING) << "Uninstantiated type parameter: " << iter->first;
471483
}

checker/internal/type_inference_context.h

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,14 @@ class TypeInferenceContext {
8787
std::string DebugString() const {
8888
return absl::StrCat(
8989
"type_parameter_bindings: ",
90-
absl::StrJoin(type_parameter_bindings_, "\n ",
91-
[](std::string* out, const auto& binding) {
92-
absl::StrAppend(
93-
out, binding.first, " -> ",
94-
binding.second.value_or(Type(TypeParamType("none")))
95-
.DebugString());
96-
}));
90+
absl::StrJoin(
91+
type_parameter_bindings_, "\n ",
92+
[](std::string* out, const auto& binding) {
93+
absl::StrAppend(
94+
out, binding.first, " (", binding.second.name, ") -> ",
95+
binding.second.type.value_or(Type(TypeParamType("none")))
96+
.DebugString());
97+
}));
9798
}
9899

99100
private:
@@ -102,10 +103,15 @@ class TypeInferenceContext {
102103
// Used for prospective substitutions during type inference.
103104
using SubstitutionMap = absl::flat_hash_map<absl::string_view, Type>;
104105

105-
absl::string_view NewTypeVar() {
106+
struct TypeVar {
107+
absl::optional<Type> type;
108+
absl::string_view name;
109+
};
110+
111+
absl::string_view NewTypeVar(absl::string_view name = "") {
106112
next_type_parameter_id_++;
107113
auto inserted = type_parameter_bindings_.insert(
108-
{absl::StrCat("T%", next_type_parameter_id_), absl::nullopt});
114+
{absl::StrCat("T%", next_type_parameter_id_), {absl::nullopt, name}});
109115
ABSL_DCHECK(inserted.second);
110116
return inserted.first->first;
111117
}
@@ -134,6 +140,9 @@ class TypeInferenceContext {
134140

135141
Type Substitute(const Type& type, const SubstitutionMap& substitutions) const;
136142

143+
bool OccursWithin(absl::string_view var_name, const Type& type,
144+
const SubstitutionMap& substitutions) const;
145+
137146
void UpdateTypeParameterBindings(
138147
const SubstitutionMap& prospective_substitutions);
139148

@@ -150,8 +159,7 @@ class TypeInferenceContext {
150159
// instance.
151160
//
152161
// nullopt signifies a free type variable.
153-
absl::node_hash_map<std::string, absl::optional<Type>>
154-
type_parameter_bindings_;
162+
absl::node_hash_map<std::string, TypeVar> type_parameter_bindings_;
155163
int64_t next_type_parameter_id_ = 0;
156164
google::protobuf::Arena* arena_;
157165
bool enable_legacy_null_assignment_;

checker/internal/type_inference_context_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ TEST(TypeInferenceContextTest, DebugString) {
485485
ASSERT_TRUE(resolution.has_value());
486486
EXPECT_TRUE(resolution->result_type.IsList());
487487

488-
EXPECT_EQ(context.DebugString(), "type_parameter_bindings: T%1 -> int");
488+
EXPECT_EQ(context.DebugString(), "type_parameter_bindings: T%1 (A) -> int");
489489
}
490490

491491
struct TypeInferenceContextWrapperTypesTestCase {

checker/standard_library_test.cc

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@
2828
#include "checker/validation_result.h"
2929
#include "common/ast.h"
3030
#include "common/constant.h"
31+
#include "common/decl.h"
32+
#include "common/type.h"
3133
#include "internal/testing.h"
34+
#include "google/protobuf/arena.h"
3235

3336
namespace cel {
3437
namespace {
@@ -56,6 +59,43 @@ TEST(StandardLibraryTest, StandardLibraryErrorsIfAddedTwice) {
5659
StatusIs(absl::StatusCode::kAlreadyExists));
5760
}
5861

62+
TEST(StandardLibraryTest, ComprehensionVarsIndirectCyclicParamAssignability) {
63+
google::protobuf::Arena arena;
64+
TypeCheckerBuilder builder;
65+
ASSERT_THAT(builder.AddLibrary(StandardLibrary()), IsOk());
66+
67+
// Note: this is atypical -- parameterized variables aren't well supported
68+
// outside of built-in syntax.
69+
// e.g. `list : Type(List(A))` is instantiated per reference to bind A to
70+
// the concrete type of a list in the same assignability context.
71+
//
72+
// Validate that parameterization is sanitized to be contextual
73+
// List(V) -> List(T%1)
74+
// Map(K, V) -> Map(T%2, T%3)
75+
Type list_type = ListType(&arena, TypeParamType("V"));
76+
Type map_type = MapType(&arena, TypeParamType("K"), TypeParamType("V"));
77+
78+
ASSERT_THAT(builder.AddVariable(MakeVariableDecl("list_var", list_type)),
79+
IsOk());
80+
ASSERT_THAT(builder.AddVariable(MakeVariableDecl("map_var", map_type)),
81+
IsOk());
82+
83+
ASSERT_OK_AND_ASSIGN(std::unique_ptr<TypeChecker> type_checker,
84+
std::move(builder).Build());
85+
86+
ASSERT_OK_AND_ASSIGN(
87+
auto ast, checker_internal::MakeTestParsedAst(
88+
"list_var.exists(v,"
89+
" map_var.filter(k, map_var[k] > 1.0).size() > int(v)"
90+
")"));
91+
ASSERT_OK_AND_ASSIGN(ValidationResult result,
92+
type_checker->Check(std::move(ast)));
93+
94+
EXPECT_TRUE(result.IsValid());
95+
96+
EXPECT_THAT(result.GetIssues(), IsEmpty());
97+
}
98+
5999
class StandardLibraryDefinitionsTest : public ::testing::Test {
60100
public:
61101
void SetUp() override {

common/types/type_type.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
#include "common/type.h"
1616

17+
#include <string>
18+
1719
#include "absl/base/nullability.h"
20+
#include "absl/strings/str_cat.h"
1821
#include "absl/types/span.h"
1922
#include "google/protobuf/arena.h"
2023

@@ -41,6 +44,14 @@ struct TypeTypeData final {
4144

4245
} // namespace common_internal
4346

47+
std::string TypeType::DebugString() const {
48+
std::string s(name());
49+
if (!GetParameters().empty()) {
50+
absl::StrAppend(&s, "(", GetParameters().front().DebugString(), ")");
51+
}
52+
return s;
53+
}
54+
4455
TypeType::TypeType(absl::Nonnull<google::protobuf::Arena*> arena, const Type& parameter)
4556
: TypeType(common_internal::TypeTypeData::Create(arena, parameter)) {}
4657

common/types/type_type.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class TypeType final {
5757

5858
TypeParameters GetParameters() const ABSL_ATTRIBUTE_LIFETIME_BOUND;
5959

60-
std::string DebugString() const { return std::string(name()); }
60+
std::string DebugString() const;
6161

6262
Type GetType() const;
6363

0 commit comments

Comments
 (0)