diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go new file mode 100644 index 0000000000000000000000000000000000000000..17e2adb73b25436153f85c73ab889d1bca31be68 --- /dev/null +++ b/internal/sshd/connection.go @@ -0,0 +1,52 @@ +package sshd + +import ( + "context" + + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + "golang.org/x/sync/semaphore" + + "gitlab.com/gitlab-org/gitlab-shell/internal/config" +) + +type connection struct { + // State set up by the sshd + cfg *config.Config + gitlabKeyId string + remoteAddr string +} + +func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel) { + concurrentSessions := semaphore.NewWeighted(c.cfg.Server.ConcurrentSessionsLimit) + + for newChannel := range chans { + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + continue + } + if !concurrentSessions.TryAcquire(1) { + newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions") + sshdHitMaxSessions.Inc() + continue + } + channel, requests, err := newChannel.Accept() + if err != nil { + log.Infof("Could not accept channel: %v", err) + concurrentSessions.Release(1) + continue + } + + go func() { + defer concurrentSessions.Release(1) + session := &session{ + cfg: c.cfg, + channel: channel, + gitlabKeyId: c.gitlabKeyId, + remoteAddr: c.remoteAddr, + } + + session.handle(ctx, requests) + }() + } +} diff --git a/internal/sshd/session.go b/internal/sshd/session.go index e178fe81883db16203d7be4d85276ea6dcf4d2fe..22cb715747ad82dce031858d062940f2fbc894b0 100644 --- a/internal/sshd/session.go +++ b/internal/sshd/session.go @@ -3,7 +3,6 @@ package sshd import ( "context" "fmt" - "net" "golang.org/x/crypto/ssh" @@ -15,11 +14,11 @@ import ( ) type session struct { - // State set up by handleConn - cfg *config.Config - channel ssh.Channel - sconn *ssh.ServerConn - nconn net.Conn + // State set up by the connection + cfg *config.Config + channel ssh.Channel + gitlabKeyId string + remoteAddr string // State managed by the session execCmd string @@ -106,12 +105,12 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 { } args := &commandargs.Shell{ - GitlabKeyId: s.sconn.Permissions.Extensions["key-id"], + GitlabKeyId: s.gitlabKeyId, Env: sshenv.Env{ IsSSHConnection: true, OriginalCommand: s.execCmd, GitProtocolVersion: s.gitProtocolVersion, - RemoteAddr: s.nconn.RemoteAddr().(*net.TCPAddr).String(), + RemoteAddr: s.remoteAddr, }, } diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index 7906f0d53104af965338ac85f11ebab928edd98e..35952686f07b948ba4dcc1add09236ea6b3f58a7 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -16,7 +16,6 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "golang.org/x/crypto/ssh" - "golang.org/x/sync/semaphore" "gitlab.com/gitlab-org/gitlab-shell/internal/config" "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys" @@ -81,7 +80,7 @@ func Run(cfg *config.Config) error { log.Infof("Listening on %v", sshListener.Addr().String()) - config := &ssh.ServerConfig{ + sshCfg := &ssh.ServerConfig{ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { if conn.User() != cfg.User { return nil, errors.New("unknown user") @@ -118,7 +117,7 @@ func Run(cfg *config.Config) error { continue } loadedHostKeys++ - config.AddHostKey(key) + sshCfg.AddHostKey(key) } if loadedHostKeys == 0 { return fmt.Errorf("No host keys could be loaded, aborting") @@ -131,55 +130,31 @@ func Run(cfg *config.Config) error { continue } - go handleConn(nconn, config, cfg) + go acceptConn(cfg, sshCfg, nconn) } } -func handleConn(nconn net.Conn, sshCfg *ssh.ServerConfig, cfg *config.Config) { +func acceptConn(cfg *config.Config, sshCfg *ssh.ServerConfig, nconn net.Conn) { begin := time.Now() - defer func() { - sshdConnectionDuration.Observe(time.Since(begin).Seconds()) - }() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + + defer sshdConnectionDuration.Observe(time.Since(begin).Seconds()) defer nconn.Close() - conn, chans, reqs, err := ssh.NewServerConn(nconn, sshCfg) + defer cancel() + + sconn, chans, reqs, err := ssh.NewServerConn(nconn, sshCfg) if err != nil { log.Infof("Failed to initialize SSH connection: %v", err) return } - concurrentSessions := semaphore.NewWeighted(cfg.Server.ConcurrentSessionsLimit) - go ssh.DiscardRequests(reqs) - for newChannel := range chans { - if newChannel.ChannelType() != "session" { - newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") - continue - } - if !concurrentSessions.TryAcquire(1) { - newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions") - sshdHitMaxSessions.Inc() - continue - } - ch, requests, err := newChannel.Accept() - if err != nil { - log.Infof("Could not accept channel: %v", err) - concurrentSessions.Release(1) - continue - } - - go func() { - defer concurrentSessions.Release(1) - session := &session{ - cfg: cfg, - channel: ch, - sconn: conn, - nconn: nconn, - } - session.handle(ctx, requests) - }() + conn := &connection{ + cfg: cfg, + gitlabKeyId: sconn.Permissions.Extensions["key-id"], + remoteAddr: nconn.RemoteAddr().(*net.TCPAddr).String(), } + + conn.handle(ctx, chans) }