From 53fc8119f703f5dacb1642695639cea0c4ce8218 Mon Sep 17 00:00:00 2001 From: Alexander Ding Date: Thu, 13 Oct 2022 14:15:29 +0000 Subject: [PATCH 1/2] feat: migrate smb API group to library model --- integrationtests/smb_test.go | 175 +++++++++++++---------------------- integrationtests/utils.go | 110 ++++++++++++++++++++++ pkg/smb/api/api.go | 83 +++++++++++++++++ pkg/smb/smb.go | 150 ++++++++++++++++++++++++++++++ pkg/smb/smb_test.go | 152 ++++++++++++++++++++++++++++++ pkg/smb/types.go | 20 ++++ 6 files changed, 578 insertions(+), 112 deletions(-) create mode 100644 pkg/smb/api/api.go create mode 100644 pkg/smb/smb.go create mode 100644 pkg/smb/smb_test.go create mode 100644 pkg/smb/types.go diff --git a/integrationtests/smb_test.go b/integrationtests/smb_test.go index 65a359d4..5b3c3fd6 100644 --- a/integrationtests/smb_test.go +++ b/integrationtests/smb_test.go @@ -1,138 +1,89 @@ package integrationtests import ( + "context" "fmt" - "io/ioutil" - "math/rand" "os" - "os/exec" - "strings" - "time" - "testing" -) -const letterset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" -var seededRand = rand.New(rand.NewSource(time.Now().UnixNano())) + fs "github.com/kubernetes-csi/csi-proxy/pkg/filesystem" + fsapi "github.com/kubernetes-csi/csi-proxy/pkg/filesystem/api" + "github.com/kubernetes-csi/csi-proxy/pkg/smb" + smbapi "github.com/kubernetes-csi/csi-proxy/pkg/smb/api" +) -func stringWithCharset(length int, charset string) string { - b := make([]byte, length) - for i := range b { - b[i] = charset[seededRand.Intn(len(charset))] - } - return string(b) +func TestSmbAPIGroup(t *testing.T) { + t.Run("v1alpha1SmbTests", func(t *testing.T) { + v1alpha1SmbTests(t) + }) + t.Run("v1beta1SmbTests", func(t *testing.T) { + v1beta1SmbTests(t) + }) + t.Run("v1beta2SmbTests", func(t *testing.T) { + v1beta2SmbTests(t) + }) + t.Run("v1SmbTests", func(t *testing.T) { + v1SmbTests(t) + }) } -// RandomString generates a random string with specified length -func randomString(length int) string { - return stringWithCharset(length, letterset) -} +func TestSmb(t *testing.T) { + fsClient, err := fs.New(fsapi.New()) + require.Nil(t, err) + client, err := smb.New(smbapi.New(), fsClient) + require.Nil(t, err) -func setupUser(username, password string) error { - cmdLine := fmt.Sprintf(`$PWord = ConvertTo-SecureString $Env:password -AsPlainText -Force` + - `;New-Localuser -name $Env:username -accountneverexpires -password $PWord`) - cmd := exec.Command("powershell", "/c", cmdLine) - cmd.Env = append(os.Environ(), - fmt.Sprintf("username=%s", username), - fmt.Sprintf("password=%s", password)) - if output, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("setupUser failed: %v, output: %q", err, string(output)) - } - return nil -} + username := randomString(5) + password := randomString(10) + "!" + sharePath := fmt.Sprintf("C:\\smbshare%s", randomString(5)) + smbShare := randomString(6) -func removeUser(t *testing.T, username string) { - cmdLine := fmt.Sprintf(`Remove-Localuser -name $Env:username`) - cmd := exec.Command("powershell", "/c", cmdLine) - cmd.Env = append(os.Environ(), - fmt.Sprintf("username=%s", username)) - if output, err := cmd.CombinedOutput(); err != nil { - t.Fatalf("setupUser failed: %v, output: %q", err, string(output)) - } -} + localPath := fmt.Sprintf("C:\\localpath%s", randomString(5)) -func setupSmbShare(shareName, localPath, username string) error { - if err := os.MkdirAll(localPath, 0755); err != nil { - return fmt.Errorf("setupSmbShare failed to create local path %q: %v", localPath, err) + if err = setupUser(username, password); err != nil { + t.Fatalf("TestSmbAPIGroup %v", err) } - cmdLine := fmt.Sprintf(`New-SMBShare -Name $Env:sharename -Path $Env:path -fullaccess $Env:username`) - cmd := exec.Command("powershell", "/c", cmdLine) - cmd.Env = append(os.Environ(), - fmt.Sprintf("sharename=%s", shareName), - fmt.Sprintf("path=%s", localPath), - fmt.Sprintf("username=%s", username)) - if output, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("setupSmbShare failed: %v, output: %q", err, string(output)) - } - - return nil -} + defer removeUser(t, username) -func removeSmbShare(t *testing.T, shareName string) { - cmdLine := fmt.Sprintf(`Remove-SMBShare -Name $Env:sharename -Force`) - cmd := exec.Command("powershell", "/c", cmdLine) - cmd.Env = append(os.Environ(), - fmt.Sprintf("sharename=%s", shareName)) - if output, err := cmd.CombinedOutput(); err != nil { - t.Fatalf("setupSmbShare failed: %v, output: %q", err, string(output)) + if err = setupSmbShare(smbShare, sharePath, username); err != nil { + t.Fatalf("TestSmbAPIGroup %v", err) } - return -} + defer removeSmbShare(t, smbShare) -func getSmbGlobalMapping(remotePath string) error { - // use PowerShell Environment Variables to store user input string to prevent command line injection - // https://docs.microsoft.com/en-us/powershell/module/microsoft.powershell.core/about/about_environment_variables?view=powershell-5.1 - cmdLine := fmt.Sprintf(`(Get-SmbGlobalMapping -RemotePath $Env:smbremotepath).Status`) + hostname, err := os.Hostname() + assert.Nil(t, err) - cmd := exec.Command("powershell", "/c", cmdLine) - cmd.Env = append(os.Environ(), - fmt.Sprintf("smbremotepath=%s", remotePath)) - output, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("Get-SmbGlobalMapping failed: %v, output: %q", err, string(output)) + username = "domain\\" + username + remotePath := "\\\\" + hostname + "\\" + smbShare + // simulate Mount SMB operations around staging a volume on a node + mountSmbShareReq := &smb.NewSmbGlobalMappingRequest{ + RemotePath: remotePath, + Username: username, + Password: password, } - if !strings.Contains(string(output), "OK") { - return fmt.Errorf("Get-SmbGlobalMapping return status %q instead of OK", string(output)) - } - return nil -} - -func writeReadFile(path string) error { - fileName := path + "\\hello.txt" - f, err := os.Create(fileName) + _, err = client.NewSmbGlobalMapping(context.Background(), mountSmbShareReq) if err != nil { - return fmt.Errorf("create file %q failed: %v", fileName, err) - } - defer f.Close() - fileContent := "Hello World" - if _, err = f.WriteString(fileContent); err != nil { - return fmt.Errorf("write to file %q failed: %v", fileName, err) + t.Fatalf("TestSmbAPIGroup %v", err) } - if err = f.Sync(); err != nil { - return fmt.Errorf("sync file %q failed: %v", fileName, err) + + err = getSmbGlobalMapping(remotePath) + assert.Nil(t, err) + + err = writeReadFile(remotePath) + assert.Nil(t, err) + + unmountSmbShareReq := &smb.RemoveSmbGlobalMappingRequest{ + RemotePath: remotePath, } - dat, err := ioutil.ReadFile(fileName) + _, err = client.RemoveSmbGlobalMapping(context.Background(), unmountSmbShareReq) if err != nil { - return fmt.Errorf("read file %q failed: %v", fileName, err) - } - if fileContent != string(dat) { - return fmt.Errorf("read content of file %q failed: expected %q, got %q", fileName, fileContent, string(dat)) + t.Fatalf("TestSmbAPIGroup %v", err) } - return nil -} - -func TestSmbAPIGroup(t *testing.T) { - t.Run("v1alpha1SmbTests", func(t *testing.T) { - v1alpha1SmbTests(t) - }) - t.Run("v1beta1SmbTests", func(t *testing.T) { - v1beta1SmbTests(t) - }) - t.Run("v1beta2SmbTests", func(t *testing.T) { - v1beta2SmbTests(t) - }) - t.Run("v1SmbTests", func(t *testing.T) { - v1SmbTests(t) - }) + err = getSmbGlobalMapping(remotePath) + assert.NotNil(t, err) + err = writeReadFile(localPath) + assert.NotNil(t, err) } diff --git a/integrationtests/utils.go b/integrationtests/utils.go index 895ccb37..b0f17576 100644 --- a/integrationtests/utils.go +++ b/integrationtests/utils.go @@ -346,3 +346,113 @@ func volumeInit(volumeClient volume.Interface, t *testing.T) (*VirtualHardDisk, } return vhd, volumeID, vhdCleanup } + +const letterset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + +var seededRand = rand.New(rand.NewSource(time.Now().UnixNano())) + +func stringWithCharset(length int, charset string) string { + b := make([]byte, length) + for i := range b { + b[i] = charset[seededRand.Intn(len(charset))] + } + return string(b) +} + +// RandomString generates a random string with specified length +func randomString(length int) string { + return stringWithCharset(length, letterset) +} + +func setupUser(username, password string) error { + cmdLine := fmt.Sprintf(`$PWord = ConvertTo-SecureString $Env:password -AsPlainText -Force` + + `;New-Localuser -name $Env:username -accountneverexpires -password $PWord`) + cmd := exec.Command("powershell", "/c", cmdLine) + cmd.Env = append(os.Environ(), + fmt.Sprintf("username=%s", username), + fmt.Sprintf("password=%s", password)) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("setupUser failed: %v, output: %q", err, string(output)) + } + return nil +} + +func removeUser(t *testing.T, username string) { + cmdLine := fmt.Sprintf(`Remove-Localuser -name $Env:username`) + cmd := exec.Command("powershell", "/c", cmdLine) + cmd.Env = append(os.Environ(), + fmt.Sprintf("username=%s", username)) + if output, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("setupUser failed: %v, output: %q", err, string(output)) + } +} + +func setupSmbShare(shareName, localPath, username string) error { + if err := os.MkdirAll(localPath, 0755); err != nil { + return fmt.Errorf("setupSmbShare failed to create local path %q: %v", localPath, err) + } + cmdLine := fmt.Sprintf(`New-SMBShare -Name $Env:sharename -Path $Env:path -fullaccess $Env:username`) + cmd := exec.Command("powershell", "/c", cmdLine) + cmd.Env = append(os.Environ(), + fmt.Sprintf("sharename=%s", shareName), + fmt.Sprintf("path=%s", localPath), + fmt.Sprintf("username=%s", username)) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("setupSmbShare failed: %v, output: %q", err, string(output)) + } + + return nil +} + +func removeSmbShare(t *testing.T, shareName string) { + cmdLine := fmt.Sprintf(`Remove-SMBShare -Name $Env:sharename -Force`) + cmd := exec.Command("powershell", "/c", cmdLine) + cmd.Env = append(os.Environ(), + fmt.Sprintf("sharename=%s", shareName)) + if output, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("setupSmbShare failed: %v, output: %q", err, string(output)) + } + return +} + +func getSmbGlobalMapping(remotePath string) error { + // use PowerShell Environment Variables to store user input string to prevent command line injection + // https://docs.microsoft.com/en-us/powershell/module/microsoft.powershell.core/about/about_environment_variables?view=powershell-5.1 + cmdLine := fmt.Sprintf(`(Get-SmbGlobalMapping -RemotePath $Env:smbremotepath).Status`) + + cmd := exec.Command("powershell", "/c", cmdLine) + cmd.Env = append(os.Environ(), + fmt.Sprintf("smbremotepath=%s", remotePath)) + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("Get-SmbGlobalMapping failed: %v, output: %q", err, string(output)) + } + if !strings.Contains(string(output), "OK") { + return fmt.Errorf("Get-SmbGlobalMapping return status %q instead of OK", string(output)) + } + return nil +} + +func writeReadFile(path string) error { + fileName := path + "\\hello.txt" + f, err := os.Create(fileName) + if err != nil { + return fmt.Errorf("create file %q failed: %v", fileName, err) + } + defer f.Close() + fileContent := "Hello World" + if _, err = f.WriteString(fileContent); err != nil { + return fmt.Errorf("write to file %q failed: %v", fileName, err) + } + if err = f.Sync(); err != nil { + return fmt.Errorf("sync file %q failed: %v", fileName, err) + } + dat, err := ioutil.ReadFile(fileName) + if err != nil { + return fmt.Errorf("read file %q failed: %v", fileName, err) + } + if fileContent != string(dat) { + return fmt.Errorf("read content of file %q failed: expected %q, got %q", fileName, fileContent, string(dat)) + } + return nil +} diff --git a/pkg/smb/api/api.go b/pkg/smb/api/api.go new file mode 100644 index 00000000..ce21c40f --- /dev/null +++ b/pkg/smb/api/api.go @@ -0,0 +1,83 @@ +package api + +import ( + "fmt" + "strings" + + "github.com/kubernetes-csi/csi-proxy/pkg/utils" +) + +type API interface { + IsSmbMapped(remotePath string) (bool, error) + NewSmbLink(remotePath, localPath string) error + NewSmbGlobalMapping(remotePath, username, password string) error + RemoveSmbGlobalMapping(remotePath string) error +} + +type smbAPI struct{} + +var _ API = &smbAPI{} + +func New() API { + return smbAPI{} +} + +func (smbAPI) IsSmbMapped(remotePath string) (bool, error) { + cmdLine := `$(Get-SmbGlobalMapping -RemotePath $Env:smbremotepath -ErrorAction Stop).Status ` + cmdEnv := fmt.Sprintf("smbremotepath=%s", remotePath) + out, err := utils.RunPowershellCmd(cmdLine, cmdEnv) + if err != nil { + return false, fmt.Errorf("error checking smb mapping. cmd %s, output: %s, err: %v", remotePath, string(out), err) + } + + if len(out) == 0 || !strings.EqualFold(strings.TrimSpace(string(out)), "OK") { + return false, nil + } + return true, nil +} + +// NewSmbLink - creates a directory symbolic link to the remote share. +// The os.Symlink was having issue for cases where the destination was an SMB share - the container +// runtime would complain stating "Access Denied". Because of this, we had to perform +// this operation with powershell commandlet creating an directory softlink. +// Since os.Symlink is currently being used in working code paths, no attempt is made in +// alpha to merge the paths. +// TODO (for beta release): Merge the link paths - os.Symlink and Powershell link path. +func (smbAPI) NewSmbLink(remotePath, localPath string) error { + if !strings.HasSuffix(remotePath, "\\") { + // Golang has issues resolving paths mapped to file shares if they do not end in a trailing \ + // so add one if needed. + remotePath = remotePath + "\\" + } + + cmdLine := `New-Item -ItemType SymbolicLink $Env:smblocalPath -Target $Env:smbremotepath` + output, err := utils.RunPowershellCmd(cmdLine, fmt.Sprintf("smbremotepath=%s", remotePath), fmt.Sprintf("smblocalpath=%s", localPath)) + if err != nil { + return fmt.Errorf("error linking %s to %s. output: %s, err: %v", remotePath, localPath, string(output), err) + } + + return nil +} + +func (smbAPI) NewSmbGlobalMapping(remotePath, username, password string) error { + // use PowerShell Environment Variables to store user input string to prevent command line injection + // https://docs.microsoft.com/en-us/powershell/module/microsoft.powershell.core/about/about_environment_variables?view=powershell-5.1 + cmdLine := fmt.Sprintf(`$PWord = ConvertTo-SecureString -String $Env:smbpassword -AsPlainText -Force` + + `;$Credential = New-Object -TypeName System.Management.Automation.PSCredential -ArgumentList $Env:smbuser, $PWord` + + `;New-SmbGlobalMapping -RemotePath $Env:smbremotepath -Credential $Credential -RequirePrivacy $true`) + + if output, err := utils.RunPowershellCmd(cmdLine, fmt.Sprintf("smbuser=%s", username), + fmt.Sprintf("smbpassword=%s", password), + fmt.Sprintf("smbremotepath=%s", remotePath)); err != nil { + return fmt.Errorf("NewSmbGlobalMapping failed. output: %q, err: %v", string(output), err) + } + return nil +} + +func (smbAPI) RemoveSmbGlobalMapping(remotePath string) error { + cmd := `Remove-SmbGlobalMapping -RemotePath $Env:smbremotepath -Force` + if output, err := utils.RunPowershellCmd(cmd, fmt.Sprintf("smbremotepath=%s", remotePath)); err != nil { + return fmt.Errorf("UnmountSmbShare failed. output: %q, err: %v", string(output), err) + } + return nil +} diff --git a/pkg/smb/smb.go b/pkg/smb/smb.go new file mode 100644 index 00000000..4ac2ff5d --- /dev/null +++ b/pkg/smb/smb.go @@ -0,0 +1,150 @@ +package smb + +import ( + "context" + "fmt" + "strings" + + fs "github.com/kubernetes-csi/csi-proxy/pkg/filesystem" + smbapi "github.com/kubernetes-csi/csi-proxy/pkg/smb/api" + "k8s.io/klog/v2" +) + +type Smb struct { + hostAPI smbapi.API + fs fs.Interface +} + +type Interface interface { + NewSmbGlobalMapping(context.Context, *NewSmbGlobalMappingRequest) (*NewSmbGlobalMappingResponse, error) + RemoveSmbGlobalMapping(context.Context, *RemoveSmbGlobalMappingRequest) (*RemoveSmbGlobalMappingResponse, error) +} + +// check that Smb implements the Interface +var _ Interface = &Smb{} + +func normalizeWindowsPath(path string) string { + normalizedPath := strings.Replace(path, "/", "\\", -1) + return normalizedPath +} + +func getRootMappingPath(path string) (string, error) { + items := strings.Split(path, "\\") + parts := []string{} + for _, s := range items { + if len(s) > 0 { + parts = append(parts, s) + if len(parts) == 2 { + break + } + } + } + if len(parts) != 2 { + klog.Errorf("remote path (%s) is invalid", path) + return "", fmt.Errorf("remote path (%s) is invalid", path) + } + // parts[0] is a smb host name + // parts[1] is a smb share name + return strings.ToLower("\\\\" + parts[0] + "\\" + parts[1]), nil +} + +func New(hostAPI smbapi.API, fsClient fs.Interface) (*Smb, error) { + return &Smb{ + hostAPI: hostAPI, + fs: fsClient, + }, nil +} + +func (s *Smb) NewSmbGlobalMapping(context context.Context, request *NewSmbGlobalMappingRequest) (*NewSmbGlobalMappingResponse, error) { + klog.V(2).Infof("calling NewSmbGlobalMapping with remote path %q", request.RemotePath) + response := &NewSmbGlobalMappingResponse{} + remotePath := normalizeWindowsPath(request.RemotePath) + localPath := request.LocalPath + + if remotePath == "" { + klog.Errorf("remote path is empty") + return response, fmt.Errorf("remote path is empty") + } + + mappingPath, err := getRootMappingPath(remotePath) + if err != nil { + return response, err + } + + isMapped, err := s.hostAPI.IsSmbMapped(mappingPath) + if err != nil { + isMapped = false + } + + if isMapped { + klog.V(4).Infof("Remote %s already mapped. Validating...", mappingPath) + + validResp, err := s.fs.PathValid(context, &fs.PathValidRequest{Path: mappingPath}) + if err != nil { + klog.Warningf("PathValid(%s) failed with %v, ignore error", mappingPath, err) + } + + if !validResp.Valid { + klog.V(4).Infof("RemotePath %s is not valid, removing now", mappingPath) + err := s.hostAPI.RemoveSmbGlobalMapping(mappingPath) + if err != nil { + klog.Errorf("RemoveSmbGlobalMapping(%s) failed with %v", mappingPath, err) + return response, err + } + isMapped = false + } else { + klog.V(4).Infof("RemotePath %s is valid", mappingPath) + } + } + + if !isMapped { + klog.V(4).Infof("Remote %s not mapped. Mapping now!", mappingPath) + err = s.hostAPI.NewSmbGlobalMapping(mappingPath, request.Username, request.Password) + if err != nil { + klog.Errorf("failed NewSmbGlobalMapping %v", err) + return response, err + } + } + + if len(localPath) != 0 { + klog.V(4).Infof("ValidatePathWindows: '%s'", localPath) + err = fs.ValidatePathWindows(localPath) + if err != nil { + klog.Errorf("failed validate plugin path %v", err) + return response, err + } + err = s.hostAPI.NewSmbLink(remotePath, localPath) + if err != nil { + klog.Errorf("failed NewSmbLink %v", err) + return response, fmt.Errorf("creating link %s to %s failed with error: %v", localPath, remotePath, err) + } + } + + klog.V(2).Infof("NewSmbGlobalMapping on remote path %q is completed", request.RemotePath) + return response, nil +} + +func (s *Smb) RemoveSmbGlobalMapping(context context.Context, request *RemoveSmbGlobalMappingRequest) (*RemoveSmbGlobalMappingResponse, error) { + klog.V(2).Infof("calling RemoveSmbGlobalMapping with remote path %q", request.RemotePath) + response := &RemoveSmbGlobalMappingResponse{} + remotePath := normalizeWindowsPath(request.RemotePath) + + if remotePath == "" { + klog.Errorf("remote path is empty") + return response, fmt.Errorf("remote path is empty") + } + + mappingPath, err := getRootMappingPath(remotePath) + if err != nil { + return response, err + } + + err = s.hostAPI.RemoveSmbGlobalMapping(mappingPath) + if err != nil { + klog.Errorf("failed RemoveSmbGlobalMapping %v", err) + return response, err + } + + klog.V(2).Infof("RemoveSmbGlobalMapping on remote path %q is completed", request.RemotePath) + return response, nil +} diff --git a/pkg/smb/smb_test.go b/pkg/smb/smb_test.go new file mode 100644 index 00000000..03fe635f --- /dev/null +++ b/pkg/smb/smb_test.go @@ -0,0 +1,152 @@ +package smb + +import ( + "context" + "testing" + + fs "github.com/kubernetes-csi/csi-proxy/pkg/filesystem" + fsapi "github.com/kubernetes-csi/csi-proxy/pkg/filesystem/api" + smbapi "github.com/kubernetes-csi/csi-proxy/pkg/smb/api" +) + +type fakeSmbAPI struct{} + +var _ smbapi.API = &fakeSmbAPI{} + +func (fakeSmbAPI) NewSmbGlobalMapping(remotePath, username, password string) error { + return nil +} + +func (fakeSmbAPI) RemoveSmbGlobalMapping(remotePath string) error { + return nil +} + +func (fakeSmbAPI) IsSmbMapped(remotePath string) (bool, error) { + return false, nil +} + +func (fakeSmbAPI) NewSmbLink(remotePath, localPath string) error { + return nil +} + +type fakeFileSystemAPI struct{} + +var _ fsapi.API = &fakeFileSystemAPI{} + +func (fakeFileSystemAPI) PathExists(path string) (bool, error) { + return true, nil +} +func (fakeFileSystemAPI) PathValid(path string) (bool, error) { + return true, nil +} +func (fakeFileSystemAPI) Mkdir(path string) error { + return nil +} +func (fakeFileSystemAPI) Rmdir(path string, force bool) error { + return nil +} +func (fakeFileSystemAPI) RmdirContents(path string) error { + return nil +} +func (fakeFileSystemAPI) CreateSymlink(tgt string, src string) error { + return nil +} + +func (fakeFileSystemAPI) IsSymlink(path string) (bool, error) { + return true, nil +} + +func TestNewSmbGlobalMapping(t *testing.T) { + testCases := []struct { + remote string + local string + username string + password string + expectError bool + }{ + { + remote: "", + username: "", + password: "", + expectError: true, + }, + { + remote: "\\\\hostname\\path", + username: "", + password: "", + expectError: false, + }, + } + fsClient, err := fs.New(&fakeFileSystemAPI{}) + if err != nil { + t.Fatalf("FileSystem client could not be initialized for testing: %v", err) + } + + client, err := New(&fakeSmbAPI{}, fsClient) + if err != nil { + t.Fatalf("Smb client could not be initialized for testing: %v", err) + } + for _, tc := range testCases { + req := &NewSmbGlobalMappingRequest{ + LocalPath: tc.local, + RemotePath: tc.remote, + Username: tc.username, + Password: tc.password, + } + _, err := client.NewSmbGlobalMapping(context.TODO(), req) + if tc.expectError && err == nil { + t.Errorf("Expected error but NewSmbGlobalMapping returned a nil error") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no errors but NewSmbGlobalMapping returned error: %v", err) + } + } +} + +func TestGetRootMappingPath(t *testing.T) { + testCases := []struct { + remote string + expectResult string + expectError bool + }{ + { + remote: "", + expectResult: "", + expectError: true, + }, + { + remote: "hostname", + expectResult: "", + expectError: true, + }, + { + remote: "\\\\hostname\\path", + expectResult: "\\\\hostname\\path", + expectError: false, + }, + { + remote: "\\\\hostname\\path\\", + expectResult: "\\\\hostname\\path", + expectError: false, + }, + { + remote: "\\\\hostname\\path\\subpath", + expectResult: "\\\\hostname\\path", + expectError: false, + }, + } + for _, tc := range testCases { + result, err := getRootMappingPath(tc.remote) + if tc.expectError && err == nil { + t.Errorf("Expected error but getRootMappingPath returned a nil error") + } + if !tc.expectError { + if err != nil { + t.Errorf("Expected no errors but getRootMappingPath returned error: %v", err) + } + if tc.expectResult != result { + t.Errorf("Expected (%s) but getRootMappingPath returned (%s)", tc.expectResult, result) + } + } + } +} diff --git a/pkg/smb/types.go b/pkg/smb/types.go new file mode 100644 index 00000000..d38ab64a --- /dev/null +++ b/pkg/smb/types.go @@ -0,0 +1,20 @@ +package smb + +type NewSmbGlobalMappingRequest struct { + RemotePath string + LocalPath string + Username string + Password string +} + +type NewSmbGlobalMappingResponse struct { + // Intentionally empty. +} + +type RemoveSmbGlobalMappingRequest struct { + RemotePath string +} + +type RemoveSmbGlobalMappingResponse struct { + // Intentionally empty. +} From 4f702842142acfe3d84c178e7528e193c1883777 Mon Sep 17 00:00:00 2001 From: Alexander Ding Date: Thu, 13 Oct 2022 18:20:18 +0000 Subject: [PATCH 2/2] refactor: move SMB specific test utils into smb_test.go --- integrationtests/smb_test.go | 115 +++++++++++++++++++++++++++++++++++ integrationtests/utils.go | 110 --------------------------------- 2 files changed, 115 insertions(+), 110 deletions(-) diff --git a/integrationtests/smb_test.go b/integrationtests/smb_test.go index 5b3c3fd6..885dc3de 100644 --- a/integrationtests/smb_test.go +++ b/integrationtests/smb_test.go @@ -3,8 +3,13 @@ package integrationtests import ( "context" "fmt" + "io/ioutil" + "math/rand" "os" + "os/exec" + "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -87,3 +92,113 @@ func TestSmb(t *testing.T) { err = writeReadFile(localPath) assert.NotNil(t, err) } + +const letterset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + +var seededRand = rand.New(rand.NewSource(time.Now().UnixNano())) + +func stringWithCharset(length int, charset string) string { + b := make([]byte, length) + for i := range b { + b[i] = charset[seededRand.Intn(len(charset))] + } + return string(b) +} + +// RandomString generates a random string with specified length +func randomString(length int) string { + return stringWithCharset(length, letterset) +} + +func setupUser(username, password string) error { + cmdLine := fmt.Sprintf(`$PWord = ConvertTo-SecureString $Env:password -AsPlainText -Force` + + `;New-Localuser -name $Env:username -accountneverexpires -password $PWord`) + cmd := exec.Command("powershell", "/c", cmdLine) + cmd.Env = append(os.Environ(), + fmt.Sprintf("username=%s", username), + fmt.Sprintf("password=%s", password)) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("setupUser failed: %v, output: %q", err, string(output)) + } + return nil +} + +func removeUser(t *testing.T, username string) { + cmdLine := fmt.Sprintf(`Remove-Localuser -name $Env:username`) + cmd := exec.Command("powershell", "/c", cmdLine) + cmd.Env = append(os.Environ(), + fmt.Sprintf("username=%s", username)) + if output, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("setupUser failed: %v, output: %q", err, string(output)) + } +} + +func setupSmbShare(shareName, localPath, username string) error { + if err := os.MkdirAll(localPath, 0755); err != nil { + return fmt.Errorf("setupSmbShare failed to create local path %q: %v", localPath, err) + } + cmdLine := fmt.Sprintf(`New-SMBShare -Name $Env:sharename -Path $Env:path -fullaccess $Env:username`) + cmd := exec.Command("powershell", "/c", cmdLine) + cmd.Env = append(os.Environ(), + fmt.Sprintf("sharename=%s", shareName), + fmt.Sprintf("path=%s", localPath), + fmt.Sprintf("username=%s", username)) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("setupSmbShare failed: %v, output: %q", err, string(output)) + } + + return nil +} + +func removeSmbShare(t *testing.T, shareName string) { + cmdLine := fmt.Sprintf(`Remove-SMBShare -Name $Env:sharename -Force`) + cmd := exec.Command("powershell", "/c", cmdLine) + cmd.Env = append(os.Environ(), + fmt.Sprintf("sharename=%s", shareName)) + if output, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("setupSmbShare failed: %v, output: %q", err, string(output)) + } + return +} + +func getSmbGlobalMapping(remotePath string) error { + // use PowerShell Environment Variables to store user input string to prevent command line injection + // https://docs.microsoft.com/en-us/powershell/module/microsoft.powershell.core/about/about_environment_variables?view=powershell-5.1 + cmdLine := fmt.Sprintf(`(Get-SmbGlobalMapping -RemotePath $Env:smbremotepath).Status`) + + cmd := exec.Command("powershell", "/c", cmdLine) + cmd.Env = append(os.Environ(), + fmt.Sprintf("smbremotepath=%s", remotePath)) + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("Get-SmbGlobalMapping failed: %v, output: %q", err, string(output)) + } + if !strings.Contains(string(output), "OK") { + return fmt.Errorf("Get-SmbGlobalMapping return status %q instead of OK", string(output)) + } + return nil +} + +func writeReadFile(path string) error { + fileName := path + "\\hello.txt" + f, err := os.Create(fileName) + if err != nil { + return fmt.Errorf("create file %q failed: %v", fileName, err) + } + defer f.Close() + fileContent := "Hello World" + if _, err = f.WriteString(fileContent); err != nil { + return fmt.Errorf("write to file %q failed: %v", fileName, err) + } + if err = f.Sync(); err != nil { + return fmt.Errorf("sync file %q failed: %v", fileName, err) + } + dat, err := ioutil.ReadFile(fileName) + if err != nil { + return fmt.Errorf("read file %q failed: %v", fileName, err) + } + if fileContent != string(dat) { + return fmt.Errorf("read content of file %q failed: expected %q, got %q", fileName, fileContent, string(dat)) + } + return nil +} diff --git a/integrationtests/utils.go b/integrationtests/utils.go index b0f17576..895ccb37 100644 --- a/integrationtests/utils.go +++ b/integrationtests/utils.go @@ -346,113 +346,3 @@ func volumeInit(volumeClient volume.Interface, t *testing.T) (*VirtualHardDisk, } return vhd, volumeID, vhdCleanup } - -const letterset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - -var seededRand = rand.New(rand.NewSource(time.Now().UnixNano())) - -func stringWithCharset(length int, charset string) string { - b := make([]byte, length) - for i := range b { - b[i] = charset[seededRand.Intn(len(charset))] - } - return string(b) -} - -// RandomString generates a random string with specified length -func randomString(length int) string { - return stringWithCharset(length, letterset) -} - -func setupUser(username, password string) error { - cmdLine := fmt.Sprintf(`$PWord = ConvertTo-SecureString $Env:password -AsPlainText -Force` + - `;New-Localuser -name $Env:username -accountneverexpires -password $PWord`) - cmd := exec.Command("powershell", "/c", cmdLine) - cmd.Env = append(os.Environ(), - fmt.Sprintf("username=%s", username), - fmt.Sprintf("password=%s", password)) - if output, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("setupUser failed: %v, output: %q", err, string(output)) - } - return nil -} - -func removeUser(t *testing.T, username string) { - cmdLine := fmt.Sprintf(`Remove-Localuser -name $Env:username`) - cmd := exec.Command("powershell", "/c", cmdLine) - cmd.Env = append(os.Environ(), - fmt.Sprintf("username=%s", username)) - if output, err := cmd.CombinedOutput(); err != nil { - t.Fatalf("setupUser failed: %v, output: %q", err, string(output)) - } -} - -func setupSmbShare(shareName, localPath, username string) error { - if err := os.MkdirAll(localPath, 0755); err != nil { - return fmt.Errorf("setupSmbShare failed to create local path %q: %v", localPath, err) - } - cmdLine := fmt.Sprintf(`New-SMBShare -Name $Env:sharename -Path $Env:path -fullaccess $Env:username`) - cmd := exec.Command("powershell", "/c", cmdLine) - cmd.Env = append(os.Environ(), - fmt.Sprintf("sharename=%s", shareName), - fmt.Sprintf("path=%s", localPath), - fmt.Sprintf("username=%s", username)) - if output, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("setupSmbShare failed: %v, output: %q", err, string(output)) - } - - return nil -} - -func removeSmbShare(t *testing.T, shareName string) { - cmdLine := fmt.Sprintf(`Remove-SMBShare -Name $Env:sharename -Force`) - cmd := exec.Command("powershell", "/c", cmdLine) - cmd.Env = append(os.Environ(), - fmt.Sprintf("sharename=%s", shareName)) - if output, err := cmd.CombinedOutput(); err != nil { - t.Fatalf("setupSmbShare failed: %v, output: %q", err, string(output)) - } - return -} - -func getSmbGlobalMapping(remotePath string) error { - // use PowerShell Environment Variables to store user input string to prevent command line injection - // https://docs.microsoft.com/en-us/powershell/module/microsoft.powershell.core/about/about_environment_variables?view=powershell-5.1 - cmdLine := fmt.Sprintf(`(Get-SmbGlobalMapping -RemotePath $Env:smbremotepath).Status`) - - cmd := exec.Command("powershell", "/c", cmdLine) - cmd.Env = append(os.Environ(), - fmt.Sprintf("smbremotepath=%s", remotePath)) - output, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("Get-SmbGlobalMapping failed: %v, output: %q", err, string(output)) - } - if !strings.Contains(string(output), "OK") { - return fmt.Errorf("Get-SmbGlobalMapping return status %q instead of OK", string(output)) - } - return nil -} - -func writeReadFile(path string) error { - fileName := path + "\\hello.txt" - f, err := os.Create(fileName) - if err != nil { - return fmt.Errorf("create file %q failed: %v", fileName, err) - } - defer f.Close() - fileContent := "Hello World" - if _, err = f.WriteString(fileContent); err != nil { - return fmt.Errorf("write to file %q failed: %v", fileName, err) - } - if err = f.Sync(); err != nil { - return fmt.Errorf("sync file %q failed: %v", fileName, err) - } - dat, err := ioutil.ReadFile(fileName) - if err != nil { - return fmt.Errorf("read file %q failed: %v", fileName, err) - } - if fileContent != string(dat) { - return fmt.Errorf("read content of file %q failed: expected %q, got %q", fileName, fileContent, string(dat)) - } - return nil -}