diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 214c13acde..1e27b17cf9 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -305,8 +305,9 @@ type DB struct { mu sync.Mutex // protects following fields freeConn []*driverConn - connRequests []chan connRequest - numOpen int // number of opened and pending open connections + connRequests map[uint64]chan connRequest + nextRequest uint64 // Next key to use in connRequests. + numOpen int // number of opened and pending open connections // Used to signal the need for new connections // a goroutine running connectionOpener() reads on this chan and // maybeOpenNewConnections sends on the chan (one send per needed connection) @@ -572,10 +573,11 @@ func Open(driverName, dataSourceName string) (*DB, error) { return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName) } db := &DB{ - driver: driveri, - dsn: dataSourceName, - openerCh: make(chan struct{}, connectionRequestQueueSize), - lastPut: make(map[*driverConn]string), + driver: driveri, + dsn: dataSourceName, + openerCh: make(chan struct{}, connectionRequestQueueSize), + lastPut: make(map[*driverConn]string), + connRequests: make(map[uint64]chan connRequest), } go db.connectionOpener() return db, nil @@ -881,6 +883,14 @@ type connRequest struct { var errDBClosed = errors.New("sql: database is closed") +// nextRequestKeyLocked returns the next connection request key. +// It is assumed that nextRequest will not overflow. +func (db *DB) nextRequestKeyLocked() uint64 { + next := db.nextRequest + db.nextRequest++ + return next +} + // conn returns a newly-opened or cached *driverConn. func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) { db.mu.Lock() @@ -918,12 +928,25 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn // Make the connRequest channel. It's buffered so that the // connectionOpener doesn't block while waiting for the req to be read. req := make(chan connRequest, 1) - db.connRequests = append(db.connRequests, req) + reqKey := db.nextRequestKeyLocked() + db.connRequests[reqKey] = req db.mu.Unlock() // Timeout the connection request with the context. select { case <-ctx.Done(): + // Remove the connection request and ensure no value has been sent + // on it after removing. + db.mu.Lock() + delete(db.connRequests, reqKey) + select { + default: + case ret, ok := <-req: + if ok { + db.putConnDBLocked(ret.conn, ret.err) + } + } + db.mu.Unlock() return nil, ctx.Err() case ret, ok := <-req: if !ok { @@ -1044,12 +1067,12 @@ func (db *DB) putConnDBLocked(dc *driverConn, err error) bool { return false } if c := len(db.connRequests); c > 0 { - req := db.connRequests[0] - // This copy is O(n) but in practice faster than a linked list. - // TODO: consider compacting it down less often and - // moving the base instead? - copy(db.connRequests, db.connRequests[1:]) - db.connRequests = db.connRequests[:c-1] + var req chan connRequest + var reqKey uint64 + for reqKey, req = range db.connRequests { + break + } + delete(db.connRequests, reqKey) // Remove from pending requests. if err == nil { dc.inUse = true } diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 1f2b915816..9a1fac8c83 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -531,6 +531,63 @@ func TestQueryNamedArg(t *testing.T) { } } +func TestPoolExhaustOnCancel(t *testing.T) { + if testing.Short() { + t.Skip("long test") + } + db := newTestDB(t, "people") + defer closeDB(t, db) + + max := 3 + + db.SetMaxOpenConns(max) + + // First saturate the connection pool. + // Then start new requests for a connection that is cancelled after it is requested. + + var saturate, saturateDone sync.WaitGroup + saturate.Add(max) + saturateDone.Add(max) + + for i := 0; i < max; i++ { + go func() { + saturate.Done() + rows, err := db.Query("WAIT|500ms|SELECT|people|name,photo|") + if err != nil { + t.Fatalf("Query: %v", err) + } + rows.Close() + saturateDone.Done() + }() + } + + saturate.Wait() + + // Now cancel the request while it is waiting. + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + + for i := 0; i < max; i++ { + ctxReq, cancelReq := context.WithCancel(ctx) + go func() { + time.Sleep(time.Millisecond * 100) + cancelReq() + }() + err := db.PingContext(ctxReq) + if err != context.Canceled { + t.Fatalf("PingContext (Exhaust): %v", err) + } + } + + saturateDone.Wait() + + // Now try to open a normal connection. + err := db.PingContext(ctx) + if err != nil { + t.Fatalf("PingContext (Normal): %v", err) + } +} + func TestByteOwnership(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db)