Skip to content

Commit 062710e

Browse files
committed
net: add multi listener impl for net.Listener
This adds an implementation of net.Listener which listens on and accepts connections from multiple addresses. Signed-off-by: Daman Arora <[email protected]>
1 parent fe8a2dd commit 062710e

File tree

2 files changed

+622
-0
lines changed

2 files changed

+622
-0
lines changed

net/multi_listen.go

+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
/*
2+
Copyright 2024 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package net
18+
19+
import (
20+
"context"
21+
"fmt"
22+
"net"
23+
"sync"
24+
)
25+
26+
// connErrPair pairs conn and error which is returned by accept on sub-listeners.
27+
type connErrPair struct {
28+
conn net.Conn
29+
err error
30+
}
31+
32+
// multiListener implements net.Listener
33+
type multiListener struct {
34+
listeners []net.Listener
35+
36+
wg sync.WaitGroup
37+
mu sync.Mutex
38+
closed bool
39+
40+
// connErrQueue holds the connections accepted by sub-listeners
41+
connErrQueue []connErrPair
42+
43+
// acceptReadyCh is used as a semaphore to wake up the waiting
44+
// multiListener.Accept() when new connections are available
45+
acceptReadyCh chan any
46+
}
47+
48+
// compile time check to ensure *multiListener implements net.Listener
49+
var _ net.Listener = &multiListener{}
50+
51+
// MultiListen returns net.Listener which can listen on and accept connections for
52+
// the given network on multiple addresses. Internally it uses stdlib to create
53+
// sub-listener and multiplexes connection requests using go-routines.
54+
// The network must be "tcp", "tcp4" or "tcp6".
55+
// It follows the semantics of net.Listen that primarily means:
56+
// 1. If the host is an unspecified/zero IP address with "tcp" network, MultiListen
57+
// listens on all available unicast and anycast IP addresses of the local system.
58+
// 2. Use "tcp4" or "tcp6" to exclusively listen on IPv4 or IPv6 family, respectively.
59+
// 3. The host can accept names (e.g, localhost) and it will create a listener for at
60+
// most one of the host's IP.
61+
func MultiListen(ctx context.Context, network string, addrs []string) (net.Listener, error) {
62+
return multiListen(
63+
ctx,
64+
network,
65+
addrs,
66+
func(ctx context.Context, network, address string) (net.Listener, error) {
67+
var lc net.ListenConfig
68+
return lc.Listen(ctx, network, address)
69+
})
70+
}
71+
72+
// multiListen implements MultiListen by consuming stdlib functions as dependency allowing
73+
// mocking for unit-testing.
74+
func multiListen(
75+
ctx context.Context,
76+
network string,
77+
addrs []string,
78+
listenFunc func(ctx context.Context, network, address string) (net.Listener, error),
79+
) (net.Listener, error) {
80+
if !(network == "tcp" || network == "tcp4" || network == "tcp6") {
81+
return nil, fmt.Errorf("network '%s' not supported", network)
82+
}
83+
if len(addrs) == 0 {
84+
return nil, fmt.Errorf("no address provided to listen on")
85+
}
86+
87+
ml := &multiListener{
88+
acceptReadyCh: make(chan any),
89+
}
90+
91+
for _, addr := range addrs {
92+
l, err := listenFunc(ctx, network, addr)
93+
if err != nil {
94+
// close all the sub-listeners and exit
95+
_ = ml.Close()
96+
return nil, err
97+
}
98+
ml.listeners = append(ml.listeners, l)
99+
}
100+
101+
for _, l := range ml.listeners {
102+
ml.wg.Add(1)
103+
go func(l net.Listener) {
104+
defer ml.wg.Done()
105+
for {
106+
conn, err := l.Accept()
107+
ml.mu.Lock()
108+
if ml.closed {
109+
ml.mu.Unlock()
110+
return
111+
}
112+
// enqueue the accepted connection
113+
ml.connErrQueue = append(ml.connErrQueue, connErrPair{conn: conn, err: err})
114+
115+
// signal the waiting ml.Accept() to consume accepted connection from the queue.
116+
select {
117+
case ml.acceptReadyCh <- struct{}{}:
118+
default:
119+
}
120+
ml.mu.Unlock()
121+
}
122+
}(l)
123+
}
124+
return ml, nil
125+
}
126+
127+
// Accept implements net.Listener.
128+
// It waits for and returns a connection from any of the sub-listener.
129+
func (ml *multiListener) Accept() (net.Conn, error) {
130+
for {
131+
// atomically return and remove the first element of the queue if it's not empty
132+
ml.mu.Lock()
133+
if len(ml.connErrQueue) > 0 {
134+
connErr := ml.connErrQueue[0]
135+
ml.connErrQueue = ml.connErrQueue[1:]
136+
ml.mu.Unlock()
137+
return connErr.conn, connErr.err
138+
}
139+
ml.mu.Unlock()
140+
141+
// wait for any sub-listener to enqueue an accepted connection
142+
_, ok := <-ml.acceptReadyCh
143+
if !ok {
144+
// The "acceptReadyCh" channel will be closed only when Close() is called on the multiListener.
145+
// Closing of this channel implies that all sub-listeners are also closed, which causes a
146+
// "use of closed network connection" error on their Accept() calls. We return the same error
147+
// for multiListener.Accept() if multiListener.Close() has already been called.
148+
return nil, fmt.Errorf("use of closed network connection")
149+
}
150+
}
151+
}
152+
153+
// Close implements net.Listener.
154+
// It will close all sub-listeners and wait for the go-routines to exit.
155+
func (ml *multiListener) Close() error {
156+
ml.mu.Lock()
157+
if ml.closed {
158+
ml.mu.Unlock()
159+
return fmt.Errorf("use of closed network connection")
160+
}
161+
ml.closed = true
162+
close(ml.acceptReadyCh)
163+
ml.mu.Unlock()
164+
165+
// Closing the listeners causes Accept() to immediately return an error,
166+
// which serves as the exit condition for the sub-listener go-routines.
167+
for _, l := range ml.listeners {
168+
_ = l.Close()
169+
}
170+
171+
// Wait for all the sub-listener go-routines to exit.
172+
ml.wg.Wait()
173+
return nil
174+
}
175+
176+
// Addr is an implementation of the net.Listener interface.
177+
// It always returns the address of the first listener.
178+
// Callers should use conn.LocalAddr() to obtain the actual
179+
// local address of the sub-listener.
180+
func (ml *multiListener) Addr() net.Addr {
181+
return ml.listeners[0].Addr()
182+
}

0 commit comments

Comments
 (0)