Skip to content
Prev Previous commit
Next Next commit
Add main functionality
  • Loading branch information
ilevkivskyi committed Aug 21, 2022
commit 8d107392cb1e66a94b9b2e8ef579c01a7ff94c4f
7 changes: 4 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,8 +730,9 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
# needs to be compatible in.
if impl_type.variables:
impl = unify_generic_callable(
impl_type,
sig1,
# Normalize both before unifying
impl_type.with_unpacked_kwargs(),
sig1.with_unpacked_kwargs(),
ignore_return=False,
return_constraint_direction=SUPERTYPE_OF,
)
Expand Down Expand Up @@ -1166,7 +1167,7 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: str | None) ->
# builtins.tuple[T] is typing.Tuple[T, ...]
arg_type = self.named_generic_type("builtins.tuple", [arg_type])
elif typ.arg_kinds[i] == nodes.ARG_STAR2:
if not isinstance(arg_type, ParamSpecType):
if not isinstance(arg_type, ParamSpecType) and not typ.unpack_kwargs:
arg_type = self.named_generic_type(
"builtins.dict", [self.str_type(), arg_type]
)
Expand Down
4 changes: 4 additions & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,6 +1322,8 @@ def check_callable_call(

See the docstring of check_call for more information.
"""
# Always unpack **kwargs before checking a call.
callee = callee.with_unpacked_kwargs()
if callable_name is None and callee.name:
callable_name = callee.name
ret_type = get_proper_type(callee.ret_type)
Expand Down Expand Up @@ -2057,6 +2059,8 @@ def check_overload_call(
context: Context,
) -> tuple[Type, Type]:
"""Checks a call to an overloaded function."""
# Normalize unpacked kwargs before checking the call.
callee = callee.with_unpacked_kwargs()
arg_types = self.infer_arg_types_in_empty_context(args)
# Step 1: Filter call targets to remove ones where the argument counts don't match
plausible_targets = self.plausible_overload_call_targets(
Expand Down
6 changes: 5 additions & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,9 +735,13 @@ def infer_constraints_from_protocol_members(
return res

def visit_callable_type(self, template: CallableType) -> list[Constraint]:
# Normalize callables before matching against each other.
# Note that non-normalized callables can be created in annotations
# using e.g. callback protocols.
template = template.with_unpacked_kwargs()
if isinstance(self.actual, CallableType):
res: list[Constraint] = []
cactual = self.actual
cactual = self.actual.with_unpacked_kwargs()
param_spec = template.param_spec()
if param_spec is None:
# FIX verify argument counts
Expand Down
18 changes: 17 additions & 1 deletion mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

from typing import Tuple

import mypy.typeops
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT
Expand Down Expand Up @@ -141,7 +143,7 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType:

def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
"""Return a simple least upper bound given the declared type."""
# TODO: check infinite recursion for aliases here.
# TODO: check infinite recursion for aliases here?
declaration = get_proper_type(declaration)
s = get_proper_type(s)
t = get_proper_type(t)
Expand Down Expand Up @@ -172,6 +174,9 @@ def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
if isinstance(s, UninhabitedType) and not isinstance(t, UninhabitedType):
s, t = t, s

# Meets/joins require callable type normalization.
s, t = normalize_callables(s, t)

value = t.accept(TypeJoinVisitor(s))
if declaration is None or is_subtype(value, declaration):
return value
Expand Down Expand Up @@ -229,6 +234,9 @@ def join_types(s: Type, t: Type, instance_joiner: InstanceJoiner | None = None)
elif isinstance(t, PlaceholderType):
return AnyType(TypeOfAny.from_error)

# Meets/joins require callable type normalization.
s, t = normalize_callables(s, t)

# Use a visitor to handle non-trivial cases.
return t.accept(TypeJoinVisitor(s, instance_joiner))

Expand Down Expand Up @@ -528,6 +536,14 @@ def is_better(t: Type, s: Type) -> bool:
return False


def normalize_callables(s: ProperType, t: ProperType) -> Tuple[ProperType, ProperType]:
if isinstance(s, (CallableType, Overloaded)):
s = s.with_unpacked_kwargs()
if isinstance(t, (CallableType, Overloaded)):
t = t.with_unpacked_kwargs()
return s, t


def is_similar_callables(t: CallableType, s: CallableType) -> bool:
"""Return True if t and s have identical numbers of
arguments, default arguments and varargs.
Expand Down
4 changes: 4 additions & 0 deletions mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def meet_types(s: Type, t: Type) -> ProperType:
return t
if isinstance(s, UnionType) and not isinstance(t, UnionType):
s, t = t, s

# Meets/joins require callable type normalization.
s, t = join.normalize_callables(s, t)

return t.accept(TypeMeetVisitor(s))


Expand Down
5 changes: 4 additions & 1 deletion mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2391,7 +2391,10 @@ def [T <: int] f(self, x: int, y: T) -> None
name = tp.arg_names[i]
if name:
s += name + ": "
s += format_type_bare(tp.arg_types[i])
type_str = format_type_bare(tp.arg_types[i])
if tp.arg_kinds[i] == ARG_STAR2 and tp.unpack_kwargs:
type_str = f"Unpack[{type_str}]"
s += type_str
if tp.arg_kinds[i].is_optional():
s += " = ..."

Expand Down
26 changes: 26 additions & 0 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@
get_proper_types,
invalid_recursive_alias,
is_named_instance,
UnpackType,
)
from mypy.typevars import fill_typevars
from mypy.util import (
Expand Down Expand Up @@ -830,6 +831,8 @@ def analyze_func_def(self, defn: FuncDef) -> None:
self.defer(defn)
return
assert isinstance(result, ProperType)
if isinstance(result, CallableType):
result = self.remove_unpack_kwargs(defn, result)
defn.type = result
self.add_type_alias_deps(analyzer.aliases_used)
self.check_function_signature(defn)
Expand Down Expand Up @@ -872,6 +875,29 @@ def analyze_func_def(self, defn: FuncDef) -> None:
defn.type = defn.type.copy_modified(ret_type=ret_type)
self.wrapped_coro_return_types[defn] = defn.type

def remove_unpack_kwargs(self, defn: FuncDef, typ: CallableType) -> CallableType:
if not typ.arg_kinds or typ.arg_kinds[-1] is not ArgKind.ARG_STAR2:
return typ
last_type = get_proper_type(typ.arg_types[-1])
if not isinstance(last_type, UnpackType):
return typ
last_type = get_proper_type(last_type.type)
if not isinstance(last_type, TypedDictType):
self.fail("Unpack item in ** argument must be a TypedDict", defn)
new_arg_types = typ.arg_types[:-1] + [AnyType(TypeOfAny.from_error)]
return typ.copy_modified(arg_types=new_arg_types)
overlap = set(typ.arg_names) & set(last_type.items)
# It is OK for TypedDict to have a key named 'kwargs'.
overlap.discard(typ.arg_names[-1])
if overlap:
overlapped = ", ".join([f'"{name}"' for name in overlap])
self.fail(f"Overlap between argument names and ** TypedDict items: {overlapped}", defn)
new_arg_types = typ.arg_types[:-1] + [AnyType(TypeOfAny.from_error)]
return typ.copy_modified(arg_types=new_arg_types)
# OK, everything looks right now, mark the callable type as using unpack.
new_arg_types = typ.arg_types[:-1] + [last_type]
return typ.copy_modified(arg_types=new_arg_types, unpack_kwargs=True)

def prepare_method_signature(self, func: FuncDef, info: TypeInfo) -> None:
"""Check basic signature validity and tweak annotation of self/cls argument."""
# Only non-static methods are special.
Expand Down
4 changes: 4 additions & 0 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,10 @@ def g(x: int) -> int: ...
If the 'some_check' function is also symmetric, the two calls would be equivalent
whether or not we check the args covariantly.
"""
# Normalize both types before comparing them.
left = left.with_unpacked_kwargs()
right = right.with_unpacked_kwargs()

if is_compat_return is None:
is_compat_return = is_compat

Expand Down
34 changes: 33 additions & 1 deletion mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,6 +1590,7 @@ class CallableType(FunctionLike):
"type_guard", # T, if -> TypeGuard[T] (ret_type is bool in this case).
"from_concatenate", # whether this callable is from a concatenate object
# (this is used for error messages)
"unpack_kwargs", # Was an Unpack[...] with **kwargs used to define this callable?
)

def __init__(
Expand All @@ -1613,6 +1614,7 @@ def __init__(
def_extras: dict[str, Any] | None = None,
type_guard: Type | None = None,
from_concatenate: bool = False,
unpack_kwargs: bool = False,
) -> None:
super().__init__(line, column)
assert len(arg_types) == len(arg_kinds) == len(arg_names)
Expand Down Expand Up @@ -1653,6 +1655,7 @@ def __init__(
else:
self.def_extras = {}
self.type_guard = type_guard
self.unpack_kwargs = unpack_kwargs

def copy_modified(
self,
Expand All @@ -1674,6 +1677,7 @@ def copy_modified(
def_extras: Bogus[dict[str, Any]] = _dummy,
type_guard: Bogus[Type | None] = _dummy,
from_concatenate: Bogus[bool] = _dummy,
unpack_kwargs: Bogus[bool] = _dummy,
) -> CallableType:
return CallableType(
arg_types=arg_types if arg_types is not _dummy else self.arg_types,
Expand All @@ -1698,6 +1702,7 @@ def copy_modified(
from_concatenate=(
from_concatenate if from_concatenate is not _dummy else self.from_concatenate
),
unpack_kwargs=unpack_kwargs if unpack_kwargs is not _dummy else self.unpack_kwargs,
)

def var_arg(self) -> FormalArgument | None:
Expand Down Expand Up @@ -1889,6 +1894,25 @@ def expand_param_spec(
variables=[*variables, *self.variables],
)

def with_unpacked_kwargs(self) -> CallableType:
if not self.unpack_kwargs:
return self.copy_modified()
last_type = get_proper_type(self.arg_types[-1])
assert isinstance(last_type, ProperType) and isinstance(last_type, TypedDictType)
extra_kinds = [
ArgKind.ARG_NAMED if name in last_type.required_keys else ArgKind.ARG_NAMED_OPT
for name in last_type.items
]
new_arg_kinds = self.arg_kinds[:-1] + extra_kinds
new_arg_names = self.arg_names[:-1] + list(last_type.items)
new_arg_types = self.arg_types[:-1] + list(last_type.items.values())
return self.copy_modified(
arg_kinds=new_arg_kinds,
arg_names=new_arg_names,
arg_types=new_arg_types,
unpack_kwargs=False,
)

def __hash__(self) -> int:
# self.is_type_obj() will fail if self.fallback.type is a FakeInfo
if isinstance(self.fallback.type, FakeInfo):
Expand Down Expand Up @@ -1940,6 +1964,7 @@ def serialize(self) -> JsonDict:
"def_extras": dict(self.def_extras),
"type_guard": self.type_guard.serialize() if self.type_guard is not None else None,
"from_concatenate": self.from_concatenate,
"unpack_kwargs": self.unpack_kwargs,
}

@classmethod
Expand All @@ -1962,6 +1987,7 @@ def deserialize(cls, data: JsonDict) -> CallableType:
deserialize_type(data["type_guard"]) if data["type_guard"] is not None else None
),
from_concatenate=data["from_concatenate"],
unpack_kwargs=data["unpack_kwargs"],
)


Expand Down Expand Up @@ -2009,6 +2035,9 @@ def with_name(self, name: str) -> Overloaded:
def get_name(self) -> str | None:
return self._items[0].name

def with_unpacked_kwargs(self) -> Overloaded:
return Overloaded([i.with_unpacked_kwargs() for i in self.items])

def accept(self, visitor: TypeVisitor[T]) -> T:
return visitor.visit_overloaded(self)

Expand Down Expand Up @@ -2917,7 +2946,10 @@ def visit_callable_type(self, t: CallableType) -> str:
name = t.arg_names[i]
if name:
s += name + ": "
s += t.arg_types[i].accept(self)
type_str = t.arg_types[i].accept(self)
if t.arg_kinds[i] == ARG_STAR2 and t.unpack_kwargs:
type_str = f"Unpack[{type_str}]"
s += type_str
if t.arg_kinds[i].is_optional():
s += " ="

Expand Down
Loading