Skip to content

Commit 241d120

Browse files
committed
Added support for checking against static protocols
Fixes #457.
1 parent d539190 commit 241d120

File tree

5 files changed

+199
-48
lines changed

5 files changed

+199
-48
lines changed

docs/features.rst

+15-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ The following type checks are not yet supported in Typeguard:
2323
* Types of values assigned to global or nonlocal variables
2424
* Stubs defined with :func:`@overload <typing.overload>` (the implementation is checked
2525
if instrumented)
26-
* ``yield_from`` statements in generator functions
26+
* ``yield from`` statements in generator functions
2727
* ``ParamSpec`` and ``Concatenate`` are currently ignored
2828
* Types where they are shadowed by arguments with the same name (e.g.
2929
``def foo(x: type, type: str): ...``)
@@ -58,6 +58,20 @@ target function should be switched to a new one. To work around this limitation,
5858
place :func:`@typechecked <typechecked>` at the bottom of the decorator stack, or use
5959
the import hook instead.
6060

61+
Protocol checking
62+
+++++++++++++++++
63+
64+
As of version 4.3.0, Typeguard can check instances and classes against Protocols,
65+
regardless of whether they were annotated with :decorator:`typing.runtime_checkable`.
66+
67+
There are several limitations on the checks performed, however:
68+
69+
* For non-callable members, only presence is checked for; no type compatibility checks
70+
are performed
71+
* For methods, only the number of positional arguments are checked against, so any added
72+
keyword-only arguments without defaults don't currently trip the checker
73+
* Likewise, argument types are not checked for compatibility
74+
6175
Special considerations for ``if TYPE_CHECKING:``
6276
------------------------------------------------
6377

docs/versionhistory.rst

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ This library adheres to
66

77
**UNRELEASED**
88

9+
- Added support for checking against static protocols
910
- Fixed some compatibility problems when running on Python 3.13
1011
(`#460 <https://github.com/agronholm/typeguard/issues/460>`_; PR by @JelleZijlstra)
1112
- Fixed test suite incompatibility with pytest 8.2

src/typeguard/_checkers.py

+92-11
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
Union,
3333
)
3434
from unittest.mock import Mock
35+
from weakref import WeakKeyDictionary
3536

3637
try:
3738
import typing_extensions
@@ -88,6 +89,9 @@
8889
if sys.version_info >= (3, 9):
8990
generic_alias_types += (types.GenericAlias,)
9091

92+
protocol_check_cache: WeakKeyDictionary[
93+
type[Any], dict[type[Any], TypeCheckError | None]
94+
] = WeakKeyDictionary()
9195

9296
# Sentinel
9397
_missing = object()
@@ -650,19 +654,96 @@ def check_protocol(
650654
args: tuple[Any, ...],
651655
memo: TypeCheckMemo,
652656
) -> None:
653-
# TODO: implement proper compatibility checking and support non-runtime protocols
654-
if getattr(origin_type, "_is_runtime_protocol", False):
655-
if not isinstance(value, origin_type):
656-
raise TypeCheckError(
657-
f"is not compatible with the {origin_type.__qualname__} protocol"
657+
subject: type[Any] = value if isclass(value) else type(value)
658+
659+
if subject in protocol_check_cache:
660+
result_map = protocol_check_cache[subject]
661+
if origin_type in result_map:
662+
if exc := result_map[origin_type]:
663+
raise exc
664+
else:
665+
return
666+
667+
# Collect a set of methods and non-method attributes present in the protocol
668+
ignored_attrs = set(dir(typing.Protocol)) | {
669+
"__annotations__",
670+
"__non_callable_proto_members__",
671+
}
672+
expected_methods: dict[str, tuple[Any, Any]] = {}
673+
expected_noncallable_members: dict[str, Any] = {}
674+
for attrname in dir(origin_type):
675+
# Skip attributes present in typing.Protocol
676+
if attrname in ignored_attrs:
677+
continue
678+
679+
member = getattr(origin_type, attrname)
680+
if callable(member):
681+
signature = inspect.signature(member)
682+
argtypes = [
683+
(p.annotation if p.annotation is not Parameter.empty else Any)
684+
for p in signature.parameters.values()
685+
if p.kind is not Parameter.KEYWORD_ONLY
686+
] or Ellipsis
687+
return_annotation = (
688+
signature.return_annotation
689+
if signature.return_annotation is not Parameter.empty
690+
else Any
658691
)
692+
expected_methods[attrname] = argtypes, return_annotation
693+
else:
694+
expected_noncallable_members[attrname] = member
695+
696+
for attrname, annotation in typing.get_type_hints(origin_type).items():
697+
expected_noncallable_members[attrname] = annotation
698+
699+
subject_annotations = typing.get_type_hints(subject)
700+
701+
# Check that all required methods are present and their signatures are compatible
702+
result_map = protocol_check_cache.setdefault(subject, {})
703+
try:
704+
for attrname, callable_args in expected_methods.items():
705+
try:
706+
method = getattr(subject, attrname)
707+
except AttributeError:
708+
if attrname in subject_annotations:
709+
raise TypeCheckError(
710+
f"is not compatible with the {origin_type.__qualname__} protocol "
711+
f"because its {attrname!r} attribute is not a method"
712+
) from None
713+
else:
714+
raise TypeCheckError(
715+
f"is not compatible with the {origin_type.__qualname__} protocol "
716+
f"because it has no method named {attrname!r}"
717+
) from None
718+
719+
if not callable(method):
720+
raise TypeCheckError(
721+
f"is not compatible with the {origin_type.__qualname__} protocol "
722+
f"because its {attrname!r} attribute is not a callable"
723+
)
724+
725+
# TODO: raise exception on added keyword-only arguments without defaults
726+
try:
727+
check_callable(method, Callable, callable_args, memo)
728+
except TypeCheckError as exc:
729+
raise TypeCheckError(
730+
f"is not compatible with the {origin_type.__qualname__} protocol "
731+
f"because its {attrname!r} method {exc}"
732+
) from None
733+
734+
# Check that all required non-callable members are present
735+
for attrname in expected_noncallable_members:
736+
# TODO: implement assignability checks for non-callable members
737+
if attrname not in subject_annotations and not hasattr(subject, attrname):
738+
raise TypeCheckError(
739+
f"is not compatible with the {origin_type.__qualname__} protocol "
740+
f"because it has no attribute named {attrname!r}"
741+
)
742+
except TypeCheckError as exc:
743+
result_map[origin_type] = exc
744+
raise
659745
else:
660-
warnings.warn(
661-
f"Typeguard cannot check the {origin_type.__qualname__} protocol because "
662-
f"it is a non-runtime protocol. If you would like to type check this "
663-
f"protocol, please use @typing.runtime_checkable",
664-
stacklevel=get_stacklevel(),
665-
)
746+
result_map[origin_type] = None
666747

667748

668749
def check_byteslike(

tests/__init__.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,13 @@ def method(self, a: int) -> None:
4747

4848

4949
class StaticProtocol(Protocol):
50-
def meth(self) -> None: ...
50+
member: int
51+
52+
def meth(self, x: str) -> None: ...
5153

5254

5355
@runtime_checkable
5456
class RuntimeProtocol(Protocol):
5557
member: int
5658

57-
def meth(self) -> None: ...
59+
def meth(self, x: str) -> None: ...

tests/test_checkers.py

+87-34
Original file line numberDiff line numberDiff line change
@@ -995,66 +995,119 @@ def test_text_real_file(self, tmp_path: Path):
995995
check_type(f, TextIO)
996996

997997

998+
@pytest.mark.parametrize(
999+
"instantiate, annotation",
1000+
[
1001+
pytest.param(True, RuntimeProtocol, id="instance_runtime"),
1002+
pytest.param(False, Type[RuntimeProtocol], id="class_runtime"),
1003+
pytest.param(True, StaticProtocol, id="instance_static"),
1004+
pytest.param(False, Type[StaticProtocol], id="class_static"),
1005+
],
1006+
)
9981007
class TestProtocol:
999-
def test_protocol(self):
1008+
def test_member_defaultval(self, instantiate, annotation):
10001009
class Foo:
10011010
member = 1
10021011

1003-
def meth(self) -> None:
1012+
def meth(self, x: str) -> None:
10041013
pass
10051014

1006-
check_type(Foo(), RuntimeProtocol)
1007-
check_type(Foo, Type[RuntimeProtocol])
1015+
subject = Foo() if instantiate else Foo
1016+
for _ in range(2): # Makes sure that the cache is also exercised
1017+
check_type(subject, annotation)
10081018

1009-
def test_protocol_warns_on_static(self):
1019+
def test_member_annotation(self, instantiate, annotation):
10101020
class Foo:
1011-
member = 1
1021+
member: int
10121022

1013-
def meth(self) -> None:
1023+
def meth(self, x: str) -> None:
10141024
pass
10151025

1016-
with pytest.warns(
1017-
UserWarning, match=r"Typeguard cannot check the StaticProtocol protocol.*"
1018-
) as warning:
1019-
check_type(Foo(), StaticProtocol)
1026+
subject = Foo() if instantiate else Foo
1027+
for _ in range(2):
1028+
check_type(subject, annotation)
10201029

1021-
assert warning.list[0].filename == __file__
1030+
def test_attribute_missing(self, instantiate, annotation):
1031+
class Foo:
1032+
val = 1
10221033

1023-
with pytest.warns(
1024-
UserWarning, match=r"Typeguard cannot check the StaticProtocol protocol.*"
1025-
) as warning:
1026-
check_type(Foo, Type[StaticProtocol])
1034+
def meth(self, x: str) -> None:
1035+
pass
10271036

1028-
assert warning.list[0].filename == __file__
1037+
clsname = f"{__name__}.TestProtocol.test_attribute_missing.<locals>.Foo"
1038+
subject = Foo() if instantiate else Foo
1039+
for _ in range(2):
1040+
pytest.raises(TypeCheckError, check_type, subject, annotation).match(
1041+
f"{clsname} is not compatible with the (Runtime|Static)Protocol "
1042+
f"protocol because it has no attribute named 'member'"
1043+
)
10291044

1030-
def test_fail_non_method_members(self):
1045+
def test_method_missing(self, instantiate, annotation):
10311046
class Foo:
1032-
val = 1
1047+
member: int
10331048

1034-
def meth(self) -> None:
1035-
pass
1049+
pattern = (
1050+
f"{__name__}.TestProtocol.test_method_missing.<locals>.Foo is not "
1051+
f"compatible with the (Runtime|Static)Protocol protocol because it has no "
1052+
f"method named 'meth'"
1053+
)
1054+
subject = Foo() if instantiate else Foo
1055+
for _ in range(2):
1056+
pytest.raises(TypeCheckError, check_type, subject, annotation).match(
1057+
pattern
1058+
)
1059+
1060+
def test_attribute_is_not_method_1(self, instantiate, annotation):
1061+
class Foo:
1062+
member: int
1063+
meth: str
10361064

1037-
clsname = f"{__name__}.TestProtocol.test_fail_non_method_members.<locals>.Foo"
1038-
pytest.raises(TypeCheckError, check_type, Foo(), RuntimeProtocol).match(
1039-
f"{clsname} is not compatible with the RuntimeProtocol protocol"
1065+
pattern = (
1066+
f"{__name__}.TestProtocol.test_attribute_is_not_method_1.<locals>.Foo is "
1067+
f"not compatible with the (Runtime|Static)Protocol protocol because its "
1068+
f"'meth' attribute is not a method"
10401069
)
1041-
pytest.raises(TypeCheckError, check_type, Foo, Type[RuntimeProtocol]).match(
1042-
f"class {clsname} is not compatible with the RuntimeProtocol protocol"
1070+
subject = Foo() if instantiate else Foo
1071+
for _ in range(2):
1072+
pytest.raises(TypeCheckError, check_type, subject, annotation).match(
1073+
pattern
1074+
)
1075+
1076+
def test_attribute_is_not_method_2(self, instantiate, annotation):
1077+
class Foo:
1078+
member: int
1079+
meth = "foo"
1080+
1081+
pattern = (
1082+
f"{__name__}.TestProtocol.test_attribute_is_not_method_2.<locals>.Foo is "
1083+
f"not compatible with the (Runtime|Static)Protocol protocol because its "
1084+
f"'meth' attribute is not a callable"
10431085
)
1086+
subject = Foo() if instantiate else Foo
1087+
for _ in range(2):
1088+
pytest.raises(TypeCheckError, check_type, subject, annotation).match(
1089+
pattern
1090+
)
10441091

1045-
def test_fail(self):
1092+
def test_method_signature_mismatch(self, instantiate, annotation):
10461093
class Foo:
1047-
def meth2(self) -> None:
1094+
member: int
1095+
1096+
def meth(self, x: str, y: int) -> None:
10481097
pass
10491098

10501099
pattern = (
1051-
f"{__name__}.TestProtocol.test_fail.<locals>.Foo is not compatible with "
1052-
f"the RuntimeProtocol protocol"
1053-
)
1054-
pytest.raises(TypeCheckError, check_type, Foo(), RuntimeProtocol).match(pattern)
1055-
pytest.raises(TypeCheckError, check_type, Foo, Type[RuntimeProtocol]).match(
1056-
pattern
1100+
rf"(class )?{__name__}.TestProtocol.test_method_signature_mismatch."
1101+
rf"<locals>.Foo is not compatible with the (Runtime|Static)Protocol "
1102+
rf"protocol because its 'meth' method has too many mandatory positional "
1103+
rf"arguments in its declaration; expected 2 but 3 mandatory positional "
1104+
rf"argument\(s\) declared"
10571105
)
1106+
subject = Foo() if instantiate else Foo
1107+
for _ in range(2):
1108+
pytest.raises(TypeCheckError, check_type, subject, annotation).match(
1109+
pattern
1110+
)
10581111

10591112

10601113
class TestRecursiveType:

0 commit comments

Comments
 (0)