Skip to content

Commit f008d34

Browse files
committed
refactoring
1 parent 9790468 commit f008d34

File tree

6 files changed

+190
-162
lines changed

6 files changed

+190
-162
lines changed

README.md

+4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ go get
1818
go test
1919
```
2020

21+
## Notice for Windows
22+
23+
The `make` step can be for example `mingw32-make.exe -B static`
24+
2125
## Usage
2226

2327
- import "github.com/deemru/go-msspi"

msspi.go

+178-40
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package msspi
33
/*
44
#cgo windows LDFLAGS: -Lmsspi/build_linux -lmsspi -lstdc++ -lcrypt32
55
#cgo linux LDFLAGS: -Lmsspi/build_linux -lmsspi-capix -lstdc++ -ldl
6+
#define NO_MSSPI_CERT
67
#include "msspi/src/msspi.h"
78
extern int cgo_msspi_read( void * goPointer, void * buf, int len );
89
extern int cgo_msspi_write( void * goPointer, void * buf, int len );
@@ -13,46 +14,45 @@ MSSPI_HANDLE cgo_msspi_open( void * goPointer ) {
1314
import "C"
1415

1516
import (
16-
"crypto/tls"
17-
"io"
17+
"errors"
1818
"net"
1919
"runtime"
20-
"sync"
2120
"unsafe"
2221

23-
"github.com/mattn/go-pointer"
22+
"go-pointer"
2423
)
2524

25+
const ByDefault = true
26+
27+
//const ByDefault = false
28+
2629
// Conn with MSSPI
27-
type Conn struct {
28-
conn net.Conn
29-
tls *tls.Conn
30-
// MSSPI
30+
type Handler struct {
31+
conn *net.Conn
3132
handle C.MSSPI_HANDLE
3233
rerr error
3334
werr error
3435
isClient bool
3536
goPointer unsafe.Pointer
36-
mu sync.Mutex
3737
}
3838

39-
func (c *Conn) error() (err error) {
40-
state := C.msspi_state(c.handle)
41-
if state&C.MSSPI_ERROR != 0 || state&(C.MSSPI_SENT_SHUTDOWN|C.MSSPI_RECEIVED_SHUTDOWN) != 0 {
42-
err = io.EOF
43-
}
44-
return nil
39+
func (c *Handler) Read(b []byte) (int, error) {
40+
n := (int)(C.msspi_read(c.handle, unsafe.Pointer(&b[0]), C.int(len(b))))
41+
return n, c.rerr
4542
}
4643

47-
func (c *Conn) Read(b []byte) (int, error) {
48-
n := (int)(C.msspi_read(c.handle, unsafe.Pointer(&b[0]), C.int(len(b))))
49-
if n > 0 {
50-
return n, nil
44+
func (c *Handler) State(val int) bool {
45+
state := C.msspi_state(c.handle)
46+
if val == 1 {
47+
return state&C.MSSPI_ERROR != 0
48+
}
49+
if val == 2 {
50+
return state&(C.MSSPI_SENT_SHUTDOWN|C.MSSPI_RECEIVED_SHUTDOWN) != 0
5151
}
52-
return 0, c.error()
52+
return false
5353
}
5454

55-
func (c *Conn) Write(b []byte) (int, error) {
55+
func (c *Handler) Write(b []byte) (int, error) {
5656
len := len(b)
5757
sent := 0
5858
for len > 0 {
@@ -63,47 +63,185 @@ func (c *Conn) Write(b []byte) (int, error) {
6363
continue
6464
}
6565

66-
return sent, c.error()
66+
break
67+
}
68+
69+
return sent, c.werr
70+
}
71+
72+
func (h *Handler) VersionTLS() uint16 {
73+
info := C.msspi_get_cipherinfo(h.handle)
74+
return uint16(info.dwProtocol)
75+
}
76+
77+
func (h *Handler) CipherSuite() uint16 {
78+
info := C.msspi_get_cipherinfo(h.handle)
79+
return uint16(info.dwCipherSuite)
80+
}
81+
82+
func (c *Handler) PeerCertificates() (certificates [][]byte) {
83+
count := C.size_t(0)
84+
85+
if 0 == C.msspi_get_peercerts(c.handle, nil, nil, &count) {
86+
return nil
87+
}
88+
89+
gocount := int(count)
90+
bufs := make([]*C.char, count)
91+
lens := make([]C.int, count)
92+
93+
if 0 == C.msspi_get_peercerts(c.handle, &bufs[0], &lens[0], &count) {
94+
return nil
95+
}
96+
97+
for i := 0; i < gocount; i++ {
98+
certificates = append(certificates, C.GoBytes(unsafe.Pointer(bufs[i]), lens[i]))
99+
}
100+
101+
return certificates
102+
}
103+
104+
func (c *Handler) VerifiedChains() (certificates [][]byte) {
105+
if C.MSSPI_VERIFY_OK != C.msspi_verify(c.handle) {
106+
return nil
107+
}
108+
109+
count := C.size_t(0)
110+
111+
if 0 == C.msspi_get_peerchain(c.handle, 0, nil, nil, &count) {
112+
return nil
113+
}
114+
115+
gocount := int(count)
116+
bufs := make([]*C.char, count)
117+
lens := make([]C.int, count)
118+
119+
if 0 == C.msspi_get_peerchain(c.handle, 0, &bufs[0], &lens[0], &count) {
120+
return nil
121+
}
122+
123+
for i := 0; i < gocount; i++ {
124+
certificates = append(certificates, C.GoBytes(unsafe.Pointer(bufs[i]), lens[i]))
125+
}
126+
127+
return certificates
128+
}
129+
130+
func (c *Handler) Handshake() error {
131+
n := -1
132+
for n < 0 {
133+
if c.isClient {
134+
n = (int)(C.msspi_connect(c.handle))
135+
} else {
136+
n = (int)(C.msspi_accept(c.handle))
137+
}
138+
}
139+
140+
if n == 1 {
141+
return nil
142+
}
143+
144+
if c.rerr != nil {
145+
return c.rerr
67146
}
68-
return sent, nil
147+
if c.werr != nil {
148+
return c.werr
149+
}
150+
return net.ErrClosed
69151
}
70152

71153
// Close with MSSPI
72-
func (c *Conn) Close() (err error) {
154+
func (c *Handler) Close() (err error) {
73155
if c.handle != nil {
74156
C.msspi_shutdown(c.handle)
157+
}
158+
159+
if c.goPointer != nil {
75160
pointer.Unref(c.goPointer)
161+
c.goPointer = nil
76162
}
77-
return c.conn.Close()
163+
164+
return (*c.conn).Close()
165+
}
166+
167+
// Shutdown with MSSPI
168+
func (c *Handler) Shutdown() (err error) {
169+
if c.handle != nil {
170+
C.msspi_shutdown(c.handle)
171+
}
172+
173+
if !c.State(1) && c.State(2) {
174+
return nil
175+
}
176+
177+
return net.ErrClosed
78178
}
79179

80180
// Finalizer with MSSPI
81-
func (c *Conn) Finalizer() {
181+
func (c *Handler) Finalizer() {
82182
if c.handle != nil {
83183
C.msspi_close(c.handle)
184+
c.handle = nil
84185
}
85186
}
86187

87188
// Client with MSSPI
88-
func Client(conn net.Conn, config *tls.Config) *Conn {
89-
c := &Conn{conn: conn, isClient: true}
90-
c.tls = tls.Client(conn, config)
189+
func Client(conn *net.Conn, CertificateBytes [][]byte, hostname string) (c *Handler, err error) {
190+
c = &Handler{conn: conn, isClient: true}
191+
runtime.SetFinalizer(c, (*Handler).Finalizer)
91192

92193
c.goPointer = pointer.Save(c)
93194
c.handle = C.cgo_msspi_open(c.goPointer)
94195

95-
if c.handle != nil {
96-
C.msspi_set_client(c.handle)
97-
runtime.SetFinalizer(c, (*Conn).Finalizer)
98-
} else {
99-
pointer.Unref(c.goPointer)
196+
if c.handle == nil {
197+
return nil, errors.New("Client msspi_open() failed")
100198
}
101-
return c
199+
200+
C.msspi_set_client(c.handle)
201+
202+
if hostname != "" {
203+
hostnameBytes := []byte(hostname)
204+
hostnameBytes = append(hostnameBytes, 0)
205+
C.msspi_set_hostname(c.handle, (*C.char)(unsafe.Pointer(&hostnameBytes[0])))
206+
}
207+
208+
for _, cbs := range CertificateBytes {
209+
ok := int(C.msspi_add_mycert(c.handle, (*C.char)(unsafe.Pointer(&cbs[0])), C.int(len(cbs))))
210+
if ok != 1 {
211+
return nil, errors.New("Client msspi_add_mycert() failed")
212+
}
213+
break
214+
}
215+
216+
return c, nil
102217
}
103218

104-
// Server with MSSPI (not implemented)
105-
func Server(conn net.Conn, config *tls.Config) *Conn {
106-
c := &Conn{conn: conn, isClient: false}
107-
c.tls = tls.Server(conn, config)
108-
return nil
219+
// Server with MSSPI
220+
func Server(conn *net.Conn, CertificateBytes [][]byte, clientAuth bool) (c *Handler, err error) {
221+
c = &Handler{conn: conn, isClient: false}
222+
runtime.SetFinalizer(c, (*Handler).Finalizer)
223+
224+
c.goPointer = pointer.Save(c)
225+
c.handle = C.cgo_msspi_open(c.goPointer)
226+
227+
if c.handle == nil {
228+
return nil, errors.New("Server msspi_open() failed")
229+
}
230+
231+
if clientAuth {
232+
C.msspi_set_peerauth(c.handle, 1)
233+
}
234+
235+
srv := []byte("srv")
236+
C.msspi_set_hostname(c.handle, (*C.char)(unsafe.Pointer(&srv[0])))
237+
238+
for _, cbs := range CertificateBytes {
239+
ok := int(C.msspi_add_mycert(c.handle, (*C.char)(unsafe.Pointer(&cbs[0])), C.int(len(cbs))))
240+
if ok != 1 {
241+
return nil, errors.New("Server msspi_add_mycert() failed")
242+
}
243+
break
244+
}
245+
246+
return c, nil
109247
}

msspi_cgo.go

+7-32
Original file line numberDiff line numberDiff line change
@@ -8,70 +8,45 @@ int cgo_msspi_write( void * goPointer, void * buf, int len );
88
import "C"
99

1010
import (
11-
"io"
12-
"os"
1311
"unsafe"
1412

15-
"github.com/mattn/go-pointer"
13+
"go-pointer"
1614
)
1715

1816
//export cgo_msspi_read
1917
func cgo_msspi_read(goPointer, buffer unsafe.Pointer, length C.int) C.int {
20-
c := pointer.Restore(goPointer).(*Conn)
18+
c := pointer.Restore(goPointer).(*Handler)
2119
if c == nil {
2220
return 0
2321
}
2422

2523
b := make([]byte, length)
26-
n, err := c.conn.Read(b)
27-
28-
// Read can be made to time out and return a net.Error with Timeout() == true
29-
// after a fixed time limit; see SetDeadline and SetReadDeadline.
24+
n, err := (*c.conn).Read(b)
25+
c.rerr = err
3026

3127
if n > 0 {
32-
c.rerr = nil
3328
C.memcpy(buffer, unsafe.Pointer(&b[0]), C.size_t(n))
3429
b = C.GoBytes(buffer, C.int(n))
3530
return C.int(n)
3631
}
3732

38-
// RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify
39-
// is an error, but popular web sites seem to do this, so we accept it
40-
// if and only if at the record boundary.
41-
if err == io.ErrUnexpectedEOF {
42-
err = io.EOF
43-
}
44-
c.rerr = err
45-
if os.IsTimeout(err) {
46-
return -1
47-
}
4833
return 0
4934
}
5035

5136
//export cgo_msspi_write
5237
func cgo_msspi_write(goPointer, buffer unsafe.Pointer, length C.int) C.int {
53-
c := pointer.Restore(goPointer).(*Conn)
38+
c := pointer.Restore(goPointer).(*Handler)
5439
if c == nil {
5540
return 0
5641
}
5742

5843
b := C.GoBytes(buffer, length)
59-
n, err := c.conn.Write(b)
44+
n, err := (*c.conn).Write(b)
45+
c.werr = err
6046

6147
if n > 0 {
62-
c.werr = nil
6348
return C.int(n)
6449
}
6550

66-
// RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify
67-
// is an error, but popular web sites seem to do this, so we accept it
68-
// if and only if at the record boundary.
69-
if err == io.ErrUnexpectedEOF {
70-
err = io.EOF
71-
}
72-
c.werr = err
73-
if os.IsTimeout(err) {
74-
return -1
75-
}
7651
return 0
7752
}

0 commit comments

Comments
 (0)