Skip to content

Commit 86dc7fe

Browse files
committed
Pre-compute the signature of each overload before checking via the O(n^2) consistency check. Add a fast-path at the top of the inner loop.
1 parent 272a1ea commit 86dc7fe

File tree

2 files changed

+179
-6
lines changed

2 files changed

+179
-6
lines changed

mypy/checker.py

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@
213213
FunctionLike,
214214
Instance,
215215
LiteralType,
216+
LiteralValue,
216217
NoneType,
217218
Overloaded,
218219
PartialType,
@@ -911,23 +912,42 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
911912
impl_type = self.extract_callable_type(inner_type, defn.impl)
912913

913914
is_descriptor_get = defn.info and defn.name == "__get__"
915+
916+
# Pre-extract callable types and literal fingerprints for each overload item.
917+
item_sigs: list[CallableType | None] = []
918+
item_literal_fingerprints: list[LiteralFingerprint] = []
919+
for item in defn.items:
920+
assert isinstance(item, Decorator)
921+
sig = self.extract_callable_type(item.var.type, item)
922+
item_sigs.append(sig)
923+
item_literal_fingerprints.append(
924+
build_literal_fingerprint(sig) if sig is not None else {}
925+
)
926+
914927
for i, item in enumerate(defn.items):
915928
assert isinstance(item, Decorator)
916-
sig1 = self.extract_callable_type(item.var.type, item)
929+
sig1 = item_sigs[i]
917930
if sig1 is None:
918931
continue
919932

920-
for j, item2 in enumerate(defn.items[i + 1 :]):
933+
for j, item2 in enumerate(defn.items[i + 1 :], i + 1):
921934
assert isinstance(item2, Decorator)
922-
sig2 = self.extract_callable_type(item2.var.type, item2)
935+
sig2 = item_sigs[j]
923936
if sig2 is None:
924937
continue
925938

926939
if not are_argument_counts_overlapping(sig1, sig2):
927940
continue
928941

942+
# If there is any argument position where both overloads
943+
# carry a LiteralType with different values they are disjoint.
944+
if literal_args_are_disjoint(
945+
item_literal_fingerprints[i], item_literal_fingerprints[j]
946+
):
947+
continue
948+
929949
if overload_can_never_match(sig1, sig2):
930-
self.msg.overloaded_signature_will_never_match(i + 1, i + j + 2, item2.func)
950+
self.msg.overloaded_signature_will_never_match(i + 1, j + 1, item2.func)
931951
elif not is_descriptor_get:
932952
# Note: we force mypy to check overload signatures in strict-optional mode
933953
# so we don't incorrectly report errors when a user tries typing an overload
@@ -947,14 +967,14 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
947967
with state.strict_optional_set(True):
948968
if is_unsafe_overlapping_overload_signatures(sig1, sig2, type_vars):
949969
flip_note = (
950-
j == 0
970+
j == i + 1
951971
and not is_unsafe_overlapping_overload_signatures(
952972
sig2, sig1, type_vars
953973
)
954974
and not overload_can_never_match(sig2, sig1)
955975
)
956976
self.msg.overloaded_signatures_overlap(
957-
i + 1, i + j + 2, flip_note, item.func
977+
i + 1, j + 1, flip_note, item.func
958978
)
959979

960980
if impl_type is not None:
@@ -8958,6 +8978,61 @@ def detach_callable(typ: CallableType, class_type_vars: list[TypeVarLikeType]) -
89588978
return typ.copy_modified(variables=list(typ.variables) + class_type_vars)
89598979

89608980

8981+
# Fingerprint type for literal-disjointness checks: maps argument index to
8982+
# the set of (Python type of value, value) pairs present at that position.
8983+
# Using type(value) as part of the key means Literal[1] (int) and
8984+
# Literal[True] (bool) are kept distinct even though 1 == True in Python.
8985+
# A union such as Literal["a", "b"] or Literal["a"] | Literal["b"] produces
8986+
# a frozenset of two entries; a plain Literal["a"] produces a singleton set.
8987+
LiteralFingerprint = dict[int, frozenset[tuple[type, LiteralValue]]]
8988+
8989+
8990+
def literal_args_are_disjoint(fp1: LiteralFingerprint, fp2: LiteralFingerprint) -> bool:
8991+
"""Return True if two overloads are provably disjoint via a Literal argument.
8992+
8993+
If there is any argument position where both carry only LiteralType values
8994+
and those value sets are disjoint, no single call can match both overloads
8995+
and the pairwise overlap check can be skipped entirely.
8996+
"""
8997+
for idx, vals1 in fp1.items():
8998+
vals2 = fp2.get(idx)
8999+
if vals2 is not None and vals1.isdisjoint(vals2):
9000+
return True
9001+
return False
9002+
9003+
9004+
def build_literal_fingerprint(sig: CallableType) -> LiteralFingerprint:
9005+
"""Build a LiteralFingerprint for one overload signature.
9006+
9007+
Each argument position that carries only LiteralType values (including
9008+
unions such as ``Literal["a", "b"]``) is recorded as a frozenset of
9009+
``(type(value), value)`` pairs. Positions with any non-literal type are
9010+
omitted so the disjointness check is conservative.
9011+
"""
9012+
fingerprint: LiteralFingerprint = {}
9013+
for idx, arg_type in enumerate(sig.arg_types):
9014+
proper = get_proper_type(arg_type)
9015+
if isinstance(proper, LiteralType):
9016+
fingerprint[idx] = frozenset([(type(proper.value), proper.value)])
9017+
elif isinstance(proper, UnionType):
9018+
# Literal["a", "b"] and Literal["a"] | Literal["b"] are both
9019+
# represented as a UnionType of LiteralTypes. Collect all the
9020+
# literal values; if any member is not a LiteralType the whole
9021+
# position is skipped (a non-literal in the union makes it too
9022+
# broad to prove disjointness).
9023+
vals: set[tuple[type, LiteralValue]] = set()
9024+
for member in proper.items:
9025+
m = get_proper_type(member)
9026+
if isinstance(m, LiteralType):
9027+
vals.add((type(m.value), m.value))
9028+
else:
9029+
vals = set()
9030+
break
9031+
if vals:
9032+
fingerprint[idx] = frozenset(vals)
9033+
return fingerprint
9034+
9035+
89619036
def overload_can_never_match(signature: CallableType, other: CallableType) -> bool:
89629037
"""Check if the 'other' method can never be matched due to 'signature'.
89639038

test-data/unit/check-overloading.test

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6865,3 +6865,101 @@ if isinstance(headers, dict):
68656865

68666866
reveal_type(headers) # N: Revealed type is "__main__.Headers | typing.Iterable[tuple[builtins.bytes, builtins.bytes]]"
68676867
[builtins fixtures/isinstancelist.pyi]
6868+
6869+
-- Tests for literal-disjointness fast path in check_overlapping_overloads
6870+
6871+
[case testOverloadLiteralDistinctStringsNoError]
6872+
# Overloads with distinct Literal[str] arguments are provably disjoint; no
6873+
# overlap or never-match errors should be reported.
6874+
from typing import overload, Literal
6875+
@overload
6876+
def f(x: Literal["a"]) -> int: ...
6877+
@overload
6878+
def f(x: Literal["b"]) -> str: ...
6879+
@overload
6880+
def f(x: Literal["c"]) -> float: ...
6881+
def f(x: str) -> object:
6882+
return x
6883+
[builtins fixtures/tuple.pyi]
6884+
6885+
[case testOverloadLiteralDuplicateStillErrors]
6886+
# Two overloads sharing the same Literal value should still trigger an error.
6887+
# Signature 1 covers all inputs of type Literal["a"], so signature 2 is unreachable.
6888+
from typing import overload, Literal
6889+
@overload
6890+
def f(x: Literal["a"]) -> int: ...
6891+
@overload
6892+
def f(x: Literal["a"]) -> str: ... # E: Overloaded function signature 2 will never be matched: signature 1's parameter type(s) are the same or broader
6893+
def f(x: str) -> object:
6894+
return x
6895+
[builtins fixtures/tuple.pyi]
6896+
6897+
[case testOverloadLiteralWithBroadCatchAll]
6898+
# Distinct Literal overloads followed by a broad catch-all should produce no
6899+
# overlap errors. The broad type must come last (correct ordering).
6900+
from typing import overload, Literal, Any
6901+
@overload
6902+
def f(x: Literal["a"]) -> int: ...
6903+
@overload
6904+
def f(x: Literal["b"]) -> str: ...
6905+
@overload
6906+
def f(x: str) -> Any: ...
6907+
def f(x: str) -> object:
6908+
return x
6909+
[builtins fixtures/tuple.pyi]
6910+
6911+
[case testOverloadLiteralBroadBeforeLiteralErrors]
6912+
# A broad type before a specific Literal means the Literal can never match.
6913+
from typing import overload, Literal
6914+
@overload
6915+
def f(x: str) -> int: ...
6916+
@overload
6917+
def f(x: Literal["a"]) -> str: ... # E: Overloaded function signature 2 will never be matched: signature 1's parameter type(s) are the same or broader
6918+
def f(x: str) -> object:
6919+
return x
6920+
[builtins fixtures/tuple.pyi]
6921+
6922+
[case testOverloadLiteralImplErrorsNotSuppressed]
6923+
# The literal fast path must not suppress implementation-body consistency errors.
6924+
# Use bytes as the impl return type — incompatible with both int and str.
6925+
from typing import overload, Literal
6926+
@overload
6927+
def f(x: Literal["a"]) -> int: ...
6928+
@overload
6929+
def f(x: Literal["b"]) -> str: ...
6930+
def f(x: str) -> bytes: # E: Overloaded function implementation cannot produce return type of signature 1 # E: Overloaded function implementation cannot produce return type of signature 2
6931+
return b""
6932+
[builtins fixtures/tuple.pyi]
6933+
6934+
[case testOverloadLiteralUnionDistinctNoError]
6935+
# Literal unions with disjoint value sets are provably disjoint; no errors.
6936+
from typing import overload, Literal, Union
6937+
@overload
6938+
def f(x: Literal["a", "b"]) -> int: ...
6939+
@overload
6940+
def f(x: Literal["c", "d"]) -> str: ...
6941+
def f(x: str) -> object:
6942+
return x
6943+
[builtins fixtures/tuple.pyi]
6944+
6945+
[case testOverloadLiteralUnionOverlapErrors]
6946+
# Literal unions that share a value are NOT disjoint and should be flagged.
6947+
from typing import overload, Literal
6948+
@overload
6949+
def f(x: Literal["a", "b"]) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
6950+
@overload
6951+
def f(x: Literal["b", "c"]) -> str: ...
6952+
def f(x: str) -> object:
6953+
return x
6954+
[builtins fixtures/tuple.pyi]
6955+
6956+
[case testOverloadLiteralUnionMixedNoFastPath]
6957+
# A union with a non-Literal member is not fingerprinted, so the full check runs.
6958+
from typing import overload, Literal, Union
6959+
@overload
6960+
def f(x: Literal["a"]) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
6961+
@overload
6962+
def f(x: Union[Literal["b"], str]) -> str: ...
6963+
def f(x: str) -> object:
6964+
return x
6965+
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)