diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index a549e859a4..a02aa35b7b 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -330,7 +330,7 @@ type driverConn struct { ci driver.Conn closed bool finalClosed bool // ci.Close has been called - openStmt map[driver.Stmt]bool + openStmt map[*driverStmt]bool // guarded by db.mu inUse bool @@ -342,10 +342,10 @@ func (dc *driverConn) releaseConn(err error) { dc.db.putConn(dc, err) } -func (dc *driverConn) removeOpenStmt(si driver.Stmt) { +func (dc *driverConn) removeOpenStmt(ds *driverStmt) { dc.Lock() defer dc.Unlock() - delete(dc.openStmt, si) + delete(dc.openStmt, ds) } func (dc *driverConn) expired(timeout time.Duration) bool { @@ -355,28 +355,23 @@ func (dc *driverConn) expired(timeout time.Duration) bool { return dc.createdAt.Add(timeout).Before(nowFunc()) } -func (dc *driverConn) prepareLocked(ctx context.Context, query string) (driver.Stmt, error) { +func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverStmt, error) { si, err := ctxDriverPrepare(ctx, dc.ci, query) - if err == nil { - // Track each driverConn's open statements, so we can close them - // before closing the conn. - // - // TODO(bradfitz): let drivers opt out of caring about - // stmt closes if the conn is about to close anyway? For now - // do the safe thing, in case stmts need to be closed. - // - // TODO(bradfitz): after Go 1.2, closing driver.Stmts - // should be moved to driverStmt, using unique - // *driverStmts everywhere (including from - // *Stmt.connStmt, instead of returning a - // driver.Stmt), using driverStmt as a pointer - // everywhere, and making it a finalCloser. - if dc.openStmt == nil { - dc.openStmt = make(map[driver.Stmt]bool) - } - dc.openStmt[si] = true + if err != nil { + return nil, err } - return si, err + + // Track each driverConn's open statements, so we can close them + // before closing the conn. + // + // Wrap all driver.Stmt is *driverStmt to ensure they are only closed once. + if dc.openStmt == nil { + dc.openStmt = make(map[*driverStmt]bool) + } + ds := &driverStmt{Locker: dc, si: si} + dc.openStmt[ds] = true + + return ds, nil } // the dc.db's Mutex is held. @@ -409,17 +404,24 @@ func (dc *driverConn) Close() error { func (dc *driverConn) finalClose() error { var err error - withLock(dc, func() { - defer func() { // In case si.Close panics. - dc.openStmt = nil - dc.finalClosed = true - err = dc.ci.Close() - dc.ci = nil - }() - for si := range dc.openStmt { - si.Close() + // Each *driverStmt has a lock to the dc. Copy the list out of the dc + // before calling close on each stmt. + var openStmt []*driverStmt + withLock(dc, func() { + openStmt = make([]*driverStmt, 0, len(dc.openStmt)) + for ds := range dc.openStmt { + openStmt = append(openStmt, ds) } + dc.openStmt = nil + }) + for _, ds := range openStmt { + ds.Close() + } + withLock(dc, func() { + dc.finalClosed = true + err = dc.ci.Close() + dc.ci = nil }) dc.db.mu.Lock() @@ -437,12 +439,21 @@ func (dc *driverConn) finalClose() error { type driverStmt struct { sync.Locker // the *driverConn si driver.Stmt + closed bool + closeErr error // return value of previous Close call } +// Close ensures dirver.Stmt is only closed once any always returns the same +// result. func (ds *driverStmt) Close() error { ds.Lock() defer ds.Unlock() - return ds.si.Close() + if ds.closed { + return ds.closeErr + } + ds.closed = true + ds.closeErr = ds.si.Close() + return ds.closeErr } // depSet is a finalCloser's outstanding dependencies @@ -933,21 +944,22 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn // putConnHook is a hook for testing. var putConnHook func(*DB, *driverConn) -// noteUnusedDriverStatement notes that si is no longer used and should +// noteUnusedDriverStatement notes that ds is no longer used and should // be closed whenever possible (when c is next not in use), unless c is // already closed. -func (db *DB) noteUnusedDriverStatement(c *driverConn, si driver.Stmt) { +func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) { db.mu.Lock() defer db.mu.Unlock() if c.inUse { c.onPut = append(c.onPut, func() { - si.Close() + ds.Close() }) } else { c.Lock() - defer c.Unlock() - if !c.finalClosed { - si.Close() + fc := c.finalClosed + c.Unlock() + if !fc { + ds.Close() } } } @@ -1084,9 +1096,9 @@ func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrat if err != nil { return nil, err } - var si driver.Stmt + var ds *driverStmt withLock(dc, func() { - si, err = dc.prepareLocked(ctx, query) + ds, err = dc.prepareLocked(ctx, query) }) if err != nil { db.putConn(dc, err) @@ -1095,7 +1107,7 @@ func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrat stmt := &Stmt{ db: db, query: query, - css: []connStmt{{dc, si}}, + css: []connStmt{{dc, ds}}, lastNumClosed: atomic.LoadUint64(&db.numClosed), } db.addDep(stmt, stmt) @@ -1160,8 +1172,9 @@ func (db *DB) exec(ctx context.Context, query string, args []interface{}, strate if err != nil { return nil, err } - defer withLock(dc, func() { si.Close() }) - return resultFromStatement(ctx, driverStmt{dc, si}, args...) + ds := &driverStmt{Locker: dc, si: si} + defer ds.Close() + return resultFromStatement(ctx, ds, args...) } // QueryContext executes a query that returns rows, typically a SELECT. @@ -1236,12 +1249,10 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er return nil, err } - ds := driverStmt{dc, si} + ds := &driverStmt{Locker: dc, si: si} rowsi, err := rowsiFromStatement(ctx, ds, args...) if err != nil { - withLock(dc, func() { - si.Close() - }) + ds.Close() releaseConn(err) return nil, err } @@ -1252,7 +1263,7 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er dc: dc, releaseConn: releaseConn, rowsi: rowsi, - closeStmt: si, + closeStmt: ds, } rows.initContextClose(ctx) return rows, nil @@ -1476,7 +1487,7 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { stmt := &Stmt{ db: tx.db, tx: tx, - txsi: &driverStmt{ + txds: &driverStmt{ Locker: dc, si: si, }, @@ -1530,7 +1541,7 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { txs := &Stmt{ db: tx.db, tx: tx, - txsi: &driverStmt{ + txds: &driverStmt{ Locker: dc, si: si, }, @@ -1591,9 +1602,10 @@ func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{} if err != nil { return nil, err } - defer withLock(dc, func() { si.Close() }) + ds := &driverStmt{Locker: dc, si: si} + defer ds.Close() - return resultFromStatement(ctx, driverStmt{dc, si}, args...) + return resultFromStatement(ctx, ds, args...) } // Exec executes a query that doesn't return rows. @@ -1635,7 +1647,7 @@ func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { // connStmt is a prepared statement on a particular connection. type connStmt struct { dc *driverConn - si driver.Stmt + ds *driverStmt } // Stmt is a prepared statement. @@ -1650,7 +1662,7 @@ type Stmt struct { // If in a transaction, else both nil: tx *Tx - txsi *driverStmt + txds *driverStmt mu sync.Mutex // protects the rest of the fields closed bool @@ -1674,7 +1686,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er var res Result for i := 0; i < maxBadConnRetries; i++ { - dc, releaseConn, si, err := s.connStmt(ctx) + _, releaseConn, ds, err := s.connStmt(ctx) if err != nil { if err == driver.ErrBadConn { continue @@ -1682,7 +1694,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er return nil, err } - res, err = resultFromStatement(ctx, driverStmt{dc, si}, args...) + res, err = resultFromStatement(ctx, ds, args...) releaseConn(err) if err != driver.ErrBadConn { return res, err @@ -1697,13 +1709,13 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { return s.ExecContext(context.Background(), args...) } -func driverNumInput(ds driverStmt) int { +func driverNumInput(ds *driverStmt) int { ds.Lock() defer ds.Unlock() // in case NumInput panics return ds.si.NumInput() } -func resultFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (Result, error) { +func resultFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (Result, error) { want := driverNumInput(ds) // -1 means the driver doesn't know how to count the number of @@ -1713,7 +1725,7 @@ func resultFromStatement(ctx context.Context, ds driverStmt, args ...interface{} return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args)) } - dargs, err := driverArgs(&ds, args) + dargs, err := driverArgs(ds, args) if err != nil { return nil, err } @@ -1757,7 +1769,7 @@ func (s *Stmt) removeClosedStmtLocked() { // connStmt returns a free driver connection on which to execute the // statement, a function to call to release the connection, and a // statement bound to that connection. -func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) { +func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), ds *driverStmt, err error) { if err = s.stickyErr; err != nil { return } @@ -1777,7 +1789,7 @@ func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(e return } releaseConn = func(error) {} - return ci, releaseConn, s.txsi.si, nil + return ci, releaseConn, s.txds, nil } s.removeClosedStmtLocked() @@ -1792,25 +1804,25 @@ func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(e for _, v := range s.css { if v.dc == dc { s.mu.Unlock() - return dc, dc.releaseConn, v.si, nil + return dc, dc.releaseConn, v.ds, nil } } s.mu.Unlock() // No luck; we need to prepare the statement on this connection withLock(dc, func() { - si, err = dc.prepareLocked(ctx, s.query) + ds, err = dc.prepareLocked(ctx, s.query) }) if err != nil { s.db.putConn(dc, err) return nil, nil, nil, err } s.mu.Lock() - cs := connStmt{dc, si} + cs := connStmt{dc, ds} s.css = append(s.css, cs) s.mu.Unlock() - return dc, dc.releaseConn, si, nil + return dc, dc.releaseConn, ds, nil } // QueryContext executes a prepared query statement with the given arguments @@ -1821,7 +1833,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er var rowsi driver.Rows for i := 0; i < maxBadConnRetries; i++ { - dc, releaseConn, si, err := s.connStmt(ctx) + dc, releaseConn, ds, err := s.connStmt(ctx) if err != nil { if err == driver.ErrBadConn { continue @@ -1829,7 +1841,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er return nil, err } - rowsi, err = rowsiFromStatement(ctx, driverStmt{dc, si}, args...) + rowsi, err = rowsiFromStatement(ctx, ds, args...) if err == nil { // Note: ownership of ci passes to the *Rows, to be freed // with releaseConn. @@ -1861,7 +1873,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { return s.QueryContext(context.Background(), args...) } -func rowsiFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (driver.Rows, error) { +func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (driver.Rows, error) { var want int withLock(ds, func() { want = ds.si.NumInput() @@ -1874,7 +1886,7 @@ func rowsiFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args)) } - dargs, err := driverArgs(&ds, args) + dargs, err := driverArgs(ds, args) if err != nil { return nil, err } @@ -1937,12 +1949,11 @@ func (s *Stmt) Close() error { return nil } s.closed = true + s.mu.Unlock() if s.tx != nil { - defer s.mu.Unlock() - return s.txsi.Close() + return s.txds.Close() } - s.mu.Unlock() return s.db.removeDep(s, s) } @@ -1952,8 +1963,8 @@ func (s *Stmt) finalClose() error { defer s.mu.Unlock() if s.css != nil { for _, v := range s.css { - s.db.noteUnusedDriverStatement(v.dc, v.si) - v.dc.removeOpenStmt(v.si) + s.db.noteUnusedDriverStatement(v.dc, v.ds) + v.dc.removeOpenStmt(v.ds) } s.css = nil } @@ -1985,7 +1996,7 @@ type Rows struct { ctxClose chan struct{} // closed when Rows is closed, may be null. lastcols []driver.Value lasterr error // non-nil only if closed is true - closeStmt driver.Stmt // if non-nil, statement to Close on close + closeStmt *driverStmt // if non-nil, statement to Close on close } func (rs *Rows) initContextClose(ctx context.Context) { diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index ea86264ae6..c46aaf60f8 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -672,7 +672,7 @@ func TestStatementClose(t *testing.T) { msg string }{ {&Stmt{stickyErr: want}, "stickyErr not propagated"}, - {&Stmt{tx: &Tx{}, txsi: &driverStmt{&sync.Mutex{}, stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"}, + {&Stmt{tx: &Tx{}, txds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"}, } for _, test := range tests { if err := test.stmt.Close(); err != want {