diff --git a/src/database/sql/ctxutil.go b/src/database/sql/ctxutil.go index ddc4b7228f..7c05ce2448 100644 --- a/src/database/sql/ctxutil.go +++ b/src/database/sql/ctxutil.go @@ -14,40 +14,16 @@ func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver if ciCtx, is := ci.(driver.ConnPrepareContext); is { return ciCtx.PrepareContext(ctx, query) } - if ctx.Done() == context.Background().Done() { - return ci.Prepare(query) - } - - type R struct { - err error - panic interface{} - si driver.Stmt - } - - rc := make(chan R, 1) - go func() { - r := R{} - defer func() { - if v := recover(); v != nil { - r.panic = v - } - rc <- r - }() - r.si, r.err = ci.Prepare(query) - }() - select { - case <-ctx.Done(): - go func() { - <-rc - close(rc) - }() - return nil, ctx.Err() - case r := <-rc: - if r.panic != nil { - panic(r.panic) + si, err := ci.Prepare(query) + if err == nil { + select { + default: + case <-ctx.Done(): + si.Close() + return nil, ctx.Err() } - return r.si, r.err } + return si, err } func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) { @@ -58,84 +34,38 @@ func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvda if err != nil { return nil, err } - if ctx.Done() == context.Background().Done() { - return execer.Exec(query, dargs) - } - type R struct { - err error - panic interface{} - resi driver.Result - } - - rc := make(chan R, 1) - go func() { - r := R{} - defer func() { - if v := recover(); v != nil { - r.panic = v - } - rc <- r - }() - r.resi, r.err = execer.Exec(query, dargs) - }() - select { - case <-ctx.Done(): - go func() { - <-rc - close(rc) - }() - return nil, ctx.Err() - case r := <-rc: - if r.panic != nil { - panic(r.panic) + resi, err := execer.Exec(query, dargs) + if err == nil { + select { + default: + case <-ctx.Done(): + return resi, ctx.Err() } - return r.resi, r.err } + return resi, err } func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) { if queryerCtx, is := queryer.(driver.QueryerContext); is { - return queryerCtx.QueryContext(ctx, query, nvdargs) + ret, err := queryerCtx.QueryContext(ctx, query, nvdargs) + return ret, err } dargs, err := namedValueToValue(nvdargs) if err != nil { return nil, err } - if ctx.Done() == context.Background().Done() { - return queryer.Query(query, dargs) - } - type R struct { - err error - panic interface{} - rowsi driver.Rows - } - - rc := make(chan R, 1) - go func() { - r := R{} - defer func() { - if v := recover(); v != nil { - r.panic = v - } - rc <- r - }() - r.rowsi, r.err = queryer.Query(query, dargs) - }() - select { - case <-ctx.Done(): - go func() { - <-rc - close(rc) - }() - return nil, ctx.Err() - case r := <-rc: - if r.panic != nil { - panic(r.panic) + rowsi, err := queryer.Query(query, dargs) + if err == nil { + select { + default: + case <-ctx.Done(): + rowsi.Close() + return nil, ctx.Err() } - return r.rowsi, r.err } + return rowsi, err } func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) { @@ -146,40 +76,16 @@ func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.Nam if err != nil { return nil, err } - if ctx.Done() == context.Background().Done() { - return si.Exec(dargs) - } - type R struct { - err error - panic interface{} - resi driver.Result - } - - rc := make(chan R, 1) - go func() { - r := R{} - defer func() { - if v := recover(); v != nil { - r.panic = v - } - rc <- r - }() - r.resi, r.err = si.Exec(dargs) - }() - select { - case <-ctx.Done(): - go func() { - <-rc - close(rc) - }() - return nil, ctx.Err() - case r := <-rc: - if r.panic != nil { - panic(r.panic) + resi, err := si.Exec(dargs) + if err == nil { + select { + default: + case <-ctx.Done(): + return resi, ctx.Err() } - return r.resi, r.err } + return resi, err } func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) { @@ -190,40 +96,17 @@ func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.Na if err != nil { return nil, err } - if ctx.Done() == context.Background().Done() { - return si.Query(dargs) - } - type R struct { - err error - panic interface{} - rowsi driver.Rows - } - - rc := make(chan R, 1) - go func() { - r := R{} - defer func() { - if v := recover(); v != nil { - r.panic = v - } - rc <- r - }() - r.rowsi, r.err = si.Query(dargs) - }() - select { - case <-ctx.Done(): - go func() { - <-rc - close(rc) - }() - return nil, ctx.Err() - case r := <-rc: - if r.panic != nil { - panic(r.panic) + rowsi, err := si.Query(dargs) + if err == nil { + select { + default: + case <-ctx.Done(): + rowsi.Close() + return nil, ctx.Err() } - return r.rowsi, r.err } + return rowsi, err } var errLevelNotSupported = errors.New("sql: selected isolation level is not supported") @@ -249,35 +132,16 @@ func ctxDriverBegin(ctx context.Context, ci driver.Conn) (driver.Tx, error) { return nil, errors.New("sql: driver does not support read-only transactions") } - type R struct { - err error - panic interface{} - txi driver.Tx - } - rc := make(chan R, 1) - go func() { - r := R{} - defer func() { - if v := recover(); v != nil { - r.panic = v - } - rc <- r - }() - r.txi, r.err = ci.Begin() - }() - select { - case <-ctx.Done(): - go func() { - <-rc - close(rc) - }() - return nil, ctx.Err() - case r := <-rc: - if r.panic != nil { - panic(r.panic) + txi, err := ci.Begin() + if err == nil { + select { + default: + case <-ctx.Done(): + txi.Rollback() + return nil, ctx.Err() } - return r.txi, r.err } + return txi, err } func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { diff --git a/src/database/sql/driver/driver.go b/src/database/sql/driver/driver.go index e2ee7a9b28..c8cbbf0696 100644 --- a/src/database/sql/driver/driver.go +++ b/src/database/sql/driver/driver.go @@ -87,12 +87,21 @@ type Pinger interface { // statement. // // Exec may return ErrSkip. +// +// Deprecated: Drivers should implement ExecerContext instead (or additionally). type Execer interface { Exec(query string, args []Value) (Result, error) } -// ExecerContext is like execer, but must honor the context timeout and return -// when the context is cancelled. +// ExecerContext is an optional interface that may be implemented by a Conn. +// +// If a Conn does not implement ExecerContext, the sql package's DB.Exec will +// first prepare a query, execute the statement, and then close the +// statement. +// +// ExecerContext may return ErrSkip. +// +// ExecerContext must honor the context timeout and return when the context is canceled. type ExecerContext interface { ExecContext(ctx context.Context, query string, args []NamedValue) (Result, error) } @@ -104,12 +113,21 @@ type ExecerContext interface { // statement. // // Query may return ErrSkip. +// +// Deprecated: Drivers should implement QueryerContext instead (or additionally). type Queryer interface { Query(query string, args []Value) (Rows, error) } -// QueryerContext is like Queryer, but most honor the context timeout and return -// when the context is cancelled. +// QueryerContext is an optional interface that may be implemented by a Conn. +// +// If a Conn does not implement QueryerContext, the sql package's DB.Query will +// first prepare a query, execute the statement, and then close the +// statement. +// +// QueryerContext may return ErrSkip. +// +// QueryerContext must honor the context timeout and return when the context is canceled. type QueryerContext interface { QueryContext(ctx context.Context, query string, args []NamedValue) (Rows, error) } @@ -133,6 +151,8 @@ type Conn interface { Close() error // Begin starts and returns a new transaction. + // + // Deprecated: Drivers should implement ConnBeginContext instead (or additionally). Begin() (Tx, error) } @@ -167,8 +187,8 @@ func ReadOnlyFromContext(ctx context.Context) (readonly bool) { // ConnBeginContext enhances the Conn interface with context. type ConnBeginContext interface { // BeginContext starts and returns a new transaction. - // The provided context should be used to roll the transaction back - // if it is cancelled. + // If the context is canceled by the user the sql package will + // call Tx.Rollback before discarding and closing the connection. // // This must call IsolationFromContext to determine if there is a set // isolation level. If the driver does not support setting the isolation @@ -215,22 +235,32 @@ type Stmt interface { // Exec executes a query that doesn't return rows, such // as an INSERT or UPDATE. + // + // Deprecated: Drivers should implement StmtExecContext instead (or additionally). Exec(args []Value) (Result, error) // Query executes a query that may return rows, such as a // SELECT. + // + // Deprecated: Drivers should implement StmtQueryContext instead (or additionally). Query(args []Value) (Rows, error) } // StmtExecContext enhances the Stmt interface by providing Exec with context. type StmtExecContext interface { - // ExecContext must honor the context timeout and return when it is cancelled. + // ExecContext executes a query that doesn't return rows, such + // as an INSERT or UPDATE. + // + // ExecContext must honor the context timeout and return when it is canceled. ExecContext(ctx context.Context, args []NamedValue) (Result, error) } // StmtQueryContext enhances the Stmt interface by providing Query with context. type StmtQueryContext interface { - // QueryContext must honor the context timeout and return when it is cancelled. + // QueryContext executes a query that may return rows, such as a + // SELECT. + // + // QueryContext must honor the context timeout and return when it is canceled. QueryContext(ctx context.Context, args []NamedValue) (Rows, error) } diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go index c42f23208f..9de9289644 100644 --- a/src/database/sql/fakedb_test.go +++ b/src/database/sql/fakedb_test.go @@ -39,6 +39,9 @@ var _ = log.Printf // Any of these can be preceded by PANIC||, to cause the // named method on fakeStmt to panic. // +// Any of these can be proceeded by WAIT||, to cause the +// named method on fakeStmt to sleep for the specified duration. +// // Multiple of these can be combined when separated with a semicolon. // // When opening a fakeDriver's database, it starts empty with no @@ -119,6 +122,7 @@ type fakeStmt struct { cmd string table string panic string + wait time.Duration next *fakeStmt // used for returning multiple results. @@ -526,14 +530,28 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { if firstStmt == nil { firstStmt = stmt } - if len(parts) >= 3 && parts[0] == "PANIC" { - stmt.panic = parts[1] - parts = parts[2:] + if len(parts) >= 3 { + switch parts[0] { + case "PANIC": + stmt.panic = parts[1] + parts = parts[2:] + case "WAIT": + wait, err := time.ParseDuration(parts[1]) + if err != nil { + return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err) + } + parts = parts[2:] + stmt.wait = wait + } } cmd := parts[0] stmt.cmd = cmd parts = parts[1:] + if stmt.wait > 0 { + time.Sleep(stmt.wait) + } + c.incrStat(&c.stmtsMade) var err error switch cmd { @@ -619,6 +637,16 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d return nil, err } + if s.wait > 0 { + time.Sleep(s.wait) + } + + select { + default: + case <-ctx.Done(): + return nil, ctx.Err() + } + db := s.c.db switch s.cmd { case "WIPE": diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 2a9ae0b95a..4ef0fa7221 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -875,9 +875,11 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn return nil, errDBClosed } // Check if the context is expired. - if err := ctx.Err(); err != nil { + select { + default: + case <-ctx.Done(): db.mu.Unlock() - return nil, err + return nil, ctx.Err() } lifetime := db.maxLifetime @@ -1288,6 +1290,11 @@ func (db *DB) QueryRow(query string, args ...interface{}) *Row { // BeginContext starts a transaction. // +// The provided context is used until the transaction is committed or rolled back. +// If the context is canceled, the sql package will roll back +// the transaction. Tx.Commit will return an error if the context provided to +// BeginContext is canceled. +// // An isolation level may be set by setting the value in the context // before calling this. If a non-default isolation level is used // that the driver doesn't support an error will be returned. Different drivers @@ -1335,15 +1342,18 @@ func (db *DB) begin(ctx context.Context, strategy connReuseStrategy) (tx *Tx, er dc: dc, txi: txi, cancel: cancel, + ctx: ctx, } - go func() { + go func(tx *Tx) { select { - case <-ctx.Done(): - if !tx.done { - tx.Rollback() + case <-tx.ctx.Done(): + if !tx.isDone() { + // Discard and close the connection used to ensure the transaction + // is closed and the resources are released. + tx.rollback(true) } } - }() + }(tx) return tx, nil } @@ -1370,10 +1380,11 @@ type Tx struct { dc *driverConn txi driver.Tx - // done transitions from false to true exactly once, on Commit + // done transitions from 0 to 1 exactly once, on Commit // or Rollback. once done, all operations fail with // ErrTxDone. - done bool + // Use atomic operations on value when checking value. + done int32 // All Stmts prepared for this transaction. These will be closed after the // transaction has been committed or rolled back. @@ -1384,6 +1395,13 @@ type Tx struct { // cancel is called after done transitions from false to true. cancel func() + + // ctx lives for the life of the transaction. + ctx context.Context +} + +func (tx *Tx) isDone() bool { + return atomic.LoadInt32(&tx.done) != 0 } // ErrTxDone is returned by any operation that is performed on a transaction @@ -1391,10 +1409,9 @@ type Tx struct { var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back") func (tx *Tx) close(err error) { - if tx.done { + if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) { panic("double close") // internal error } - tx.done = true tx.db.putConn(tx.dc, err) tx.cancel() tx.dc = nil @@ -1402,7 +1419,7 @@ func (tx *Tx) close(err error) { } func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) { - if tx.done { + if tx.isDone() { return nil, ErrTxDone } return tx.dc, nil @@ -1419,7 +1436,12 @@ func (tx *Tx) closePrepared() { // Commit commits the transaction. func (tx *Tx) Commit() error { - if tx.done { + select { + default: + case <-tx.ctx.Done(): + return tx.ctx.Err() + } + if tx.isDone() { return ErrTxDone } var err error @@ -1433,9 +1455,10 @@ func (tx *Tx) Commit() error { return err } -// Rollback aborts the transaction. -func (tx *Tx) Rollback() error { - if tx.done { +// rollback aborts the transaction and optionally forces the pool to discard +// the connection. +func (tx *Tx) rollback(discardConn bool) error { + if tx.isDone() { return ErrTxDone } var err error @@ -1445,10 +1468,18 @@ func (tx *Tx) Rollback() error { if err != driver.ErrBadConn { tx.closePrepared() } + if discardConn { + err = driver.ErrBadConn + } tx.close(err) return err } +// Rollback aborts the transaction. +func (tx *Tx) Rollback() error { + return tx.rollback(false) +} + // Prepare creates a prepared statement for use within a transaction. // // The returned statement operates within the transaction and will be closed @@ -1480,7 +1511,7 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { var si driver.Stmt withLock(dc, func() { - si, err = dc.ci.Prepare(query) + si, err = ctxDriverPrepare(ctx, dc.ci, query) }) if err != nil { return nil, err @@ -1538,7 +1569,7 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { } var si driver.Stmt withLock(dc, func() { - si, err = dc.ci.Prepare(stmt.query) + si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query) }) txs := &Stmt{ db: tx.db, diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index b64d4dda5a..d94ef5cad3 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -141,10 +141,7 @@ func closeDB(t testing.TB, db *DB) { if err != nil { t.Fatalf("error closing DB: %v", err) } - db.mu.Lock() - count := db.numOpen - db.mu.Unlock() - if count != 0 { + if count := db.numOpenConns(); count != 0 { t.Fatalf("%d connections still open after closing DB", count) } } @@ -183,6 +180,12 @@ func (db *DB) numFreeConns() int { return len(db.freeConn) } +func (db *DB) numOpenConns() int { + db.mu.Lock() + defer db.mu.Unlock() + return db.numOpen +} + // clearAllConns closes all connections in db. func (db *DB) clearAllConns(t *testing.T) { db.SetMaxIdleConns(0) @@ -320,6 +323,75 @@ func TestQueryContext(t *testing.T) { } } +func waitCondition(waitFor, checkEvery time.Duration, fn func() bool) bool { + deadline := time.Now().Add(waitFor) + for time.Now().Before(deadline) { + if fn() { + return true + } + time.Sleep(checkEvery) + } + return false +} + +func TestQueryContextWait(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + prepares0 := numPrepares(t, db) + + ctx, _ := context.WithTimeout(context.Background(), time.Millisecond*15) + + // This will trigger the *fakeConn.Prepare method which will take time + // performing the query. The ctxDriverPrepare func will check the context + // after this and close the rows and return an error. + _, err := db.QueryContext(ctx, "WAIT|30ms|SELECT|people|age,name|") + if err != context.DeadlineExceeded { + t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err) + } + + // Verify closed rows connection after error condition. + if n := db.numFreeConns(); n != 1 { + t.Fatalf("free conns after query hitting EOF = %d; want 1", n) + } + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + +func TestTxContextWait(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, _ := context.WithTimeout(context.Background(), time.Millisecond*15) + + tx, err := db.BeginContext(ctx) + if err != nil { + t.Fatal(err) + } + + // This will trigger the *fakeConn.Prepare method which will take time + // performing the query. The ctxDriverPrepare func will check the context + // after this and close the rows and return an error. + _, err = tx.QueryContext(ctx, "WAIT|30ms|SELECT|people|age,name|") + if err != context.DeadlineExceeded { + t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err) + } + + var numFree int + if !waitCondition(5*time.Second, 5*time.Millisecond, func() bool { + numFree = db.numFreeConns() + return numFree == 0 + }) { + t.Fatalf("free conns after hitting EOF = %d; want 0", numFree) + } + + // Ensure the dropped connection allows more connections to be made. + // Checked on DB Close. + waitCondition(5*time.Second, 5*time.Millisecond, func() bool { + return db.numOpenConns() == 0 + }) +} + func TestMultiResultSetQuery(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db)