stream_test.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. package stream
  2. import (
  3. "fmt"
  4. "io"
  5. "sync"
  6. "testing"
  7. "time"
  8. "github.com/rs/zerolog"
  9. "github.com/stretchr/testify/require"
  10. )
  11. func TestPipeBidirectionalFinishBothSides(t *testing.T) {
  12. fun := func(upstream, downstream *mockedStream) {
  13. downstream.closeReader()
  14. upstream.closeReader()
  15. }
  16. testPipeBidirectionalUnblocking(t, fun, time.Millisecond*200, false)
  17. }
  18. func TestPipeBidirectionalFinishOneSideTimeout(t *testing.T) {
  19. fun := func(upstream, downstream *mockedStream) {
  20. downstream.closeReader()
  21. }
  22. testPipeBidirectionalUnblocking(t, fun, time.Millisecond*200, true)
  23. }
  24. func TestPipeBidirectionalClosingWriteBothSidesAlsoExists(t *testing.T) {
  25. fun := func(upstream, downstream *mockedStream) {
  26. downstream.CloseWrite()
  27. upstream.CloseWrite()
  28. downstream.writeToReader("abc")
  29. upstream.writeToReader("abc")
  30. }
  31. testPipeBidirectionalUnblocking(t, fun, time.Millisecond*200, false)
  32. }
  33. func TestPipeBidirectionalClosingWriteSingleSideAlsoExists(t *testing.T) {
  34. fun := func(upstream, downstream *mockedStream) {
  35. downstream.CloseWrite()
  36. downstream.writeToReader("abc")
  37. upstream.writeToReader("abc")
  38. }
  39. testPipeBidirectionalUnblocking(t, fun, time.Millisecond*200, true)
  40. }
  41. func testPipeBidirectionalUnblocking(t *testing.T, afterFun func(*mockedStream, *mockedStream), timeout time.Duration, expectTimeout bool) {
  42. logger := zerolog.Nop()
  43. downstream := newMockedStream()
  44. upstream := newMockedStream()
  45. resultCh := make(chan error)
  46. go func() {
  47. resultCh <- PipeBidirectional(downstream, upstream, timeout, &logger)
  48. }()
  49. afterFun(upstream, downstream)
  50. select {
  51. case err := <-resultCh:
  52. if expectTimeout {
  53. require.NotNil(t, err)
  54. } else {
  55. require.Nil(t, err)
  56. }
  57. case <-time.After(timeout * 2):
  58. require.Fail(t, "test timeout")
  59. }
  60. }
  61. func newMockedStream() *mockedStream {
  62. return &mockedStream{
  63. readCh: make(chan *string),
  64. writeCh: make(chan struct{}),
  65. }
  66. }
  67. type mockedStream struct {
  68. readCh chan *string
  69. writeCh chan struct{}
  70. writeCloseOnce sync.Once
  71. }
  72. func (m *mockedStream) Read(p []byte) (n int, err error) {
  73. result := <-m.readCh
  74. if result == nil {
  75. return 0, io.EOF
  76. }
  77. return len(*result), nil
  78. }
  79. func (m *mockedStream) Write(p []byte) (n int, err error) {
  80. <-m.writeCh
  81. return 0, fmt.Errorf("closed")
  82. }
  83. func (m *mockedStream) CloseWrite() error {
  84. m.writeCloseOnce.Do(func() {
  85. close(m.writeCh)
  86. })
  87. return nil
  88. }
  89. func (m *mockedStream) closeReader() {
  90. close(m.readCh)
  91. }
  92. func (m *mockedStream) writeToReader(content string) {
  93. m.readCh <- &content
  94. }