Skip to content

Commit 8c437f2

Browse files
authored
Merge all of the code that calls __annotate__/evaluate_value (#86)
Now we only have one place where we replace the closure of a function.
1 parent 57f8da3 commit 8c437f2

5 files changed

Lines changed: 75 additions & 240 deletions

File tree

tests/test_type_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ def test_getmember_01():
422422

423423
def test_getmember_02():
424424
class C:
425-
def f[T](self, x: T) -> OnlyIntToSet[T]: ...
425+
def f[TX](self, x: TX) -> OnlyIntToSet[TX]: ...
426426

427427
m = eval_typing(GetMember[C, Literal["f"]])
428428
assert eval_typing(GetName[m]) == Literal["f"]

typemap/type_eval/_apply_generic.py

Lines changed: 62 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -189,84 +189,67 @@ def make_func(
189189
return new_func
190190

191191

192-
def _get_closure_types(af: types.FunctionType) -> dict[str, type]:
193-
# Generate a fallback mapping of closure classes.
194-
# This is needed for locally defined generic types which reference
195-
# themselves in their type annotations.
196-
if not af.__closure__:
197-
return {}
198-
return {
199-
name: variable.cell_contents
200-
for name, variable in zip(
201-
af.__code__.co_freevars, af.__closure__, strict=True
202-
)
203-
}
204-
205-
206192
EXCLUDED_ATTRIBUTES = typing.EXCLUDED_ATTRIBUTES - {'__init__'} # type: ignore[attr-defined]
207193

208194

209-
def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]:
210-
annos: dict[str, Any] = {}
211-
dct: dict[str, Any] = {}
212-
213-
if af := typing.cast(
214-
types.FunctionType, getattr(boxed.cls, "__annotate__", None)
215-
):
216-
# Class has annotations, let's resolve generic arguments
217-
218-
closure_types = _get_closure_types(af)
219-
args = tuple(
220-
types.CellType(
221-
boxed.cls.__dict__
222-
if name == "__classdict__"
223-
else boxed.str_args[name]
224-
if name in boxed.str_args
225-
else closure_types[name]
195+
def get_annotations(
196+
obj: object,
197+
args: dict[str, object],
198+
key: str = '__annotate__',
199+
annos_ok: bool = True,
200+
) -> Any | None:
201+
"""Get the annotations on an object, substituting in type vars."""
202+
203+
rr = None
204+
globs = None
205+
if af := typing.cast(types.FunctionType, getattr(obj, key, None)):
206+
# Substitute in names that are provided but keep the existing
207+
# values for everything else.
208+
closure = tuple(
209+
types.CellType(args[name]) if name in args else orig_value
210+
for name, orig_value in zip(
211+
af.__code__.co_freevars, af.__closure__ or (), strict=True
226212
)
227-
for name in af.__code__.co_freevars
228213
)
229214

230-
ff = types.FunctionType(
231-
af.__code__, af.__globals__, af.__name__, None, args
232-
)
215+
globs = af.__globals__
216+
ff = types.FunctionType(af.__code__, globs, af.__name__, None, closure)
233217
rr = ff(annotationlib.Format.VALUE)
234-
235-
if rr:
236-
for k, v in rr.items():
218+
elif annos_ok and (rr := getattr(obj, "__annotations__", None)):
219+
globs = {}
220+
if mod := sys.modules.get(obj.__module__):
221+
globs.update(vars(mod))
222+
223+
if isinstance(rr, dict) and any(isinstance(v, str) for v in rr.values()):
224+
# Copy in any __type_params__ that aren't provided for, so that if
225+
# we have to eval, we have them.
226+
if params := getattr(obj, "__type_params__", None):
227+
args = args.copy()
228+
for param in params:
229+
if str(param) not in args:
230+
args[str(param)] = param
231+
232+
for k, v in rr.items():
233+
# Eval strings
234+
if isinstance(v, str):
235+
v = eval(v, globs, args)
236+
# Handle cases where annotation is explicitly a string,
237+
# e.g.:
238+
# class Foo[X]:
239+
# x: "Foo[X | None]"
237240
if isinstance(v, str):
238-
# Handle cases where annotation is explicitly a string,
239-
# e.g.:
240-
#
241-
# class Foo[X]:
242-
# x: "Foo[X | None]"
241+
v = eval(v, globs, args)
242+
rr[k] = v
243243

244-
annos[k] = eval(v, af.__globals__, boxed.str_args)
245-
else:
246-
annos[k] = v
247-
elif af := getattr(boxed.cls, "__annotations__", None):
248-
# TODO: substitute vars in this case
249-
_globals = {}
250-
if mod := sys.modules.get(boxed.cls.__module__):
251-
_globals.update(vars(mod))
252-
_globals.update(boxed.str_args)
253-
254-
_locals = dict(boxed.cls.__dict__)
255-
_locals.update(boxed.str_args)
256-
257-
for k, v in af.items():
258-
if isinstance(v, str):
259-
result = eval(v, _globals, _locals)
260-
# Handle cases where annotation is explicitly a string
261-
# e.g.
262-
# class Foo[T]:
263-
# x: "Bar[T]"
264-
if isinstance(result, str):
265-
result = eval(result, _globals, _locals)
266-
annos[k] = result
244+
return rr
267245

268-
else:
269-
annos[k] = v
246+
247+
def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]:
248+
annos: dict[str, Any] = {}
249+
dct: dict[str, Any] = {}
250+
251+
if (rr := get_annotations(boxed.cls, boxed.str_args)) is not None:
252+
annos.update(rr)
270253

271254
for name, orig in boxed.cls.__dict__.items():
272255
if name in EXCLUDED_ATTRIBUTES:
@@ -275,42 +258,18 @@ def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]:
275258
stuff = inspect.unwrap(orig)
276259

277260
if isinstance(stuff, types.FunctionType):
278-
local_fn: types.FunctionType | classmethod | staticmethod | None = (
279-
None
280-
)
281-
282-
if af := typing.cast(
283-
types.FunctionType, getattr(stuff, "__annotate__", None)
284-
):
285-
params = dict(
286-
zip(
287-
map(str, stuff.__type_params__),
288-
stuff.__type_params__,
289-
strict=True,
290-
)
291-
)
292-
293-
closure_types = _get_closure_types(af)
294-
args = tuple(
295-
types.CellType(
296-
boxed.cls.__dict__
297-
if name == "__classdict__"
298-
else params[name]
299-
if name in params
300-
else boxed.str_args[name]
301-
if name in boxed.str_args
302-
else closure_types[name]
303-
)
304-
for name in af.__code__.co_freevars
305-
)
306-
307-
ff = types.FunctionType(
308-
af.__code__, af.__globals__, af.__name__, None, args
309-
)
310-
rr = ff(annotationlib.Format.VALUE)
311-
261+
local_fn: Any = None
262+
263+
# TODO: This annos_ok thing is a hack because processing
264+
# __annotations__ on methods broke stuff and I didn't want
265+
# to chase it down yet.
266+
if (
267+
rr := get_annotations(stuff, boxed.str_args, annos_ok=False)
268+
) is not None:
312269
local_fn = make_func(orig, rr)
313-
elif af := getattr(stuff, "__annotations__", None):
270+
elif getattr(stuff, "__annotations__", None):
271+
# XXX: This is totally wrong; we still need to do
272+
# substitute in class vars
314273
local_fn = stuff
315274

316275
if local_fn is not None:

typemap/type_eval/_eval_call.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import annotationlib
21
import enum
32
import inspect
43
import types
@@ -12,7 +11,7 @@
1211
from . import _eval_typing
1312
from . import _typing_inspect
1413
from ._eval_operators import _callable_type_to_signature
15-
from ._apply_generic import substitute, _get_closure_types
14+
from ._apply_generic import substitute, get_annotations
1615

1716
RtType = Any
1817

@@ -39,7 +38,7 @@ def _get_bound_type_args(
3938
arg_types: tuple[RtType, ...],
4039
kwarg_types: dict[str, RtType],
4140
) -> dict[str, RtType]:
42-
sig = _eval_operators._resolved_function_signature(func)
41+
sig = inspect.signature(func)
4342

4443
bound = sig.bind(*arg_types, **kwarg_types)
4544

@@ -188,30 +187,12 @@ def _eval_call_with_type_vars(
188187
vars: dict[str, RtType],
189188
ctx: _eval_typing.EvalContext,
190189
) -> RtType:
191-
try:
192-
af = typing.cast(types.FunctionType, func.__annotate__)
193-
except AttributeError:
194-
raise ValueError("func has no __annotate__ attribute")
195-
if not af:
196-
raise ValueError("func has no __annotate__ attribute")
197-
198-
closure_types = _get_closure_types(af)
199-
for name, value in closure_types.items():
200-
if name not in vars:
201-
vars[name] = value
202-
203-
af_args = tuple(
204-
types.CellType(vars[name]) for name in af.__code__.co_freevars
205-
)
206-
207-
ff = types.FunctionType(
208-
af.__code__, af.__globals__, af.__name__, None, af_args
209-
)
210-
211190
old_obj = ctx.current_generic_alias
212191
ctx.current_generic_alias = func
213192
try:
214-
rr = ff(annotationlib.Format.VALUE)
193+
rr = get_annotations(func, vars)
194+
if rr is None:
195+
return Any
215196
return _eval_typing.eval_typing(rr["return"])
216197
finally:
217198
ctx.current_generic_alias = old_obj

typemap/type_eval/_eval_operators.py

Lines changed: 1 addition & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import re
88
import types
99
import typing
10-
import sys
1110

1211
from typing_extensions import _AnnotatedAlias as typing_AnnotatedAlias
1312

@@ -637,7 +636,7 @@ def _callable_type_to_method(name, typ, ctx):
637636

638637
def _function_type(func, *, receiver_type):
639638
root = inspect.unwrap(func)
640-
sig = _resolved_function_signature(root, receiver_type)
639+
sig = inspect.signature(root)
641640
# XXX: __type_params__!!!
642641

643642
empty = inspect.Parameter.empty
@@ -730,94 +729,6 @@ def _create_generic_callable_lambda(
730729
]
731730

732731

733-
def _resolved_function_signature(func, receiver_type=None):
734-
"""Get the signature of a function with type hints resolved.
735-
736-
This is used to deal with string annotations in the signature which are
737-
generated when using __future__ import annotations.
738-
"""
739-
740-
sig = inspect.signature(func)
741-
742-
_globals, _locals = _get_function_hint_namespaces(func, receiver_type)
743-
if hints := typing.get_type_hints(
744-
func, globalns=_globals, localns=_locals, include_extras=True
745-
):
746-
params = []
747-
for name, param in sig.parameters.items():
748-
annotation = hints.get(name, param.annotation)
749-
params.append(param.replace(annotation=annotation))
750-
751-
return_annotation = hints.get("return", sig.return_annotation)
752-
sig = sig.replace(
753-
parameters=params, return_annotation=return_annotation
754-
)
755-
756-
return sig
757-
758-
759-
def _get_class_type_hint_namespaces(
760-
obj: type,
761-
) -> tuple[dict[str, typing.Any], dict[str, typing.Any]]:
762-
globalns: dict[str, typing.Any] = {}
763-
localns: dict[str, typing.Any] = {}
764-
765-
# Get module globals
766-
if obj.__module__ and (module := sys.modules.get(obj.__module__)):
767-
globalns.update(module.__dict__)
768-
769-
# Annotations may use typevars defined in the class
770-
localns.update(obj.__dict__)
771-
772-
if _typing_inspect.is_generic_alias(obj):
773-
# We need the origin's type vars
774-
localns.update(obj.__origin__.__dict__)
775-
776-
# Extract type parameters from the class
777-
args = typing.get_args(obj)
778-
origin = typing.get_origin(obj)
779-
tps = getattr(obj, '__type_params__', ()) or getattr(
780-
origin, '__parameters__', ()
781-
)
782-
for tp, arg in zip(tps, args, strict=False):
783-
localns[tp.__name__] = arg
784-
785-
# Add the class itself for self-references
786-
localns[obj.__name__] = obj
787-
788-
return globalns, localns
789-
790-
791-
def _get_function_hint_namespaces(func, receiver_type=None):
792-
globalns = {}
793-
localns = {}
794-
795-
# module globals
796-
module = inspect.getmodule(func)
797-
if module:
798-
globalns |= module.__dict__
799-
800-
# If no receiver was specified, this might still be a method, try to get
801-
# the class from the qualname.
802-
if (
803-
not receiver_type
804-
and (qn := getattr(func, '__qualname__', None))
805-
and '.' in qn
806-
):
807-
class_name = qn.rsplit('.', 1)[0]
808-
receiver_type = getattr(module, class_name, None)
809-
810-
# Get the class's type hint namespaces
811-
if receiver_type:
812-
cls_globalns, cls_localns = _get_class_type_hint_namespaces(
813-
receiver_type
814-
)
815-
globalns.update(cls_globalns)
816-
localns.update(cls_localns)
817-
818-
return globalns, localns
819-
820-
821732
def _hint_to_member(n, t, qs, init, d, *, ctx):
822733
return Member[
823734
typing.Literal[n],

0 commit comments

Comments
 (0)