diff --git a/src/database/sql/driver/driver.go b/src/database/sql/driver/driver.go index 316e7cea37..a2b844d71f 100644 --- a/src/database/sql/driver/driver.go +++ b/src/database/sql/driver/driver.go @@ -255,15 +255,23 @@ type ConnBeginTx interface { // SessionResetter may be implemented by Conn to allow drivers to reset the // session state associated with the connection and to signal a bad connection. type SessionResetter interface { - // ResetSession is called while a connection is in the connection - // pool. No queries will run on this connection until this method returns. - // - // If the connection is bad this should return driver.ErrBadConn to prevent - // the connection from being returned to the connection pool. Any other - // error will be discarded. + // ResetSession is called prior to executing a query on the connection + // if the connection has been used before. If the driver returns ErrBadConn + // the connection is discarded. ResetSession(ctx context.Context) error } +// ConnectionValidator may be implemented by Conn to allow drivers to +// signal if a connection is valid or if it should be discarded. +// +// If implemented, drivers may return the underlying error from queries, +// even if the connection should be discarded by the connection pool. +type ConnectionValidator interface { + // ValidConnection is called prior to placing the connection into the + // connection pool. The connection will be discarded if false is returned. + ValidConnection() bool +} + // Result is the result of a query execution. type Result interface { // LastInsertId returns the database's auto-generated ID diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go index a0028be0e5..73dab101b7 100644 --- a/src/database/sql/fakedb_test.go +++ b/src/database/sql/fakedb_test.go @@ -396,6 +396,12 @@ func (c *fakeConn) ResetSession(ctx context.Context) error { return nil } +var _ driver.ConnectionValidator = (*fakeConn)(nil) + +func (c *fakeConn) ValidConnection() bool { + return !c.isBad() +} + func (c *fakeConn) Close() (err error) { drv := fdriver.(*fakeDriver) defer func() { diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 550b58753f..1bf3731b00 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -421,7 +421,6 @@ type DB struct { // It is closed during db.Close(). The close tells the connectionOpener // goroutine to exit. openerCh chan struct{} - resetterCh chan *driverConn closed bool dep map[finalCloser]depSet lastPut map[*driverConn]string // stacktrace of last conn's put; debug only @@ -460,10 +459,10 @@ type driverConn struct { sync.Mutex // guards following ci driver.Conn + needReset bool // The connection session should be reset before use if true. closed bool finalClosed bool // ci.Close has been called openStmt map[*driverStmt]bool - lastErr error // lastError captures the result of the session resetter. // guarded by db.mu inUse bool @@ -489,6 +488,36 @@ func (dc *driverConn) expired(timeout time.Duration) bool { return dc.createdAt.Add(timeout).Before(nowFunc()) } +// resetSession checks if the driver connection needs the +// session to be reset and if required, resets it. +func (dc *driverConn) resetSession(ctx context.Context) error { + dc.Lock() + defer dc.Unlock() + + if !dc.needReset { + return nil + } + if cr, ok := dc.ci.(driver.SessionResetter); ok { + return cr.ResetSession(ctx) + } + return nil +} + +// validateConnection checks if the connection is valid and can +// still be used. It also marks the session for reset if required. +func (dc *driverConn) validateConnection(needsReset bool) bool { + dc.Lock() + defer dc.Unlock() + + if needsReset { + dc.needReset = true + } + if cv, ok := dc.ci.(driver.ConnectionValidator); ok { + return cv.ValidConnection() + } + return true +} + // prepareLocked prepares the query on dc. When cg == nil the dc must keep track of // the prepared statements in a pool. func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, query string) (*driverStmt, error) { @@ -514,19 +543,6 @@ func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, que return ds, nil } -// resetSession resets the connection session and sets the lastErr -// that is checked before returning the connection to another query. -// -// resetSession assumes that the embedded mutex is locked when the connection -// was returned to the pool. This unlocks the mutex. -func (dc *driverConn) resetSession(ctx context.Context) { - defer dc.Unlock() // In case of panic. - if dc.closed { // Check if the database has been closed. - return - } - dc.lastErr = dc.ci.(driver.SessionResetter).ResetSession(ctx) -} - // the dc.db's Mutex is held. func (dc *driverConn) closeDBLocked() func() error { dc.Lock() @@ -716,14 +732,12 @@ func OpenDB(c driver.Connector) *DB { db := &DB{ connector: c, openerCh: make(chan struct{}, connectionRequestQueueSize), - resetterCh: make(chan *driverConn, 50), lastPut: make(map[*driverConn]string), connRequests: make(map[uint64]chan connRequest), stop: cancel, } go db.connectionOpener(ctx) - go db.connectionResetter(ctx) return db } @@ -1118,23 +1132,6 @@ func (db *DB) connectionOpener(ctx context.Context) { } } -// connectionResetter runs in a separate goroutine to reset connections async -// to exported API. -func (db *DB) connectionResetter(ctx context.Context) { - for { - select { - case <-ctx.Done(): - close(db.resetterCh) - for dc := range db.resetterCh { - dc.Unlock() - } - return - case dc := <-db.resetterCh: - dc.resetSession(ctx) - } - } -} - // Open one new connection func (db *DB) openNewConnection(ctx context.Context) { // maybeOpenNewConnctions has already executed db.numOpen++ before it sent @@ -1216,14 +1213,13 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn conn.Close() return nil, driver.ErrBadConn } - // Lock around reading lastErr to ensure the session resetter finished. - conn.Lock() - err := conn.lastErr - conn.Unlock() - if err == driver.ErrBadConn { + + // Reset the session if required. + if err := conn.resetSession(ctx); err == driver.ErrBadConn { conn.Close() return nil, driver.ErrBadConn } + return conn, nil } @@ -1272,11 +1268,9 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn if ret.conn == nil { return nil, ret.err } - // Lock around reading lastErr to ensure the session resetter finished. - ret.conn.Lock() - err := ret.conn.lastErr - ret.conn.Unlock() - if err == driver.ErrBadConn { + + // Reset the session if required. + if err := ret.conn.resetSession(ctx); err == driver.ErrBadConn { ret.conn.Close() return nil, driver.ErrBadConn } @@ -1337,6 +1331,11 @@ const debugGetPut = false // putConn adds a connection to the db's free pool. // err is optionally the last error that occurred on this connection. func (db *DB) putConn(dc *driverConn, err error, resetSession bool) { + if err != driver.ErrBadConn { + if !dc.validateConnection(resetSession) { + err = driver.ErrBadConn + } + } db.mu.Lock() if !dc.inUse { if debugGetPut { @@ -1368,41 +1367,13 @@ func (db *DB) putConn(dc *driverConn, err error, resetSession bool) { if putConnHook != nil { putConnHook(db, dc) } - if db.closed { - // Connections do not need to be reset if they will be closed. - // Prevents writing to resetterCh after the DB has closed. - resetSession = false - } - if resetSession { - if _, resetSession = dc.ci.(driver.SessionResetter); resetSession { - // Lock the driverConn here so it isn't released until - // the connection is reset. - // The lock must be taken before the connection is put into - // the pool to prevent it from being taken out before it is reset. - dc.Lock() - } - } added := db.putConnDBLocked(dc, nil) db.mu.Unlock() if !added { - if resetSession { - dc.Unlock() - } dc.Close() return } - if !resetSession { - return - } - select { - default: - // If the resetterCh is blocking then mark the connection - // as bad and continue on. - dc.lastErr = driver.ErrBadConn - dc.Unlock() - case db.resetterCh <- dc: - } } // Satisfy a connRequest or put the driverConn in the idle pool and return true