diff --git a/src/database/sql/ctxutil.go b/src/database/sql/ctxutil.go index 1071446227..bd652b5462 100644 --- a/src/database/sql/ctxutil.go +++ b/src/database/sql/ctxutil.go @@ -35,15 +35,12 @@ func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvda return nil, err } - resi, err := execer.Exec(query, dargs) - if err == nil { - select { - default: - case <-ctx.Done(): - return resi, ctx.Err() - } + select { + default: + case <-ctx.Done(): + return nil, ctx.Err() } - return resi, err + return execer.Exec(query, dargs) } func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) { @@ -56,16 +53,12 @@ func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, n return nil, err } - rowsi, err := queryer.Query(query, dargs) - if err == nil { - select { - default: - case <-ctx.Done(): - rowsi.Close() - return nil, ctx.Err() - } + select { + default: + case <-ctx.Done(): + return nil, ctx.Err() } - return rowsi, err + return queryer.Query(query, dargs) } func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) { @@ -77,15 +70,12 @@ func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.Nam return nil, err } - resi, err := si.Exec(dargs) - if err == nil { - select { - default: - case <-ctx.Done(): - return resi, ctx.Err() - } + select { + default: + case <-ctx.Done(): + return nil, ctx.Err() } - return resi, err + return si.Exec(dargs) } func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) { @@ -97,16 +87,12 @@ func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.Na return nil, err } - rowsi, err := si.Query(dargs) - if err == nil { - select { - default: - case <-ctx.Done(): - rowsi.Close() - return nil, ctx.Err() - } + select { + default: + case <-ctx.Done(): + return nil, ctx.Err() } - return rowsi, err + return si.Query(dargs) } var errLevelNotSupported = errors.New("sql: selected isolation level is not supported") diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index feb91223a9..214c13acde 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -2071,14 +2071,21 @@ type Rows struct { dc *driverConn // owned; must call releaseConn when closed to release releaseConn func(error) rowsi driver.Rows + cancel func() // called when Rows is closed, may be nil. + closeStmt *driverStmt // if non-nil, statement to Close on close - // closed value is 1 when the Rows is closed. - // Use atomic operations on value when checking value. - closed int32 - cancel func() // called when Rows is closed, may be nil. - lastcols []driver.Value - lasterr error // non-nil only if closed is true - closeStmt *driverStmt // if non-nil, statement to Close on close + // closemu prevents Rows from closing while there + // is an active streaming result. It is held for read during non-close operations + // and exclusively during close. + // + // closemu guards lasterr and closed. + closemu sync.RWMutex + closed bool + lasterr error // non-nil only if closed is true + + // lastcols is only used in Scan, Next, and NextResultSet which are expected + // not not be called concurrently. + lastcols []driver.Value } func (rs *Rows) initContextClose(ctx context.Context) { @@ -2089,7 +2096,7 @@ func (rs *Rows) initContextClose(ctx context.Context) { // awaitDone blocks until the rows are closed or the context canceled. func (rs *Rows) awaitDone(ctx context.Context) { <-ctx.Done() - rs.Close() + rs.close(ctx.Err()) } // Next prepares the next result row for reading with the Scan method. It @@ -2099,8 +2106,19 @@ func (rs *Rows) awaitDone(ctx context.Context) { // // Every call to Scan, even the first one, must be preceded by a call to Next. func (rs *Rows) Next() bool { - if rs.isClosed() { - return false + var doClose, ok bool + withLock(rs.closemu.RLocker(), func() { + doClose, ok = rs.nextLocked() + }) + if doClose { + rs.Close() + } + return ok +} + +func (rs *Rows) nextLocked() (doClose, ok bool) { + if rs.closed { + return false, false } if rs.lastcols == nil { rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns())) @@ -2109,23 +2127,21 @@ func (rs *Rows) Next() bool { if rs.lasterr != nil { // Close the connection if there is a driver error. if rs.lasterr != io.EOF { - rs.Close() - return false + return true, false } nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet) if !ok { - rs.Close() - return false + return true, false } // The driver is at the end of the current result set. // Test to see if there is another result set after the current one. // Only close Rows if there is no further result sets to read. if !nextResultSet.HasNextResultSet() { - rs.Close() + doClose = true } - return false + return doClose, false } - return true + return false, true } // NextResultSet prepares the next result set for reading. It returns true if @@ -2137,18 +2153,28 @@ func (rs *Rows) Next() bool { // scanning. If there are further result sets they may not have rows in the result // set. func (rs *Rows) NextResultSet() bool { - if rs.isClosed() { + var doClose bool + defer func() { + if doClose { + rs.Close() + } + }() + rs.closemu.RLock() + defer rs.closemu.RUnlock() + + if rs.closed { return false } + rs.lastcols = nil nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet) if !ok { - rs.Close() + doClose = true return false } rs.lasterr = nextResultSet.NextResultSet() if rs.lasterr != nil { - rs.Close() + doClose = true return false } return true @@ -2157,6 +2183,8 @@ func (rs *Rows) NextResultSet() bool { // Err returns the error, if any, that was encountered during iteration. // Err may be called after an explicit or implicit Close. func (rs *Rows) Err() error { + rs.closemu.RLock() + defer rs.closemu.RUnlock() if rs.lasterr == io.EOF { return nil } @@ -2167,7 +2195,9 @@ func (rs *Rows) Err() error { // Columns returns an error if the rows are closed, or if the rows // are from QueryRow and there was a deferred error. func (rs *Rows) Columns() ([]string, error) { - if rs.isClosed() { + rs.closemu.RLock() + defer rs.closemu.RUnlock() + if rs.closed { return nil, errors.New("sql: Rows are closed") } if rs.rowsi == nil { @@ -2179,7 +2209,9 @@ func (rs *Rows) Columns() ([]string, error) { // ColumnTypes returns column information such as column type, length, // and nullable. Some information may not be available from some drivers. func (rs *Rows) ColumnTypes() ([]*ColumnType, error) { - if rs.isClosed() { + rs.closemu.RLock() + defer rs.closemu.RUnlock() + if rs.closed { return nil, errors.New("sql: Rows are closed") } if rs.rowsi == nil { @@ -2329,9 +2361,13 @@ func rowsColumnInfoSetup(rowsi driver.Rows) []*ColumnType { // For scanning into *bool, the source may be true, false, 1, 0, or // string inputs parseable by strconv.ParseBool. func (rs *Rows) Scan(dest ...interface{}) error { - if rs.isClosed() { + rs.closemu.RLock() + if rs.closed { + rs.closemu.RUnlock() return errors.New("sql: Rows are closed") } + rs.closemu.RUnlock() + if rs.lastcols == nil { return errors.New("sql: Scan called without calling Next") } @@ -2351,20 +2387,28 @@ func (rs *Rows) Scan(dest ...interface{}) error { // hook throug a test only mutex. var rowsCloseHook = func() func(*Rows, *error) { return nil } -func (rs *Rows) isClosed() bool { - return atomic.LoadInt32(&rs.closed) != 0 -} - // Close closes the Rows, preventing further enumeration. If Next is called // and returns false and there are no further result sets, // the Rows are closed automatically and it will suffice to check the // result of Err. Close is idempotent and does not affect the result of Err. func (rs *Rows) Close() error { - if !atomic.CompareAndSwapInt32(&rs.closed, 0, 1) { + return rs.close(nil) +} + +func (rs *Rows) close(err error) error { + rs.closemu.Lock() + defer rs.closemu.Unlock() + + if rs.closed { return nil } + rs.closed = true - err := rs.rowsi.Close() + if rs.lasterr == nil { + rs.lasterr = err + } + + err = rs.rowsi.Close() if fn := rowsCloseHook(); fn != nil { fn(rs, &err) } diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 898df3b455..1f2b915816 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -313,9 +313,13 @@ func TestQueryContext(t *testing.T) { got = append(got, r) index++ } - err = rows.Err() - if err != nil { - t.Fatalf("Err: %v", err) + select { + case <-ctx.Done(): + if err := ctx.Err(); err != context.Canceled { + t.Fatalf("context err = %v; want context.Canceled") + } + default: + t.Fatalf("context err = nil; want context.Canceled") } want := []row{ {age: 1, name: "Alice"},