Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ Version history
This library adheres to
`Semantic Versioning 2.0 <https://semver.org/#semantic-versioning-200>`_.

**UNRELEASED**

- Fixed basic support for intersection protocols
(`#490 <https://github.com/agronholm/typeguard/pull/490>`_; PR by @antonagestam)

**4.3.0** (2024-05-27)

- Added support for checking against static protocols
Expand Down
22 changes: 8 additions & 14 deletions src/typeguard/_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,19 +654,13 @@ def check_protocol(
else:
return

# Collect a set of methods and non-method attributes present in the protocol
ignored_attrs = set(dir(typing.Protocol)) | {
"__annotations__",
"__non_callable_proto_members__",
}
expected_methods: dict[str, tuple[Any, Any]] = {}
expected_noncallable_members: dict[str, Any] = {}
for attrname in dir(origin_type):
# Skip attributes present in typing.Protocol
if attrname in ignored_attrs:
continue
origin_annotations = typing.get_type_hints(origin_type)

for attrname in typing_extensions.get_protocol_members(origin_type):
member = getattr(origin_type, attrname, None)

member = getattr(origin_type, attrname)
if callable(member):
signature = inspect.signature(member)
argtypes = [
Expand All @@ -681,10 +675,10 @@ def check_protocol(
)
expected_methods[attrname] = argtypes, return_annotation
else:
expected_noncallable_members[attrname] = member

for attrname, annotation in typing.get_type_hints(origin_type).items():
expected_noncallable_members[attrname] = annotation
try:
expected_noncallable_members[attrname] = origin_annotations[attrname]
except KeyError:
expected_noncallable_members[attrname] = member

subject_annotations = typing.get_type_hints(subject)

Expand Down
83 changes: 83 additions & 0 deletions tests/test_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
Dict,
ForwardRef,
FrozenSet,
Iterable,
Iterator,
List,
Literal,
Mapping,
MutableMapping,
Optional,
Protocol,
Sequence,
Set,
Sized,
TextIO,
Tuple,
Type,
Expand Down Expand Up @@ -995,6 +998,86 @@ def test_text_real_file(self, tmp_path: Path):
check_type(f, TextIO)


class TestIntersectingProtocol:
SIT = TypeVar("SIT", covariant=True)

class SizedIterable(
Sized,
Iterable[SIT],
Protocol[SIT],
): ...

@pytest.mark.parametrize(
"subject, predicate_type",
(
pytest.param(
(),
SizedIterable,
id="empty_tuple_unspecialized",
),
pytest.param(
range(2),
SizedIterable,
id="range",
),
pytest.param(
(),
SizedIterable[int],
id="empty_tuple_int_specialized",
),
pytest.param(
(1, 2, 3),
SizedIterable[int],
id="tuple_int_specialized",
),
pytest.param(
("1", "2", "3"),
SizedIterable[str],
id="tuple_str_specialized",
),
),
)
def test_valid_member_passes(self, subject: object, predicate_type: type) -> None:
for _ in range(2): # Makes sure that the cache is also exercised
check_type(subject, predicate_type)

xfail_nested_protocol_checks = pytest.mark.xfail(
reason="false negative due to missing support for nested protocol checks",
)

@pytest.mark.parametrize(
"subject, predicate_type",
(
pytest.param(
(1 for _ in ()),
SizedIterable,
id="generator",
),
pytest.param(
range(2),
SizedIterable[str],
marks=xfail_nested_protocol_checks,
id="range_str_specialized",
),
pytest.param(
(1, 2, 3),
SizedIterable[str],
marks=xfail_nested_protocol_checks,
id="int_tuple_str_specialized",
),
pytest.param(
("1", "2", "3"),
SizedIterable[int],
marks=xfail_nested_protocol_checks,
id="str_tuple_int_specialized",
),
),
)
def test_raises_for_non_member(self, subject: object, predicate_type: type) -> None:
with pytest.raises(TypeCheckError):
check_type(subject, predicate_type)


@pytest.mark.parametrize(
"instantiate, annotation",
[
Expand Down