@@ -128,21 +128,6 @@ FunctionOverloadInstance InstantiateFunctionOverload(
128128 return result;
129129}
130130
131- bool OccursWithin (absl::string_view var_name, Type t) {
132- // This is difficult to trigger without lambdas in CEL, but we still check
133- // to guarantee that we don't introduce a recursive type definition (a cycle
134- // in the substitution map).
135- if (t.kind () == TypeKind::kTypeParam && t.AsTypeParam ()->name () == var_name) {
136- return true ;
137- }
138- for (const auto & param : t.GetParameters ()) {
139- if (OccursWithin (var_name, param)) {
140- return true ;
141- }
142- }
143- return false ;
144- }
145-
146131// Converts a wrapper type to its corresponding primitive type.
147132// Returns nullopt if the type is not a wrapper type.
148133absl::optional<Type> WrapperToPrimitive (const Type& t) {
@@ -205,7 +190,7 @@ Type TypeInferenceContext::InstantiateTypeParams(
205190 if (auto it = substitutions.find (name); it != substitutions.end ()) {
206191 return TypeParamType (it->second );
207192 }
208- absl::string_view substitution = NewTypeVar ();
193+ absl::string_view substitution = NewTypeVar (name );
209194 substitutions[type.AsTypeParam ()->name ()] = substitution;
210195 return TypeParamType (substitution);
211196 }
@@ -360,8 +345,8 @@ Type TypeInferenceContext::Substitute(
360345 }
361346 if (auto it = type_parameter_bindings_.find (t.name ());
362347 it != type_parameter_bindings_.end ()) {
363- if (it->second .has_value ()) {
364- subs = *it->second ;
348+ if (it->second .type . has_value ()) {
349+ subs = *it->second . type ;
365350 continue ;
366351 }
367352 }
@@ -370,6 +355,33 @@ Type TypeInferenceContext::Substitute(
370355 return subs;
371356}
372357
358+ bool TypeInferenceContext::OccursWithin (
359+ absl::string_view var_name, const Type& type,
360+ const SubstitutionMap& substitutions) const {
361+ // This is difficult to trigger in normal CEL expressions, but may
362+ // happen with comprehensions where we can potentially reference a variable
363+ // with a free type var in different ways.
364+ //
365+ // This check guarantees that we don't introduce a recursive type definition
366+ // (a cycle in the substitution map).
367+ if (type.kind () == TypeKind::kTypeParam ) {
368+ if (type.AsTypeParam ()->name () == var_name) {
369+ return true ;
370+ }
371+ auto typeSubs = Substitute (type, substitutions);
372+ if (typeSubs != type && OccursWithin (var_name, typeSubs, substitutions)) {
373+ return true ;
374+ }
375+ }
376+
377+ for (const auto & param : type.GetParameters ()) {
378+ if (OccursWithin (var_name, param, substitutions)) {
379+ return true ;
380+ }
381+ }
382+ return false ;
383+ }
384+
373385bool TypeInferenceContext::IsAssignableWithConstraints (
374386 const Type& from, const Type& to,
375387 SubstitutionMap& prospective_substitutions) {
@@ -384,16 +396,16 @@ bool TypeInferenceContext::IsAssignableWithConstraints(
384396
385397 if (to.kind () == TypeKind::kTypeParam ) {
386398 absl::string_view name = to.AsTypeParam ()->name ();
387- if (!OccursWithin (name, from)) {
388- prospective_substitutions[to. AsTypeParam ()-> name () ] = from;
399+ if (!OccursWithin (name, from, prospective_substitutions )) {
400+ prospective_substitutions[name] = from;
389401 return true ;
390402 }
391403 }
392404
393405 if (from.kind () == TypeKind::kTypeParam ) {
394406 absl::string_view name = from.AsTypeParam ()->name ();
395- if (!OccursWithin (name, to)) {
396- prospective_substitutions[from. AsTypeParam ()-> name () ] = to;
407+ if (!OccursWithin (name, to, prospective_substitutions )) {
408+ prospective_substitutions[name] = to;
397409 return true ;
398410 }
399411 }
@@ -465,7 +477,7 @@ void TypeInferenceContext::UpdateTypeParameterBindings(
465477 iter != prospective_substitutions.end (); ++iter) {
466478 if (auto binding_iter = type_parameter_bindings_.find (iter->first );
467479 binding_iter != type_parameter_bindings_.end ()) {
468- binding_iter->second = iter->second ;
480+ binding_iter->second . type = iter->second ;
469481 } else {
470482 ABSL_LOG (WARNING) << " Uninstantiated type parameter: " << iter->first ;
471483 }
0 commit comments