Skip to content
Snippets Groups Projects
Commit 7e55ecd6 authored by Nick Thomas's avatar Nick Thomas
Browse files

sshd: Extract connections into their own file

parent 31920be4
No related merge requests found
package sshd
import (
"context"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/semaphore"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
)
type connection struct {
// State set up by the sshd
cfg *config.Config
gitlabKeyId string
remoteAddr string
}
func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel) {
concurrentSessions := semaphore.NewWeighted(c.cfg.Server.ConcurrentSessionsLimit)
for newChannel := range chans {
if newChannel.ChannelType() != "session" {
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
}
if !concurrentSessions.TryAcquire(1) {
newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions")
sshdHitMaxSessions.Inc()
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
log.Infof("Could not accept channel: %v", err)
concurrentSessions.Release(1)
continue
}
go func() {
defer concurrentSessions.Release(1)
session := &session{
cfg: c.cfg,
channel: channel,
gitlabKeyId: c.gitlabKeyId,
remoteAddr: c.remoteAddr,
}
session.handle(ctx, requests)
}()
}
}
......@@ -3,7 +3,6 @@ package sshd
import (
"context"
"fmt"
"net"
"golang.org/x/crypto/ssh"
......@@ -15,11 +14,11 @@ import (
)
type session struct {
// State set up by handleConn
cfg *config.Config
channel ssh.Channel
sconn *ssh.ServerConn
nconn net.Conn
// State set up by the connection
cfg *config.Config
channel ssh.Channel
gitlabKeyId string
remoteAddr string
// State managed by the session
execCmd string
......@@ -106,12 +105,12 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
}
args := &commandargs.Shell{
GitlabKeyId: s.sconn.Permissions.Extensions["key-id"],
GitlabKeyId: s.gitlabKeyId,
Env: sshenv.Env{
IsSSHConnection: true,
OriginalCommand: s.execCmd,
GitProtocolVersion: s.gitProtocolVersion,
RemoteAddr: s.nconn.RemoteAddr().(*net.TCPAddr).String(),
RemoteAddr: s.remoteAddr,
},
}
......
......@@ -16,7 +16,6 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/semaphore"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys"
......@@ -81,7 +80,7 @@ func Run(cfg *config.Config) error {
log.Infof("Listening on %v", sshListener.Addr().String())
config := &ssh.ServerConfig{
sshCfg := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if conn.User() != cfg.User {
return nil, errors.New("unknown user")
......@@ -118,7 +117,7 @@ func Run(cfg *config.Config) error {
continue
}
loadedHostKeys++
config.AddHostKey(key)
sshCfg.AddHostKey(key)
}
if loadedHostKeys == 0 {
return fmt.Errorf("No host keys could be loaded, aborting")
......@@ -131,55 +130,31 @@ func Run(cfg *config.Config) error {
continue
}
go handleConn(nconn, config, cfg)
go acceptConn(cfg, sshCfg, nconn)
}
}
func handleConn(nconn net.Conn, sshCfg *ssh.ServerConfig, cfg *config.Config) {
func acceptConn(cfg *config.Config, sshCfg *ssh.ServerConfig, nconn net.Conn) {
begin := time.Now()
defer func() {
sshdConnectionDuration.Observe(time.Since(begin).Seconds())
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
defer sshdConnectionDuration.Observe(time.Since(begin).Seconds())
defer nconn.Close()
conn, chans, reqs, err := ssh.NewServerConn(nconn, sshCfg)
defer cancel()
sconn, chans, reqs, err := ssh.NewServerConn(nconn, sshCfg)
if err != nil {
log.Infof("Failed to initialize SSH connection: %v", err)
return
}
concurrentSessions := semaphore.NewWeighted(cfg.Server.ConcurrentSessionsLimit)
go ssh.DiscardRequests(reqs)
for newChannel := range chans {
if newChannel.ChannelType() != "session" {
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
}
if !concurrentSessions.TryAcquire(1) {
newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions")
sshdHitMaxSessions.Inc()
continue
}
ch, requests, err := newChannel.Accept()
if err != nil {
log.Infof("Could not accept channel: %v", err)
concurrentSessions.Release(1)
continue
}
go func() {
defer concurrentSessions.Release(1)
session := &session{
cfg: cfg,
channel: ch,
sconn: conn,
nconn: nconn,
}
session.handle(ctx, requests)
}()
conn := &connection{
cfg: cfg,
gitlabKeyId: sconn.Permissions.Extensions["key-id"],
remoteAddr: nconn.RemoteAddr().(*net.TCPAddr).String(),
}
conn.handle(ctx, chans)
}
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment