diff --git a/pkg/git/libgit2/managed/ssh.go b/pkg/git/libgit2/managed/ssh.go index ee8f580b6..b990dd0af 100644 --- a/pkg/git/libgit2/managed/ssh.go +++ b/pkg/git/libgit2/managed/ssh.go @@ -95,12 +95,11 @@ type sshSmartSubtransport struct { } type connection struct { - conn net.Conn client *ssh.Client session *ssh.Session currentStream *sshSmartSubtransportStream connected bool - m sync.Mutex + m sync.RWMutex } func (t *sshSmartSubtransport) Action(transportOptionsURL string, action git2go.SmartServiceAction) (git2go.SmartSubtransportStream, error) { @@ -155,11 +154,6 @@ func (t *sshSmartSubtransport) Action(transportOptionsURL string, action git2go. return nil, fmt.Errorf("unexpected action: %v", action) } - if t.con.connected { - // Disregard errors from previous stream, futher details inside Close(). - _ = t.Close() - } - port := "22" if u.Port() != "" { port = u.Port() @@ -189,13 +183,18 @@ func (t *sshSmartSubtransport) Action(transportOptionsURL string, action git2go. return nil } + t.con.m.RLock() + if t.con.connected == true { + // The connection is no longer shared across actions, so ensures + // all has been released before starting a new connection. + _ = t.Close() + } + t.con.m.RUnlock() + err = t.createConn(t.addr, sshConfig) if err != nil { return nil, err } - t.con.m.Lock() - t.con.connected = true - t.con.m.Unlock() traceLog.Info("[ssh]: creating new ssh session") if t.con.session, err = t.con.client.NewSession(); err != nil { @@ -244,12 +243,12 @@ func (t *sshSmartSubtransport) Action(transportOptionsURL string, action git2go. return nil default: - t.con.m.Lock() + t.con.m.RLock() if !t.con.connected { - t.con.m.Unlock() + t.con.m.RUnlock() return nil } - t.con.m.Unlock() + t.con.m.RUnlock() _, err := io.Copy(w, reader) if err != nil { @@ -286,8 +285,10 @@ func (t *sshSmartSubtransport) createConn(addr string, sshConfig *ssh.ClientConf return err } - t.con.conn = conn + t.con.m.Lock() + t.con.connected = true t.con.client = ssh.NewClient(c, chans, reqs) + t.con.m.Unlock() return nil } @@ -309,7 +310,7 @@ func (t *sshSmartSubtransport) Close() error { if t.con.client != nil && t.stdin != nil { _ = t.stdin.Close() } - t.con.client = nil + t.stdin = nil if t.con.session != nil { traceLog.Info("[ssh]: session.Close()", "server", t.addr) @@ -317,21 +318,16 @@ func (t *sshSmartSubtransport) Close() error { } t.con.session = nil - return nil -} - -func (t *sshSmartSubtransport) Free() { - traceLog.Info("[ssh]: sshSmartSubtransport.Free()") if t.con.client != nil { _ = t.con.client.Close() } - if t.con.conn != nil { - _ = t.con.conn.Close() - } - t.con.m.Lock() t.con.connected = false - t.con.m.Unlock() + + return nil +} + +func (t *sshSmartSubtransport) Free() { } type sshSmartSubtransportStream struct {