sshserver_unix.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. //+build !windows
  2. package sshserver
  3. import (
  4. "crypto/ecdsa"
  5. "crypto/elliptic"
  6. "crypto/rand"
  7. "encoding/binary"
  8. "encoding/json"
  9. "fmt"
  10. "io"
  11. "net"
  12. "runtime"
  13. "strings"
  14. "time"
  15. "github.com/cloudflare/cloudflared/logger"
  16. "github.com/cloudflare/cloudflared/sshgen"
  17. "github.com/cloudflare/cloudflared/sshlog"
  18. "github.com/gliderlabs/ssh"
  19. "github.com/google/uuid"
  20. "github.com/pkg/errors"
  21. gossh "golang.org/x/crypto/ssh"
  22. )
  23. const (
  24. auditEventStart = "session_start"
  25. auditEventStop = "session_stop"
  26. auditEventExec = "exec"
  27. auditEventScp = "scp"
  28. auditEventResize = "resize"
  29. auditEventShell = "shell"
  30. sshContextSessionID = "sessionID"
  31. sshContextEventLogger = "eventLogger"
  32. sshContextPreamble = "sshPreamble"
  33. sshContextSSHClient = "sshClient"
  34. SSHPreambleLength = 2
  35. defaultSSHPort = "22"
  36. )
  37. type auditEvent struct {
  38. Event string `json:"event,omitempty"`
  39. EventType string `json:"event_type,omitempty"`
  40. SessionID string `json:"session_id,omitempty"`
  41. User string `json:"user,omitempty"`
  42. Login string `json:"login,omitempty"`
  43. Datetime string `json:"datetime,omitempty"`
  44. Hostname string `json:"hostname,omitempty"`
  45. Destination string `json:"destination,omitempty"`
  46. }
  47. // sshConn wraps the incoming net.Conn and a cleanup function
  48. // This is done to allow the outgoing SSH client to be retrieved and closed when the conn itself is closed.
  49. type sshConn struct {
  50. net.Conn
  51. cleanupFunc func()
  52. }
  53. // close calls the cleanupFunc before closing the conn
  54. func (c sshConn) Close() error {
  55. c.cleanupFunc()
  56. return c.Conn.Close()
  57. }
  58. type SSHProxy struct {
  59. ssh.Server
  60. hostname string
  61. logger logger.Service
  62. shutdownC chan struct{}
  63. caCert ssh.PublicKey
  64. logManager sshlog.Manager
  65. }
  66. type SSHPreamble struct {
  67. Destination string
  68. JWT string
  69. }
  70. // New creates a new SSHProxy and configures its host keys and authentication by the data provided
  71. func New(logManager sshlog.Manager, logger logger.Service, version, localAddress, hostname, hostKeyDir string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration) (*SSHProxy, error) {
  72. sshProxy := SSHProxy{
  73. hostname: hostname,
  74. logger: logger,
  75. shutdownC: shutdownC,
  76. logManager: logManager,
  77. }
  78. sshProxy.Server = ssh.Server{
  79. Addr: localAddress,
  80. MaxTimeout: maxTimeout,
  81. IdleTimeout: idleTimeout,
  82. Version: fmt.Sprintf("SSH-2.0-Cloudflare-Access_%s_%s", version, runtime.GOOS),
  83. PublicKeyHandler: sshProxy.proxyAuthCallback,
  84. ConnCallback: sshProxy.connCallback,
  85. ChannelHandlers: map[string]ssh.ChannelHandler{
  86. "default": sshProxy.channelHandler,
  87. },
  88. }
  89. if err := sshProxy.configureHostKeys(hostKeyDir); err != nil {
  90. return nil, err
  91. }
  92. return &sshProxy, nil
  93. }
  94. // Start the SSH proxy listener to start handling SSH connections from clients
  95. func (s *SSHProxy) Start() error {
  96. s.logger.Infof("Starting SSH server at %s", s.Addr)
  97. go func() {
  98. <-s.shutdownC
  99. if err := s.Close(); err != nil {
  100. s.logger.Errorf("Cannot close SSH server: %s", err)
  101. }
  102. }()
  103. return s.ListenAndServe()
  104. }
  105. // proxyAuthCallback attempts to connect to ultimate SSH destination. If successful, it allows the incoming connection
  106. // to connect to the proxy and saves the outgoing SSH client to the context. Otherwise, no connection to the
  107. // the proxy is allowed.
  108. func (s *SSHProxy) proxyAuthCallback(ctx ssh.Context, key ssh.PublicKey) bool {
  109. client, err := s.dialDestination(ctx)
  110. if err != nil {
  111. return false
  112. }
  113. ctx.SetValue(sshContextSSHClient, client)
  114. return true
  115. }
  116. // connCallback reads the preamble sent from the proxy server and saves an audit event logger to the context.
  117. // If any errors occur, the connection is terminated by returning nil from the callback.
  118. func (s *SSHProxy) connCallback(ctx ssh.Context, conn net.Conn) net.Conn {
  119. // AUTH-2050: This is a temporary workaround of a timing issue in the tunnel muxer to allow further testing.
  120. // TODO: Remove this
  121. time.Sleep(10 * time.Millisecond)
  122. preamble, err := s.readPreamble(conn)
  123. if err != nil {
  124. if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
  125. s.logger.Info("Could not establish session. Client likely does not have --destination set and is using old-style ssh config")
  126. } else if err != io.EOF {
  127. s.logger.Errorf("failed to read SSH preamble: %s", err)
  128. }
  129. return nil
  130. }
  131. ctx.SetValue(sshContextPreamble, preamble)
  132. logger, sessionID, err := s.auditLogger()
  133. if err != nil {
  134. s.logger.Errorf("failed to configure logger: %s", err)
  135. return nil
  136. }
  137. ctx.SetValue(sshContextEventLogger, logger)
  138. ctx.SetValue(sshContextSessionID, sessionID)
  139. // attempts to retrieve and close the outgoing ssh client when the incoming conn is closed.
  140. // If no client exists, the conn is being closed before the PublicKeyCallback was called (where the client is created).
  141. cleanupFunc := func() {
  142. client, ok := ctx.Value(sshContextSSHClient).(*gossh.Client)
  143. if ok && client != nil {
  144. client.Close()
  145. }
  146. }
  147. return sshConn{conn, cleanupFunc}
  148. }
  149. // channelHandler proxies incoming and outgoing SSH traffic back and forth over an SSH Channel
  150. func (s *SSHProxy) channelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
  151. if newChan.ChannelType() != "session" && newChan.ChannelType() != "direct-tcpip" {
  152. msg := fmt.Sprintf("channel type %s is not supported", newChan.ChannelType())
  153. s.logger.Info(msg)
  154. if err := newChan.Reject(gossh.UnknownChannelType, msg); err != nil {
  155. s.logger.Errorf("Error rejecting SSH channel: %s", err)
  156. }
  157. return
  158. }
  159. localChan, localChanReqs, err := newChan.Accept()
  160. if err != nil {
  161. s.logger.Errorf("Failed to accept session channel: %s", err)
  162. return
  163. }
  164. defer localChan.Close()
  165. // client will be closed when the sshConn is closed
  166. client, ok := ctx.Value(sshContextSSHClient).(*gossh.Client)
  167. if !ok {
  168. s.logger.Error("Could not retrieve client from context")
  169. return
  170. }
  171. remoteChan, remoteChanReqs, err := client.OpenChannel(newChan.ChannelType(), newChan.ExtraData())
  172. if err != nil {
  173. s.logger.Errorf("Failed to open remote channel: %s", err)
  174. return
  175. }
  176. defer remoteChan.Close()
  177. // Proxy ssh traffic back and forth between client and destination
  178. s.proxyChannel(localChan, remoteChan, localChanReqs, remoteChanReqs, conn, ctx)
  179. }
  180. // proxyChannel couples two SSH channels and proxies SSH traffic and channel requests back and forth.
  181. func (s *SSHProxy) proxyChannel(localChan, remoteChan gossh.Channel, localChanReqs, remoteChanReqs <-chan *gossh.Request, conn *gossh.ServerConn, ctx ssh.Context) {
  182. done := make(chan struct{}, 2)
  183. go func() {
  184. if _, err := io.Copy(localChan, remoteChan); err != nil {
  185. s.logger.Errorf("remote to local copy error: %s", err)
  186. }
  187. done <- struct{}{}
  188. }()
  189. go func() {
  190. if _, err := io.Copy(remoteChan, localChan); err != nil {
  191. s.logger.Errorf("local to remote copy error: %s", err)
  192. }
  193. done <- struct{}{}
  194. }()
  195. // stderr streams are used non-pty sessions since they have distinct IO streams.
  196. remoteStderr := remoteChan.Stderr()
  197. localStderr := localChan.Stderr()
  198. go func() {
  199. if _, err := io.Copy(remoteStderr, localStderr); err != nil {
  200. s.logger.Errorf("stderr local to remote copy error: %s", err)
  201. }
  202. }()
  203. go func() {
  204. if _, err := io.Copy(localStderr, remoteStderr); err != nil {
  205. s.logger.Errorf("stderr remote to local copy error: %s", err)
  206. }
  207. }()
  208. s.logAuditEvent(conn, "", auditEventStart, ctx)
  209. defer s.logAuditEvent(conn, "", auditEventStop, ctx)
  210. // Proxy channel requests
  211. for {
  212. select {
  213. case req := <-localChanReqs:
  214. if req == nil {
  215. return
  216. }
  217. if err := s.forwardChannelRequest(remoteChan, req); err != nil {
  218. s.logger.Errorf("Failed to forward request: %s", err)
  219. return
  220. }
  221. s.logChannelRequest(req, conn, ctx)
  222. case req := <-remoteChanReqs:
  223. if req == nil {
  224. return
  225. }
  226. if err := s.forwardChannelRequest(localChan, req); err != nil {
  227. s.logger.Errorf("Failed to forward request: %s", err)
  228. return
  229. }
  230. case <-done:
  231. return
  232. }
  233. }
  234. }
  235. // readPreamble reads a preamble from the SSH connection before any SSH traffic is sent.
  236. // This preamble is a JSON encoded struct containing the users JWT and ultimate destination.
  237. // The first 4 bytes contain the length of the preamble which follows immediately.
  238. func (s *SSHProxy) readPreamble(conn net.Conn) (*SSHPreamble, error) {
  239. // Set conn read deadline while reading preamble to prevent hangs if preamble wasnt sent.
  240. if err := conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)); err != nil {
  241. return nil, errors.Wrap(err, "failed to set conn deadline")
  242. }
  243. defer func() {
  244. if err := conn.SetReadDeadline(time.Time{}); err != nil {
  245. s.logger.Errorf("Failed to unset conn read deadline: %s", err)
  246. }
  247. }()
  248. size := make([]byte, SSHPreambleLength)
  249. if _, err := io.ReadFull(conn, size); err != nil {
  250. return nil, err
  251. }
  252. payloadLength := binary.BigEndian.Uint16(size)
  253. payload := make([]byte, payloadLength)
  254. if _, err := io.ReadFull(conn, payload); err != nil {
  255. return nil, err
  256. }
  257. var preamble SSHPreamble
  258. err := json.Unmarshal(payload, &preamble)
  259. if err != nil {
  260. return nil, err
  261. }
  262. preamble.Destination, err = canonicalizeDest(preamble.Destination)
  263. if err != nil {
  264. return nil, err
  265. }
  266. return &preamble, nil
  267. }
  268. // canonicalizeDest adds a default port if one doesnt exist
  269. func canonicalizeDest(dest string) (string, error) {
  270. _, _, err := net.SplitHostPort(dest)
  271. // if host and port are split without error, a port exists.
  272. if err != nil {
  273. addrErr, ok := err.(*net.AddrError)
  274. if !ok {
  275. return "", err
  276. }
  277. // If the port is missing, append it.
  278. if addrErr.Err == "missing port in address" {
  279. return fmt.Sprintf("%s:%s", dest, defaultSSHPort), nil
  280. }
  281. // If there are too many colons and address is IPv6, wrap in brackets and append port. Otherwise invalid address
  282. ip := net.ParseIP(dest)
  283. if addrErr.Err == "too many colons in address" && ip != nil && ip.To4() == nil {
  284. return fmt.Sprintf("[%s]:%s", dest, defaultSSHPort), nil
  285. }
  286. return "", addrErr
  287. }
  288. return dest, nil
  289. }
  290. // dialDestination creates a new SSH client and dials the destination server
  291. func (s *SSHProxy) dialDestination(ctx ssh.Context) (*gossh.Client, error) {
  292. preamble, ok := ctx.Value(sshContextPreamble).(*SSHPreamble)
  293. if !ok {
  294. msg := "failed to retrieve SSH preamble from context"
  295. s.logger.Error(msg)
  296. return nil, errors.New(msg)
  297. }
  298. signer, err := s.genSSHSigner(preamble.JWT)
  299. if err != nil {
  300. s.logger.Errorf("Failed to generate signed short lived cert: %s", err)
  301. return nil, err
  302. }
  303. s.logger.Debugf("Short lived certificate for %s connecting to %s:\n\n%s", ctx.User(), preamble.Destination, gossh.MarshalAuthorizedKey(signer.PublicKey()))
  304. clientConfig := &gossh.ClientConfig{
  305. User: ctx.User(),
  306. // AUTH-2103 TODO: proper host key check
  307. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  308. Auth: []gossh.AuthMethod{gossh.PublicKeys(signer)},
  309. ClientVersion: ctx.ServerVersion(),
  310. }
  311. client, err := gossh.Dial("tcp", preamble.Destination, clientConfig)
  312. if err != nil {
  313. s.logger.Errorf("Failed to connect to destination SSH server: %s", err)
  314. return nil, err
  315. }
  316. return client, nil
  317. }
  318. // Generates a key pair and sends public key to get signed by CA
  319. func (s *SSHProxy) genSSHSigner(jwt string) (gossh.Signer, error) {
  320. key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  321. if err != nil {
  322. return nil, errors.Wrap(err, "failed to generate ecdsa key pair")
  323. }
  324. pub, err := gossh.NewPublicKey(&key.PublicKey)
  325. if err != nil {
  326. return nil, errors.Wrap(err, "failed to convert ecdsa public key to SSH public key")
  327. }
  328. pubBytes := gossh.MarshalAuthorizedKey(pub)
  329. signedCertBytes, err := sshgen.SignCert(jwt, string(pubBytes))
  330. if err != nil {
  331. return nil, errors.Wrap(err, "failed to retrieve cert from SSHCAAPI")
  332. }
  333. signedPub, _, _, _, err := gossh.ParseAuthorizedKey([]byte(signedCertBytes))
  334. if err != nil {
  335. return nil, errors.Wrap(err, "failed to parse SSH public key")
  336. }
  337. cert, ok := signedPub.(*gossh.Certificate)
  338. if !ok {
  339. return nil, errors.Wrap(err, "failed to assert public key as certificate")
  340. }
  341. signer, err := gossh.NewSignerFromKey(key)
  342. if err != nil {
  343. return nil, errors.Wrap(err, "failed to create signer")
  344. }
  345. certSigner, err := gossh.NewCertSigner(cert, signer)
  346. if err != nil {
  347. return nil, errors.Wrap(err, "failed to create cert signer")
  348. }
  349. return certSigner, nil
  350. }
  351. // forwardChannelRequest sends request req to SSH channel sshChan, waits for reply, and sends the reply back.
  352. func (s *SSHProxy) forwardChannelRequest(sshChan gossh.Channel, req *gossh.Request) error {
  353. reply, err := sshChan.SendRequest(req.Type, req.WantReply, req.Payload)
  354. if err != nil {
  355. return errors.Wrap(err, "Failed to send request")
  356. }
  357. if err := req.Reply(reply, nil); err != nil {
  358. return errors.Wrap(err, "Failed to reply to request")
  359. }
  360. return nil
  361. }
  362. // logChannelRequest creates an audit log for different types of channel requests
  363. func (s *SSHProxy) logChannelRequest(req *gossh.Request, conn *gossh.ServerConn, ctx ssh.Context) {
  364. var eventType string
  365. var event string
  366. switch req.Type {
  367. case "exec":
  368. var payload struct{ Value string }
  369. if err := gossh.Unmarshal(req.Payload, &payload); err != nil {
  370. s.logger.Errorf("Failed to unmarshal channel request payload: %s:%s with error: %s", req.Type, req.Payload, err)
  371. }
  372. event = payload.Value
  373. eventType = auditEventExec
  374. if strings.HasPrefix(string(req.Payload), "scp") {
  375. eventType = auditEventScp
  376. }
  377. case "shell":
  378. eventType = auditEventShell
  379. case "window-change":
  380. eventType = auditEventResize
  381. default:
  382. return
  383. }
  384. s.logAuditEvent(conn, event, eventType, ctx)
  385. }
  386. func (s *SSHProxy) auditLogger() (io.WriteCloser, string, error) {
  387. sessionUUID, err := uuid.NewRandom()
  388. if err != nil {
  389. return nil, "", errors.Wrap(err, "failed to create sessionID")
  390. }
  391. sessionID := sessionUUID.String()
  392. writer, err := s.logManager.NewLogger(fmt.Sprintf("%s-event.log", sessionID), s.logger)
  393. if err != nil {
  394. return nil, "", errors.Wrap(err, "failed to create logger")
  395. }
  396. return writer, sessionID, nil
  397. }
  398. func (s *SSHProxy) logAuditEvent(conn *gossh.ServerConn, event, eventType string, ctx ssh.Context) {
  399. sessionID, sessionIDOk := ctx.Value(sshContextSessionID).(string)
  400. writer, writerOk := ctx.Value(sshContextEventLogger).(io.WriteCloser)
  401. if !writerOk || !sessionIDOk {
  402. s.logger.Error("Failed to retrieve audit logger from context")
  403. return
  404. }
  405. var destination string
  406. preamble, ok := ctx.Value(sshContextPreamble).(*SSHPreamble)
  407. if ok {
  408. destination = preamble.Destination
  409. } else {
  410. s.logger.Error("Failed to retrieve SSH preamble from context")
  411. }
  412. ae := auditEvent{
  413. Event: event,
  414. EventType: eventType,
  415. SessionID: sessionID,
  416. User: conn.User(),
  417. Login: conn.User(),
  418. Datetime: time.Now().UTC().Format(time.RFC3339),
  419. Hostname: s.hostname,
  420. Destination: destination,
  421. }
  422. data, err := json.Marshal(&ae)
  423. if err != nil {
  424. s.logger.Errorf("Failed to marshal audit event. malformed audit object: %s", err)
  425. return
  426. }
  427. line := string(data) + "\n"
  428. if _, err := writer.Write([]byte(line)); err != nil {
  429. s.logger.Errorf("Failed to write audit event: %s", err)
  430. }
  431. }