handlers_test.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. package diagnostic_test
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "net"
  7. "net/http"
  8. "net/http/httptest"
  9. "runtime"
  10. "testing"
  11. "github.com/google/uuid"
  12. "github.com/rs/zerolog"
  13. "github.com/stretchr/testify/assert"
  14. "github.com/stretchr/testify/require"
  15. "github.com/cloudflare/cloudflared/connection"
  16. "github.com/cloudflare/cloudflared/diagnostic"
  17. "github.com/cloudflare/cloudflared/tunnelstate"
  18. )
  19. type SystemCollectorMock struct {
  20. systemInfo *diagnostic.SystemInformation
  21. err error
  22. }
  23. const (
  24. systemInformationKey = "sikey"
  25. errorKey = "errkey"
  26. )
  27. func newTrackerFromConns(t *testing.T, connections []tunnelstate.IndexedConnectionInfo) *tunnelstate.ConnTracker {
  28. t.Helper()
  29. log := zerolog.Nop()
  30. tracker := tunnelstate.NewConnTracker(&log)
  31. for _, conn := range connections {
  32. tracker.OnTunnelEvent(connection.Event{
  33. Index: conn.Index,
  34. EventType: connection.Connected,
  35. Protocol: conn.Protocol,
  36. EdgeAddress: conn.EdgeAddress,
  37. })
  38. }
  39. return tracker
  40. }
  41. func (collector *SystemCollectorMock) Collect(context.Context) (*diagnostic.SystemInformation, error) {
  42. return collector.systemInfo, collector.err
  43. }
  44. func TestSystemHandler(t *testing.T) {
  45. t.Parallel()
  46. log := zerolog.Nop()
  47. tests := []struct {
  48. name string
  49. systemInfo *diagnostic.SystemInformation
  50. err error
  51. statusCode int
  52. }{
  53. {
  54. name: "happy path",
  55. systemInfo: diagnostic.NewSystemInformation(
  56. 0, 0, 0, 0,
  57. "string", "string", "string", "string",
  58. "string", "string",
  59. runtime.Version(), runtime.GOARCH, nil,
  60. ),
  61. err: nil,
  62. statusCode: http.StatusOK,
  63. },
  64. {
  65. name: "on error and no raw info", systemInfo: nil,
  66. err: errors.New("an error"), statusCode: http.StatusOK,
  67. },
  68. }
  69. for _, tCase := range tests {
  70. t.Run(tCase.name, func(t *testing.T) {
  71. t.Parallel()
  72. handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{
  73. systemInfo: tCase.systemInfo,
  74. err: tCase.err,
  75. }, uuid.New(), uuid.New(), nil, map[string]string{}, nil)
  76. recorder := httptest.NewRecorder()
  77. ctx := context.Background()
  78. request, err := http.NewRequestWithContext(ctx, http.MethodGet, "/diag/system", nil)
  79. require.NoError(t, err)
  80. handler.SystemHandler(recorder, request)
  81. assert.Equal(t, tCase.statusCode, recorder.Code)
  82. if tCase.statusCode == http.StatusOK && tCase.systemInfo != nil {
  83. var response diagnostic.SystemInformationResponse
  84. decoder := json.NewDecoder(recorder.Body)
  85. err := decoder.Decode(&response)
  86. require.NoError(t, err)
  87. assert.Equal(t, tCase.systemInfo, response.Info)
  88. }
  89. })
  90. }
  91. }
  92. func TestTunnelStateHandler(t *testing.T) {
  93. t.Parallel()
  94. log := zerolog.Nop()
  95. tests := []struct {
  96. name string
  97. tunnelID uuid.UUID
  98. clientID uuid.UUID
  99. connections []tunnelstate.IndexedConnectionInfo
  100. icmpSources []string
  101. }{
  102. {
  103. name: "case1",
  104. tunnelID: uuid.New(),
  105. clientID: uuid.New(),
  106. },
  107. {
  108. name: "case2",
  109. tunnelID: uuid.New(),
  110. clientID: uuid.New(),
  111. icmpSources: []string{"172.17.0.3", "::1"},
  112. connections: []tunnelstate.IndexedConnectionInfo{{
  113. ConnectionInfo: tunnelstate.ConnectionInfo{
  114. IsConnected: true,
  115. Protocol: connection.QUIC,
  116. EdgeAddress: net.IPv4(100, 100, 100, 100),
  117. },
  118. Index: 0,
  119. }},
  120. },
  121. }
  122. for _, tCase := range tests {
  123. t.Run(tCase.name, func(t *testing.T) {
  124. t.Parallel()
  125. tracker := newTrackerFromConns(t, tCase.connections)
  126. handler := diagnostic.NewDiagnosticHandler(
  127. &log,
  128. 0,
  129. nil,
  130. tCase.tunnelID,
  131. tCase.clientID,
  132. tracker,
  133. map[string]string{},
  134. tCase.icmpSources,
  135. )
  136. recorder := httptest.NewRecorder()
  137. handler.TunnelStateHandler(recorder, nil)
  138. decoder := json.NewDecoder(recorder.Body)
  139. var response diagnostic.TunnelState
  140. err := decoder.Decode(&response)
  141. require.NoError(t, err)
  142. assert.Equal(t, http.StatusOK, recorder.Code)
  143. assert.Equal(t, tCase.tunnelID, response.TunnelID)
  144. assert.Equal(t, tCase.clientID, response.ConnectorID)
  145. assert.Equal(t, tCase.connections, response.Connections)
  146. assert.Equal(t, tCase.icmpSources, response.ICMPSources)
  147. })
  148. }
  149. }
  150. func TestConfigurationHandler(t *testing.T) {
  151. t.Parallel()
  152. log := zerolog.Nop()
  153. tests := []struct {
  154. name string
  155. flags map[string]string
  156. expected map[string]string
  157. }{
  158. {
  159. name: "empty cli",
  160. flags: make(map[string]string),
  161. expected: map[string]string{
  162. "uid": "0",
  163. },
  164. },
  165. {
  166. name: "cli with flags",
  167. flags: map[string]string{
  168. "b": "a",
  169. "c": "a",
  170. "d": "a",
  171. "uid": "0",
  172. },
  173. expected: map[string]string{
  174. "b": "a",
  175. "c": "a",
  176. "d": "a",
  177. "uid": "0",
  178. },
  179. },
  180. }
  181. for _, tCase := range tests {
  182. t.Run(tCase.name, func(t *testing.T) {
  183. t.Parallel()
  184. var response map[string]string
  185. handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, uuid.New(), uuid.New(), nil, tCase.flags, nil)
  186. recorder := httptest.NewRecorder()
  187. handler.ConfigurationHandler(recorder, nil)
  188. decoder := json.NewDecoder(recorder.Body)
  189. err := decoder.Decode(&response)
  190. require.NoError(t, err)
  191. _, ok := response["uid"]
  192. assert.True(t, ok)
  193. delete(tCase.expected, "uid")
  194. delete(response, "uid")
  195. assert.Equal(t, http.StatusOK, recorder.Code)
  196. assert.Equal(t, tCase.expected, response)
  197. })
  198. }
  199. }