1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
//! Implementation of an ad-hoc protocol on top of UDP used for sending gameplay
//! update between the server and client.
//!
//! This protocol provides the following qualities on top of UDP:
//! - Reliability. Messages are queued up and resent until the peer confirms
//!   that they've received them.
//! - Ordering. Messages are sent with ordered message ids, so the receiver can
//!   build up their own ordered queue of received messages.
//! - Compression. The sent messages are compressed after serialization, using
//!   Google's Snappy compression, using the [snap] library.
//! - Encryption. Messages are encrypted using ChaCha20-Poly1305, with a
//!   randomly generated nonce included in plaintext in front of the message.
//!   The encryption key is shared during the login process which is implemented
//!   directly in the [server](../../server/index.html) and
//!   [client](../../client/index.html) modules.

use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use snap::read::FrameDecoder;
use snap::write::FrameEncoder;
use std::collections::VecDeque;
use std::fmt::Display;
use std::io::{self, Cursor};
use std::time::Duration;

use crate::symmetric_crypto;

const MAX_MESSAGES_AT_ONCE: usize = 10;
/// The amount of time after which the peer is considered to have disconnected.
pub const DISCONNECT_THRESHOLD: Duration = Duration::from_secs(5);

/// A message container with metadata about the sender's acked messages (so that
/// the receiver can clear their to-be-sent queue) and the messages' order
/// number (which can be used to determine if the message has already been
/// processed and should be skipped).
#[derive(Serialize, Deserialize)]
struct ReliableMessage<T> {
    /// Messages sent from the peer, in sending order. They should be
    /// interpreted in sending order, where the first element is the first
    /// message that was sent.
    pub messages: Vec<T>,
    /// The message number of the first message in `messages`. This should be
    /// used to filter incoming messages to avoid processing any messages twice.
    pub first_message_number: u64,
    /// How many messages the peer has received. This should be used to clear
    /// out the message queue up to this point.
    pub received_messages: u64,
}

/// Owns any relevant state and provides the functions needed to use the
/// gameplay message protocol with a specific peer who has a matching
/// [MessageOrderer].
///
/// Provides two sets of send/receive functions: one to pass in received UDP
/// datagrams and get UDP datagrams to send, another for sending and receiving
/// messages of the generic types `S` and `R`. The peer's [MessageOrderer] must
/// have the same `S` and `R` types, except flipped.
pub struct MessageOrderer<S, R> {
    sent_messages: u64,
    sent_message_queue: Vec<S>,
    received_messages: u64,
    received_message_queue: VecDeque<R>,
    ack_needed: bool,
    encryption_key: [u8; symmetric_crypto::KEY_LEN],
}

impl<S: Serialize + DeserializeOwned + Clone, R: Serialize + DeserializeOwned>
    MessageOrderer<S, R>
{
    /// Creates a new [MessageOrderer] with the given key. One should be created
    /// for each peer, so the client should have one, and the server should have
    /// multiple.
    pub fn new(encryption_key: [u8; symmetric_crypto::KEY_LEN]) -> MessageOrderer<S, R> {
        MessageOrderer {
            sent_messages: 0,
            sent_message_queue: Vec::new(),
            received_messages: 0,
            received_message_queue: VecDeque::new(),
            ack_needed: false,
            encryption_key,
        }
    }

    /// Get the next message, if there are any.
    pub fn recv(&mut self) -> Option<R> {
        self.received_message_queue.pop_front()
    }

    /// Send a message. Note that this only adds the message to this struct's
    /// queue, and needs to be separately sent to the intended recipient over
    /// UDP.
    pub fn send(&mut self, message: S) {
        self.sent_messages += 1;
        self.sent_message_queue.push(message);
    }

    /// Pass in a UDP datagram's bytes to be handled as a received message,
    /// produced on the other end by [MessageOrderer::transport_send]. The
    /// address is just for logging.
    ///
    /// `peer_id` is just used for tagging logs appropriately.
    pub fn transport_recv<D: Display>(&mut self, peer_id: D, encrypted_datagram: &mut [u8]) {
        let compressed_datagram = match symmetric_crypto::decrypt(
            self.encryption_key,
            encrypted_datagram,
        ) {
            Ok(datagram) => datagram,
            Err(_) => {
                log::debug!("<- [{peer_id}]: could not decrypt received datagram! Probably caused by mismatched encryption keys.");
                return;
            }
        };

        let mut serialized_datagram = Vec::with_capacity(compressed_datagram.len() * 2);
        let mut decoder = FrameDecoder::new(&*compressed_datagram);
        if let Err(err) = io::copy(&mut decoder, &mut serialized_datagram) {
            log::debug!("<- [{peer_id}]: error decompressing UDP datagram: {err}");
            return;
        }

        let mut message = match bincode::deserialize::<ReliableMessage<R>>(&serialized_datagram) {
            Ok(message) => message,
            Err(err) => {
                log::debug!("<- [{peer_id}]: error deserializing UDP datagram: {err}");
                return;
            }
        };

        let discarded = if self.received_messages < message.first_message_number {
            let skipped = message.first_message_number - self.received_messages;
            log::debug!("<- [{peer_id}]: skipped {skipped} messages!");
            0
        } else {
            (self.received_messages - message.first_message_number) as usize
        };
        let total = message.messages.len();
        if discarded < total {
            message.messages.drain(..discarded as usize);
        } else {
            message.messages.clear();
        }
        log::trace!("<- [{peer_id}]: received {total} messages, {discarded} already seen");

        self.received_messages += message.messages.len() as u64;
        for message in message.messages {
            self.received_message_queue.push_back(message);
            self.ack_needed = true;
        }

        if self.sent_messages < message.received_messages {
            log::trace!("<- [{peer_id}]: misbehaving peer, they report more received messages than we've sent");
            return;
        }
        let messages_to_send_left = (self.sent_messages - message.received_messages) as usize;
        if messages_to_send_left > self.sent_message_queue.len() {
            log::trace!("<- [{peer_id}]: out-of-order datagram (implied that the peer hasn't received previously acked messages)");
            return;
        }
        let received_messages_in_queue = self.sent_message_queue.len() - messages_to_send_left;
        self.sent_message_queue.drain(..received_messages_in_queue);
    }

    /// Returns a UDP datagram to send if there's anything to send. The other
    /// end should pass the datagram to [MessageOrderer::transport_recv].
    ///
    /// `peer_id` is just used for tagging logs appropriately.
    pub fn transport_send<D: Display>(&mut self, peer_id: D) -> Option<Vec<u8>> {
        let queue_len = self.sent_message_queue.len();
        if queue_len > 0 || self.ack_needed {
            self.ack_needed = false;
            let message = ReliableMessage {
                messages: self.sent_message_queue[..queue_len.min(MAX_MESSAGES_AT_ONCE)].to_vec(),
                first_message_number: self.sent_messages - queue_len as u64,
                received_messages: self.received_messages,
            };
            log::trace!("-> [{peer_id}]: sent {queue_len} messages");
            match bincode::serialize(&message) {
                Ok(bytes) => {
                    let mut compressed_bytes = Cursor::new(Vec::with_capacity(bytes.len()));
                    let mut uncompressed_bytes = Cursor::new(bytes);
                    let mut encoder = FrameEncoder::new(&mut compressed_bytes);
                    match io::copy(&mut uncompressed_bytes, &mut encoder) {
                        Ok(_) => {
                            drop(encoder);
                            let mut compressed_bytes = compressed_bytes.into_inner();
                            symmetric_crypto::encrypt(self.encryption_key, &mut compressed_bytes);
                            Some(compressed_bytes)
                        }
                        Err(err) => {
                            log::warn!("-> [{peer_id}]: failed to send {queue_len} messages due to compression failing: {err}");
                            None
                        }
                    }
                }
                Err(err) => {
                    log::warn!("-> [{peer_id}]: failed to send {queue_len} messages, could not serialize: {err}");
                    None
                }
            }
        } else {
            None
        }
    }
}