server.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. // Snowflake-specific websocket server plugin. It reports the transport name as
  2. // "snowflake".
  3. package main
  4. import (
  5. "errors"
  6. "flag"
  7. "fmt"
  8. "io"
  9. "log"
  10. "net"
  11. "net/http"
  12. "os"
  13. "os/signal"
  14. "path/filepath"
  15. "strconv"
  16. "strings"
  17. "sync"
  18. "syscall"
  19. "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/ptutil/safelog"
  20. "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/version"
  21. "golang.org/x/crypto/acme/autocert"
  22. pt "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib"
  23. sf "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/server/lib"
  24. )
  25. const ptMethodName = "snowflake"
  26. var ptInfo pt.ServerInfo
  27. func usage() {
  28. fmt.Fprintf(os.Stderr, `Usage: %s [OPTIONS]
  29. WebSocket server pluggable transport for Snowflake. Works only as a managed
  30. proxy. Uses TLS with ACME (Let's Encrypt) by default. Set the certificate
  31. hostnames with the --acme-hostnames option. Use ServerTransportListenAddr in
  32. torrc to choose the listening port. When using TLS, this program will open an
  33. additional HTTP listener on port 80 to work with ACME.
  34. `, os.Args[0])
  35. flag.PrintDefaults()
  36. }
  37. // proxy copies data bidirectionally from one connection to another.
  38. func proxy(local *net.TCPConn, conn net.Conn) {
  39. var wg sync.WaitGroup
  40. wg.Add(2)
  41. go func() {
  42. if _, err := io.Copy(conn, local); err != nil && !errors.Is(err, io.ErrClosedPipe) {
  43. log.Printf("error copying ORPort to WebSocket %v", err)
  44. }
  45. local.CloseRead()
  46. conn.Close()
  47. wg.Done()
  48. }()
  49. go func() {
  50. if _, err := io.Copy(local, conn); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrClosedPipe) {
  51. log.Printf("error copying WebSocket to ORPort %v", err)
  52. }
  53. local.CloseWrite()
  54. conn.Close()
  55. wg.Done()
  56. }()
  57. wg.Wait()
  58. }
  59. // handleConn bidirectionally connects a client snowflake connection with the
  60. // ORPort. If orPortSrcAddr is not nil, addresses from the given range are used
  61. // when dialing the ORPOrt.
  62. func handleConn(conn net.Conn, orPortSrcAddr *net.IPNet) error {
  63. addr := conn.RemoteAddr().String()
  64. statsChannel <- addr != ""
  65. dialer := net.Dialer{
  66. Control: dialerControl,
  67. }
  68. if orPortSrcAddr != nil {
  69. // Use a random source IP address in the given range.
  70. ip, err := randIPAddr(orPortSrcAddr)
  71. if err != nil {
  72. return err
  73. }
  74. dialer.LocalAddr = &net.TCPAddr{IP: ip}
  75. }
  76. or, err := pt.DialOrWithDialer(&dialer, &ptInfo, addr, ptMethodName)
  77. if err != nil {
  78. return fmt.Errorf("failed to connect to ORPort: %s", err)
  79. }
  80. defer or.Close()
  81. proxy(or.(*net.TCPConn), conn)
  82. return nil
  83. }
  84. // acceptLoop accepts incoming client snowflake connections and passes them to
  85. // handleConn. If orPortSrcAddr is not nil, addresses from the given range are
  86. // used when dialing the ORPOrt.
  87. func acceptLoop(ln net.Listener, orPortSrcAddr *net.IPNet) {
  88. for {
  89. conn, err := ln.Accept()
  90. if err != nil {
  91. if err, ok := err.(net.Error); ok && err.Temporary() {
  92. continue
  93. }
  94. log.Printf("Snowflake accept error: %s", err)
  95. break
  96. }
  97. go func() {
  98. defer conn.Close()
  99. err := handleConn(conn, orPortSrcAddr)
  100. if err != nil {
  101. log.Printf("handleConn: %v", err)
  102. }
  103. }()
  104. }
  105. }
  106. func getCertificateCacheDir() (string, error) {
  107. stateDir, err := pt.MakeStateDir()
  108. if err != nil {
  109. return "", err
  110. }
  111. return filepath.Join(stateDir, "snowflake-certificate-cache"), nil
  112. }
  113. func main() {
  114. var acmeEmail string
  115. var acmeHostnamesCommas string
  116. var disableTLS bool
  117. var logFilename string
  118. var unsafeLogging bool
  119. var versionFlag bool
  120. flag.Usage = usage
  121. flag.StringVar(&acmeEmail, "acme-email", "", "optional contact email for Let's Encrypt notifications")
  122. flag.StringVar(&acmeHostnamesCommas, "acme-hostnames", "", "comma-separated hostnames for TLS certificate")
  123. flag.BoolVar(&disableTLS, "disable-tls", false, "don't use HTTPS")
  124. flag.StringVar(&logFilename, "log", "", "log file to write to")
  125. flag.BoolVar(&unsafeLogging, "unsafe-logging", false, "prevent logs from being scrubbed")
  126. flag.BoolVar(&versionFlag, "version", false, "display version info to stderr and quit")
  127. flag.Parse()
  128. if versionFlag {
  129. fmt.Fprintf(os.Stderr, "snowflake-server %s", version.ConstructResult())
  130. os.Exit(0)
  131. }
  132. log.SetFlags(log.LstdFlags | log.LUTC)
  133. var logOutput io.Writer = os.Stderr
  134. if logFilename != "" {
  135. f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
  136. if err != nil {
  137. log.Fatalf("can't open log file: %s", err)
  138. }
  139. defer f.Close()
  140. logOutput = f
  141. }
  142. if unsafeLogging {
  143. log.SetOutput(logOutput)
  144. } else {
  145. // We want to send the log output through our scrubber first
  146. log.SetOutput(&safelog.LogScrubber{Output: logOutput})
  147. }
  148. log.Printf("snowflake-server %s\n", version.GetVersion())
  149. if !disableTLS && acmeHostnamesCommas == "" {
  150. log.Fatal("the --acme-hostnames option is required")
  151. }
  152. acmeHostnames := strings.Split(acmeHostnamesCommas, ",")
  153. log.Printf("starting")
  154. var err error
  155. ptInfo, err = pt.ServerSetup(nil)
  156. if err != nil {
  157. log.Fatalf("error in setup: %s", err)
  158. }
  159. pt.ReportVersion("snowflake-server", version.GetVersion())
  160. go statsThread()
  161. var certManager *autocert.Manager
  162. if !disableTLS {
  163. log.Printf("ACME hostnames: %q", acmeHostnames)
  164. var cache autocert.Cache
  165. var cacheDir string
  166. cacheDir, err = getCertificateCacheDir()
  167. if err == nil {
  168. log.Printf("caching ACME certificates in directory %q", cacheDir)
  169. cache = autocert.DirCache(cacheDir)
  170. } else {
  171. log.Printf("disabling ACME certificate cache: %s", err)
  172. }
  173. certManager = &autocert.Manager{
  174. Prompt: autocert.AcceptTOS,
  175. HostPolicy: autocert.HostWhitelist(acmeHostnames...),
  176. Email: acmeEmail,
  177. Cache: cache,
  178. }
  179. }
  180. // The ACME HTTP-01 responder only works when it is running on port 80.
  181. // We actually open the port in the loop below, so that any errors can
  182. // be reported in the SMETHOD-ERROR of some bindaddr.
  183. // https://github.com/ietf-wg-acme/acme/blob/master/draft-ietf-acme-acme.md#http-challenge
  184. needHTTP01Listener := !disableTLS
  185. listeners := make([]net.Listener, 0)
  186. for _, bindaddr := range ptInfo.Bindaddrs {
  187. if bindaddr.MethodName != ptMethodName {
  188. pt.SmethodError(bindaddr.MethodName, "no such method")
  189. continue
  190. }
  191. if needHTTP01Listener {
  192. addr := *bindaddr.Addr
  193. addr.Port = 80
  194. log.Printf("Starting HTTP-01 ACME listener")
  195. var lnHTTP01 *net.TCPListener
  196. lnHTTP01, err = net.ListenTCP("tcp", &addr)
  197. if err != nil {
  198. log.Printf("error opening HTTP-01 ACME listener: %s", err)
  199. pt.SmethodError(bindaddr.MethodName, "HTTP-01 ACME listener: "+err.Error())
  200. continue
  201. }
  202. server := &http.Server{
  203. Addr: addr.String(),
  204. Handler: certManager.HTTPHandler(nil),
  205. }
  206. go func() {
  207. log.Fatal(server.Serve(lnHTTP01))
  208. }()
  209. listeners = append(listeners, lnHTTP01)
  210. needHTTP01Listener = false
  211. }
  212. // We're not capable of listening on port 0 (i.e., an ephemeral port
  213. // unknown in advance). The reason is that while the net/http package
  214. // exposes ListenAndServe and ListenAndServeTLS, those functions never
  215. // return, so there's no opportunity to find out what the port number
  216. // is, in between the Listen and Serve steps.
  217. // https://groups.google.com/d/msg/Golang-nuts/3F1VRCCENp8/3hcayZiwYM8J
  218. if bindaddr.Addr.Port == 0 {
  219. err := fmt.Errorf(
  220. "cannot listen on port %d; configure a port using ServerTransportListenAddr",
  221. bindaddr.Addr.Port)
  222. log.Printf("error opening listener: %s", err)
  223. pt.SmethodError(bindaddr.MethodName, err.Error())
  224. continue
  225. }
  226. var transport *sf.Transport
  227. args := pt.Args{}
  228. if disableTLS {
  229. args.Add("tls", "no")
  230. transport = sf.NewSnowflakeServer(nil)
  231. } else {
  232. args.Add("tls", "yes")
  233. for _, hostname := range acmeHostnames {
  234. args.Add("hostname", hostname)
  235. }
  236. transport = sf.NewSnowflakeServer(certManager.GetCertificate)
  237. }
  238. // Are we requested to use source addresses from a particular
  239. // range when dialing the ORPort for this transport?
  240. var orPortSrcAddr *net.IPNet
  241. if orPortSrcAddrCIDR, ok := bindaddr.Options.Get("orport-srcaddr"); ok {
  242. ipnet, err := parseIPCIDR(orPortSrcAddrCIDR)
  243. if err != nil {
  244. err = fmt.Errorf("parsing srcaddr: %w", err)
  245. log.Println(err)
  246. pt.SmethodError(bindaddr.MethodName, err.Error())
  247. continue
  248. }
  249. orPortSrcAddr = ipnet
  250. }
  251. numKCPInstances := 1
  252. // Are we requested to run a certain number of KCP state
  253. // machines?
  254. if value, ok := bindaddr.Options.Get("num-turbotunnel"); ok {
  255. n, err := strconv.Atoi(value)
  256. if err == nil && n < 1 {
  257. err = fmt.Errorf("cannot be less than 1")
  258. }
  259. if err != nil {
  260. err = fmt.Errorf("parsing num-turbotunnel: %w", err)
  261. log.Println(err)
  262. pt.SmethodError(bindaddr.MethodName, err.Error())
  263. continue
  264. }
  265. numKCPInstances = n
  266. }
  267. ln, err := transport.Listen(bindaddr.Addr, numKCPInstances)
  268. if err != nil {
  269. log.Printf("error opening listener: %s", err)
  270. pt.SmethodError(bindaddr.MethodName, err.Error())
  271. continue
  272. }
  273. defer ln.Close()
  274. go acceptLoop(ln, orPortSrcAddr)
  275. pt.SmethodArgs(bindaddr.MethodName, bindaddr.Addr, args)
  276. listeners = append(listeners, ln)
  277. }
  278. pt.SmethodsDone()
  279. sigChan := make(chan os.Signal, 1)
  280. signal.Notify(sigChan, syscall.SIGTERM)
  281. if os.Getenv("TOR_PT_EXIT_ON_STDIN_CLOSE") == "1" {
  282. // This environment variable means we should treat EOF on stdin
  283. // just like SIGTERM: https://bugs.torproject.org/15435.
  284. go func() {
  285. if _, err := io.Copy(io.Discard, os.Stdin); err != nil {
  286. log.Printf("error copying os.Stdin to io.Discard: %v", err)
  287. }
  288. log.Printf("synthesizing SIGTERM because of stdin close")
  289. sigChan <- syscall.SIGTERM
  290. }()
  291. }
  292. // Wait for a signal.
  293. sig := <-sigChan
  294. // Signal received, shut down.
  295. log.Printf("caught signal %q, exiting", sig)
  296. for _, ln := range listeners {
  297. ln.Close()
  298. }
  299. }