diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index adf964992d..04986a28ea 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -1554,19 +1554,6 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { tx.closemu.RLock() defer tx.closemu.RUnlock() - // TODO(bradfitz): We could be more efficient here and either - // provide a method to take an existing Stmt (created on - // perhaps a different Conn), and re-create it on this Conn if - // necessary. Or, better: keep a map in DB of query string to - // Stmts, and have Stmt.Execute do the right thing and - // re-prepare if the Conn in use doesn't have that prepared - // statement. But we'll want to avoid caching the statement - // in the case where we only call conn.Prepare implicitly - // (such as in db.Exec or tx.Exec), but the caller package - // can't be holding a reference to the returned statement. - // Perhaps just looking at the reference count (by noting - // Stmt.Close) would be enough. We might also want a finalizer - // on Stmt to drop the reference count. dc, err := tx.grabConn(ctx) if err != nil { return nil, err @@ -1621,11 +1608,6 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { tx.closemu.RLock() defer tx.closemu.RUnlock() - // TODO(bradfitz): optimize this. Currently this re-prepares - // each time. This is fine for now to illustrate the API but - // we should really cache already-prepared statements - // per-Conn. See also the big comment in Tx.Prepare. - if tx.db != stmt.db { return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")} } @@ -1634,9 +1616,45 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { return &Stmt{stickyErr: err} } var si driver.Stmt - withLock(dc, func() { - si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query) - }) + var parentStmt *Stmt + stmt.mu.Lock() + if stmt.closed || stmt.tx != nil { + // If the statement has been closed or already belongs to a + // transaction, we can't reuse it in this connection. + // Since tx.StmtContext should never need to be called with a + // Stmt already belonging to tx, we ignore this edge case and + // re-prepare the statement in this case. No need to add + // code-complexity for this. + stmt.mu.Unlock() + withLock(dc, func() { + si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query) + }) + if err != nil { + return &Stmt{stickyErr: err} + } + } else { + stmt.removeClosedStmtLocked() + // See if the statement has already been prepared on this connection, + // and reuse it if possible. + for _, v := range stmt.css { + if v.dc == dc { + si = v.ds.si + break + } + } + + stmt.mu.Unlock() + + if si == nil { + cs, err := stmt.prepareOnConnLocked(ctx, dc) + if err != nil { + return &Stmt{stickyErr: err} + } + si = cs.si + } + parentStmt = stmt + } + txs := &Stmt{ db: tx.db, tx: tx, @@ -1644,8 +1662,11 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { Locker: dc, si: si, }, - query: stmt.query, - stickyErr: err, + parentStmt: parentStmt, + query: stmt.query, + } + if parentStmt != nil { + tx.db.addDep(parentStmt, txs) } tx.stmts.Lock() tx.stmts.v = append(tx.stmts.v, txs) @@ -1769,13 +1790,21 @@ type Stmt struct { tx *Tx txds *driverStmt + // parentStmt is set when a transaction-specific statement + // is requested from an identical statement prepared on the same + // conn. parentStmt is used to track the dependency of this statement + // on its originating ("parent") statement so that parentStmt may + // be closed by the user without them having to know whether or not + // any transactions are still using it. + parentStmt *Stmt + mu sync.Mutex // protects the rest of the fields closed bool // css is a list of underlying driver statement interfaces // that are valid on particular connections. This is only // used if tx == nil and one is found that has idle - // connections. If tx != nil, txsi is always used. + // connections. If tx != nil, txds is always used. css []connStmt // lastNumClosed is copied from db.numClosed when Stmt is created @@ -1916,20 +1945,30 @@ func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(e // No luck; we need to prepare the statement on this connection withLock(dc, func() { - ds, err = dc.prepareLocked(ctx, s.query) + ds, err = s.prepareOnConnLocked(ctx, dc) }) if err != nil { s.db.putConn(dc, err) return nil, nil, nil, err } - s.mu.Lock() - cs := connStmt{dc, ds} - s.css = append(s.css, cs) - s.mu.Unlock() return dc, dc.releaseConn, ds, nil } +// prepareOnConnLocked prepares the query in Stmt s on dc and adds it to the list of +// open connStmt on the statement. It assumes the caller is holding the lock on dc. +func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) { + si, err := dc.prepareLocked(ctx, s.query) + if err != nil { + return nil, err + } + cs := connStmt{dc, si} + s.mu.Lock() + s.css = append(s.css, cs) + s.mu.Unlock() + return cs.ds, nil +} + // QueryContext executes a prepared query statement with the given arguments // and returns the query results as a *Rows. func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) { @@ -2056,11 +2095,16 @@ func (s *Stmt) Close() error { s.closed = true s.mu.Unlock() - if s.tx != nil { - return s.txds.Close() + if s.tx == nil { + return s.db.removeDep(s, s) } - return s.db.removeDep(s, s) + if s.parentStmt != nil { + // If parentStmt is set, we must not close s.txds since it's stored + // in the css array of the parentStmt. + return s.db.removeDep(s.parentStmt, s) + } + return s.txds.Close() } func (s *Stmt) finalClose() error { diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 8441aff4bc..79732d4703 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -1024,6 +1024,196 @@ func TestTxStmt(t *testing.T) { } } +func TestTxStmtPreparedOnce(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32") + + prepares0 := numPrepares(t, db) + + // db.Prepare increments numPrepares. + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Fatalf("Stmt, err = %v, %v", stmt, err) + } + defer stmt.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin = %v", err) + } + + txs1 := tx.Stmt(stmt) + txs2 := tx.Stmt(stmt) + + _, err = txs1.Exec("Go", 7) + if err != nil { + t.Fatalf("Exec = %v", err) + } + txs1.Close() + + _, err = txs2.Exec("Gopher", 8) + if err != nil { + t.Fatalf("Exec = %v", err) + } + txs2.Close() + + err = tx.Commit() + if err != nil { + t.Fatalf("Commit = %v", err) + } + + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + +func TestTxStmtClosedRePrepares(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32") + + prepares0 := numPrepares(t, db) + + // db.Prepare increments numPrepares. + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Fatalf("Stmt, err = %v, %v", stmt, err) + } + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin = %v", err) + } + err = stmt.Close() + if err != nil { + t.Fatalf("stmt.Close() = %v", err) + } + // tx.Stmt increments numPrepares because stmt is closed. + txs := tx.Stmt(stmt) + if txs.stickyErr != nil { + t.Fatal(txs.stickyErr) + } + if txs.parentStmt != nil { + t.Fatal("expected nil parentStmt") + } + _, err = txs.Exec(`Eric`, 82) + if err != nil { + t.Fatalf("txs.Exec = %v", err) + } + + err = txs.Close() + if err != nil { + t.Fatalf("txs.Close = %v", err) + } + + tx.Rollback() + + if prepares := numPrepares(t, db) - prepares0; prepares != 2 { + t.Errorf("executed %d Prepare statements; want 2", prepares) + } +} + +func TestParentStmtOutlivesTxStmt(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32") + + // Make sure everything happens on the same connection. + db.SetMaxOpenConns(1) + + prepares0 := numPrepares(t, db) + + // db.Prepare increments numPrepares. + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Fatalf("Stmt, err = %v, %v", stmt, err) + } + defer stmt.Close() + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin = %v", err) + } + txs := tx.Stmt(stmt) + if len(stmt.css) != 1 { + t.Fatalf("len(stmt.css) = %v; want 1", len(stmt.css)) + } + err = txs.Close() + if err != nil { + t.Fatalf("txs.Close() = %v", err) + } + err = tx.Rollback() + if err != nil { + t.Fatalf("tx.Rollback() = %v", err) + } + // txs must not be valid. + _, err = txs.Exec("Suzan", 30) + if err == nil { + t.Fatalf("txs.Exec(), expected err") + } + // Stmt must still be valid. + _, err = stmt.Exec("Janina", 25) + if err != nil { + t.Fatalf("stmt.Exec() = %v", err) + } + + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + +// Test that tx.Stmt called with a statment already +// associated with tx as argument re-prepares the same +// statement again. +func TestTxStmtFromTxStmtRePrepares(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32") + prepares0 := numPrepares(t, db) + // db.Prepare increments numPrepares. + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Fatalf("Stmt, err = %v, %v", stmt, err) + } + defer stmt.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin = %v", err) + } + txs1 := tx.Stmt(stmt) + + // tx.Stmt(txs1) increments numPrepares because txs1 already + // belongs to a transaction (albeit the same transaction). + txs2 := tx.Stmt(txs1) + if txs2.stickyErr != nil { + t.Fatal(txs2.stickyErr) + } + if txs2.parentStmt != nil { + t.Fatal("expected nil parentStmt") + } + _, err = txs2.Exec(`Eric`, 82) + if err != nil { + t.Fatal(err) + } + + err = txs1.Close() + if err != nil { + t.Fatalf("txs1.Close = %v", err) + } + err = txs2.Close() + if err != nil { + t.Fatalf("txs1.Close = %v", err) + } + err = tx.Rollback() + if err != nil { + t.Fatalf("tx.Rollback = %v", err) + } + + if prepares := numPrepares(t, db) - prepares0; prepares != 2 { + t.Errorf("executed %d Prepare statements; want 2", prepares) + } +} + // Issue: https://golang.org/issue/2784 // This test didn't fail before because we got lucky with the fakedb driver. // It was failing, and now not, in github.com/bradfitz/go-sql-test