1
1
import unittest
2
2
import mock
3
3
import itertools
4
+ import time
5
+ from nose .tools import *
4
6
from collections import namedtuple
5
7
from kafka_influxdb .reader import kafka_reader
6
8
from kafka .client import KafkaClient
7
9
from kafka .common import ConnectionError
8
10
from kafka .common import Message
11
+ from kafka_influxdb .tests .helpers .timeout import timeout
9
12
10
13
class TestKafkaReader (unittest .TestCase ):
11
14
@@ -31,16 +34,22 @@ def sample_messages(self, payload, count):
31
34
def test_handle_read (self ):
32
35
sample_messages , extracted_messages = self .sample_messages ("hello" , 3 )
33
36
self .reader .consumer .__iter__ .return_value = sample_messages
34
- real_messages = list (self .reader .handle_read ())
35
- self .assertEquals (real_messages , extracted_messages )
37
+ received_messages = list (self .reader .handle_read ())
38
+ self .assertEquals (received_messages , extracted_messages )
36
39
37
- # TODO: Run in separate process
38
- #def test_reconnect(self):
39
- #""" In case of a connection error, the client should reconnect and
40
- #start receiving messages again without interruption """
41
- # sample_messages1, extracted_messages1 = self.sample_messages("hi", 3)
42
- # sample_messages2, extracted_messages2 = self.sample_messages("world", 3)
43
- # sample_messages = sample_messages1 + [ConnectionError] + sample_messages2
44
- # self.reader.consumer.__iter__.return_value = sample_messages
45
- # real_messages = list(self.reader.handle_read())
46
- # self.assertEquals(real_messages, extracted_messages1 + extracted_messages2)
40
+ @timeout (0.1 )
41
+ def test_reconnect (self ):
42
+ """
43
+ In case of a connection error, the client should reconnect and
44
+ start receiving messages again without interruption
45
+ """
46
+ sample_messages1 , extracted_messages1 = self .sample_messages ("hi" , 3 )
47
+ sample_messages2 , extracted_messages2 = self .sample_messages ("world" , 3 )
48
+ sample_messages = sample_messages1 + [ConnectionError ] + sample_messages2
49
+ self .reader .consumer .__iter__ .return_value = sample_messages
50
+ received_messages = list (self .receive_messages ())
51
+ self .assertEquals (received_messages , extracted_messages1 + extracted_messages2 )
52
+
53
+ def receive_messages (self ):
54
+ for message in self .reader .read ():
55
+ yield message
0 commit comments