websocket_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
  2. // See LICENSE.txt for license information.
  3. package api4
  4. import (
  5. "fmt"
  6. "net/http"
  7. "strings"
  8. "testing"
  9. "time"
  10. "github.com/gorilla/websocket"
  11. "github.com/stretchr/testify/require"
  12. "github.com/mattermost/mattermost-server/v5/model"
  13. )
  14. func TestWebSocket(t *testing.T) {
  15. th := Setup(t).InitBasic()
  16. defer th.TearDown()
  17. WebSocketClient, err := th.CreateWebSocketClient()
  18. require.Nil(t, err)
  19. defer WebSocketClient.Close()
  20. time.Sleep(300 * time.Millisecond)
  21. // Test closing and reconnecting
  22. WebSocketClient.Close()
  23. err = WebSocketClient.Connect()
  24. require.Nil(t, err)
  25. WebSocketClient.Listen()
  26. resp := <-WebSocketClient.ResponseChannel
  27. require.Equal(t, resp.Status, model.STATUS_OK, "should have responded OK to authentication challenge")
  28. WebSocketClient.SendMessage("ping", nil)
  29. resp = <-WebSocketClient.ResponseChannel
  30. require.Equal(t, resp.Data["text"].(string), "pong", "wrong response")
  31. WebSocketClient.SendMessage("", nil)
  32. resp = <-WebSocketClient.ResponseChannel
  33. require.Equal(t, resp.Error.Id, "api.web_socket_router.no_action.app_error", "should have been no action response")
  34. WebSocketClient.SendMessage("junk", nil)
  35. resp = <-WebSocketClient.ResponseChannel
  36. require.Equal(t, resp.Error.Id, "api.web_socket_router.bad_action.app_error", "should have been bad action response")
  37. WebSocketClient.UserTyping("", "")
  38. resp = <-WebSocketClient.ResponseChannel
  39. require.Equal(t, resp.Error.Id, "api.websocket_handler.invalid_param.app_error", "should have been invalid param response")
  40. require.Equal(t, resp.Error.DetailedError, "", "detailed error not cleared")
  41. WebSocketClient.UserTyping(th.BasicChannel.Id, "")
  42. resp = <-WebSocketClient.ResponseChannel
  43. require.Nil(t, resp.Error)
  44. WebSocketClient.UserTyping(th.BasicPrivateChannel2.Id, "")
  45. resp = <-WebSocketClient.ResponseChannel
  46. require.Equal(t, resp.Error.Id, "api.websocket_handler.invalid_param.app_error", "should have been invalid param response")
  47. require.Equal(t, resp.Error.DetailedError, "", "detailed error not cleared")
  48. }
  49. func TestWebSocketTrailingSlash(t *testing.T) {
  50. th := Setup(t)
  51. defer th.TearDown()
  52. url := fmt.Sprintf("ws://localhost:%v", th.App.Srv().ListenAddr.Port)
  53. _, _, err := websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX+"/websocket/", nil)
  54. require.NoError(t, err)
  55. }
  56. func TestWebSocketEvent(t *testing.T) {
  57. th := Setup(t).InitBasic()
  58. defer th.TearDown()
  59. WebSocketClient, err := th.CreateWebSocketClient()
  60. require.Nil(t, err)
  61. defer WebSocketClient.Close()
  62. WebSocketClient.Listen()
  63. resp := <-WebSocketClient.ResponseChannel
  64. require.Equal(t, resp.Status, model.STATUS_OK, "should have responded OK to authentication challenge")
  65. omitUser := make(map[string]bool, 1)
  66. omitUser["somerandomid"] = true
  67. evt1 := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_TYPING, "", th.BasicChannel.Id, "", omitUser)
  68. evt1.Add("user_id", "somerandomid")
  69. th.App.Publish(evt1)
  70. time.Sleep(300 * time.Millisecond)
  71. stop := make(chan bool)
  72. eventHit := false
  73. go func() {
  74. for {
  75. select {
  76. case resp := <-WebSocketClient.EventChannel:
  77. if resp.EventType() == model.WEBSOCKET_EVENT_TYPING && resp.GetData()["user_id"].(string) == "somerandomid" {
  78. eventHit = true
  79. }
  80. case <-stop:
  81. return
  82. }
  83. }
  84. }()
  85. time.Sleep(400 * time.Millisecond)
  86. stop <- true
  87. require.True(t, eventHit, "did not receive typing event")
  88. evt2 := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_TYPING, "", "somerandomid", "", nil)
  89. th.App.Publish(evt2)
  90. time.Sleep(300 * time.Millisecond)
  91. eventHit = false
  92. go func() {
  93. for {
  94. select {
  95. case resp := <-WebSocketClient.EventChannel:
  96. if resp.EventType() == model.WEBSOCKET_EVENT_TYPING {
  97. eventHit = true
  98. }
  99. case <-stop:
  100. return
  101. }
  102. }
  103. }()
  104. time.Sleep(400 * time.Millisecond)
  105. stop <- true
  106. require.False(t, eventHit, "got typing event for bad channel id")
  107. }
  108. func TestCreateDirectChannelWithSocket(t *testing.T) {
  109. th := Setup(t).InitBasic()
  110. defer th.TearDown()
  111. Client := th.Client
  112. user2 := th.BasicUser2
  113. users := make([]*model.User, 0)
  114. users = append(users, user2)
  115. for i := 0; i < 10; i++ {
  116. users = append(users, th.CreateUser())
  117. }
  118. WebSocketClient, err := th.CreateWebSocketClient()
  119. require.Nil(t, err)
  120. defer WebSocketClient.Close()
  121. WebSocketClient.Listen()
  122. resp := <-WebSocketClient.ResponseChannel
  123. require.Equal(t, resp.Status, model.STATUS_OK, "should have responded OK to authentication challenge")
  124. wsr := <-WebSocketClient.EventChannel
  125. require.Equal(t, wsr.EventType(), model.WEBSOCKET_EVENT_HELLO, "missing hello")
  126. stop := make(chan bool)
  127. count := 0
  128. go func() {
  129. for {
  130. select {
  131. case wsr := <-WebSocketClient.EventChannel:
  132. if wsr != nil && wsr.EventType() == model.WEBSOCKET_EVENT_DIRECT_ADDED {
  133. count = count + 1
  134. }
  135. case <-stop:
  136. return
  137. }
  138. }
  139. }()
  140. for _, user := range users {
  141. time.Sleep(100 * time.Millisecond)
  142. _, resp := Client.CreateDirectChannel(th.BasicUser.Id, user.Id)
  143. require.Nil(t, resp.Error, "failed to create DM channel")
  144. }
  145. time.Sleep(5000 * time.Millisecond)
  146. stop <- true
  147. require.Equal(t, count, len(users), "We didn't get the proper amount of direct_added messages")
  148. }
  149. func TestWebsocketOriginSecurity(t *testing.T) {
  150. th := Setup(t)
  151. defer th.TearDown()
  152. url := fmt.Sprintf("ws://localhost:%v", th.App.Srv().ListenAddr.Port)
  153. // Should fail because origin doesn't match
  154. _, _, err := websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX+"/websocket", http.Header{
  155. "Origin": []string{"http://www.evil.com"},
  156. })
  157. require.NotNil(t, err, "Should have errored because Origin does not match host! SECURITY ISSUE!")
  158. // We are not a browser so we can spoof this just fine
  159. _, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX+"/websocket", http.Header{
  160. "Origin": []string{fmt.Sprintf("http://localhost:%v", th.App.Srv().ListenAddr.Port)},
  161. })
  162. require.Nil(t, err, err)
  163. // Should succeed now because open CORS
  164. th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.AllowCorsFrom = "*" })
  165. _, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX+"/websocket", http.Header{
  166. "Origin": []string{"http://www.evil.com"},
  167. })
  168. require.Nil(t, err, err)
  169. // Should succeed now because matching CORS
  170. th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.AllowCorsFrom = "http://www.evil.com" })
  171. _, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX+"/websocket", http.Header{
  172. "Origin": []string{"http://www.evil.com"},
  173. })
  174. require.Nil(t, err, err)
  175. // Should fail because non-matching CORS
  176. th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.AllowCorsFrom = "http://www.good.com" })
  177. _, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX+"/websocket", http.Header{
  178. "Origin": []string{"http://www.evil.com"},
  179. })
  180. require.NotNil(t, err, "Should have errored because Origin contain AllowCorsFrom")
  181. // Should fail because non-matching CORS
  182. th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.AllowCorsFrom = "http://www.good.com" })
  183. _, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX+"/websocket", http.Header{
  184. "Origin": []string{"http://www.good.co"},
  185. })
  186. require.NotNil(t, err, "Should have errored because Origin does not match host! SECURITY ISSUE!")
  187. th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.AllowCorsFrom = "" })
  188. }
  189. func TestWebSocketStatuses(t *testing.T) {
  190. th := Setup(t).InitBasic()
  191. defer th.TearDown()
  192. Client := th.Client
  193. WebSocketClient, err := th.CreateWebSocketClient()
  194. require.Nil(t, err, err)
  195. defer WebSocketClient.Close()
  196. WebSocketClient.Listen()
  197. resp := <-WebSocketClient.ResponseChannel
  198. require.Equal(t, resp.Status, model.STATUS_OK, "should have responded OK to authentication challenge")
  199. team := model.Team{DisplayName: "Name", Name: "z-z-" + model.NewRandomTeamName() + "a", Email: "test@nowhere.com", Type: model.TEAM_OPEN}
  200. rteam, _ := Client.CreateTeam(&team)
  201. user := model.User{Email: strings.ToLower(model.NewId()) + "success+test@simulator.amazonses.com", Nickname: "Corey Hulen", Password: "passwd1"}
  202. ruser := Client.Must(Client.CreateUser(&user)).(*model.User)
  203. th.LinkUserToTeam(ruser, rteam)
  204. _, err = th.App.Srv().Store.User().VerifyEmail(ruser.Id, ruser.Email)
  205. require.Nil(t, err)
  206. user2 := model.User{Email: strings.ToLower(model.NewId()) + "success+test@simulator.amazonses.com", Nickname: "Corey Hulen", Password: "passwd1"}
  207. ruser2 := Client.Must(Client.CreateUser(&user2)).(*model.User)
  208. th.LinkUserToTeam(ruser2, rteam)
  209. _, err = th.App.Srv().Store.User().VerifyEmail(ruser2.Id, ruser2.Email)
  210. require.Nil(t, err)
  211. Client.Login(user.Email, user.Password)
  212. th.LoginBasic2()
  213. WebSocketClient2, err2 := th.CreateWebSocketClient()
  214. require.Nil(t, err2, err2)
  215. time.Sleep(1000 * time.Millisecond)
  216. WebSocketClient.GetStatuses()
  217. resp = <-WebSocketClient.ResponseChannel
  218. require.Nil(t, resp.Error, resp.Error)
  219. require.Equal(t, resp.SeqReply, WebSocketClient.Sequence-1, "bad sequence number")
  220. allowedValues := [4]string{model.STATUS_OFFLINE, model.STATUS_AWAY, model.STATUS_ONLINE, model.STATUS_DND}
  221. for _, status := range resp.Data {
  222. require.Containsf(t, allowedValues, status, "one of the statuses had an invalid value status=%v", status)
  223. }
  224. status, ok := resp.Data[th.BasicUser2.Id]
  225. require.True(t, ok, "should have had user status")
  226. require.Equal(t, status, model.STATUS_ONLINE, "status should have been online status=%v", status)
  227. WebSocketClient.GetStatusesByIds([]string{th.BasicUser2.Id})
  228. resp = <-WebSocketClient.ResponseChannel
  229. require.Nil(t, resp.Error, resp.Error)
  230. require.Equal(t, resp.SeqReply, WebSocketClient.Sequence-1, "bad sequence number")
  231. allowedValues = [4]string{model.STATUS_OFFLINE, model.STATUS_AWAY, model.STATUS_ONLINE}
  232. for _, status := range resp.Data {
  233. require.Containsf(t, allowedValues, status, "one of the statuses had an invalid value status")
  234. }
  235. status, ok = resp.Data[th.BasicUser2.Id]
  236. require.True(t, ok, "should have had user status")
  237. require.Equal(t, status, model.STATUS_ONLINE, "status should have been online status=%v", status)
  238. require.Equal(t, len(resp.Data), 1, "only 1 status should be returned")
  239. WebSocketClient.GetStatusesByIds([]string{ruser2.Id, "junk"})
  240. resp = <-WebSocketClient.ResponseChannel
  241. require.Nil(t, resp.Error, resp.Error)
  242. require.Equal(t, resp.SeqReply, WebSocketClient.Sequence-1, "bad sequence number")
  243. require.Equal(t, len(resp.Data), 2, "2 statuses should be returned")
  244. WebSocketClient.GetStatusesByIds([]string{})
  245. if resp2 := <-WebSocketClient.ResponseChannel; resp2.Error == nil {
  246. require.Equal(t, resp2.SeqReply, WebSocketClient.Sequence-1, "bad sequence number")
  247. require.NotNil(t, resp2.Error, "should have errored - empty user ids")
  248. }
  249. WebSocketClient2.Close()
  250. th.App.SetStatusAwayIfNeeded(th.BasicUser.Id, false)
  251. awayTimeout := *th.App.Config().TeamSettings.UserStatusAwayTimeout
  252. defer func() {
  253. th.App.UpdateConfig(func(cfg *model.Config) { *cfg.TeamSettings.UserStatusAwayTimeout = awayTimeout })
  254. }()
  255. th.App.UpdateConfig(func(cfg *model.Config) { *cfg.TeamSettings.UserStatusAwayTimeout = 1 })
  256. time.Sleep(1500 * time.Millisecond)
  257. th.App.SetStatusAwayIfNeeded(th.BasicUser.Id, false)
  258. th.App.SetStatusOnline(th.BasicUser.Id, false)
  259. time.Sleep(1500 * time.Millisecond)
  260. WebSocketClient.GetStatuses()
  261. resp = <-WebSocketClient.ResponseChannel
  262. require.Nil(t, resp.Error)
  263. require.Equal(t, resp.SeqReply, WebSocketClient.Sequence-1, "bad sequence number")
  264. _, ok = resp.Data[th.BasicUser2.Id]
  265. require.False(t, ok, "should not have had user status")
  266. stop := make(chan bool)
  267. onlineHit := false
  268. awayHit := false
  269. go func() {
  270. for {
  271. select {
  272. case resp := <-WebSocketClient.EventChannel:
  273. if resp.EventType() == model.WEBSOCKET_EVENT_STATUS_CHANGE && resp.GetData()["user_id"].(string) == th.BasicUser.Id {
  274. status := resp.GetData()["status"].(string)
  275. if status == model.STATUS_ONLINE {
  276. onlineHit = true
  277. } else if status == model.STATUS_AWAY {
  278. awayHit = true
  279. }
  280. }
  281. case <-stop:
  282. return
  283. }
  284. }
  285. }()
  286. time.Sleep(500 * time.Millisecond)
  287. stop <- true
  288. require.True(t, onlineHit, "didn't get online event")
  289. require.True(t, awayHit, "didn't get away event")
  290. time.Sleep(500 * time.Millisecond)
  291. WebSocketClient.Close()
  292. }