request_handler_test.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. package socks
  2. import (
  3. "bytes"
  4. "testing"
  5. "github.com/stretchr/testify/assert"
  6. "github.com/cloudflare/cloudflared/ipaccess"
  7. )
  8. func TestUnsupportedBind(t *testing.T) {
  9. req := createRequest(t, socks5Version, bindCommand, "2001:db8::68", 1337, false)
  10. var b bytes.Buffer
  11. requestHandler := NewRequestHandler(NewNetDialer(), nil)
  12. err := requestHandler.Handle(req, &b)
  13. assert.NoError(t, err)
  14. assert.True(t, b.Bytes()[1] == commandNotSupported, "expected a response")
  15. }
  16. func TestUnsupportedAssociate(t *testing.T) {
  17. req := createRequest(t, socks5Version, associateCommand, "127.0.0.1", 1337, false)
  18. var b bytes.Buffer
  19. requestHandler := NewRequestHandler(NewNetDialer(), nil)
  20. err := requestHandler.Handle(req, &b)
  21. assert.NoError(t, err)
  22. assert.True(t, b.Bytes()[1] == commandNotSupported, "expected a response")
  23. }
  24. func TestHandleConnect(t *testing.T) {
  25. req := createRequest(t, socks5Version, connectCommand, "127.0.0.1", 1337, false)
  26. var b bytes.Buffer
  27. requestHandler := NewRequestHandler(NewNetDialer(), nil)
  28. err := requestHandler.Handle(req, &b)
  29. assert.Error(t, err)
  30. assert.True(t, b.Bytes()[1] == connectionRefused, "expected a response")
  31. }
  32. func TestHandleConnectIPAccess(t *testing.T) {
  33. prefix := "127.0.0.0/24"
  34. rule1, _ := ipaccess.NewRuleByCIDR(&prefix, []int{1337}, true)
  35. rule2, _ := ipaccess.NewRuleByCIDR(&prefix, []int{1338}, false)
  36. rules := []ipaccess.Rule{rule1, rule2}
  37. var b bytes.Buffer
  38. accessPolicy, _ := ipaccess.NewPolicy(false, nil)
  39. requestHandler := NewRequestHandler(NewNetDialer(), accessPolicy)
  40. req := createRequest(t, socks5Version, connectCommand, "127.0.0.1", 1337, false)
  41. err := requestHandler.Handle(req, &b)
  42. assert.Error(t, err)
  43. assert.True(t, b.Bytes()[1] == ruleFailure, "expected to be denied as no rules and defaultAllow=false")
  44. b.Reset()
  45. accessPolicy, _ = ipaccess.NewPolicy(true, nil)
  46. requestHandler = NewRequestHandler(NewNetDialer(), accessPolicy)
  47. req = createRequest(t, socks5Version, connectCommand, "127.0.0.1", 1337, false)
  48. err = requestHandler.Handle(req, &b)
  49. assert.Error(t, err)
  50. assert.True(t, b.Bytes()[1] == connectionRefused, "expected to be allowed as no rules and defaultAllow=true")
  51. b.Reset()
  52. accessPolicy, _ = ipaccess.NewPolicy(false, rules)
  53. requestHandler = NewRequestHandler(NewNetDialer(), accessPolicy)
  54. req = createRequest(t, socks5Version, connectCommand, "127.0.0.1", 1337, false)
  55. err = requestHandler.Handle(req, &b)
  56. assert.Error(t, err)
  57. assert.True(t, b.Bytes()[1] == connectionRefused, "expected to be allowed as matching rule")
  58. b.Reset()
  59. req = createRequest(t, socks5Version, connectCommand, "127.0.0.1", 1338, false)
  60. err = requestHandler.Handle(req, &b)
  61. assert.Error(t, err)
  62. assert.True(t, b.Bytes()[1] == ruleFailure, "expected to be denied as matching rule")
  63. b.Reset()
  64. req = createRequest(t, socks5Version, connectCommand, "127.0.0.1", 1339, false)
  65. err = requestHandler.Handle(req, &b)
  66. assert.Error(t, err)
  67. assert.True(t, b.Bytes()[1] == ruleFailure, "expect to be denied as no matching rule and defaultAllow=false")
  68. }