server_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. // Copyright 2014 The go-ethereum Authors
  2. // This file is part of the go-ethereum library.
  3. //
  4. // The go-ethereum library is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Lesser General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // The go-ethereum library is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Lesser General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Lesser General Public License
  15. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
  16. package p2p
  17. import (
  18. "crypto/ecdsa"
  19. "errors"
  20. "math/rand"
  21. "net"
  22. "reflect"
  23. "testing"
  24. "time"
  25. "github.com/ethereum/go-ethereum/crypto"
  26. "github.com/ethereum/go-ethereum/crypto/sha3"
  27. "github.com/ethereum/go-ethereum/log"
  28. "github.com/ethereum/go-ethereum/p2p/discover"
  29. )
  30. func init() {
  31. // log.Root().SetHandler(log.LvlFilterHandler(log.LvlError, log.StreamHandler(os.Stderr, log.TerminalFormat(false))))
  32. }
  33. type testTransport struct {
  34. id discover.NodeID
  35. *rlpx
  36. closeErr error
  37. }
  38. func newTestTransport(id discover.NodeID, fd net.Conn) transport {
  39. wrapped := newRLPX(fd).(*rlpx)
  40. wrapped.rw = newRLPXFrameRW(fd, secrets{
  41. MAC: zero16,
  42. AES: zero16,
  43. IngressMAC: sha3.NewKeccak256(),
  44. EgressMAC: sha3.NewKeccak256(),
  45. })
  46. return &testTransport{id: id, rlpx: wrapped}
  47. }
  48. func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
  49. return c.id, nil
  50. }
  51. func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
  52. return &protoHandshake{ID: c.id, Name: "test"}, nil
  53. }
  54. func (c *testTransport) close(err error) {
  55. c.rlpx.fd.Close()
  56. c.closeErr = err
  57. }
  58. func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server {
  59. config := Config{
  60. Name: "test",
  61. MaxPeers: 10,
  62. ListenAddr: "127.0.0.1:0",
  63. PrivateKey: newkey(),
  64. }
  65. server := &Server{
  66. Config: config,
  67. newPeerHook: pf,
  68. newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) },
  69. }
  70. if err := server.Start(); err != nil {
  71. t.Fatalf("Could not start server: %v", err)
  72. }
  73. return server
  74. }
  75. func TestServerListen(t *testing.T) {
  76. // start the test server
  77. connected := make(chan *Peer)
  78. remid := randomID()
  79. srv := startTestServer(t, remid, func(p *Peer) {
  80. if p.ID() != remid {
  81. t.Error("peer func called with wrong node id")
  82. }
  83. if p == nil {
  84. t.Error("peer func called with nil conn")
  85. }
  86. connected <- p
  87. })
  88. defer close(connected)
  89. defer srv.Stop()
  90. // dial the test server
  91. conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second)
  92. if err != nil {
  93. t.Fatalf("could not dial: %v", err)
  94. }
  95. defer conn.Close()
  96. select {
  97. case peer := <-connected:
  98. if peer.LocalAddr().String() != conn.RemoteAddr().String() {
  99. t.Errorf("peer started with wrong conn: got %v, want %v",
  100. peer.LocalAddr(), conn.RemoteAddr())
  101. }
  102. peers := srv.Peers()
  103. if !reflect.DeepEqual(peers, []*Peer{peer}) {
  104. t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
  105. }
  106. case <-time.After(1 * time.Second):
  107. t.Error("server did not accept within one second")
  108. }
  109. }
  110. func TestServerDial(t *testing.T) {
  111. // run a one-shot TCP server to handle the connection.
  112. listener, err := net.Listen("tcp", "127.0.0.1:0")
  113. if err != nil {
  114. t.Fatalf("could not setup listener: %v", err)
  115. }
  116. defer listener.Close()
  117. accepted := make(chan net.Conn)
  118. go func() {
  119. conn, err := listener.Accept()
  120. if err != nil {
  121. t.Error("accept error:", err)
  122. return
  123. }
  124. accepted <- conn
  125. }()
  126. // start the server
  127. connected := make(chan *Peer)
  128. remid := randomID()
  129. srv := startTestServer(t, remid, func(p *Peer) { connected <- p })
  130. defer close(connected)
  131. defer srv.Stop()
  132. // tell the server to connect
  133. tcpAddr := listener.Addr().(*net.TCPAddr)
  134. srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)})
  135. select {
  136. case conn := <-accepted:
  137. defer conn.Close()
  138. select {
  139. case peer := <-connected:
  140. if peer.ID() != remid {
  141. t.Errorf("peer has wrong id")
  142. }
  143. if peer.Name() != "test" {
  144. t.Errorf("peer has wrong name")
  145. }
  146. if peer.RemoteAddr().String() != conn.LocalAddr().String() {
  147. t.Errorf("peer started with wrong conn: got %v, want %v",
  148. peer.RemoteAddr(), conn.LocalAddr())
  149. }
  150. peers := srv.Peers()
  151. if !reflect.DeepEqual(peers, []*Peer{peer}) {
  152. t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
  153. }
  154. case <-time.After(1 * time.Second):
  155. t.Error("server did not launch peer within one second")
  156. }
  157. case <-time.After(1 * time.Second):
  158. t.Error("server did not connect within one second")
  159. }
  160. }
  161. // This test checks that tasks generated by dialstate are
  162. // actually executed and taskdone is called for them.
  163. func TestServerTaskScheduling(t *testing.T) {
  164. var (
  165. done = make(chan *testTask)
  166. quit, returned = make(chan struct{}), make(chan struct{})
  167. tc = 0
  168. tg = taskgen{
  169. newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
  170. tc++
  171. return []task{&testTask{index: tc - 1}}
  172. },
  173. doneFunc: func(t task) {
  174. select {
  175. case done <- t.(*testTask):
  176. case <-quit:
  177. }
  178. },
  179. }
  180. )
  181. // The Server in this test isn't actually running
  182. // because we're only interested in what run does.
  183. srv := &Server{
  184. Config: Config{MaxPeers: 10},
  185. quit: make(chan struct{}),
  186. ntab: fakeTable{},
  187. running: true,
  188. log: log.New(),
  189. }
  190. srv.loopWG.Add(1)
  191. go func() {
  192. srv.run(tg)
  193. close(returned)
  194. }()
  195. var gotdone []*testTask
  196. for i := 0; i < 100; i++ {
  197. gotdone = append(gotdone, <-done)
  198. }
  199. for i, task := range gotdone {
  200. if task.index != i {
  201. t.Errorf("task %d has wrong index, got %d", i, task.index)
  202. break
  203. }
  204. if !task.called {
  205. t.Errorf("task %d was not called", i)
  206. break
  207. }
  208. }
  209. close(quit)
  210. srv.Stop()
  211. select {
  212. case <-returned:
  213. case <-time.After(500 * time.Millisecond):
  214. t.Error("Server.run did not return within 500ms")
  215. }
  216. }
  217. // This test checks that Server doesn't drop tasks,
  218. // even if newTasks returns more than the maximum number of tasks.
  219. func TestServerManyTasks(t *testing.T) {
  220. alltasks := make([]task, 300)
  221. for i := range alltasks {
  222. alltasks[i] = &testTask{index: i}
  223. }
  224. var (
  225. srv = &Server{
  226. quit: make(chan struct{}),
  227. ntab: fakeTable{},
  228. running: true,
  229. log: log.New(),
  230. }
  231. done = make(chan *testTask)
  232. start, end = 0, 0
  233. )
  234. defer srv.Stop()
  235. srv.loopWG.Add(1)
  236. go srv.run(taskgen{
  237. newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
  238. start, end = end, end+maxActiveDialTasks+10
  239. if end > len(alltasks) {
  240. end = len(alltasks)
  241. }
  242. return alltasks[start:end]
  243. },
  244. doneFunc: func(tt task) {
  245. done <- tt.(*testTask)
  246. },
  247. })
  248. doneset := make(map[int]bool)
  249. timeout := time.After(2 * time.Second)
  250. for len(doneset) < len(alltasks) {
  251. select {
  252. case tt := <-done:
  253. if doneset[tt.index] {
  254. t.Errorf("task %d got done more than once", tt.index)
  255. } else {
  256. doneset[tt.index] = true
  257. }
  258. case <-timeout:
  259. t.Errorf("%d of %d tasks got done within 2s", len(doneset), len(alltasks))
  260. for i := 0; i < len(alltasks); i++ {
  261. if !doneset[i] {
  262. t.Logf("task %d not done", i)
  263. }
  264. }
  265. return
  266. }
  267. }
  268. }
  269. type taskgen struct {
  270. newFunc func(running int, peers map[discover.NodeID]*Peer) []task
  271. doneFunc func(task)
  272. }
  273. func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task {
  274. return tg.newFunc(running, peers)
  275. }
  276. func (tg taskgen) taskDone(t task, now time.Time) {
  277. tg.doneFunc(t)
  278. }
  279. func (tg taskgen) addStatic(*discover.Node) {
  280. }
  281. func (tg taskgen) removeStatic(*discover.Node) {
  282. }
  283. type testTask struct {
  284. index int
  285. called bool
  286. }
  287. func (t *testTask) Do(srv *Server) {
  288. t.called = true
  289. }
  290. // This test checks that connections are disconnected
  291. // just after the encryption handshake when the server is
  292. // at capacity. Trusted connections should still be accepted.
  293. func TestServerAtCap(t *testing.T) {
  294. trustedID := randomID()
  295. srv := &Server{
  296. Config: Config{
  297. PrivateKey: newkey(),
  298. MaxPeers: 10,
  299. NoDial: true,
  300. TrustedNodes: []*discover.Node{{ID: trustedID}},
  301. },
  302. }
  303. if err := srv.Start(); err != nil {
  304. t.Fatalf("could not start: %v", err)
  305. }
  306. defer srv.Stop()
  307. newconn := func(id discover.NodeID) *conn {
  308. fd, _ := net.Pipe()
  309. tx := newTestTransport(id, fd)
  310. return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)}
  311. }
  312. // Inject a few connections to fill up the peer set.
  313. for i := 0; i < 10; i++ {
  314. c := newconn(randomID())
  315. if err := srv.checkpoint(c, srv.addpeer); err != nil {
  316. t.Fatalf("could not add conn %d: %v", i, err)
  317. }
  318. }
  319. // Try inserting a non-trusted connection.
  320. c := newconn(randomID())
  321. if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
  322. t.Error("wrong error for insert:", err)
  323. }
  324. // Try inserting a trusted connection.
  325. c = newconn(trustedID)
  326. if err := srv.checkpoint(c, srv.posthandshake); err != nil {
  327. t.Error("unexpected error for trusted conn @posthandshake:", err)
  328. }
  329. if !c.is(trustedConn) {
  330. t.Error("Server did not set trusted flag")
  331. }
  332. }
  333. func TestServerSetupConn(t *testing.T) {
  334. id := randomID()
  335. srvkey := newkey()
  336. srvid := discover.PubkeyID(&srvkey.PublicKey)
  337. tests := []struct {
  338. dontstart bool
  339. tt *setupTransport
  340. flags connFlag
  341. dialDest *discover.Node
  342. wantCloseErr error
  343. wantCalls string
  344. }{
  345. {
  346. dontstart: true,
  347. tt: &setupTransport{id: id},
  348. wantCalls: "close,",
  349. wantCloseErr: errServerStopped,
  350. },
  351. {
  352. tt: &setupTransport{id: id, encHandshakeErr: errors.New("read error")},
  353. flags: inboundConn,
  354. wantCalls: "doEncHandshake,close,",
  355. wantCloseErr: errors.New("read error"),
  356. },
  357. {
  358. tt: &setupTransport{id: id},
  359. dialDest: &discover.Node{ID: randomID()},
  360. flags: dynDialedConn,
  361. wantCalls: "doEncHandshake,close,",
  362. wantCloseErr: DiscUnexpectedIdentity,
  363. },
  364. {
  365. tt: &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}},
  366. dialDest: &discover.Node{ID: id},
  367. flags: dynDialedConn,
  368. wantCalls: "doEncHandshake,doProtoHandshake,close,",
  369. wantCloseErr: DiscUnexpectedIdentity,
  370. },
  371. {
  372. tt: &setupTransport{id: id, protoHandshakeErr: errors.New("foo")},
  373. dialDest: &discover.Node{ID: id},
  374. flags: dynDialedConn,
  375. wantCalls: "doEncHandshake,doProtoHandshake,close,",
  376. wantCloseErr: errors.New("foo"),
  377. },
  378. {
  379. tt: &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}},
  380. flags: inboundConn,
  381. wantCalls: "doEncHandshake,close,",
  382. wantCloseErr: DiscSelf,
  383. },
  384. {
  385. tt: &setupTransport{id: id, phs: &protoHandshake{ID: id}},
  386. flags: inboundConn,
  387. wantCalls: "doEncHandshake,doProtoHandshake,close,",
  388. wantCloseErr: DiscUselessPeer,
  389. },
  390. }
  391. for i, test := range tests {
  392. srv := &Server{
  393. Config: Config{
  394. PrivateKey: srvkey,
  395. MaxPeers: 10,
  396. NoDial: true,
  397. Protocols: []Protocol{discard},
  398. },
  399. newTransport: func(fd net.Conn) transport { return test.tt },
  400. log: log.New(),
  401. }
  402. if !test.dontstart {
  403. if err := srv.Start(); err != nil {
  404. t.Fatalf("couldn't start server: %v", err)
  405. }
  406. }
  407. p1, _ := net.Pipe()
  408. srv.SetupConn(p1, test.flags, test.dialDest)
  409. if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
  410. t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
  411. }
  412. if test.tt.calls != test.wantCalls {
  413. t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
  414. }
  415. }
  416. }
  417. type setupTransport struct {
  418. id discover.NodeID
  419. encHandshakeErr error
  420. phs *protoHandshake
  421. protoHandshakeErr error
  422. calls string
  423. closeErr error
  424. }
  425. func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
  426. c.calls += "doEncHandshake,"
  427. return c.id, c.encHandshakeErr
  428. }
  429. func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
  430. c.calls += "doProtoHandshake,"
  431. if c.protoHandshakeErr != nil {
  432. return nil, c.protoHandshakeErr
  433. }
  434. return c.phs, nil
  435. }
  436. func (c *setupTransport) close(err error) {
  437. c.calls += "close,"
  438. c.closeErr = err
  439. }
  440. // setupConn shouldn't write to/read from the connection.
  441. func (c *setupTransport) WriteMsg(Msg) error {
  442. panic("WriteMsg called on setupTransport")
  443. }
  444. func (c *setupTransport) ReadMsg() (Msg, error) {
  445. panic("ReadMsg called on setupTransport")
  446. }
  447. func newkey() *ecdsa.PrivateKey {
  448. key, err := crypto.GenerateKey()
  449. if err != nil {
  450. panic("couldn't generate key: " + err.Error())
  451. }
  452. return key
  453. }
  454. func randomID() (id discover.NodeID) {
  455. for i := range id {
  456. id[i] = byte(rand.Intn(255))
  457. }
  458. return id
  459. }