diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go index 8cbbb29a7c..112f280ec5 100644 --- a/src/database/sql/fakedb_test.go +++ b/src/database/sql/fakedb_test.go @@ -699,13 +699,25 @@ func (s *fakeStmt) NumInput() int { return s.placeholders } +// hook to simulate broken connections +var hookCommitBadConn func() bool + func (tx *fakeTx) Commit() error { tx.c.currTx = nil + if hookCommitBadConn != nil && hookCommitBadConn() { + return driver.ErrBadConn + } return nil } +// hook to simulate broken connections +var hookRollbackBadConn func() bool + func (tx *fakeTx) Rollback() error { tx.c.currTx = nil + if hookRollbackBadConn != nil && hookRollbackBadConn() { + return driver.ErrBadConn + } return nil } diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index aaa4ea28be..0120ae8abe 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -1103,12 +1103,12 @@ type Tx struct { var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back") -func (tx *Tx) close() { +func (tx *Tx) close(err error) { if tx.done { panic("double close") // internal error } tx.done = true - tx.db.putConn(tx.dc, nil) + tx.db.putConn(tx.dc, err) tx.dc = nil tx.txi = nil } @@ -1134,13 +1134,13 @@ func (tx *Tx) Commit() error { if tx.done { return ErrTxDone } - defer tx.close() tx.dc.Lock() err := tx.txi.Commit() tx.dc.Unlock() if err != driver.ErrBadConn { tx.closePrepared() } + tx.close(err) return err } @@ -1149,13 +1149,13 @@ func (tx *Tx) Rollback() error { if tx.done { return ErrTxDone } - defer tx.close() tx.dc.Lock() err := tx.txi.Rollback() tx.dc.Unlock() if err != driver.ErrBadConn { tx.closePrepared() } + tx.close(err) return err } diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 432a641b85..b4135a3078 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -1564,6 +1564,77 @@ func TestErrBadConnReconnect(t *testing.T) { simulateBadConn("stmt.Query exec", &hookQueryBadConn, stmtQuery) } +// golang.org/issue/11264 +func TestTxEndBadConn(t *testing.T) { + db := newTestDB(t, "foo") + defer closeDB(t, db) + db.SetMaxIdleConns(0) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + db.SetMaxIdleConns(1) + + simulateBadConn := func(name string, hook *func() bool, op func() error) { + broken := false + numOpen := db.numOpen + + *hook = func() bool { + if !broken { + broken = true + } + return broken + } + + if err := op(); err != driver.ErrBadConn { + t.Errorf(name+": %v", err) + return + } + + if !broken { + t.Error(name + ": Failed to simulate broken connection") + } + *hook = nil + + if numOpen != db.numOpen { + t.Errorf(name+": leaked %d connection(s)!", db.numOpen-numOpen) + } + } + + // db.Exec + dbExec := func(endTx func(tx *Tx) error) func() error { + return func() error { + tx, err := db.Begin() + if err != nil { + return err + } + _, err = tx.Exec("INSERT|t1|name=?,age=?,dead=?", "Gordon", 3, true) + if err != nil { + return err + } + return endTx(tx) + } + } + simulateBadConn("db.Tx.Exec commit", &hookCommitBadConn, dbExec((*Tx).Commit)) + simulateBadConn("db.Tx.Exec rollback", &hookRollbackBadConn, dbExec((*Tx).Rollback)) + + // db.Query + dbQuery := func(endTx func(tx *Tx) error) func() error { + return func() error { + tx, err := db.Begin() + if err != nil { + return err + } + rows, err := tx.Query("SELECT|t1|age,name|") + if err == nil { + err = rows.Close() + } else { + return err + } + return endTx(tx) + } + } + simulateBadConn("db.Tx.Query commit", &hookCommitBadConn, dbQuery((*Tx).Commit)) + simulateBadConn("db.Tx.Query rollback", &hookRollbackBadConn, dbQuery((*Tx).Rollback)) +} + type concurrentTest interface { init(t testing.TB, db *DB) finish(t testing.TB)