Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

p2p: fix race in dialScheduler #29235

Merged
merged 2 commits into from
Mar 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions p2p/dial.go
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@ import (
mrand "math/rand"
"net"
"sync"
"sync/atomic"
"time"

"github.com/ethereum/go-ethereum/common/mclock"
@@ -248,7 +249,7 @@ loop:
}

case task := <-d.doneCh:
id := task.dest.ID()
id := task.dest().ID()
delete(d.dialing, id)
d.updateStaticPool(id)
d.doneSinceLastLog++
@@ -410,7 +411,7 @@ func (d *dialScheduler) startStaticDials(n int) (started int) {
// updateStaticPool attempts to move the given static dial back into staticPool.
func (d *dialScheduler) updateStaticPool(id enode.ID) {
task, ok := d.static[id]
if ok && task.staticPoolIndex < 0 && d.checkDial(task.dest) == nil {
if ok && task.staticPoolIndex < 0 && d.checkDial(task.dest()) == nil {
d.addToStaticPool(task)
}
}
@@ -437,10 +438,11 @@ func (d *dialScheduler) removeFromStaticPool(idx int) {

// startDial runs the given dial task in a separate goroutine.
func (d *dialScheduler) startDial(task *dialTask) {
d.log.Trace("Starting p2p dial", "id", task.dest.ID(), "ip", task.dest.IP(), "flag", task.flags)
hkey := string(task.dest.ID().Bytes())
node := task.dest()
d.log.Trace("Starting p2p dial", "id", node.ID(), "ip", node.IP(), "flag", task.flags)
hkey := string(node.ID().Bytes())
d.history.add(hkey, d.clock.Now().Add(dialHistoryExpiration))
d.dialing[task.dest.ID()] = task
d.dialing[node.ID()] = task
go func() {
task.run(d)
d.doneCh <- task
@@ -451,39 +453,46 @@ func (d *dialScheduler) startDial(task *dialTask) {
type dialTask struct {
staticPoolIndex int
flags connFlag

// These fields are private to the task and should not be
// accessed by dialScheduler while the task is running.
dest *enode.Node
destPtr atomic.Pointer[enode.Node]
lastResolved mclock.AbsTime
resolveDelay time.Duration
}

func newDialTask(dest *enode.Node, flags connFlag) *dialTask {
return &dialTask{dest: dest, flags: flags, staticPoolIndex: -1}
t := &dialTask{flags: flags, staticPoolIndex: -1}
t.destPtr.Store(dest)
return t
}

type dialError struct {
error
}

func (t *dialTask) dest() *enode.Node {
return t.destPtr.Load()
}

func (t *dialTask) run(d *dialScheduler) {
if t.needResolve() && !t.resolve(d) {
return
}

err := t.dial(d, t.dest)
err := t.dial(d, t.dest())
if err != nil {
// For static nodes, resolve one more time if dialing fails.
if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 {
if t.resolve(d) {
t.dial(d, t.dest)
t.dial(d, t.dest())
}
}
}
}

func (t *dialTask) needResolve() bool {
return t.flags&staticDialedConn != 0 && t.dest.IP() == nil
return t.flags&staticDialedConn != 0 && t.dest().IP() == nil
}

// resolve attempts to find the current endpoint for the destination
@@ -502,38 +511,41 @@ func (t *dialTask) resolve(d *dialScheduler) bool {
if t.lastResolved > 0 && time.Duration(d.clock.Now()-t.lastResolved) < t.resolveDelay {
return false
}
resolved := d.resolver.Resolve(t.dest)

node := t.dest()
resolved := d.resolver.Resolve(node)
t.lastResolved = d.clock.Now()
if resolved == nil {
t.resolveDelay *= 2
if t.resolveDelay > maxResolveDelay {
t.resolveDelay = maxResolveDelay
}
d.log.Debug("Resolving node failed", "id", t.dest.ID(), "newdelay", t.resolveDelay)
d.log.Debug("Resolving node failed", "id", node.ID(), "newdelay", t.resolveDelay)
return false
}
// The node was found.
t.resolveDelay = initialResolveDelay
t.dest = resolved
d.log.Debug("Resolved node", "id", t.dest.ID(), "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()})
t.destPtr.Store(resolved)
d.log.Debug("Resolved node", "id", resolved.ID(), "addr", &net.TCPAddr{IP: resolved.IP(), Port: resolved.TCP()})
return true
}

// dial performs the actual connection attempt.
func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error {
dialMeter.Mark(1)
fd, err := d.dialer.Dial(d.ctx, t.dest)
fd, err := d.dialer.Dial(d.ctx, dest)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, this was a bit funky before this change. Mostly ignored the incoming parameter, just going with t.dest. Good change

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although, afaict it's only ever called like t.dial(d, t.dest) (now t.dial(d, t.dest())), so I guess we could just drop the second param, and make it t.dial(d) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a holdover from when resolve() returned the new node or something.

if err != nil {
d.log.Trace("Dial error", "id", t.dest.ID(), "addr", nodeAddr(t.dest), "conn", t.flags, "err", cleanupDialErr(err))
d.log.Trace("Dial error", "id", dest.ID(), "addr", nodeAddr(dest), "conn", t.flags, "err", cleanupDialErr(err))
dialConnectionError.Mark(1)
return &dialError{err}
}
return d.setupFunc(newMeteredConn(fd), t.flags, dest)
}

func (t *dialTask) String() string {
id := t.dest.ID()
return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], t.dest.IP(), t.dest.TCP())
node := t.dest()
id := node.ID()
return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], node.IP(), node.TCP())
}

func cleanupDialErr(err error) error {