diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go
new file mode 100644
index 0000000000000000000000000000000000000000..17e2adb73b25436153f85c73ab889d1bca31be68
--- /dev/null
+++ b/internal/sshd/connection.go
@@ -0,0 +1,52 @@
+package sshd
+
+import (
+	"context"
+
+	log "github.com/sirupsen/logrus"
+	"golang.org/x/crypto/ssh"
+	"golang.org/x/sync/semaphore"
+
+	"gitlab.com/gitlab-org/gitlab-shell/internal/config"
+)
+
+type connection struct {
+	// State set up by the sshd
+	cfg *config.Config
+	gitlabKeyId string
+	remoteAddr  string
+}
+
+func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel) {
+	concurrentSessions := semaphore.NewWeighted(c.cfg.Server.ConcurrentSessionsLimit)
+
+	for newChannel := range chans {
+		if newChannel.ChannelType() != "session" {
+			newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
+			continue
+		}
+		if !concurrentSessions.TryAcquire(1) {
+			newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions")
+			sshdHitMaxSessions.Inc()
+			continue
+		}
+		channel, requests, err := newChannel.Accept()
+		if err != nil {
+			log.Infof("Could not accept channel: %v", err)
+			concurrentSessions.Release(1)
+			continue
+		}
+
+		go func() {
+			defer concurrentSessions.Release(1)
+			session := &session{
+				cfg:         c.cfg,
+				channel:     channel,
+				gitlabKeyId: c.gitlabKeyId,
+				remoteAddr:  c.remoteAddr,
+			}
+
+			session.handle(ctx, requests)
+		}()
+	}
+}
diff --git a/internal/sshd/session.go b/internal/sshd/session.go
index e178fe81883db16203d7be4d85276ea6dcf4d2fe..22cb715747ad82dce031858d062940f2fbc894b0 100644
--- a/internal/sshd/session.go
+++ b/internal/sshd/session.go
@@ -3,7 +3,6 @@ package sshd
 import (
 	"context"
 	"fmt"
-	"net"
 
 	"golang.org/x/crypto/ssh"
 
@@ -15,11 +14,11 @@ import (
 )
 
 type session struct {
-	// State set up by handleConn
-	cfg     *config.Config
-	channel ssh.Channel
-	sconn   *ssh.ServerConn
-	nconn   net.Conn
+	// State set up by the connection
+	cfg         *config.Config
+	channel     ssh.Channel
+	gitlabKeyId string
+	remoteAddr  string
 
 	// State managed by the session
 	execCmd            string
@@ -106,12 +105,12 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
 	}
 
 	args := &commandargs.Shell{
-		GitlabKeyId: s.sconn.Permissions.Extensions["key-id"],
+		GitlabKeyId: s.gitlabKeyId,
 		Env: sshenv.Env{
 			IsSSHConnection:    true,
 			OriginalCommand:    s.execCmd,
 			GitProtocolVersion: s.gitProtocolVersion,
-			RemoteAddr:         s.nconn.RemoteAddr().(*net.TCPAddr).String(),
+			RemoteAddr:         s.remoteAddr,
 		},
 	}
 
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index 7906f0d53104af965338ac85f11ebab928edd98e..35952686f07b948ba4dcc1add09236ea6b3f58a7 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -16,7 +16,6 @@ import (
 	"github.com/prometheus/client_golang/prometheus"
 	"github.com/prometheus/client_golang/prometheus/promauto"
 	"golang.org/x/crypto/ssh"
-	"golang.org/x/sync/semaphore"
 
 	"gitlab.com/gitlab-org/gitlab-shell/internal/config"
 	"gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys"
@@ -81,7 +80,7 @@ func Run(cfg *config.Config) error {
 
 	log.Infof("Listening on %v", sshListener.Addr().String())
 
-	config := &ssh.ServerConfig{
+	sshCfg := &ssh.ServerConfig{
 		PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
 			if conn.User() != cfg.User {
 				return nil, errors.New("unknown user")
@@ -118,7 +117,7 @@ func Run(cfg *config.Config) error {
 			continue
 		}
 		loadedHostKeys++
-		config.AddHostKey(key)
+		sshCfg.AddHostKey(key)
 	}
 	if loadedHostKeys == 0 {
 		return fmt.Errorf("No host keys could be loaded, aborting")
@@ -131,55 +130,31 @@ func Run(cfg *config.Config) error {
 			continue
 		}
 
-		go handleConn(nconn, config, cfg)
+		go acceptConn(cfg, sshCfg, nconn)
 	}
 }
 
-func handleConn(nconn net.Conn, sshCfg *ssh.ServerConfig, cfg *config.Config) {
+func acceptConn(cfg *config.Config, sshCfg *ssh.ServerConfig, nconn net.Conn) {
 	begin := time.Now()
-	defer func() {
-		sshdConnectionDuration.Observe(time.Since(begin).Seconds())
-	}()
-
 	ctx, cancel := context.WithCancel(context.Background())
-	defer cancel()
+
+	defer sshdConnectionDuration.Observe(time.Since(begin).Seconds())
 	defer nconn.Close()
-	conn, chans, reqs, err := ssh.NewServerConn(nconn, sshCfg)
+	defer cancel()
+
+	sconn, chans, reqs, err := ssh.NewServerConn(nconn, sshCfg)
 	if err != nil {
 		log.Infof("Failed to initialize SSH connection: %v", err)
 		return
 	}
 
-	concurrentSessions := semaphore.NewWeighted(cfg.Server.ConcurrentSessionsLimit)
-
 	go ssh.DiscardRequests(reqs)
-	for newChannel := range chans {
-		if newChannel.ChannelType() != "session" {
-			newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
-			continue
-		}
-		if !concurrentSessions.TryAcquire(1) {
-			newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions")
-			sshdHitMaxSessions.Inc()
-			continue
-		}
-		ch, requests, err := newChannel.Accept()
-		if err != nil {
-			log.Infof("Could not accept channel: %v", err)
-			concurrentSessions.Release(1)
-			continue
-		}
-
-		go func() {
-			defer concurrentSessions.Release(1)
-			session := &session{
-				cfg:     cfg,
-				channel: ch,
-				sconn:   conn,
-				nconn:   nconn,
-			}
 
-			session.handle(ctx, requests)
-		}()
+	conn := &connection{
+		cfg:         cfg,
+		gitlabKeyId: sconn.Permissions.Extensions["key-id"],
+		remoteAddr:  nconn.RemoteAddr().(*net.TCPAddr).String(),
 	}
+
+	conn.handle(ctx, chans)
 }