Skip to content

Commit 4c683b1

Browse files
danielnelsonrgitzel
authored andcommitted
Rework mqtt_consumer connect/reconnect (influxdata#4846)
1 parent 2b853a6 commit 4c683b1

File tree

2 files changed

+68
-161
lines changed

2 files changed

+68
-161
lines changed

plugins/inputs/mqtt_consumer/mqtt_consumer.go

+54-69
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
package mqtt_consumer
22

33
import (
4+
"errors"
45
"fmt"
56
"log"
67
"strings"
7-
"sync"
88
"time"
99

1010
"github.com/influxdata/telegraf"
@@ -19,6 +19,14 @@ import (
1919
// 30 Seconds is the default used by paho.mqtt.golang
2020
var defaultConnectionTimeout = internal.Duration{Duration: 30 * time.Second}
2121

22+
type ConnectionState int
23+
24+
const (
25+
Disconnected ConnectionState = iota
26+
Connecting
27+
Connected
28+
)
29+
2230
type MQTTConsumer struct {
2331
Servers []string
2432
Topics []string
@@ -36,16 +44,10 @@ type MQTTConsumer struct {
3644
ClientID string `toml:"client_id"`
3745
tls.ClientConfig
3846

39-
sync.Mutex
40-
client mqtt.Client
41-
// channel of all incoming raw mqtt messages
42-
in chan mqtt.Message
43-
done chan struct{}
44-
45-
// keep the accumulator internally:
46-
acc telegraf.Accumulator
47-
48-
connected bool
47+
client mqtt.Client
48+
acc telegraf.Accumulator
49+
state ConnectionState
50+
subscribed bool
4951
}
5052

5153
var sampleConfig = `
@@ -110,22 +112,19 @@ func (m *MQTTConsumer) SetParser(parser parsers.Parser) {
110112
}
111113

112114
func (m *MQTTConsumer) Start(acc telegraf.Accumulator) error {
113-
m.Lock()
114-
defer m.Unlock()
115-
m.connected = false
115+
m.state = Disconnected
116116

117117
if m.PersistentSession && m.ClientID == "" {
118-
return fmt.Errorf("ERROR MQTT Consumer: When using persistent_session" +
119-
" = true, you MUST also set client_id")
118+
return errors.New("persistent_session requires client_id")
120119
}
121120

122121
m.acc = acc
123122
if m.QoS > 2 || m.QoS < 0 {
124-
return fmt.Errorf("MQTT Consumer, invalid QoS value: %d", m.QoS)
123+
return fmt.Errorf("qos value must be 0, 1, or 2: %d", m.QoS)
125124
}
126125

127126
if m.ConnectionTimeout.Duration < 1*time.Second {
128-
return fmt.Errorf("MQTT Consumer, invalid connection_timeout value: %s", m.ConnectionTimeout.Duration)
127+
return fmt.Errorf("connection_timeout must be greater than 1s: %s", m.ConnectionTimeout.Duration)
129128
}
130129

131130
opts, err := m.createOpts()
@@ -134,9 +133,7 @@ func (m *MQTTConsumer) Start(acc telegraf.Accumulator) error {
134133
}
135134

136135
m.client = mqtt.NewClient(opts)
137-
m.in = make(chan mqtt.Message, 1000)
138-
m.done = make(chan struct{})
139-
136+
m.state = Connecting
140137
m.connect()
141138

142139
return nil
@@ -145,80 +142,68 @@ func (m *MQTTConsumer) Start(acc telegraf.Accumulator) error {
145142
func (m *MQTTConsumer) connect() error {
146143
if token := m.client.Connect(); token.Wait() && token.Error() != nil {
147144
err := token.Error()
148-
log.Printf("D! MQTT Consumer, connection error - %v", err)
149-
145+
m.state = Disconnected
150146
return err
151147
}
152148

153-
go m.receiver()
149+
log.Printf("I! [inputs.mqtt_consumer]: connected %v", m.Servers)
150+
m.state = Connected
154151

155-
return nil
156-
}
157-
158-
func (m *MQTTConsumer) onConnect(c mqtt.Client) {
159-
log.Printf("I! MQTT Client Connected")
160-
if !m.PersistentSession || !m.connected {
152+
// Only subscribe on first connection when using persistent sessions. On
153+
// subsequent connections the subscriptions should be stored in the
154+
// session, but the proper way to do this is to check the connection
155+
// response to ensure a session was found.
156+
if !m.PersistentSession || !m.subscribed {
161157
topics := make(map[string]byte)
162158
for _, topic := range m.Topics {
163159
topics[topic] = byte(m.QoS)
164160
}
165-
subscribeToken := c.SubscribeMultiple(topics, m.recvMessage)
161+
subscribeToken := m.client.SubscribeMultiple(topics, m.recvMessage)
166162
subscribeToken.Wait()
167163
if subscribeToken.Error() != nil {
168-
m.acc.AddError(fmt.Errorf("E! MQTT Subscribe Error\ntopics: %s\nerror: %s",
164+
m.acc.AddError(fmt.Errorf("subscription error: topics: %s: %v",
169165
strings.Join(m.Topics[:], ","), subscribeToken.Error()))
170166
}
171-
m.connected = true
167+
m.subscribed = true
172168
}
173-
return
169+
170+
return nil
174171
}
175172

176173
func (m *MQTTConsumer) onConnectionLost(c mqtt.Client, err error) {
177-
m.acc.AddError(fmt.Errorf("E! MQTT Connection lost\nerror: %s\nMQTT Client will try to reconnect", err.Error()))
174+
m.acc.AddError(fmt.Errorf("connection lost: %v", err))
175+
log.Printf("D! [inputs.mqtt_consumer]: disconnected %v", m.Servers)
176+
m.state = Disconnected
178177
return
179178
}
180179

181-
// receiver() reads all incoming messages from the consumer, and parses them into
182-
// influxdb metric points.
183-
func (m *MQTTConsumer) receiver() {
184-
for {
185-
select {
186-
case <-m.done:
187-
return
188-
case msg := <-m.in:
189-
topic := msg.Topic()
190-
metrics, err := m.parser.Parse(msg.Payload())
191-
if err != nil {
192-
m.acc.AddError(fmt.Errorf("E! MQTT Parse Error\nmessage: %s\nerror: %s",
193-
string(msg.Payload()), err.Error()))
194-
}
195-
196-
for _, metric := range metrics {
197-
tags := metric.Tags()
198-
tags["topic"] = topic
199-
m.acc.AddFields(metric.Name(), metric.Fields(), tags, metric.Time())
200-
}
201-
}
180+
func (m *MQTTConsumer) recvMessage(c mqtt.Client, msg mqtt.Message) {
181+
topic := msg.Topic()
182+
metrics, err := m.parser.Parse(msg.Payload())
183+
if err != nil {
184+
m.acc.AddError(err)
202185
}
203-
}
204186

205-
func (m *MQTTConsumer) recvMessage(_ mqtt.Client, msg mqtt.Message) {
206-
m.in <- msg
187+
for _, metric := range metrics {
188+
tags := metric.Tags()
189+
tags["topic"] = topic
190+
m.acc.AddFields(metric.Name(), metric.Fields(), tags, metric.Time())
191+
}
207192
}
208193

209194
func (m *MQTTConsumer) Stop() {
210-
m.Lock()
211-
defer m.Unlock()
212-
213-
if m.connected {
214-
close(m.done)
195+
if m.state == Connected {
196+
log.Printf("D! [inputs.mqtt_consumer]: disconnecting %v", m.Servers)
215197
m.client.Disconnect(200)
216-
m.connected = false
198+
log.Printf("D! [inputs.mqtt_consumer]: disconnected %v", m.Servers)
199+
m.state = Disconnected
217200
}
218201
}
219202

220203
func (m *MQTTConsumer) Gather(acc telegraf.Accumulator) error {
221-
if !m.connected {
204+
if m.state == Disconnected {
205+
m.state = Connecting
206+
log.Printf("D! [inputs.mqtt_consumer]: connecting %v", m.Servers)
222207
m.connect()
223208
}
224209

@@ -261,7 +246,7 @@ func (m *MQTTConsumer) createOpts() (*mqtt.ClientOptions, error) {
261246
for _, server := range m.Servers {
262247
// Preserve support for host:port style servers; deprecated in Telegraf 1.4.4
263248
if !strings.Contains(server, "://") {
264-
log.Printf("W! mqtt_consumer server %q should be updated to use `scheme://host:port` format", server)
249+
log.Printf("W! [inputs.mqtt_consumer] server %q should be updated to use `scheme://host:port` format", server)
265250
if tlsCfg == nil {
266251
server = "tcp://" + server
267252
} else {
@@ -271,10 +256,9 @@ func (m *MQTTConsumer) createOpts() (*mqtt.ClientOptions, error) {
271256

272257
opts.AddBroker(server)
273258
}
274-
opts.SetAutoReconnect(true)
259+
opts.SetAutoReconnect(false)
275260
opts.SetKeepAlive(time.Second * 60)
276261
opts.SetCleanSession(!m.PersistentSession)
277-
opts.SetOnConnectHandler(m.onConnect)
278262
opts.SetConnectionLostHandler(m.onConnectionLost)
279263

280264
return opts, nil
@@ -284,6 +268,7 @@ func init() {
284268
inputs.Add("mqtt_consumer", func() telegraf.Input {
285269
return &MQTTConsumer{
286270
ConnectionTimeout: defaultConnectionTimeout,
271+
state: Disconnected,
287272
}
288273
})
289274
}

plugins/inputs/mqtt_consumer/mqtt_consumer_test.go

+14-92
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,17 @@ import (
1212
)
1313

1414
const (
15-
testMsg = "cpu_load_short,host=server01 value=23422.0 1422568543702900257\n"
16-
testMsgNeg = "cpu_load_short,host=server01 value=-23422.0 1422568543702900257\n"
17-
testMsgGraphite = "cpu.load.short.graphite 23422 1454780029"
18-
testMsgJSON = "{\"a\": 5, \"b\": {\"c\": 6}}\n"
19-
invalidMsg = "cpu_load_short,host=server01 1422568543702900257\n"
15+
testMsg = "cpu_load_short,host=server01 value=23422.0 1422568543702900257\n"
16+
invalidMsg = "cpu_load_short,host=server01 1422568543702900257\n"
2017
)
2118

22-
func newTestMQTTConsumer() (*MQTTConsumer, chan mqtt.Message) {
23-
in := make(chan mqtt.Message, 100)
19+
func newTestMQTTConsumer() *MQTTConsumer {
2420
n := &MQTTConsumer{
25-
Topics: []string{"telegraf"},
26-
Servers: []string{"localhost:1883"},
27-
in: in,
28-
done: make(chan struct{}),
29-
connected: true,
21+
Topics: []string{"telegraf"},
22+
Servers: []string{"localhost:1883"},
3023
}
3124

32-
return n, in
25+
return n
3326
}
3427

3528
// Test that default client has random ID
@@ -79,31 +72,12 @@ func TestPersistentClientIDFail(t *testing.T) {
7972
}
8073

8174
func TestRunParser(t *testing.T) {
82-
n, in := newTestMQTTConsumer()
75+
n := newTestMQTTConsumer()
8376
acc := testutil.Accumulator{}
8477
n.acc = &acc
85-
defer close(n.done)
86-
8778
n.parser, _ = parsers.NewInfluxParser()
88-
go n.receiver()
89-
in <- mqttMsg(testMsgNeg)
90-
acc.Wait(1)
91-
92-
if a := acc.NFields(); a != 1 {
93-
t.Errorf("got %v, expected %v", a, 1)
94-
}
95-
}
9679

97-
func TestRunParserNegativeNumber(t *testing.T) {
98-
n, in := newTestMQTTConsumer()
99-
acc := testutil.Accumulator{}
100-
n.acc = &acc
101-
defer close(n.done)
102-
103-
n.parser, _ = parsers.NewInfluxParser()
104-
go n.receiver()
105-
in <- mqttMsg(testMsg)
106-
acc.Wait(1)
80+
n.recvMessage(nil, mqttMsg(testMsg))
10781

10882
if a := acc.NFields(); a != 1 {
10983
t.Errorf("got %v, expected %v", a, 1)
@@ -112,84 +86,32 @@ func TestRunParserNegativeNumber(t *testing.T) {
11286

11387
// Test that the parser ignores invalid messages
11488
func TestRunParserInvalidMsg(t *testing.T) {
115-
n, in := newTestMQTTConsumer()
89+
n := newTestMQTTConsumer()
11690
acc := testutil.Accumulator{}
11791
n.acc = &acc
118-
defer close(n.done)
119-
12092
n.parser, _ = parsers.NewInfluxParser()
121-
go n.receiver()
122-
in <- mqttMsg(invalidMsg)
123-
acc.WaitError(1)
93+
94+
n.recvMessage(nil, mqttMsg(invalidMsg))
12495

12596
if a := acc.NFields(); a != 0 {
12697
t.Errorf("got %v, expected %v", a, 0)
12798
}
128-
assert.Contains(t, acc.Errors[0].Error(), "MQTT Parse Error")
99+
assert.Len(t, acc.Errors, 1)
129100
}
130101

131102
// Test that the parser parses line format messages into metrics
132103
func TestRunParserAndGather(t *testing.T) {
133-
n, in := newTestMQTTConsumer()
104+
n := newTestMQTTConsumer()
134105
acc := testutil.Accumulator{}
135106
n.acc = &acc
136-
137-
defer close(n.done)
138-
139107
n.parser, _ = parsers.NewInfluxParser()
140-
go n.receiver()
141-
in <- mqttMsg(testMsg)
142-
acc.Wait(1)
143108

144-
n.Gather(&acc)
109+
n.recvMessage(nil, mqttMsg(testMsg))
145110

146111
acc.AssertContainsFields(t, "cpu_load_short",
147112
map[string]interface{}{"value": float64(23422)})
148113
}
149114

150-
// Test that the parser parses graphite format messages into metrics
151-
func TestRunParserAndGatherGraphite(t *testing.T) {
152-
n, in := newTestMQTTConsumer()
153-
acc := testutil.Accumulator{}
154-
n.acc = &acc
155-
defer close(n.done)
156-
157-
n.parser, _ = parsers.NewGraphiteParser("_", []string{}, nil)
158-
go n.receiver()
159-
in <- mqttMsg(testMsgGraphite)
160-
161-
n.Gather(&acc)
162-
acc.Wait(1)
163-
164-
acc.AssertContainsFields(t, "cpu_load_short_graphite",
165-
map[string]interface{}{"value": float64(23422)})
166-
}
167-
168-
// Test that the parser parses json format messages into metrics
169-
func TestRunParserAndGatherJSON(t *testing.T) {
170-
n, in := newTestMQTTConsumer()
171-
acc := testutil.Accumulator{}
172-
n.acc = &acc
173-
defer close(n.done)
174-
175-
n.parser, _ = parsers.NewParser(&parsers.Config{
176-
DataFormat: "json",
177-
MetricName: "nats_json_test",
178-
})
179-
go n.receiver()
180-
in <- mqttMsg(testMsgJSON)
181-
182-
n.Gather(&acc)
183-
184-
acc.Wait(1)
185-
186-
acc.AssertContainsFields(t, "nats_json_test",
187-
map[string]interface{}{
188-
"a": float64(5),
189-
"b_c": float64(6),
190-
})
191-
}
192-
193115
func mqttMsg(val string) mqtt.Message {
194116
return &message{
195117
topic: "telegraf/unit_test",

0 commit comments

Comments
 (0)