dial_test.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. // Copyright (c) 2017 Arista Networks, Inc.
  2. // Use of this source code is governed by the Apache License 2.0
  3. // that can be found in the COPYING file.
  4. package dscp_test
  5. import (
  6. "fmt"
  7. "net"
  8. "strings"
  9. "testing"
  10. "time"
  11. "notabug.org/themusicgod1/goarista/dscp"
  12. )
  13. func TestDialTCPWithTOS(t *testing.T) {
  14. addr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}
  15. listen, err := net.ListenTCP("tcp", addr)
  16. if err != nil {
  17. t.Fatal(err)
  18. }
  19. defer listen.Close()
  20. done := make(chan struct{})
  21. go func() {
  22. conn, err := listen.Accept()
  23. if err != nil {
  24. t.Fatal(err)
  25. }
  26. defer conn.Close()
  27. buf := []byte{'!'}
  28. conn.Write(buf)
  29. n, err := conn.Read(buf)
  30. if n != 1 || err != nil {
  31. t.Fatalf("Read returned %d / %s", n, err)
  32. } else if buf[0] != '!' {
  33. t.Fatalf("Expected to read '!' but got %q", buf)
  34. }
  35. close(done)
  36. }()
  37. conn, err := dscp.DialTCPWithTOS(nil, listen.Addr().(*net.TCPAddr), 40)
  38. if err != nil {
  39. t.Fatal("Connection failed:", err)
  40. }
  41. defer conn.Close()
  42. buf := make([]byte, 1)
  43. n, err := conn.Read(buf)
  44. if n != 1 || err != nil {
  45. t.Fatalf("Read returned %d / %s", n, err)
  46. } else if buf[0] != '!' {
  47. t.Fatalf("Expected to read '!' but got %q", buf)
  48. }
  49. conn.Write(buf)
  50. <-done
  51. }
  52. func TestDialTCPTimeoutWithTOS(t *testing.T) {
  53. raddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}
  54. for name, td := range map[string]*net.TCPAddr{
  55. "ipNoPort": &net.TCPAddr{
  56. IP: net.ParseIP("127.0.0.42"), Port: 0,
  57. },
  58. "ipWithPort": &net.TCPAddr{
  59. IP: net.ParseIP("127.0.0.42"), Port: 10001,
  60. },
  61. } {
  62. t.Run(name, func(t *testing.T) {
  63. l, err := net.ListenTCP("tcp", raddr)
  64. if err != nil {
  65. t.Fatal(err)
  66. }
  67. defer l.Close()
  68. var srcAddr net.Addr
  69. done := make(chan struct{})
  70. go func() {
  71. conn, err := l.Accept()
  72. if err != nil {
  73. t.Fatal(err)
  74. }
  75. defer conn.Close()
  76. srcAddr = conn.RemoteAddr()
  77. close(done)
  78. }()
  79. conn, err := dscp.DialTCPTimeoutWithTOS(td, l.Addr().(*net.TCPAddr), 40, 5*time.Second)
  80. if err != nil {
  81. t.Fatal("Connection failed:", err)
  82. }
  83. defer conn.Close()
  84. pfx := td.IP.String() + ":"
  85. if td.Port > 0 {
  86. pfx = fmt.Sprintf("%s%d", pfx, td.Port)
  87. }
  88. <-done
  89. if !strings.HasPrefix(srcAddr.String(), pfx) {
  90. t.Fatalf("DialTCPTimeoutWithTOS wrong address: %q instead of %q", srcAddr, pfx)
  91. }
  92. })
  93. }
  94. }