probetest.go 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. /*
  2. Probe test server to check the reachability of Snowflake proxies from
  3. clients with symmetric NATs.
  4. The probe server receives an offer from a proxy, returns an answer, and then
  5. attempts to establish a datachannel connection to that proxy. The proxy will
  6. self-determine whether the connection opened successfully.
  7. */
  8. package main
  9. import (
  10. "crypto/tls"
  11. "flag"
  12. "fmt"
  13. "io"
  14. "log"
  15. "net"
  16. "net/http"
  17. "os"
  18. "strings"
  19. "time"
  20. "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/ptutil/safelog"
  21. "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
  22. "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/util"
  23. "github.com/pion/transport/v3/stdnet"
  24. "github.com/pion/webrtc/v4"
  25. "golang.org/x/crypto/acme/autocert"
  26. )
  27. const (
  28. // Maximum number of bytes to be read from an HTTP request
  29. readLimit = 100000
  30. // Time after which we assume proxy data channel will not open
  31. dataChannelOpenTimeout = 20 * time.Second
  32. // How long to wait after the data channel has been open before closing the peer connection.
  33. dataChannelCloseTimeout = 5 * time.Second
  34. // Default STUN URL
  35. defaultStunUrls = "stun:stun.l.google.com:19302,stun:stun.voip.blackberry.com:3478"
  36. )
  37. type ProbeHandler struct {
  38. stunURL string
  39. handle func(string, http.ResponseWriter, *http.Request)
  40. }
  41. func (h ProbeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  42. h.handle(h.stunURL, w, r)
  43. }
  44. // Create a PeerConnection from an SDP offer. Blocks until the gathering of ICE
  45. // candidates is complete and the answer is available in LocalDescription.
  46. func makePeerConnectionFromOffer(stunURL string, sdp *webrtc.SessionDescription,
  47. dataChanOpen chan struct{}, dataChanClosed chan struct{}, iceGatheringTimeout time.Duration) (*webrtc.PeerConnection, error) {
  48. settingsEngine := webrtc.SettingEngine{}
  49. settingsEngine.SetIPFilter(func(ip net.IP) (keep bool) {
  50. // `IsLoopback()` and `IsUnspecified` are likely not neded here,
  51. // but let's keep them just in case.
  52. // FYI there is similar code in other files in this project.
  53. keep = !util.IsLocal(ip) && !ip.IsLoopback() && !ip.IsUnspecified()
  54. return
  55. })
  56. // FYI this is `false` by default anyway as of pion/webrtc@4
  57. settingsEngine.SetIncludeLoopbackCandidate(false)
  58. // Use the SetNet setting https://pkg.go.dev/github.com/pion/webrtc/v3#SettingEngine.SetNet
  59. // to functionally revert a new change in pion by silently ignoring
  60. // when net.Interfaces() fails, rather than throwing an error
  61. vnet, _ := stdnet.NewNet()
  62. settingsEngine.SetNet(vnet)
  63. api := webrtc.NewAPI(webrtc.WithSettingEngine(settingsEngine))
  64. config := webrtc.Configuration{
  65. ICEServers: []webrtc.ICEServer{
  66. {
  67. URLs: strings.Split(stunURL, ","),
  68. },
  69. },
  70. }
  71. pc, err := api.NewPeerConnection(config)
  72. if err != nil {
  73. return nil, fmt.Errorf("accept: NewPeerConnection: %s", err)
  74. }
  75. pc.OnDataChannel(func(dc *webrtc.DataChannel) {
  76. dc.OnOpen(func() {
  77. close(dataChanOpen)
  78. })
  79. dc.OnClose(func() {
  80. close(dataChanClosed)
  81. dc.Close()
  82. })
  83. })
  84. // As of v3.0.0, pion-webrtc uses trickle ICE by default.
  85. // We have to wait for candidate gathering to complete
  86. // before we send the offer
  87. done := webrtc.GatheringCompletePromise(pc)
  88. err = pc.SetRemoteDescription(*sdp)
  89. if err != nil {
  90. if inerr := pc.Close(); inerr != nil {
  91. log.Printf("unable to call pc.Close after pc.SetRemoteDescription with error: %v", inerr)
  92. }
  93. return nil, fmt.Errorf("accept: SetRemoteDescription: %s", err)
  94. }
  95. answer, err := pc.CreateAnswer(nil)
  96. if err != nil {
  97. if inerr := pc.Close(); inerr != nil {
  98. log.Printf("ICE gathering has generated an error when calling pc.Close: %v", inerr)
  99. }
  100. return nil, err
  101. }
  102. err = pc.SetLocalDescription(answer)
  103. if err != nil {
  104. if err = pc.Close(); err != nil {
  105. log.Printf("pc.Close after setting local description returned : %v", err)
  106. }
  107. return nil, err
  108. }
  109. // Wait for ICE candidate gathering to complete,
  110. // or for whatever we managed to gather before the client times out.
  111. // See https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/40230
  112. select {
  113. case <-done:
  114. case <-time.After(iceGatheringTimeout):
  115. }
  116. return pc, nil
  117. }
  118. func probeHandler(stunURL string, w http.ResponseWriter, r *http.Request) {
  119. w.Header().Set("Access-Control-Allow-Origin", "*")
  120. resp, err := io.ReadAll(http.MaxBytesReader(w, r.Body, readLimit))
  121. if nil != err {
  122. log.Println("Invalid data.")
  123. w.WriteHeader(http.StatusBadRequest)
  124. return
  125. }
  126. offer, _, err := messages.DecodePollResponse(resp)
  127. if err != nil {
  128. log.Printf("Error reading offer: %s", err.Error())
  129. w.WriteHeader(http.StatusBadRequest)
  130. return
  131. }
  132. if offer == "" {
  133. log.Printf("Error processing session description: %s", err.Error())
  134. w.WriteHeader(http.StatusBadRequest)
  135. return
  136. }
  137. sdp, err := util.DeserializeSessionDescription(offer)
  138. if err != nil {
  139. log.Printf("Error processing session description: %s", err.Error())
  140. w.WriteHeader(http.StatusBadRequest)
  141. return
  142. }
  143. dataChanOpen := make(chan struct{})
  144. dataChanClosed := make(chan struct{})
  145. // TODO refactor: DRY this must be below `ResponseHeaderTimeout` in proxy
  146. // https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/blob/e1d9b4ace69897521cc29585b5084c5f4d1ce874/proxy/lib/snowflake.go#L207
  147. iceGatheringTimeout := 10 * time.Second
  148. pc, err := makePeerConnectionFromOffer(stunURL, sdp, dataChanOpen, dataChanClosed, iceGatheringTimeout)
  149. if err != nil {
  150. log.Printf("Error making WebRTC connection: %s", err)
  151. w.WriteHeader(http.StatusInternalServerError)
  152. return
  153. }
  154. // We'll set this to `false` if the signaling (this function) succeeds.
  155. closePcOnReturn := true
  156. defer func() {
  157. if closePcOnReturn {
  158. if err := pc.Close(); err != nil {
  159. log.Printf("Error calling pc.Close: %v", err)
  160. }
  161. }
  162. // Otherwise it must be closed below, wherever `closePcOnReturn` is set to `false`.
  163. }()
  164. answer, err := util.SerializeSessionDescription(pc.LocalDescription())
  165. if err != nil {
  166. log.Printf("Error making WebRTC connection: %s", err)
  167. w.WriteHeader(http.StatusInternalServerError)
  168. return
  169. }
  170. body, err := messages.EncodeAnswerRequest(answer, "stub-sid")
  171. if err != nil {
  172. log.Printf("Error making WebRTC connection: %s", err)
  173. w.WriteHeader(http.StatusInternalServerError)
  174. return
  175. }
  176. w.Write(body)
  177. // Set a timeout on peerconnection. If the connection state has not
  178. // advanced to PeerConnectionStateConnected in this time,
  179. // destroy the peer connection and return the token.
  180. closePcOnReturn = false
  181. go func() {
  182. timer := time.NewTimer(dataChannelOpenTimeout)
  183. defer timer.Stop()
  184. select {
  185. case <-dataChanOpen:
  186. // Let's not close the `PeerConnection` immediately now,
  187. // instead let's wait for the peer (or timeout)
  188. // to close the connection,
  189. // in order to ensure that the DataChannel also gets opened
  190. // on the proxy's side.
  191. // Otherwise the proxy might receive the "close PeerConnection"
  192. // "event" before they receive "dataChannel.OnOpen",
  193. // which would wrongly result in a "restricted" NAT.
  194. // See https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/40387
  195. select {
  196. case <-dataChanClosed:
  197. case <-time.After(dataChannelCloseTimeout):
  198. }
  199. case <-timer.C:
  200. }
  201. if err := pc.Close(); err != nil {
  202. log.Printf("Error calling pc.Close: %v", err)
  203. }
  204. }()
  205. return
  206. }
  207. func main() {
  208. var acmeEmail string
  209. var acmeHostnamesCommas string
  210. var acmeCertCacheDir string
  211. var addr string
  212. var disableTLS bool
  213. var certFilename, keyFilename string
  214. var unsafeLogging bool
  215. var stunURL string
  216. flag.StringVar(&acmeEmail, "acme-email", "", "optional contact email for Let's Encrypt notifications")
  217. flag.StringVar(&acmeHostnamesCommas, "acme-hostnames", "", "comma-separated hostnames for TLS certificate")
  218. flag.StringVar(&acmeCertCacheDir, "acme-cert-cache", "acme-cert-cache", "directory in which certificates should be cached")
  219. flag.StringVar(&certFilename, "cert", "", "TLS certificate file")
  220. flag.StringVar(&keyFilename, "key", "", "TLS private key file")
  221. flag.StringVar(&addr, "addr", ":8443", "address to listen on")
  222. flag.BoolVar(&disableTLS, "disable-tls", false, "don't use HTTPS")
  223. flag.BoolVar(&unsafeLogging, "unsafe-logging", false, "prevent logs from being scrubbed")
  224. flag.StringVar(&stunURL, "stun", defaultStunUrls, "STUN servers to use for NAT traversal (comma-separated)")
  225. flag.Parse()
  226. var logOutput io.Writer = os.Stderr
  227. if unsafeLogging {
  228. log.SetOutput(logOutput)
  229. } else {
  230. // Scrub log output just in case an address ends up there
  231. log.SetOutput(&safelog.LogScrubber{Output: logOutput})
  232. }
  233. log.SetFlags(log.LstdFlags | log.LUTC)
  234. http.Handle("/probe", ProbeHandler{stunURL, probeHandler})
  235. server := http.Server{
  236. Addr: addr,
  237. }
  238. var err error
  239. if acmeHostnamesCommas != "" {
  240. acmeHostnames := strings.Split(acmeHostnamesCommas, ",")
  241. log.Printf("ACME hostnames: %q", acmeHostnames)
  242. var cache autocert.Cache
  243. if err = os.MkdirAll(acmeCertCacheDir, 0700); err != nil {
  244. log.Printf("Warning: Couldn't create cache directory %q (reason: %s) so we're *not* using our certificate cache.", acmeCertCacheDir, err)
  245. } else {
  246. cache = autocert.DirCache(acmeCertCacheDir)
  247. }
  248. certManager := autocert.Manager{
  249. Cache: cache,
  250. Prompt: autocert.AcceptTOS,
  251. HostPolicy: autocert.HostWhitelist(acmeHostnames...),
  252. Email: acmeEmail,
  253. }
  254. // start certificate manager handler
  255. go func() {
  256. log.Printf("Starting HTTP-01 listener")
  257. log.Fatal(http.ListenAndServe(":80", certManager.HTTPHandler(nil)))
  258. }()
  259. server.TLSConfig = &tls.Config{GetCertificate: certManager.GetCertificate}
  260. err = server.ListenAndServeTLS("", "")
  261. } else if certFilename != "" && keyFilename != "" {
  262. err = server.ListenAndServeTLS(certFilename, keyFilename)
  263. } else if disableTLS {
  264. err = server.ListenAndServe()
  265. } else {
  266. log.Fatal("the --cert and --key, --acme-hostnames, or --disable-tls option is required")
  267. }
  268. if err != nil {
  269. log.Println(err)
  270. }
  271. }