Skip to content

Commit d6eb366

Browse files
committed
net: add multi listener impl for net.Listener
This adds an implementation of net.Listener which listens on and accepts connections from multiple addresses. Signed-off-by: Daman Arora <[email protected]>
1 parent fe8a2dd commit d6eb366

File tree

2 files changed

+661
-0
lines changed

2 files changed

+661
-0
lines changed

net/multi_listen.go

+179
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
/*
2+
Copyright 2024 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package net
18+
19+
import (
20+
"context"
21+
"fmt"
22+
"net"
23+
"sync"
24+
)
25+
26+
// connErrPair pairs conn and error which is returned by accept on sub-listeners.
27+
type connErrPair struct {
28+
conn net.Conn
29+
err error
30+
}
31+
32+
// multiListener implements net.Listener
33+
type multiListener struct {
34+
listeners []net.Listener
35+
wg sync.WaitGroup
36+
37+
// connCh passes accepted connections, from child listeners to parent.
38+
connCh chan connErrPair
39+
// stopCh communicates from parent to child listeners.
40+
stopCh chan struct{}
41+
}
42+
43+
// compile time check to ensure *multiListener implements net.Listener
44+
var _ net.Listener = &multiListener{}
45+
46+
// MultiListen returns net.Listener which can listen on and accept connections for
47+
// the given network on multiple addresses. Internally it uses stdlib to create
48+
// sub-listener and multiplexes connection requests using go-routines.
49+
// The network must be "tcp", "tcp4" or "tcp6".
50+
// It follows the semantics of net.Listen that primarily means:
51+
// 1. If the host is an unspecified/zero IP address with "tcp" network, MultiListen
52+
// listens on all available unicast and anycast IP addresses of the local system.
53+
// 2. Use "tcp4" or "tcp6" to exclusively listen on IPv4 or IPv6 family, respectively.
54+
// 3. The host can accept names (e.g, localhost) and it will create a listener for at
55+
// most one of the host's IP.
56+
func MultiListen(ctx context.Context, network string, addrs ...string) (net.Listener, error) {
57+
var lc net.ListenConfig
58+
return multiListen(
59+
ctx,
60+
network,
61+
addrs,
62+
func(ctx context.Context, network, address string) (net.Listener, error) {
63+
return lc.Listen(ctx, network, address)
64+
})
65+
}
66+
67+
// multiListen implements MultiListen by consuming stdlib functions as dependency allowing
68+
// mocking for unit-testing.
69+
func multiListen(
70+
ctx context.Context,
71+
network string,
72+
addrs []string,
73+
listenFunc func(ctx context.Context, network, address string) (net.Listener, error),
74+
) (net.Listener, error) {
75+
if !(network == "tcp" || network == "tcp4" || network == "tcp6") {
76+
return nil, fmt.Errorf("network %q not supported", network)
77+
}
78+
if len(addrs) == 0 {
79+
return nil, fmt.Errorf("no address provided to listen on")
80+
}
81+
82+
ml := &multiListener{
83+
connCh: make(chan connErrPair),
84+
stopCh: make(chan struct{}),
85+
}
86+
for _, addr := range addrs {
87+
l, err := listenFunc(ctx, network, addr)
88+
if err != nil {
89+
// close all the sub-listeners and exit
90+
_ = ml.Close()
91+
return nil, err
92+
}
93+
ml.listeners = append(ml.listeners, l)
94+
}
95+
96+
for _, l := range ml.listeners {
97+
ml.wg.Add(1)
98+
go func(l net.Listener) {
99+
defer ml.wg.Done()
100+
for {
101+
// Accept() is blocking, unless ml.Close() is called, in which
102+
// case it will return immediately with an error.
103+
conn, err := l.Accept()
104+
105+
select {
106+
case ml.connCh <- connErrPair{conn: conn, err: err}:
107+
case <-ml.stopCh:
108+
return
109+
}
110+
}
111+
}(l)
112+
}
113+
return ml, nil
114+
}
115+
116+
// Accept implements net.Listener. It waits for and returns a connection from
117+
// any of the sub-listener.
118+
func (ml *multiListener) Accept() (net.Conn, error) {
119+
// wait for any sub-listener to enqueue an accepted connection
120+
connErr, ok := <-ml.connCh
121+
if !ok {
122+
// The channel will be closed only when Close() is called on the
123+
// multiListener. Closing of this channel implies that all
124+
// sub-listeners are also closed, which causes a "use of closed
125+
// network connection" error on their Accept() calls. We return the
126+
// same error for multiListener.Accept() if multiListener.Close()
127+
// has already been called.
128+
return nil, fmt.Errorf("use of closed network connection")
129+
}
130+
return connErr.conn, connErr.err
131+
}
132+
133+
// Close implements net.Listener. It will close all sub-listeners and wait for
134+
// the go-routines to exit.
135+
func (ml *multiListener) Close() error {
136+
select {
137+
case <-ml.stopCh:
138+
// return error is multiListener is already closed
139+
return fmt.Errorf("use of closed network connection")
140+
default:
141+
}
142+
143+
// Tell all sub-listeners to stop.
144+
close(ml.stopCh)
145+
146+
// Closing the listeners causes Accept() to immediately return an error in
147+
// the sub-listener go-routines.
148+
for _, l := range ml.listeners {
149+
_ = l.Close()
150+
}
151+
152+
// Wait for all the sub-listener go-routines to exit.
153+
ml.wg.Wait()
154+
close(ml.connCh)
155+
156+
// Drain any already-queued connections.
157+
for connErr := range ml.connCh {
158+
if connErr.conn != nil {
159+
_ = connErr.conn.Close()
160+
}
161+
}
162+
return nil
163+
}
164+
165+
// Addr is an implementation of the net.Listener interface. It always returns
166+
// the address of the first listener. Callers should use conn.LocalAddr() to
167+
// obtain the actual local address of the sub-listener.
168+
func (ml *multiListener) Addr() net.Addr {
169+
return ml.listeners[0].Addr()
170+
}
171+
172+
// Addrs is like Addr, but returns the address for all registered listeners.
173+
func (ml *multiListener) Addrs() []net.Addr {
174+
var ret []net.Addr
175+
for _, l := range ml.listeners {
176+
ret = append(ret, l.Addr())
177+
}
178+
return ret
179+
}

0 commit comments

Comments
 (0)