diff --git a/pkg/git/libgit2/managed/ssh.go b/pkg/git/libgit2/managed/ssh.go index 986efd937..1c11afe86 100644 --- a/pkg/git/libgit2/managed/ssh.go +++ b/pkg/git/libgit2/managed/ssh.go @@ -54,6 +54,7 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "time" "golang.org/x/crypto/ssh" @@ -80,10 +81,12 @@ func registerManagedSSH() error { } func sshSmartSubtransportFactory(remote *git2go.Remote, transport *git2go.Transport) (git2go.SmartSubtransport, error) { + var closed int32 = 0 return &sshSmartSubtransport{ - transport: transport, - ctx: context.Background(), - logger: logr.Discard(), + transport: transport, + ctx: context.Background(), + logger: logr.Discard(), + closedSessions: &closed, }, nil } @@ -109,15 +112,12 @@ type sshSmartSubtransport struct { stdin io.WriteCloser stdout io.Reader - con connection -} + closedSessions *int32 -type connection struct { client *ssh.Client session *ssh.Session currentStream *sshSmartSubtransportStream connected bool - m sync.RWMutex } func (t *sshSmartSubtransport) Action(transportOptionsURL string, action git2go.SmartServiceAction) (git2go.SmartSubtransportStream, error) { @@ -151,17 +151,17 @@ func (t *sshSmartSubtransport) Action(transportOptionsURL string, action git2go. var cmd string switch action { case git2go.SmartServiceActionUploadpackLs, git2go.SmartServiceActionUploadpack: - if t.con.currentStream != nil { + if t.currentStream != nil { if t.lastAction == git2go.SmartServiceActionUploadpackLs { - return t.con.currentStream, nil + return t.currentStream, nil } } cmd = fmt.Sprintf("git-upload-pack '%s'", uPath) case git2go.SmartServiceActionReceivepackLs, git2go.SmartServiceActionReceivepack: - if t.con.currentStream != nil { + if t.currentStream != nil { if t.lastAction == git2go.SmartServiceActionReceivepackLs { - return t.con.currentStream, nil + return t.currentStream, nil } } cmd = fmt.Sprintf("git-receive-pack '%s'", uPath) @@ -208,13 +208,11 @@ func (t *sshSmartSubtransport) Action(transportOptionsURL string, action git2go. return nil } - t.con.m.RLock() - if t.con.connected == true { + if t.connected { // The connection is no longer shared across actions, so ensures // all has been released before starting a new connection. _ = t.Close() } - t.con.m.RUnlock() err = t.createConn(addr, sshConfig) if err != nil { @@ -222,18 +220,18 @@ func (t *sshSmartSubtransport) Action(transportOptionsURL string, action git2go. } t.logger.V(logger.TraceLevel).Info("creating new ssh session") - if t.con.session, err = t.con.client.NewSession(); err != nil { + if t.session, err = t.client.NewSession(); err != nil { return nil, err } - if t.stdin, err = t.con.session.StdinPipe(); err != nil { + if t.stdin, err = t.session.StdinPipe(); err != nil { return nil, err } var w *io.PipeWriter var reader io.Reader t.stdout, w = io.Pipe() - if reader, err = t.con.session.StdoutPipe(); err != nil { + if reader, err = t.session.StdoutPipe(); err != nil { return nil, err } @@ -251,7 +249,6 @@ func (t *sshSmartSubtransport) Action(transportOptionsURL string, action git2go. "recovered from libgit2 ssh smart subtransport panic") } }() - var cancel context.CancelFunc ctx := t.ctx @@ -261,6 +258,7 @@ func (t *sshSmartSubtransport) Action(transportOptionsURL string, action git2go. defer cancel() } + closedAlready := atomic.LoadInt32(t.closedSessions) for { select { case <-ctx.Done(): @@ -268,12 +266,9 @@ func (t *sshSmartSubtransport) Action(transportOptionsURL string, action git2go. return nil default: - t.con.m.RLock() - if !t.con.connected { - t.con.m.RUnlock() + if atomic.LoadInt32(t.closedSessions) > closedAlready { return nil } - t.con.m.RUnlock() _, err := io.Copy(w, reader) if err != nil { @@ -285,16 +280,16 @@ func (t *sshSmartSubtransport) Action(transportOptionsURL string, action git2go. }() t.logger.V(logger.TraceLevel).Info("run on remote", "cmd", cmd) - if err := t.con.session.Start(cmd); err != nil { + if err := t.session.Start(cmd); err != nil { return nil, err } t.lastAction = action - t.con.currentStream = &sshSmartSubtransportStream{ + t.currentStream = &sshSmartSubtransportStream{ owner: t, } - return t.con.currentStream, nil + return t.currentStream, nil } func (t *sshSmartSubtransport) createConn(addr string, sshConfig *ssh.ClientConfig) error { @@ -311,10 +306,8 @@ func (t *sshSmartSubtransport) createConn(addr string, sshConfig *ssh.ClientConf return err } - t.con.m.Lock() - t.con.connected = true - t.con.client = ssh.NewClient(c, chans, reqs) - t.con.m.Unlock() + t.connected = true + t.client = ssh.NewClient(c, chans, reqs) return nil } @@ -330,27 +323,27 @@ func (t *sshSmartSubtransport) createConn(addr string, sshConfig *ssh.ClientConf // SmartSubTransport (i.e. unreleased resources, staled connections). func (t *sshSmartSubtransport) Close() error { t.logger.V(logger.TraceLevel).Info("sshSmartSubtransport.Close()") - t.con.m.Lock() - defer t.con.m.Unlock() - t.con.currentStream = nil - if t.con.client != nil && t.stdin != nil { + t.currentStream = nil + if t.client != nil && t.stdin != nil { _ = t.stdin.Close() } t.stdin = nil - if t.con.session != nil { + if t.session != nil { t.logger.V(logger.TraceLevel).Info("session.Close()") - _ = t.con.session.Close() + _ = t.session.Close() } - t.con.session = nil + t.session = nil - if t.con.client != nil { - _ = t.con.client.Close() + if t.client != nil { + _ = t.client.Close() t.logger.V(logger.TraceLevel).Info("close client") } + t.client = nil - t.con.connected = false + t.connected = false + atomic.AddInt32(t.closedSessions, 1) return nil }