Skip to content

Commit 32b2fda

Browse files
committed
net: add multi listener impl for net.Listener
This adds an implementation of net.Listener which listens on and accepts connections from multiple addresses. Signed-off-by: Daman Arora <[email protected]> Signed-off-by: Tim Hockin <[email protected]>
1 parent fe8a2dd commit 32b2fda

File tree

2 files changed

+693
-0
lines changed

2 files changed

+693
-0
lines changed

net/multi_listen.go

+195
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
/*
2+
Copyright 2024 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package net
18+
19+
import (
20+
"context"
21+
"fmt"
22+
"net"
23+
"sync"
24+
)
25+
26+
// connErrPair pairs conn and error which is returned by accept on sub-listeners.
27+
type connErrPair struct {
28+
conn net.Conn
29+
err error
30+
}
31+
32+
// multiListener implements net.Listener
33+
type multiListener struct {
34+
listeners []net.Listener
35+
wg sync.WaitGroup
36+
37+
// connCh passes accepted connections, from child listeners to parent.
38+
connCh chan connErrPair
39+
// stopCh communicates from parent to child listeners.
40+
stopCh chan struct{}
41+
}
42+
43+
// compile time check to ensure *multiListener implements net.Listener
44+
var _ net.Listener = &multiListener{}
45+
46+
// MultiListen returns net.Listener which can listen on and accept connections for
47+
// the given network on multiple addresses. Internally it uses stdlib to create
48+
// sub-listener and multiplexes connection requests using go-routines.
49+
// The network must be "tcp", "tcp4" or "tcp6".
50+
// It follows the semantics of net.Listen that primarily means:
51+
// 1. If the host is an unspecified/zero IP address with "tcp" network, MultiListen
52+
// listens on all available unicast and anycast IP addresses of the local system.
53+
// 2. Use "tcp4" or "tcp6" to exclusively listen on IPv4 or IPv6 family, respectively.
54+
// 3. The host can accept names (e.g, localhost) and it will create a listener for at
55+
// most one of the host's IP.
56+
func MultiListen(ctx context.Context, network string, addrs ...string) (net.Listener, error) {
57+
var lc net.ListenConfig
58+
return multiListen(
59+
ctx,
60+
network,
61+
addrs,
62+
func(ctx context.Context, network, address string) (net.Listener, error) {
63+
return lc.Listen(ctx, network, address)
64+
})
65+
}
66+
67+
// multiListen implements MultiListen by consuming stdlib functions as dependency allowing
68+
// mocking for unit-testing.
69+
func multiListen(
70+
ctx context.Context,
71+
network string,
72+
addrs []string,
73+
listenFunc func(ctx context.Context, network, address string) (net.Listener, error),
74+
) (net.Listener, error) {
75+
if !(network == "tcp" || network == "tcp4" || network == "tcp6") {
76+
return nil, fmt.Errorf("network %q not supported", network)
77+
}
78+
if len(addrs) == 0 {
79+
return nil, fmt.Errorf("no address provided to listen on")
80+
}
81+
82+
ml := &multiListener{
83+
connCh: make(chan connErrPair),
84+
stopCh: make(chan struct{}),
85+
}
86+
for _, addr := range addrs {
87+
l, err := listenFunc(ctx, network, addr)
88+
if err != nil {
89+
// close all the sub-listeners and exit
90+
_ = ml.Close()
91+
return nil, err
92+
}
93+
ml.listeners = append(ml.listeners, l)
94+
}
95+
96+
for _, l := range ml.listeners {
97+
ml.wg.Add(1)
98+
go func(l net.Listener) {
99+
defer ml.wg.Done()
100+
for {
101+
// Accept() is blocking, unless ml.Close() is called, in which
102+
// case it will return immediately with an error.
103+
conn, err := l.Accept()
104+
// This assumes that ANY error from Accept() will terminate the
105+
// sub-listener. We could maybe be more precise, but it
106+
// doesn't seem necessary.
107+
terminate := err != nil
108+
109+
select {
110+
case ml.connCh <- connErrPair{conn: conn, err: err}:
111+
case <-ml.stopCh:
112+
// In case we accepted a connection AND were stopped, and
113+
// this select-case was chosen, just throw away the
114+
// connection. This avoids potentially blocking on connCh
115+
// or leaking a connection.
116+
if conn != nil {
117+
_ = conn.Close()
118+
}
119+
terminate = true
120+
}
121+
// Make sure we don't loop on Accept() returning an error and
122+
// the select choosing the channel case.
123+
if terminate {
124+
return
125+
}
126+
}
127+
}(l)
128+
}
129+
return ml, nil
130+
}
131+
132+
// Accept implements net.Listener. It waits for and returns a connection from
133+
// any of the sub-listener.
134+
func (ml *multiListener) Accept() (net.Conn, error) {
135+
// wait for any sub-listener to enqueue an accepted connection
136+
connErr, ok := <-ml.connCh
137+
if !ok {
138+
// The channel will be closed only when Close() is called on the
139+
// multiListener. Closing of this channel implies that all
140+
// sub-listeners are also closed, which causes a "use of closed
141+
// network connection" error on their Accept() calls. We return the
142+
// same error for multiListener.Accept() if multiListener.Close()
143+
// has already been called.
144+
return nil, fmt.Errorf("use of closed network connection")
145+
}
146+
return connErr.conn, connErr.err
147+
}
148+
149+
// Close implements net.Listener. It will close all sub-listeners and wait for
150+
// the go-routines to exit.
151+
func (ml *multiListener) Close() error {
152+
// Make sure this can be called repeatedly without explosions.
153+
select {
154+
case <-ml.stopCh:
155+
return fmt.Errorf("use of closed network connection")
156+
default:
157+
}
158+
159+
// Tell all sub-listeners to stop.
160+
close(ml.stopCh)
161+
162+
// Closing the listeners causes Accept() to immediately return an error in
163+
// the sub-listener go-routines.
164+
for _, l := range ml.listeners {
165+
_ = l.Close()
166+
}
167+
168+
// Wait for all the sub-listener go-routines to exit.
169+
ml.wg.Wait()
170+
close(ml.connCh)
171+
172+
// Drain any already-queued connections.
173+
for connErr := range ml.connCh {
174+
if connErr.conn != nil {
175+
_ = connErr.conn.Close()
176+
}
177+
}
178+
return nil
179+
}
180+
181+
// Addr is an implementation of the net.Listener interface. It always returns
182+
// the address of the first listener. Callers should use conn.LocalAddr() to
183+
// obtain the actual local address of the sub-listener.
184+
func (ml *multiListener) Addr() net.Addr {
185+
return ml.listeners[0].Addr()
186+
}
187+
188+
// Addrs is like Addr, but returns the address for all registered listeners.
189+
func (ml *multiListener) Addrs() []net.Addr {
190+
var ret []net.Addr
191+
for _, l := range ml.listeners {
192+
ret = append(ret, l.Addr())
193+
}
194+
return ret
195+
}

0 commit comments

Comments
 (0)