diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index afedbaa1..64e33d87 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -4,6 +4,11 @@ Version history This library adheres to `Semantic Versioning 2.0 `_. +**UNRELEASED** + +- Fixed basic support for intersection protocols + (`#490 `_; PR by @antonagestam) + **4.3.0** (2024-05-27) - Added support for checking against static protocols diff --git a/src/typeguard/_checkers.py b/src/typeguard/_checkers.py index 485bcb74..52ec2b82 100644 --- a/src/typeguard/_checkers.py +++ b/src/typeguard/_checkers.py @@ -654,19 +654,13 @@ def check_protocol( else: return - # Collect a set of methods and non-method attributes present in the protocol - ignored_attrs = set(dir(typing.Protocol)) | { - "__annotations__", - "__non_callable_proto_members__", - } expected_methods: dict[str, tuple[Any, Any]] = {} expected_noncallable_members: dict[str, Any] = {} - for attrname in dir(origin_type): - # Skip attributes present in typing.Protocol - if attrname in ignored_attrs: - continue + origin_annotations = typing.get_type_hints(origin_type) + + for attrname in typing_extensions.get_protocol_members(origin_type): + member = getattr(origin_type, attrname, None) - member = getattr(origin_type, attrname) if callable(member): signature = inspect.signature(member) argtypes = [ @@ -681,10 +675,10 @@ def check_protocol( ) expected_methods[attrname] = argtypes, return_annotation else: - expected_noncallable_members[attrname] = member - - for attrname, annotation in typing.get_type_hints(origin_type).items(): - expected_noncallable_members[attrname] = annotation + try: + expected_noncallable_members[attrname] = origin_annotations[attrname] + except KeyError: + expected_noncallable_members[attrname] = member subject_annotations = typing.get_type_hints(subject) diff --git a/tests/test_checkers.py b/tests/test_checkers.py index f8b21d69..d9237a96 100644 --- a/tests/test_checkers.py +++ b/tests/test_checkers.py @@ -16,14 +16,17 @@ Dict, ForwardRef, FrozenSet, + Iterable, Iterator, List, Literal, Mapping, MutableMapping, Optional, + Protocol, Sequence, Set, + Sized, TextIO, Tuple, Type, @@ -995,6 +998,86 @@ def test_text_real_file(self, tmp_path: Path): check_type(f, TextIO) +class TestIntersectingProtocol: + SIT = TypeVar("SIT", covariant=True) + + class SizedIterable( + Sized, + Iterable[SIT], + Protocol[SIT], + ): ... + + @pytest.mark.parametrize( + "subject, predicate_type", + ( + pytest.param( + (), + SizedIterable, + id="empty_tuple_unspecialized", + ), + pytest.param( + range(2), + SizedIterable, + id="range", + ), + pytest.param( + (), + SizedIterable[int], + id="empty_tuple_int_specialized", + ), + pytest.param( + (1, 2, 3), + SizedIterable[int], + id="tuple_int_specialized", + ), + pytest.param( + ("1", "2", "3"), + SizedIterable[str], + id="tuple_str_specialized", + ), + ), + ) + def test_valid_member_passes(self, subject: object, predicate_type: type) -> None: + for _ in range(2): # Makes sure that the cache is also exercised + check_type(subject, predicate_type) + + xfail_nested_protocol_checks = pytest.mark.xfail( + reason="false negative due to missing support for nested protocol checks", + ) + + @pytest.mark.parametrize( + "subject, predicate_type", + ( + pytest.param( + (1 for _ in ()), + SizedIterable, + id="generator", + ), + pytest.param( + range(2), + SizedIterable[str], + marks=xfail_nested_protocol_checks, + id="range_str_specialized", + ), + pytest.param( + (1, 2, 3), + SizedIterable[str], + marks=xfail_nested_protocol_checks, + id="int_tuple_str_specialized", + ), + pytest.param( + ("1", "2", "3"), + SizedIterable[int], + marks=xfail_nested_protocol_checks, + id="str_tuple_int_specialized", + ), + ), + ) + def test_raises_for_non_member(self, subject: object, predicate_type: type) -> None: + with pytest.raises(TypeCheckError): + check_type(subject, predicate_type) + + @pytest.mark.parametrize( "instantiate, annotation", [