123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325 |
- // Copyright (C) 2018 The Syncthing Authors.
- //
- // This Source Code Form is subject to the terms of the Mozilla Public
- // License, v. 2.0. If a copy of the MPL was not distributed with this file,
- // You can obtain one at https://mozilla.org/MPL/2.0/.
- package main
- import (
- "context"
- "crypto/tls"
- "encoding/binary"
- "fmt"
- io "io"
- "log"
- "net"
- "time"
- "github.com/syncthing/syncthing/lib/protocol"
- )
- const (
- replicationReadTimeout = time.Minute
- replicationWriteTimeout = 30 * time.Second
- replicationHeartbeatInterval = time.Second * 30
- )
- type replicator interface {
- send(key string, addrs []DatabaseAddress, seen int64)
- }
- // a replicationSender tries to connect to the remote address and provide
- // them with a feed of replication updates.
- type replicationSender struct {
- dst string
- cert tls.Certificate // our certificate
- allowedIDs []protocol.DeviceID
- outbox chan ReplicationRecord
- }
- func newReplicationSender(dst string, cert tls.Certificate, allowedIDs []protocol.DeviceID) *replicationSender {
- return &replicationSender{
- dst: dst,
- cert: cert,
- allowedIDs: allowedIDs,
- outbox: make(chan ReplicationRecord, replicationOutboxSize),
- }
- }
- func (s *replicationSender) Serve(ctx context.Context) error {
- // Sleep a little at startup. Peers often restart at the same time, and
- // this avoid the service failing and entering backoff state
- // unnecessarily, while also reducing the reconnect rate to something
- // reasonable by default.
- time.Sleep(2 * time.Second)
- tlsCfg := &tls.Config{
- Certificates: []tls.Certificate{s.cert},
- MinVersion: tls.VersionTLS12,
- InsecureSkipVerify: true,
- }
- // Dial the TLS connection.
- conn, err := tls.Dial("tcp", s.dst, tlsCfg)
- if err != nil {
- log.Println("Replication connect:", err)
- return err
- }
- defer func() {
- conn.SetWriteDeadline(time.Now().Add(time.Second))
- conn.Close()
- }()
- // The replication stream is not especially latency sensitive, but it is
- // quite a lot of data in small writes. Make it more efficient.
- if tcpc, ok := conn.NetConn().(*net.TCPConn); ok {
- _ = tcpc.SetNoDelay(false)
- }
- // Get the other side device ID.
- remoteID, err := deviceID(conn)
- if err != nil {
- log.Println("Replication connect:", err)
- return err
- }
- // Verify it's in the set of allowed device IDs.
- if !deviceIDIn(remoteID, s.allowedIDs) {
- log.Println("Replication connect: unexpected device ID:", remoteID)
- return err
- }
- heartBeatTicker := time.NewTicker(replicationHeartbeatInterval)
- defer heartBeatTicker.Stop()
- // Send records.
- buf := make([]byte, 1024)
- for {
- select {
- case <-heartBeatTicker.C:
- if len(s.outbox) > 0 {
- // No need to send heartbeats if there are events/prevrious
- // heartbeats to send, they will keep the connection alive.
- continue
- }
- // Empty replication message is the heartbeat:
- s.outbox <- ReplicationRecord{}
- case rec := <-s.outbox:
- // Buffer must hold record plus four bytes for size
- size := rec.Size()
- if len(buf) < size+4 {
- buf = make([]byte, size+4)
- }
- // Record comes after the four bytes size
- n, err := rec.MarshalTo(buf[4:])
- if err != nil {
- // odd to get an error here, but we haven't sent anything
- // yet so it's not fatal
- replicationSendsTotal.WithLabelValues("error").Inc()
- log.Println("Replication marshal:", err)
- continue
- }
- binary.BigEndian.PutUint32(buf, uint32(n))
- // Send
- conn.SetWriteDeadline(time.Now().Add(replicationWriteTimeout))
- if _, err := conn.Write(buf[:4+n]); err != nil {
- replicationSendsTotal.WithLabelValues("error").Inc()
- log.Println("Replication write:", err)
- // Yes, we are losing the replication event here.
- return err
- }
- replicationSendsTotal.WithLabelValues("success").Inc()
- case <-ctx.Done():
- return nil
- }
- }
- }
- func (s *replicationSender) String() string {
- return fmt.Sprintf("replicationSender(%q)", s.dst)
- }
- func (s *replicationSender) send(key string, ps []DatabaseAddress, _ int64) {
- item := ReplicationRecord{
- Key: key,
- Addresses: ps,
- }
- // The send should never block. The inbox is suitably buffered for at
- // least a few seconds of stalls, which shouldn't happen in practice.
- select {
- case s.outbox <- item:
- default:
- replicationSendsTotal.WithLabelValues("drop").Inc()
- }
- }
- // a replicationMultiplexer sends to multiple replicators
- type replicationMultiplexer []replicator
- func (m replicationMultiplexer) send(key string, ps []DatabaseAddress, seen int64) {
- for _, s := range m {
- // each send is nonblocking
- s.send(key, ps, seen)
- }
- }
- // replicationListener accepts incoming connections and reads replication
- // items from them. Incoming items are applied to the KV store.
- type replicationListener struct {
- addr string
- cert tls.Certificate
- allowedIDs []protocol.DeviceID
- db database
- }
- func newReplicationListener(addr string, cert tls.Certificate, allowedIDs []protocol.DeviceID, db database) *replicationListener {
- return &replicationListener{
- addr: addr,
- cert: cert,
- allowedIDs: allowedIDs,
- db: db,
- }
- }
- func (l *replicationListener) Serve(ctx context.Context) error {
- tlsCfg := &tls.Config{
- Certificates: []tls.Certificate{l.cert},
- ClientAuth: tls.RequestClientCert,
- MinVersion: tls.VersionTLS12,
- InsecureSkipVerify: true,
- }
- lst, err := tls.Listen("tcp", l.addr, tlsCfg)
- if err != nil {
- log.Println("Replication listen:", err)
- return err
- }
- defer lst.Close()
- for {
- select {
- case <-ctx.Done():
- return nil
- default:
- }
- // Accept a connection
- conn, err := lst.Accept()
- if err != nil {
- log.Println("Replication accept:", err)
- return err
- }
- // Figure out the other side device ID
- remoteID, err := deviceID(conn.(*tls.Conn))
- if err != nil {
- log.Println("Replication accept:", err)
- conn.SetWriteDeadline(time.Now().Add(time.Second))
- conn.Close()
- continue
- }
- // Verify it is in the set of allowed device IDs
- if !deviceIDIn(remoteID, l.allowedIDs) {
- log.Println("Replication accept: unexpected device ID:", remoteID)
- conn.SetWriteDeadline(time.Now().Add(time.Second))
- conn.Close()
- continue
- }
- go l.handle(ctx, conn)
- }
- }
- func (l *replicationListener) String() string {
- return fmt.Sprintf("replicationListener(%q)", l.addr)
- }
- func (l *replicationListener) handle(ctx context.Context, conn net.Conn) {
- defer func() {
- conn.SetWriteDeadline(time.Now().Add(time.Second))
- conn.Close()
- }()
- buf := make([]byte, 1024)
- for {
- select {
- case <-ctx.Done():
- return
- default:
- }
- conn.SetReadDeadline(time.Now().Add(replicationReadTimeout))
- // First four bytes are the size
- if _, err := io.ReadFull(conn, buf[:4]); err != nil {
- log.Println("Replication read size:", err)
- replicationRecvsTotal.WithLabelValues("error").Inc()
- return
- }
- // Read the rest of the record
- size := int(binary.BigEndian.Uint32(buf[:4]))
- if len(buf) < size {
- buf = make([]byte, size)
- }
- if size == 0 {
- // Heartbeat, ignore
- continue
- }
- if _, err := io.ReadFull(conn, buf[:size]); err != nil {
- log.Println("Replication read record:", err)
- replicationRecvsTotal.WithLabelValues("error").Inc()
- return
- }
- // Unmarshal
- var rec ReplicationRecord
- if err := rec.Unmarshal(buf[:size]); err != nil {
- log.Println("Replication unmarshal:", err)
- replicationRecvsTotal.WithLabelValues("error").Inc()
- continue
- }
- // Store
- l.db.merge(rec.Key, rec.Addresses, rec.Seen)
- replicationRecvsTotal.WithLabelValues("success").Inc()
- }
- }
- func deviceID(conn *tls.Conn) (protocol.DeviceID, error) {
- // Handshake may not be complete on the server side yet, which we need
- // to get the client certificate.
- if !conn.ConnectionState().HandshakeComplete {
- if err := conn.Handshake(); err != nil {
- return protocol.DeviceID{}, err
- }
- }
- // We expect exactly one certificate.
- certs := conn.ConnectionState().PeerCertificates
- if len(certs) != 1 {
- return protocol.DeviceID{}, fmt.Errorf("unexpected number of certificates (%d != 1)", len(certs))
- }
- return protocol.NewDeviceID(certs[0].Raw), nil
- }
- func deviceIDIn(id protocol.DeviceID, ids []protocol.DeviceID) bool {
- for _, candidate := range ids {
- if id == candidate {
- return true
- }
- }
- return false
- }
|