123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- package diagnostic_test
- import (
- "context"
- "encoding/json"
- "errors"
- "net"
- "net/http"
- "net/http/httptest"
- "runtime"
- "testing"
- "github.com/google/uuid"
- "github.com/rs/zerolog"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- "github.com/cloudflare/cloudflared/connection"
- "github.com/cloudflare/cloudflared/diagnostic"
- "github.com/cloudflare/cloudflared/tunnelstate"
- )
- type SystemCollectorMock struct {
- systemInfo *diagnostic.SystemInformation
- err error
- }
- const (
- systemInformationKey = "sikey"
- errorKey = "errkey"
- )
- func newTrackerFromConns(t *testing.T, connections []tunnelstate.IndexedConnectionInfo) *tunnelstate.ConnTracker {
- t.Helper()
- log := zerolog.Nop()
- tracker := tunnelstate.NewConnTracker(&log)
- for _, conn := range connections {
- tracker.OnTunnelEvent(connection.Event{
- Index: conn.Index,
- EventType: connection.Connected,
- Protocol: conn.Protocol,
- EdgeAddress: conn.EdgeAddress,
- })
- }
- return tracker
- }
- func (collector *SystemCollectorMock) Collect(context.Context) (*diagnostic.SystemInformation, error) {
- return collector.systemInfo, collector.err
- }
- func TestSystemHandler(t *testing.T) {
- t.Parallel()
- log := zerolog.Nop()
- tests := []struct {
- name string
- systemInfo *diagnostic.SystemInformation
- err error
- statusCode int
- }{
- {
- name: "happy path",
- systemInfo: diagnostic.NewSystemInformation(
- 0, 0, 0, 0,
- "string", "string", "string", "string",
- "string", "string",
- runtime.Version(), runtime.GOARCH, nil,
- ),
- err: nil,
- statusCode: http.StatusOK,
- },
- {
- name: "on error and no raw info", systemInfo: nil,
- err: errors.New("an error"), statusCode: http.StatusOK,
- },
- }
- for _, tCase := range tests {
- t.Run(tCase.name, func(t *testing.T) {
- t.Parallel()
- handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{
- systemInfo: tCase.systemInfo,
- err: tCase.err,
- }, uuid.New(), uuid.New(), nil, map[string]string{}, nil)
- recorder := httptest.NewRecorder()
- ctx := context.Background()
- request, err := http.NewRequestWithContext(ctx, http.MethodGet, "/diag/system", nil)
- require.NoError(t, err)
- handler.SystemHandler(recorder, request)
- assert.Equal(t, tCase.statusCode, recorder.Code)
- if tCase.statusCode == http.StatusOK && tCase.systemInfo != nil {
- var response diagnostic.SystemInformationResponse
- decoder := json.NewDecoder(recorder.Body)
- err := decoder.Decode(&response)
- require.NoError(t, err)
- assert.Equal(t, tCase.systemInfo, response.Info)
- }
- })
- }
- }
- func TestTunnelStateHandler(t *testing.T) {
- t.Parallel()
- log := zerolog.Nop()
- tests := []struct {
- name string
- tunnelID uuid.UUID
- clientID uuid.UUID
- connections []tunnelstate.IndexedConnectionInfo
- icmpSources []string
- }{
- {
- name: "case1",
- tunnelID: uuid.New(),
- clientID: uuid.New(),
- },
- {
- name: "case2",
- tunnelID: uuid.New(),
- clientID: uuid.New(),
- icmpSources: []string{"172.17.0.3", "::1"},
- connections: []tunnelstate.IndexedConnectionInfo{{
- ConnectionInfo: tunnelstate.ConnectionInfo{
- IsConnected: true,
- Protocol: connection.QUIC,
- EdgeAddress: net.IPv4(100, 100, 100, 100),
- },
- Index: 0,
- }},
- },
- }
- for _, tCase := range tests {
- t.Run(tCase.name, func(t *testing.T) {
- t.Parallel()
- tracker := newTrackerFromConns(t, tCase.connections)
- handler := diagnostic.NewDiagnosticHandler(
- &log,
- 0,
- nil,
- tCase.tunnelID,
- tCase.clientID,
- tracker,
- map[string]string{},
- tCase.icmpSources,
- )
- recorder := httptest.NewRecorder()
- handler.TunnelStateHandler(recorder, nil)
- decoder := json.NewDecoder(recorder.Body)
- var response diagnostic.TunnelState
- err := decoder.Decode(&response)
- require.NoError(t, err)
- assert.Equal(t, http.StatusOK, recorder.Code)
- assert.Equal(t, tCase.tunnelID, response.TunnelID)
- assert.Equal(t, tCase.clientID, response.ConnectorID)
- assert.Equal(t, tCase.connections, response.Connections)
- assert.Equal(t, tCase.icmpSources, response.ICMPSources)
- })
- }
- }
- func TestConfigurationHandler(t *testing.T) {
- t.Parallel()
- log := zerolog.Nop()
- tests := []struct {
- name string
- flags map[string]string
- expected map[string]string
- }{
- {
- name: "empty cli",
- flags: make(map[string]string),
- expected: map[string]string{
- "uid": "0",
- },
- },
- {
- name: "cli with flags",
- flags: map[string]string{
- "b": "a",
- "c": "a",
- "d": "a",
- "uid": "0",
- },
- expected: map[string]string{
- "b": "a",
- "c": "a",
- "d": "a",
- "uid": "0",
- },
- },
- }
- for _, tCase := range tests {
- t.Run(tCase.name, func(t *testing.T) {
- t.Parallel()
- var response map[string]string
- handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, uuid.New(), uuid.New(), nil, tCase.flags, nil)
- recorder := httptest.NewRecorder()
- handler.ConfigurationHandler(recorder, nil)
- decoder := json.NewDecoder(recorder.Body)
- err := decoder.Decode(&response)
- require.NoError(t, err)
- _, ok := response["uid"]
- assert.True(t, ok)
- delete(tCase.expected, "uid")
- delete(response, "uid")
- assert.Equal(t, http.StatusOK, recorder.Code)
- assert.Equal(t, tCase.expected, response)
- })
- }
- }
|