request_test.go 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. package socks
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "net"
  6. "testing"
  7. "github.com/stretchr/testify/assert"
  8. )
  9. func createRequestData(version, command uint8, ip net.IP, port uint16) []byte {
  10. // set the command
  11. b := []byte{version, command, 0}
  12. // append the ip
  13. if len(ip) == net.IPv4len {
  14. b = append(b, 1)
  15. b = append(b, ip.To4()...)
  16. } else {
  17. b = append(b, 4)
  18. b = append(b, ip.To16()...)
  19. }
  20. // append the port
  21. p := []byte{0, 0}
  22. binary.BigEndian.PutUint16(p, port)
  23. b = append(b, p...)
  24. return b
  25. }
  26. func createRequest(t *testing.T, version, command uint8, ipStr string, port uint16, shouldFail bool) *Request {
  27. ip := net.ParseIP(ipStr)
  28. data := createRequestData(version, command, ip, port)
  29. reader := bytes.NewReader(data)
  30. req, err := NewRequest(reader)
  31. if shouldFail {
  32. assert.Error(t, err)
  33. return nil
  34. }
  35. assert.NoError(t, err)
  36. assert.True(t, req.Version == socks5Version, "version doesn't match expectation: %v", req.Version)
  37. assert.True(t, req.Command == command, "command doesn't match expectation: %v", req.Command)
  38. assert.True(t, req.DestAddr.Port == int(port), "port doesn't match expectation: %v", req.DestAddr.Port)
  39. assert.True(t, req.DestAddr.IP.String() == ipStr, "ip doesn't match expectation: %v", req.DestAddr.IP.String())
  40. return req
  41. }
  42. func TestValidConnectRequest(t *testing.T) {
  43. createRequest(t, socks5Version, connectCommand, "127.0.0.1", 1337, false)
  44. }
  45. func TestValidBindRequest(t *testing.T) {
  46. createRequest(t, socks5Version, bindCommand, "2001:db8::68", 1337, false)
  47. }
  48. func TestValidAssociateRequest(t *testing.T) {
  49. createRequest(t, socks5Version, associateCommand, "127.0.0.1", 1234, false)
  50. }
  51. func TestInValidVersionRequest(t *testing.T) {
  52. createRequest(t, 4, connectCommand, "127.0.0.1", 1337, true)
  53. }
  54. func TestInValidIPRequest(t *testing.T) {
  55. createRequest(t, 4, connectCommand, "127.0.01", 1337, true)
  56. }