Skip to content

Commit 876756d

Browse files
authored
Merge pull request #2339 from Zhupku/mengzezhu/ut2
test: add UT for ListenEndpoint in utils.go
2 parents e719c1b + e1f59a6 commit 876756d

File tree

2 files changed

+75
-3
lines changed

2 files changed

+75
-3
lines changed

pkg/csi-common/utils.go

+11-3
Original file line numberDiff line numberDiff line change
@@ -41,24 +41,32 @@ func parseEndpoint(ep string) (string, string, error) {
4141
}
4242
return "", "", fmt.Errorf("Invalid endpoint: %v", ep)
4343
}
44+
45+
var klogFatalf = func(format string, args ...interface{}) {
46+
klog.Fatalf(format, args...)
47+
}
48+
4449
func ListenEndpoint(endpoint string) (net.Listener, error) {
4550
proto, addr, err := parseEndpoint(endpoint)
4651
if err != nil {
47-
klog.Fatal(err.Error())
52+
klogFatalf("Invalid endpoint: %v", err)
53+
return nil, err
4854
}
4955

5056
if proto == "unix" {
5157
if runtime.GOOS != "windows" {
5258
addr = "/" + addr
5359
}
5460
if err := os.Remove(addr); err != nil && !os.IsNotExist(err) {
55-
klog.Fatalf("Failed to remove %s, error: %s", addr, err.Error())
61+
klogFatalf("Failed to remove %s, error: %s", addr, err.Error())
62+
return nil, err
5663
}
5764
}
5865

5966
listener, err := net.Listen(proto, addr)
6067
if err != nil {
61-
klog.Fatalf("Failed to listen: %v", err)
68+
klogFatalf("Failed to listen: %v", err)
69+
return nil, err
6270
}
6371
return listener, err
6472
}

pkg/csi-common/utils_test.go

+64
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ import (
2020
"bytes"
2121
"context"
2222
"flag"
23+
"os"
24+
"runtime"
2325
"testing"
2426

2527
"google.golang.org/grpc"
@@ -302,3 +304,65 @@ func TestGetLogLevel(t *testing.T) {
302304
}
303305
}
304306
}
307+
308+
func TestListenEndpoint(t *testing.T) {
309+
if runtime.GOOS == "windows" {
310+
t.Skip("Skip test on Windows")
311+
}
312+
313+
originalKlogFatalf := klogFatalf
314+
klogFatalf = func(_ string, _ ...interface{}) {}
315+
defer func() { klogFatalf = originalKlogFatalf }()
316+
317+
tests := []struct {
318+
name string
319+
endpoint string
320+
filePath string
321+
wantErr bool
322+
}{
323+
{
324+
name: "unix socket",
325+
endpoint: "unix:///tmp/csi.sock",
326+
filePath: "/tmp/csi.sock",
327+
wantErr: false,
328+
},
329+
{
330+
name: "tcp socket",
331+
endpoint: "tcp://127.0.0.1:0",
332+
wantErr: false,
333+
},
334+
{
335+
name: "invalid endpoint",
336+
endpoint: "invalid://",
337+
wantErr: true,
338+
},
339+
{
340+
name: "invalid unix socket",
341+
endpoint: "unix://does/not/exist",
342+
wantErr: true,
343+
},
344+
}
345+
for _, tt := range tests {
346+
t.Run(tt.name, func(t *testing.T) {
347+
defer func() {
348+
if r := recover(); r != nil {
349+
if !tt.wantErr {
350+
t.Errorf("ListenEndpoint() panicked unexpectedly: %v", r)
351+
}
352+
}
353+
}()
354+
355+
got, err := ListenEndpoint(tt.endpoint)
356+
if (err != nil) != tt.wantErr {
357+
t.Errorf("Listen() error = %v, wantErr %v", err, tt.wantErr)
358+
return
359+
}
360+
if err == nil {
361+
got.Close()
362+
if tt.filePath != "" {
363+
os.Remove(tt.filePath)
364+
}
365+
}
366+
})
367+
}
368+
}

0 commit comments

Comments
 (0)