Skip to content

Commit 9def857

Browse files
kardianosbradfitz
authored andcommitted
database/sql: prevent Tx.rollback from racing Tx.close
Previously Tx.done was being set in close, but in a Tx rollback and Commit are the real closing methods, and Tx.close is just a helper common to both. Prior to this change a multiple rollback statements could be called, one would enter close and begin closing it while the other was still in rollback breaking it. Fix that by setting done in rollback and Commit, not in Tx.close. Fixes #18429 Change-Id: Ie274f60c2aa6a4a5aa38e55109c05ea9d4fe0223 Reviewed-on: https://go-review.googlesource.com/34716 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org> Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org>
1 parent f78cd56 commit 9def857

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

src/database/sql/sql.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,10 +1421,9 @@ func (tx *Tx) isDone() bool {
14211421
// that has already been committed or rolled back.
14221422
var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back")
14231423

1424+
// close returns the connection to the pool and
1425+
// must only be called by Tx.rollback or Tx.Commit.
14241426
func (tx *Tx) close(err error) {
1425-
if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
1426-
panic("double close") // internal error
1427-
}
14281427
tx.db.putConn(tx.dc, err)
14291428
tx.cancel()
14301429
tx.dc = nil
@@ -1449,7 +1448,7 @@ func (tx *Tx) closePrepared() {
14491448

14501449
// Commit commits the transaction.
14511450
func (tx *Tx) Commit() error {
1452-
if tx.isDone() {
1451+
if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
14531452
return ErrTxDone
14541453
}
14551454
select {
@@ -1471,7 +1470,7 @@ func (tx *Tx) Commit() error {
14711470
// rollback aborts the transaction and optionally forces the pool to discard
14721471
// the connection.
14731472
func (tx *Tx) rollback(discardConn bool) error {
1474-
if tx.isDone() {
1473+
if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
14751474
return ErrTxDone
14761475
}
14771476
var err error

src/database/sql/sql_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2607,6 +2607,54 @@ func TestIssue6081(t *testing.T) {
26072607
}
26082608
}
26092609

2610+
// TestIssue18429 attempts to stress rolling back the transaction from a context
2611+
// cancel while simultaneously calling Tx.Rollback. Rolling back from a context
2612+
// happens concurrently so tx.rollback and tx.Commit must gaurded to not
2613+
// be entered twice.
2614+
//
2615+
// The test is composed of a context that is canceled while the query is in process
2616+
// so the internal rollback will run concurrently with the explicitly called
2617+
// Tx.Rollback.
2618+
func TestIssue18429(t *testing.T) {
2619+
db := newTestDB(t, "people")
2620+
defer closeDB(t, db)
2621+
2622+
ctx := context.Background()
2623+
sem := make(chan bool, 20)
2624+
var wg sync.WaitGroup
2625+
2626+
const milliWait = 30
2627+
2628+
for i := 0; i < 100; i++ {
2629+
sem <- true
2630+
wg.Add(1)
2631+
go func() {
2632+
defer func() {
2633+
<-sem
2634+
wg.Done()
2635+
}()
2636+
qwait := (time.Duration(rand.Intn(milliWait)) * time.Millisecond).String()
2637+
2638+
ctx, cancel := context.WithTimeout(ctx, time.Duration(rand.Intn(milliWait))*time.Millisecond)
2639+
defer cancel()
2640+
2641+
tx, err := db.BeginTx(ctx, nil)
2642+
if err != nil {
2643+
return
2644+
}
2645+
rows, err := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
2646+
if rows != nil {
2647+
rows.Close()
2648+
}
2649+
// This call will race with the context cancel rollback to complete
2650+
// if the rollback itself isn't guarded.
2651+
tx.Rollback()
2652+
}()
2653+
}
2654+
wg.Wait()
2655+
time.Sleep(milliWait * 3 * time.Millisecond)
2656+
}
2657+
26102658
func TestConcurrency(t *testing.T) {
26112659
doConcurrentTest(t, new(concurrentDBQueryTest))
26122660
doConcurrentTest(t, new(concurrentDBExecTest))

0 commit comments

Comments
 (0)