lib.rs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. // -*- coding: utf-8 -*-
  2. //
  3. // Copyright (C) 2024 Michael Büsch <m@bues.ch>
  4. //
  5. // Licensed under the Apache License version 2.0
  6. // or the MIT license, at your option.
  7. // SPDX-License-Identifier: Apache-2.0 OR MIT
  8. //! This crate implements the firewall socket protocol
  9. //! for communication between the `letmeind` and `letmeinfwd` daemons.
  10. //!
  11. //! Serializing messages to a raw byte stream and
  12. //! deserializing raw byte stream to a message is implemented here.
  13. #![forbid(unsafe_code)]
  14. use anyhow::{self as ah, format_err as err, Context as _};
  15. use std::{
  16. future::Future,
  17. net::{IpAddr, Ipv4Addr},
  18. };
  19. use tokio::io::ErrorKind;
  20. #[cfg(any(target_os = "linux", target_os = "android"))]
  21. use tokio::net::UnixStream;
  22. #[cfg(target_os = "windows")]
  23. use tokio::net::windows::named_pipe::{NamedPipeClient, NamedPipeServer};
  24. /// Firewall daemon Unix socket file name.
  25. pub const SOCK_FILE: &str = "letmeinfwd.sock";
  26. /// The operation to perform on the firewall.
  27. #[derive(Clone, Copy, PartialEq, Eq, Debug, Default)]
  28. #[repr(u16)]
  29. pub enum FirewallOperation {
  30. /// Not-Acknowledge message.
  31. #[default]
  32. Nack,
  33. /// Acknowledge message.
  34. Ack,
  35. /// Open a port.
  36. Open,
  37. }
  38. impl TryFrom<u16> for FirewallOperation {
  39. type Error = ah::Error;
  40. fn try_from(value: u16) -> Result<Self, Self::Error> {
  41. const OPERATION_OPEN: u16 = FirewallOperation::Open as u16;
  42. const OPERATION_ACK: u16 = FirewallOperation::Ack as u16;
  43. const OPERATION_NACK: u16 = FirewallOperation::Nack as u16;
  44. match value {
  45. OPERATION_OPEN => Ok(Self::Open),
  46. OPERATION_ACK => Ok(Self::Ack),
  47. OPERATION_NACK => Ok(Self::Nack),
  48. _ => Err(err!("Invalid FirewallMessage/Operation value")),
  49. }
  50. }
  51. }
  52. impl From<FirewallOperation> for u16 {
  53. fn from(operation: FirewallOperation) -> u16 {
  54. operation as _
  55. }
  56. }
  57. /// The type of port to open in the firewall.
  58. #[derive(Clone, Copy, PartialEq, Eq, Debug, Default)]
  59. #[repr(u16)]
  60. pub enum PortType {
  61. /// TCP port only.
  62. #[default]
  63. Tcp,
  64. /// UDP port only.
  65. Udp,
  66. /// TCP and UDP port.
  67. TcpUdp,
  68. }
  69. impl TryFrom<u16> for PortType {
  70. type Error = ah::Error;
  71. fn try_from(value: u16) -> Result<Self, Self::Error> {
  72. const PORTTYPE_TCP: u16 = PortType::Tcp as u16;
  73. const PORTTYPE_UDP: u16 = PortType::Udp as u16;
  74. const PORTTYPE_TCPUDP: u16 = PortType::TcpUdp as u16;
  75. match value {
  76. PORTTYPE_TCP => Ok(Self::Tcp),
  77. PORTTYPE_UDP => Ok(Self::Udp),
  78. PORTTYPE_TCPUDP => Ok(Self::TcpUdp),
  79. _ => Err(err!("Invalid FirewallMessage/PortType value")),
  80. }
  81. }
  82. }
  83. impl From<PortType> for u16 {
  84. fn from(port_type: PortType) -> u16 {
  85. port_type as _
  86. }
  87. }
  88. /// The type of address to open in the firewall.
  89. #[derive(Clone, Copy, PartialEq, Eq, Debug, Default)]
  90. #[repr(u16)]
  91. pub enum AddrType {
  92. /// IPv6 address.
  93. #[default]
  94. Ipv6,
  95. /// IPv4 address.
  96. Ipv4,
  97. }
  98. impl TryFrom<u16> for AddrType {
  99. type Error = ah::Error;
  100. fn try_from(value: u16) -> Result<Self, Self::Error> {
  101. const ADDRTYPE_IPV6: u16 = AddrType::Ipv6 as u16;
  102. const ADDRTYPE_IPV4: u16 = AddrType::Ipv4 as u16;
  103. match value {
  104. ADDRTYPE_IPV6 => Ok(Self::Ipv6),
  105. ADDRTYPE_IPV4 => Ok(Self::Ipv4),
  106. _ => Err(err!("Invalid FirewallMessage/AddrType value")),
  107. }
  108. }
  109. }
  110. impl From<AddrType> for u16 {
  111. fn from(addr_type: AddrType) -> u16 {
  112. addr_type as _
  113. }
  114. }
  115. /// Size of the `addr` field in the message.
  116. const ADDR_SIZE: usize = 16;
  117. /// Size of the firewall control message.
  118. const FWMSG_SIZE: usize = 2 + 2 + 2 + 2 + ADDR_SIZE;
  119. /// Byte offset of the `operation` field in the firewall control message.
  120. const FWMSG_OFFS_OPERATION: usize = 0;
  121. /// Byte offset of the `port_type` field in the firewall control message.
  122. const FWMSG_OFFS_PORT_TYPE: usize = 2;
  123. /// Byte offset of the `port` field in the firewall control message.
  124. const FWMSG_OFFS_PORT: usize = 4;
  125. /// Byte offset of the `addr_type` field in the firewall control message.
  126. const FWMSG_OFFS_ADDR_TYPE: usize = 6;
  127. /// Byte offset of the `addr` field in the firewall control message.
  128. const FWMSG_OFFS_ADDR: usize = 8;
  129. /// A message to control the firewall.
  130. #[derive(PartialEq, Eq, Debug, Default)]
  131. pub struct FirewallMessage {
  132. operation: FirewallOperation,
  133. port_type: PortType,
  134. port: u16,
  135. addr_type: AddrType,
  136. addr: [u8; ADDR_SIZE],
  137. }
  138. /// Convert an `IpAddr` to the `operation` and `addr` fields of a firewall control message.
  139. fn addr_to_octets(addr: IpAddr) -> (AddrType, [u8; ADDR_SIZE]) {
  140. match addr {
  141. IpAddr::V4(addr) => {
  142. let o = addr.octets();
  143. (
  144. AddrType::Ipv4,
  145. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, o[0], o[1], o[2], o[3]],
  146. )
  147. }
  148. IpAddr::V6(addr) => (AddrType::Ipv6, addr.octets()),
  149. }
  150. }
  151. /// Convert a firewall control message `operation` and `addr` fields to an `IpAddr`.
  152. fn octets_to_addr(addr_type: AddrType, addr: &[u8; ADDR_SIZE]) -> IpAddr {
  153. match addr_type {
  154. AddrType::Ipv4 => Ipv4Addr::new(addr[12], addr[13], addr[14], addr[15]).into(),
  155. AddrType::Ipv6 => (*addr).into(),
  156. }
  157. }
  158. impl FirewallMessage {
  159. /// Construct a new message that requests installing a firewall-port-open rule.
  160. pub fn new_open(addr: IpAddr, port_type: PortType, port: u16) -> Self {
  161. let (addr_type, addr) = addr_to_octets(addr);
  162. Self {
  163. operation: FirewallOperation::Open,
  164. port_type,
  165. port,
  166. addr_type,
  167. addr,
  168. }
  169. }
  170. /// Construct a new acknowledge message.
  171. pub fn new_ack() -> Self {
  172. Self {
  173. operation: FirewallOperation::Ack,
  174. ..Default::default()
  175. }
  176. }
  177. /// Construct a new not-acknowledge message.
  178. pub fn new_nack() -> Self {
  179. Self {
  180. operation: FirewallOperation::Nack,
  181. ..Default::default()
  182. }
  183. }
  184. /// Get the operation type from this message.
  185. pub fn operation(&self) -> FirewallOperation {
  186. self.operation
  187. }
  188. /// Get the port number from this message.
  189. pub fn port(&self) -> Option<(PortType, u16)> {
  190. match self.operation {
  191. FirewallOperation::Open => Some((self.port_type, self.port)),
  192. FirewallOperation::Ack | FirewallOperation::Nack => None,
  193. }
  194. }
  195. /// Get the `IpAddr` from this message.
  196. pub fn addr(&self) -> Option<IpAddr> {
  197. match self.operation {
  198. FirewallOperation::Open => Some(octets_to_addr(self.addr_type, &self.addr)),
  199. FirewallOperation::Ack | FirewallOperation::Nack => None,
  200. }
  201. }
  202. /// Serialize this message into a byte stream.
  203. pub fn msg_serialize(&self) -> ah::Result<[u8; FWMSG_SIZE]> {
  204. // The serialization is simple enough to do manually.
  205. // Therefore, we don't use the `serde` crate here.
  206. #[inline]
  207. fn serialize_u16(buf: &mut [u8], value: u16) {
  208. buf[0..2].copy_from_slice(&value.to_be_bytes());
  209. }
  210. let mut buf = [0; FWMSG_SIZE];
  211. serialize_u16(&mut buf[FWMSG_OFFS_OPERATION..], self.operation.into());
  212. serialize_u16(&mut buf[FWMSG_OFFS_PORT_TYPE..], self.port_type.into());
  213. serialize_u16(&mut buf[FWMSG_OFFS_PORT..], self.port);
  214. serialize_u16(&mut buf[FWMSG_OFFS_ADDR_TYPE..], self.addr_type.into());
  215. buf[FWMSG_OFFS_ADDR..FWMSG_OFFS_ADDR + ADDR_SIZE].copy_from_slice(&self.addr);
  216. Ok(buf)
  217. }
  218. /// Try to deserialize a byte stream into a message.
  219. pub fn try_msg_deserialize(buf: &[u8]) -> ah::Result<Self> {
  220. if buf.len() != FWMSG_SIZE {
  221. return Err(err!("Deserialize: Raw message size mismatch."));
  222. }
  223. // The deserialization is simple enough to do manually.
  224. // Therefore, we don't use the `serde` crate here.
  225. #[inline]
  226. fn deserialize_u16(buf: &[u8]) -> ah::Result<u16> {
  227. Ok(u16::from_be_bytes(buf[0..2].try_into()?))
  228. }
  229. let operation = deserialize_u16(&buf[FWMSG_OFFS_OPERATION..])?;
  230. let port_type = deserialize_u16(&buf[FWMSG_OFFS_PORT_TYPE..])?;
  231. let port = deserialize_u16(&buf[FWMSG_OFFS_PORT..])?;
  232. let addr_type = deserialize_u16(&buf[FWMSG_OFFS_ADDR_TYPE..])?;
  233. let addr = &buf[FWMSG_OFFS_ADDR..FWMSG_OFFS_ADDR + ADDR_SIZE];
  234. Ok(Self {
  235. operation: operation.try_into()?,
  236. port_type: port_type.try_into()?,
  237. port,
  238. addr_type: addr_type.try_into()?,
  239. addr: addr.try_into()?,
  240. })
  241. }
  242. /// Send this message over a [Stream].
  243. pub async fn send(&self, stream: &mut impl Stream) -> ah::Result<()> {
  244. let txbuf = self.msg_serialize()?;
  245. let mut txcount = 0;
  246. loop {
  247. stream.writable().await.context("Socket polling (tx)")?;
  248. match stream.try_write(&txbuf[txcount..]) {
  249. Ok(n) => {
  250. txcount += n;
  251. assert!(txcount <= txbuf.len());
  252. if txcount == txbuf.len() {
  253. return Ok(());
  254. }
  255. }
  256. Err(e) if e.kind() == ErrorKind::WouldBlock => (),
  257. Err(e) => {
  258. return Err(err!("Socket write: {e}"));
  259. }
  260. }
  261. }
  262. }
  263. /// Try to receive a message from a [Stream].
  264. pub async fn recv(stream: &mut impl Stream) -> ah::Result<Option<Self>> {
  265. let mut rxbuf = [0; FWMSG_SIZE];
  266. let mut rxcount = 0;
  267. loop {
  268. stream.readable().await.context("Socket polling (rx)")?;
  269. match stream.try_read(&mut rxbuf[rxcount..]) {
  270. Ok(n) => {
  271. if n == 0 {
  272. return Ok(None);
  273. }
  274. rxcount += n;
  275. assert!(rxcount <= FWMSG_SIZE);
  276. if rxcount == FWMSG_SIZE {
  277. return Ok(Some(Self::try_msg_deserialize(&rxbuf)?));
  278. }
  279. }
  280. Err(e) if e.kind() == ErrorKind::WouldBlock => (),
  281. Err(e) => {
  282. return Err(err!("Socket read: {e}"));
  283. }
  284. }
  285. }
  286. }
  287. }
  288. /// Communication stream abstraction.
  289. pub trait Stream {
  290. fn readable(&self) -> impl Future<Output = std::io::Result<()>> + Send;
  291. fn try_read(&self, buf: &mut [u8]) -> std::io::Result<usize>;
  292. fn writable(&self) -> impl Future<Output = std::io::Result<()>> + Send;
  293. fn try_write(&self, buf: &[u8]) -> std::io::Result<usize>;
  294. }
  295. macro_rules! impl_stream_for {
  296. ($ty:ty) => {
  297. impl Stream for $ty {
  298. fn readable(&self) -> impl Future<Output = std::io::Result<()>> + Send {
  299. self.readable()
  300. }
  301. fn try_read(&self, buf: &mut [u8]) -> std::io::Result<usize> {
  302. self.try_read(buf)
  303. }
  304. fn writable(&self) -> impl Future<Output = std::io::Result<()>> + Send {
  305. self.writable()
  306. }
  307. fn try_write(&self, buf: &[u8]) -> std::io::Result<usize> {
  308. self.try_write(buf)
  309. }
  310. }
  311. };
  312. }
  313. #[cfg(any(target_os = "linux", target_os = "android"))]
  314. impl_stream_for!(UnixStream);
  315. #[cfg(target_os = "windows")]
  316. impl_stream_for!(NamedPipeClient);
  317. #[cfg(target_os = "windows")]
  318. impl_stream_for!(NamedPipeServer);
  319. #[cfg(test)]
  320. mod tests {
  321. use super::*;
  322. fn check_ser_de(msg: &FirewallMessage) {
  323. // Serialize a message and then deserialize the byte stream
  324. // and check if the resulting message is the same.
  325. let bytes = msg.msg_serialize().unwrap();
  326. let msg_de = FirewallMessage::try_msg_deserialize(&bytes).unwrap();
  327. assert_eq!(*msg, msg_de);
  328. }
  329. #[test]
  330. fn test_msg_open_v6() {
  331. let msg = FirewallMessage::new_open("::1".parse().unwrap(), PortType::Tcp, 0x9876);
  332. assert_eq!(msg.operation(), FirewallOperation::Open);
  333. assert_eq!(msg.port(), Some((PortType::Tcp, 0x9876)));
  334. assert_eq!(msg.addr(), Some("::1".parse().unwrap()));
  335. check_ser_de(&msg);
  336. let msg = FirewallMessage::new_open(
  337. "0102:0304:0506:0708:090A:0B0C:0D0E:0F10".parse().unwrap(),
  338. PortType::Tcp,
  339. 0x9876,
  340. );
  341. let bytes = msg.msg_serialize().unwrap();
  342. assert_eq!(
  343. bytes,
  344. [
  345. 0x00, 0x02, // operation
  346. 0x00, 0x00, // port_type
  347. 0x98, 0x76, // port
  348. 0x00, 0x00, // addr_type
  349. 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, // addr
  350. 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, // addr
  351. ]
  352. );
  353. let msg = FirewallMessage::new_open(
  354. "0102:0304:0506:0708:090A:0B0C:0D0E:0F10".parse().unwrap(),
  355. PortType::Udp,
  356. 0x9876,
  357. );
  358. let bytes = msg.msg_serialize().unwrap();
  359. assert_eq!(
  360. bytes,
  361. [
  362. 0x00, 0x02, // operation
  363. 0x00, 0x01, // port_type
  364. 0x98, 0x76, // port
  365. 0x00, 0x00, // addr_type
  366. 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, // addr
  367. 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, // addr
  368. ]
  369. );
  370. let msg = FirewallMessage::new_open(
  371. "0102:0304:0506:0708:090A:0B0C:0D0E:0F10".parse().unwrap(),
  372. PortType::TcpUdp,
  373. 0x9876,
  374. );
  375. let bytes = msg.msg_serialize().unwrap();
  376. assert_eq!(
  377. bytes,
  378. [
  379. 0x00, 0x02, // operation
  380. 0x00, 0x02, // port_type
  381. 0x98, 0x76, // port
  382. 0x00, 0x00, // addr_type
  383. 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, // addr
  384. 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, // addr
  385. ]
  386. );
  387. }
  388. #[test]
  389. fn test_msg_open_v4() {
  390. let msg = FirewallMessage::new_open("1.2.3.4".parse().unwrap(), PortType::Tcp, 0x9876);
  391. assert_eq!(msg.operation(), FirewallOperation::Open);
  392. assert_eq!(msg.port(), Some((PortType::Tcp, 0x9876)));
  393. assert_eq!(msg.addr(), Some("1.2.3.4".parse().unwrap()));
  394. check_ser_de(&msg);
  395. let bytes = msg.msg_serialize().unwrap();
  396. assert_eq!(
  397. bytes,
  398. [
  399. 0x00, 0x02, // operation
  400. 0x00, 0x00, // port_type
  401. 0x98, 0x76, // port
  402. 0x00, 0x01, // addr_type
  403. 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // addr
  404. 0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03, 0x04, // addr
  405. ]
  406. );
  407. }
  408. #[test]
  409. fn test_msg_ack() {
  410. let msg = FirewallMessage::new_ack();
  411. assert_eq!(msg.operation(), FirewallOperation::Ack);
  412. assert_eq!(msg.port(), None);
  413. assert_eq!(msg.addr(), None);
  414. check_ser_de(&msg);
  415. let bytes = msg.msg_serialize().unwrap();
  416. assert_eq!(
  417. bytes,
  418. [
  419. 0x00, 0x01, // operation
  420. 0x00, 0x00, // port_type
  421. 0x00, 0x00, // port
  422. 0x00, 0x00, // addr_type
  423. 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // addr
  424. 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // addr
  425. ]
  426. );
  427. }
  428. #[test]
  429. fn test_msg_nack() {
  430. let msg = FirewallMessage::new_nack();
  431. assert_eq!(msg.operation(), FirewallOperation::Nack);
  432. assert_eq!(msg.port(), None);
  433. assert_eq!(msg.addr(), None);
  434. check_ser_de(&msg);
  435. let bytes = msg.msg_serialize().unwrap();
  436. assert_eq!(
  437. bytes,
  438. [
  439. 0x00, 0x00, // operation
  440. 0x00, 0x00, // port_type
  441. 0x00, 0x00, // port
  442. 0x00, 0x00, // addr_type
  443. 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // addr
  444. 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // addr
  445. ]
  446. );
  447. }
  448. }
  449. // vim: ts=4 sw=4 expandtab