Skip to content

Commit 4cba502

Browse files
authored
add typing to aiokafka/coordinator/* (#1006)
* add typing to aiokafka/record/* * add some annotations to tests/record * fix almost all errors * test w/o protocols * Revert "test w/o protocols" This reverts commit 7fa1efa. * use TypeIs * use dataclass * remove timestamp/timestamp_type from cython DefaultRecord * sync cython stubs with code * simplify types * add typing to aiokafka/coordinator/* * fix review * fix format * fix review * fix type errors * fix review * fix review * assert consumer is not None * fix review (continue is consumer is None)
1 parent 14aa358 commit 4cba502

File tree

11 files changed

+382
-179
lines changed

11 files changed

+382
-179
lines changed

Makefile

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ DOCKER_IMAGE=aiolibs/kafka:$(SCALA_VERSION)_$(KAFKA_VERSION)
77
DIFF_BRANCH=origin/master
88
FORMATTED_AREAS=\
99
aiokafka/codec.py \
10+
aiokafka/coordinator/ \
1011
aiokafka/errors.py \
1112
aiokafka/helpers.py \
1213
aiokafka/structs.py \
@@ -17,6 +18,7 @@ FORMATTED_AREAS=\
1718
tests/test_helpers.py \
1819
tests/test_protocol.py \
1920
tests/test_protocol_object_conversion.py \
21+
tests/coordinator/ \
2022
tests/record/
2123

2224
.PHONY: setup

aiokafka/cluster.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import threading
55
import time
66
from concurrent.futures import Future
7+
from typing import Optional, Set
78

89
from aiokafka import errors as Errors
910
from aiokafka.conn import collect_hosts
@@ -103,7 +104,7 @@ def broker_metadata(self, broker_id):
103104
or self._coordinator_brokers.get(broker_id)
104105
)
105106

106-
def partitions_for_topic(self, topic):
107+
def partitions_for_topic(self, topic: str) -> Optional[Set[int]]:
107108
"""Return set of all partitions for topic (whether available or not)
108109
109110
Arguments:

aiokafka/coordinator/assignors/abstract.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,33 @@
11
import abc
22
import logging
3+
from typing import Dict, Iterable, Mapping
4+
5+
from aiokafka.cluster import ClusterMetadata
6+
from aiokafka.coordinator.protocol import (
7+
ConsumerProtocolMemberAssignment,
8+
ConsumerProtocolMemberMetadata,
9+
)
310

411
log = logging.getLogger(__name__)
512

613

7-
class AbstractPartitionAssignor:
14+
class AbstractPartitionAssignor(abc.ABC):
815
"""Abstract assignor implementation which does some common grunt work (in particular
916
collecting partition counts which are always needed in assignors).
1017
"""
1118

12-
@abc.abstractproperty
13-
def name(self):
19+
@property
20+
@abc.abstractmethod
21+
def name(self) -> str:
1422
""".name should be a string identifying the assignor"""
1523

24+
@classmethod
1625
@abc.abstractmethod
17-
def assign(self, cluster, members):
26+
def assign(
27+
cls,
28+
cluster: ClusterMetadata,
29+
members: Mapping[str, ConsumerProtocolMemberMetadata],
30+
) -> Dict[str, ConsumerProtocolMemberAssignment]:
1831
"""Perform group assignment given cluster metadata and member subscriptions
1932
2033
Arguments:
@@ -26,8 +39,9 @@ def assign(self, cluster, members):
2639
dict: {member_id: MemberAssignment}
2740
"""
2841

42+
@classmethod
2943
@abc.abstractmethod
30-
def metadata(self, topics):
44+
def metadata(cls, topics: Iterable[str]) -> ConsumerProtocolMemberMetadata:
3145
"""Generate ProtocolMetadata to be submitted via JoinGroupRequest.
3246
3347
Arguments:
@@ -37,8 +51,9 @@ def metadata(self, topics):
3751
MemberMetadata struct
3852
"""
3953

54+
@classmethod
4055
@abc.abstractmethod
41-
def on_assignment(self, assignment):
56+
def on_assignment(cls, assignment: ConsumerProtocolMemberAssignment) -> None:
4257
"""Callback that runs on each assignment.
4358
4459
This method can be used to update internal state, if any, of the

aiokafka/coordinator/assignors/range.py

+18-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import collections
22
import logging
3+
from typing import Dict, Iterable, List, Mapping
34

5+
from aiokafka.cluster import ClusterMetadata
46
from aiokafka.coordinator.assignors.abstract import AbstractPartitionAssignor
57
from aiokafka.coordinator.protocol import (
68
ConsumerProtocolMemberAssignment,
@@ -32,45 +34,49 @@ class RangePartitionAssignor(AbstractPartitionAssignor):
3234
version = 0
3335

3436
@classmethod
35-
def assign(cls, cluster, member_metadata):
36-
consumers_per_topic = collections.defaultdict(list)
37-
for member, metadata in member_metadata.items():
37+
def assign(
38+
cls,
39+
cluster: ClusterMetadata,
40+
members: Mapping[str, ConsumerProtocolMemberMetadata],
41+
) -> Dict[str, ConsumerProtocolMemberAssignment]:
42+
consumers_per_topic: Dict[str, List[str]] = collections.defaultdict(list)
43+
for member, metadata in members.items():
3844
for topic in metadata.subscription:
3945
consumers_per_topic[topic].append(member)
4046

4147
# construct {member_id: {topic: [partition, ...]}}
42-
assignment = collections.defaultdict(dict)
48+
assignment: Dict[str, Dict[str, List[int]]] = collections.defaultdict(dict)
4349

4450
for topic, consumers_for_topic in consumers_per_topic.items():
4551
partitions = cluster.partitions_for_topic(topic)
4652
if partitions is None:
4753
log.warning("No partition metadata for topic %s", topic)
4854
continue
49-
partitions = sorted(partitions)
55+
partitions_list = sorted(partitions)
5056
consumers_for_topic.sort()
5157

52-
partitions_per_consumer = len(partitions) // len(consumers_for_topic)
53-
consumers_with_extra = len(partitions) % len(consumers_for_topic)
58+
partitions_per_consumer = len(partitions_list) // len(consumers_for_topic)
59+
consumers_with_extra = len(partitions_list) % len(consumers_for_topic)
5460

5561
for i, member in enumerate(consumers_for_topic):
5662
start = partitions_per_consumer * i
5763
start += min(i, consumers_with_extra)
5864
length = partitions_per_consumer
5965
if not i + 1 > consumers_with_extra:
6066
length += 1
61-
assignment[member][topic] = partitions[start : start + length]
67+
assignment[member][topic] = partitions_list[start : start + length]
6268

63-
protocol_assignment = {}
64-
for member_id in member_metadata:
69+
protocol_assignment: Dict[str, ConsumerProtocolMemberAssignment] = {}
70+
for member_id in members:
6571
protocol_assignment[member_id] = ConsumerProtocolMemberAssignment(
6672
cls.version, sorted(assignment[member_id].items()), b""
6773
)
6874
return protocol_assignment
6975

7076
@classmethod
71-
def metadata(cls, topics):
77+
def metadata(cls, topics: Iterable[str]) -> ConsumerProtocolMemberMetadata:
7278
return ConsumerProtocolMemberMetadata(cls.version, list(topics), b"")
7379

7480
@classmethod
75-
def on_assignment(cls, assignment):
81+
def on_assignment(cls, assignment: ConsumerProtocolMemberAssignment) -> None:
7682
pass

aiokafka/coordinator/assignors/roundrobin.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import collections
22
import itertools
33
import logging
4+
from typing import Dict, Iterable, List, Mapping
45

6+
from aiokafka.cluster import ClusterMetadata
57
from aiokafka.coordinator.assignors.abstract import AbstractPartitionAssignor
68
from aiokafka.coordinator.protocol import (
79
ConsumerProtocolMemberAssignment,
@@ -49,12 +51,16 @@ class RoundRobinPartitionAssignor(AbstractPartitionAssignor):
4951
version = 0
5052

5153
@classmethod
52-
def assign(cls, cluster, member_metadata):
54+
def assign(
55+
cls,
56+
cluster: ClusterMetadata,
57+
members: Mapping[str, ConsumerProtocolMemberMetadata],
58+
) -> Dict[str, ConsumerProtocolMemberAssignment]:
5359
all_topics = set()
54-
for metadata in member_metadata.values():
60+
for metadata in members.values():
5561
all_topics.update(metadata.subscription)
5662

57-
all_topic_partitions = []
63+
all_topic_partitions: List[TopicPartition] = []
5864
for topic in all_topics:
5965
partitions = cluster.partitions_for_topic(topic)
6066
if partitions is None:
@@ -66,31 +72,33 @@ def assign(cls, cluster, member_metadata):
6672
all_topic_partitions.sort()
6773

6874
# construct {member_id: {topic: [partition, ...]}}
69-
assignment = collections.defaultdict(lambda: collections.defaultdict(list))
75+
assignment: Dict[str, Dict[str, List[int]]] = collections.defaultdict(
76+
lambda: collections.defaultdict(list)
77+
)
7078

71-
member_iter = itertools.cycle(sorted(member_metadata.keys()))
79+
member_iter = itertools.cycle(sorted(members.keys()))
7280
for partition in all_topic_partitions:
7381
member_id = next(member_iter)
7482

7583
# Because we constructed all_topic_partitions from the set of
7684
# member subscribed topics, we should be safe assuming that
7785
# each topic in all_topic_partitions is in at least one member
7886
# subscription; otherwise this could yield an infinite loop
79-
while partition.topic not in member_metadata[member_id].subscription:
87+
while partition.topic not in members[member_id].subscription:
8088
member_id = next(member_iter)
8189
assignment[member_id][partition.topic].append(partition.partition)
8290

8391
protocol_assignment = {}
84-
for member_id in member_metadata:
92+
for member_id in members:
8593
protocol_assignment[member_id] = ConsumerProtocolMemberAssignment(
8694
cls.version, sorted(assignment[member_id].items()), b""
8795
)
8896
return protocol_assignment
8997

9098
@classmethod
91-
def metadata(cls, topics):
99+
def metadata(cls, topics: Iterable[str]) -> ConsumerProtocolMemberMetadata:
92100
return ConsumerProtocolMemberMetadata(cls.version, list(topics), b"")
93101

94102
@classmethod
95-
def on_assignment(cls, assignment):
103+
def on_assignment(cls, assignment: ConsumerProtocolMemberAssignment) -> None:
96104
pass

aiokafka/coordinator/assignors/sticky/partition_movements.py

+33-16
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
import logging
2-
from collections import defaultdict, namedtuple
2+
from collections import defaultdict
33
from copy import deepcopy
4+
from typing import Any, Dict, List, NamedTuple, Sequence, Set, Tuple
5+
6+
from aiokafka.structs import TopicPartition
47

58
log = logging.getLogger(__name__)
69

710

8-
ConsumerPair = namedtuple("ConsumerPair", ["src_member_id", "dst_member_id"])
11+
class ConsumerPair(NamedTuple):
12+
src_member_id: str
13+
dst_member_id: str
14+
15+
916
"""
1017
Represents a pair of Kafka consumer ids involved in a partition reassignment.
1118
Each ConsumerPair corresponds to a particular partition or topic, indicates that the
@@ -16,7 +23,7 @@
1623
"""
1724

1825

19-
def is_sublist(source, target):
26+
def is_sublist(source: Sequence[Any], target: Sequence[Any]) -> bool:
2027
"""Checks if one list is a sublist of another.
2128
2229
Arguments:
@@ -40,11 +47,13 @@ class PartitionMovements:
4047
form a ConsumerPair object) for each partition.
4148
"""
4249

43-
def __init__(self):
44-
self.partition_movements_by_topic = defaultdict(lambda: defaultdict(set))
45-
self.partition_movements = {}
50+
def __init__(self) -> None:
51+
self.partition_movements_by_topic: Dict[str, Dict[ConsumerPair, Set[TopicPartition]]] = defaultdict(lambda: defaultdict(set)) # fmt: skip # noqa: E501
52+
self.partition_movements: Dict[TopicPartition, ConsumerPair] = {}
4653

47-
def move_partition(self, partition, old_consumer, new_consumer):
54+
def move_partition(
55+
self, partition: TopicPartition, old_consumer: str, new_consumer: str
56+
) -> None:
4857
pair = ConsumerPair(src_member_id=old_consumer, dst_member_id=new_consumer)
4958
if partition in self.partition_movements:
5059
# this partition has previously moved
@@ -62,7 +71,9 @@ def move_partition(self, partition, old_consumer, new_consumer):
6271
else:
6372
self._add_partition_movement_record(partition, pair)
6473

65-
def get_partition_to_be_moved(self, partition, old_consumer, new_consumer):
74+
def get_partition_to_be_moved(
75+
self, partition: TopicPartition, old_consumer: str, new_consumer: str
76+
) -> TopicPartition:
6677
if partition.topic not in self.partition_movements_by_topic:
6778
return partition
6879
if partition in self.partition_movements:
@@ -79,7 +90,7 @@ def get_partition_to_be_moved(self, partition, old_consumer, new_consumer):
7990
iter(self.partition_movements_by_topic[partition.topic][reverse_pair])
8091
)
8192

82-
def are_sticky(self):
93+
def are_sticky(self) -> bool:
8394
for topic, movements in self.partition_movements_by_topic.items():
8495
movement_pairs = set(movements.keys())
8596
if self._has_cycles(movement_pairs):
@@ -93,7 +104,9 @@ def are_sticky(self):
93104
return False
94105
return True
95106

96-
def _remove_movement_record_of_partition(self, partition):
107+
def _remove_movement_record_of_partition(
108+
self, partition: TopicPartition
109+
) -> ConsumerPair:
97110
pair = self.partition_movements[partition]
98111
del self.partition_movements[partition]
99112

@@ -105,16 +118,18 @@ def _remove_movement_record_of_partition(self, partition):
105118

106119
return pair
107120

108-
def _add_partition_movement_record(self, partition, pair):
121+
def _add_partition_movement_record(
122+
self, partition: TopicPartition, pair: ConsumerPair
123+
) -> None:
109124
self.partition_movements[partition] = pair
110125
self.partition_movements_by_topic[partition.topic][pair].add(partition)
111126

112-
def _has_cycles(self, consumer_pairs):
113-
cycles = set()
127+
def _has_cycles(self, consumer_pairs: Set[ConsumerPair]) -> bool:
128+
cycles: Set[Tuple[str, ...]] = set()
114129
for pair in consumer_pairs:
115130
reduced_pairs = deepcopy(consumer_pairs)
116131
reduced_pairs.remove(pair)
117-
path = [pair.src_member_id]
132+
path: List[str] = [pair.src_member_id]
118133
if self._is_linked(
119134
pair.dst_member_id, pair.src_member_id, reduced_pairs, path
120135
) and not self._is_subcycle(path, cycles):
@@ -132,7 +147,7 @@ def _has_cycles(self, consumer_pairs):
132147
)
133148

134149
@staticmethod
135-
def _is_subcycle(cycle, cycles):
150+
def _is_subcycle(cycle: List[str], cycles: Set[Tuple[str, ...]]) -> bool:
136151
super_cycle = deepcopy(cycle)
137152
super_cycle = super_cycle[:-1]
138153
super_cycle.extend(cycle)
@@ -141,7 +156,9 @@ def _is_subcycle(cycle, cycles):
141156
return True
142157
return False
143158

144-
def _is_linked(self, src, dst, pairs, current_path):
159+
def _is_linked(
160+
self, src: str, dst: str, pairs: Set[ConsumerPair], current_path: List[str]
161+
) -> bool:
145162
if src == dst:
146163
return False
147164
if not pairs:

0 commit comments

Comments
 (0)