Skip to content

Commit 573423d

Browse files
authored
Fix: Connection not closed when database name is incorrect #173 fix (#224)
* connection not closed when database name is incorrect #173 fix * test for leaked connections (connection not closed when database name is incorrect #173 * Checking the number of open connections from local_net_address only * using sql.NullString for localNetAddr * handling local_net_address==NULL correctly
1 parent 02deabf commit 573423d

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

tds.go

+1
Original file line numberDiff line numberDiff line change
@@ -1382,6 +1382,7 @@ initiate_connection:
13821382
if token.isError() {
13831383
tokenErr := token.getError()
13841384
tokenErr.Message = "login error: " + tokenErr.Message
1385+
conn.Close()
13851386
return nil, tokenErr
13861387
}
13871388
case error:

tds_test.go

+66
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,72 @@ func TestBadCredentials(t *testing.T) {
694694
_ = testConnectionBad(t, params.URL().String())
695695
}
696696

697+
func TestLeakedConnections(t *testing.T) {
698+
goodParams := testConnParams(t)
699+
badParams := testConnParams(t)
700+
badParams.Database = "unknown_db"
701+
702+
// Connecting with good credentials should not fail
703+
goodConn, err := sql.Open("sqlserver", goodParams.URL().String())
704+
if err != nil {
705+
t.Fatal("Open connection failed:", err.Error())
706+
}
707+
err = goodConn.Ping()
708+
if err != nil {
709+
t.Fatal("Ping with good credentials should not fail, but got error:", err.Error())
710+
}
711+
712+
var localNetAddr sql.NullString
713+
err = goodConn.QueryRow("SELECT local_net_address FROM sys.dm_exec_connections WHERE session_id=@@SPID").Scan(&localNetAddr)
714+
if err != nil {
715+
t.Fatal("cannot scan local_net_address value", err)
716+
}
717+
718+
// Remember the number of open connections from local_net_address, excluding the current one
719+
// NULL value is possible, particularly for non-tcp local connections
720+
var openConnections int
721+
err = goodConn.QueryRow(`
722+
SELECT COUNT(*) AS openConnections
723+
FROM sys.dm_exec_connections
724+
WHERE session_id != @@SPID
725+
AND ((@p1 IS NULL AND local_net_address IS NULL)
726+
OR local_net_address = @p1)`,
727+
localNetAddr).Scan(&openConnections)
728+
if err != nil {
729+
t.Fatal("cannot scan value", err)
730+
}
731+
732+
// Open 10 connections to the unknown database, all should be closed immediately
733+
for i := 0; i < 10; i++ {
734+
conn, err := sql.Open("sqlserver", badParams.URL().String())
735+
if err != nil {
736+
// should not fail here
737+
t.Fatal("sql.Open failed:", err.Error())
738+
}
739+
err = conn.Ping()
740+
if err == nil {
741+
t.Fatalf("Pinging %s should fail, but it succeeded", badParams.Database)
742+
}
743+
conn.Close() // force close the connection
744+
}
745+
746+
// Check if the number of open connections is the same as before
747+
var newOpenConnections int
748+
err = goodConn.QueryRow(`
749+
SELECT COUNT(*) AS openConnections
750+
FROM sys.dm_exec_connections
751+
WHERE session_id != @@SPID
752+
AND ((@p1 IS NULL AND local_net_address IS NULL)
753+
OR local_net_address = @p1)`,
754+
localNetAddr).Scan(&newOpenConnections)
755+
if err != nil {
756+
t.Fatal("cannot scan value", err)
757+
}
758+
if openConnections != newOpenConnections {
759+
t.Fatalf("Number of open connections should be the same as before, %d leaked connections found", newOpenConnections-openConnections)
760+
}
761+
}
762+
697763
func TestBadHost(t *testing.T) {
698764
params := testConnParams(t)
699765
params.Host = "badhost"

0 commit comments

Comments
 (0)