From 5dfad8264ef9d2cb0f2b7e650c29f42d3d39751a Mon Sep 17 00:00:00 2001 From: PoliEcho Date: Thu, 31 Jul 2025 20:39:07 +0200 Subject: [PATCH] add actual hole punching --- src/client/main.rs | 100 +++++++------- src/client/net.rs | 285 +++++++++++++++++++--------------------- src/client/net_utils.rs | 112 ---------------- src/lib.rs | 5 +- src/server/net.rs | 33 +++-- src/server/types.rs | 12 +- src/shared/mod.rs | 1 + src/shared/net.rs | 213 ++++++++++++++++++++++++++++++ 8 files changed, 428 insertions(+), 333 deletions(-) delete mode 100644 src/client/net_utils.rs create mode 100644 src/shared/net.rs diff --git a/src/client/main.rs b/src/client/main.rs index 1e4286c..e652365 100644 --- a/src/client/main.rs +++ b/src/client/main.rs @@ -1,5 +1,4 @@ mod net; -mod net_utils; mod tun; mod types; use colored::Colorize; @@ -192,28 +191,23 @@ fn main() -> std::io::Result<()> { "[LOG]".blue() ); let mut network_write_lock = virtual_network.write().unwrap(); // avoid deadlock - - + let encrypted = network_write_lock.encrypted; let key = network_write_lock.key; - network_write_lock - .peers - .iter_mut() - .for_each(|peer| { - match net::P2P_query(&mut buf, &peer.sock_addr, &socket, encrypted,key) { - Ok(ip) => { - ips_used[ip.octets()[3] as usize] = true; - peer.private_ip = ip; - } - Err(e) => eprintln!( - "{} while getting ip from peer: {}, Error: {}", - "[ERROR]".red(), - peer.sock_addr, - e - ), - }; - }); - + network_write_lock.peers.iter_mut().for_each(|peer| { + match net::P2P_query(&mut buf, &peer.sock_addr, &socket, encrypted, key) { + Ok(ip) => { + ips_used[ip.octets()[3] as usize] = true; + peer.private_ip = ip; + } + Err(e) => eprintln!( + "{} while getting ip from peer: {}, Error: {}", + "[ERROR]".red(), + peer.sock_addr, + e + ), + }; + }); network_write_lock.private_ip = std::net::Ipv4Addr::new( DEFAULT_NETWORK_PREFIX[0], @@ -226,30 +220,28 @@ fn main() -> std::io::Result<()> { .peers .retain(|peer| peer.private_ip != std::net::Ipv4Addr::UNSPECIFIED); // remove all peers without ip - network_write_lock - .peers - .iter() - .for_each(|peer| { - match net::P2P_hello( - &mut buf, - &peer.sock_addr, - &socket, - network_write_lock.private_ip, - encrypted,key - ) { - Ok(_) => eprintln!( - "{} registered with peer: {}", - "[SUCCESS]".green(), - peer.sock_addr - ), - Err(e) => eprintln!( - "{} failed to register with peer: {}, Error: {}", - "[ERROR]".red(), - peer.sock_addr, - e - ), - } - }); + network_write_lock.peers.iter().for_each(|peer| { + match net::P2P_hello( + &mut buf, + &peer.sock_addr, + &socket, + network_write_lock.private_ip, + encrypted, + key, + ) { + Ok(_) => eprintln!( + "{} registered with peer: {}", + "[SUCCESS]".green(), + peer.sock_addr + ), + Err(e) => eprintln!( + "{} failed to register with peer: {}, Error: {}", + "[ERROR]".red(), + peer.sock_addr, + e + ), + } + }); } let tun_iface = Arc::new( @@ -270,17 +262,15 @@ fn main() -> std::io::Result<()> { #[cfg(not(feature = "no-timeout"))] socket.set_read_timeout(None)?; - {let tun_iface_clone = tun_iface.clone(); - let socket_clone = socket.clone(); - let virtual_network_clone = virtual_network.clone(); + { + let tun_iface_clone = tun_iface.clone(); + let socket_clone = socket.clone(); + let virtual_network_clone = virtual_network.clone(); - std::thread::spawn(move || { - tun::read_tun_iface( - tun_iface_clone, - socket_clone, - virtual_network_clone, - ) -});} // just let me have my thread + std::thread::spawn(move || { + tun::read_tun_iface(tun_iface_clone, socket_clone, virtual_network_clone) + }); + } // just let me have my thread smol::block_on(async { loop { diff --git a/src/client/net.rs b/src/client/net.rs index f2c7907..098a321 100644 --- a/src/client/net.rs +++ b/src/client/net.rs @@ -1,116 +1,14 @@ use std::{ - io::ErrorKind, net::{Ipv4Addr, SocketAddr, UdpSocket}, str::FromStr, sync::{Arc, RwLock}, }; use super::types; -use crate::net_utils; -use crate::types::Peer; use colored::Colorize; -use libc::socket; -use pea_2_pea::*; +use pea_2_pea::{shared::net::send_and_recv_with_retry, *}; use rand::{RngCore, rng}; -// return data_lenght and number of retryes -pub fn send_and_recv_with_retry( - buf: &mut [u8; UDP_BUFFER_SIZE], - send_buf: &[u8], - dst: &SocketAddr, - socket: &UdpSocket, - retry_max: usize, -) -> Result<(usize, usize), ServerErrorResponses> { - #[cfg(any(target_os = "linux", target_os = "windows"))] - net_utils::enable_icmp_errors(socket)?; - - let mut retry_count: usize = 0; - - loop { - match socket.send_to(send_buf, dst) { - Ok(s) => { - #[cfg(debug_assertions)] - eprintln!("send {} bytes", s); - } - Err(e) => match e.kind() { - ErrorKind::ConnectionReset - | ErrorKind::ConnectionRefused - | ErrorKind::NetworkUnreachable - | ErrorKind::HostUnreachable => { - return Err(ServerErrorResponses::IO(std::io::Error::new( - e.kind(), - format!("Destination unreachable: {}", e), - ))); - } - _ => return Err(ServerErrorResponses::IO(e)), - }, - } - - #[cfg(target_os = "linux")] - if let Err(icmp_error) = net_utils::check_icmp_error_queue(socket) { - return Err(ServerErrorResponses::IO(icmp_error)); - } - - match socket.recv_from(buf) { - Ok((data_length, src)) => { - if src != *dst { - continue; - } - match buf[0] { - x if x == send_buf[0] as u8 => { - return Ok((data_length, retry_count)); - } - x if x == ServerResponse::GENERAL_ERROR as u8 => { - return Err(ServerErrorResponses::IO(std::io::Error::new( - std::io::ErrorKind::InvalidData, - match std::str::from_utf8(&buf[1..data_length]) { - Ok(s) => s.to_string(), - Err(e) => format!("invalid error string: {}", e), - }, - ))); - } - x if x == ServerResponse::ID_DOESNT_EXIST as u8 => { - return Err(ServerErrorResponses::ID_DOESNT_EXIST); - } - x if x == ServerResponse::ID_EXISTS as u8 => { - return Err(ServerErrorResponses::ID_EXISTS); - } - _ => { - continue; - } - } - } - Err(e) if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut => { - #[cfg(target_os = "linux")] - if let Err(icmp_error) = net_utils::check_icmp_error_queue(socket) { - return Err(ServerErrorResponses::IO(icmp_error)); - } - - if retry_count >= retry_max { - return Err(ServerErrorResponses::IO(std::io::Error::new( - ErrorKind::TimedOut, - "Max retry count reached - destination may be unreachable", - ))); - } - retry_count += 1; - continue; - } - Err(e) => match e.kind() { - ErrorKind::ConnectionReset - | ErrorKind::ConnectionRefused - | ErrorKind::NetworkUnreachable - | ErrorKind::HostUnreachable => { - return Err(ServerErrorResponses::IO(std::io::Error::new( - e.kind(), - format!("Destination unreachable during receive: {}", e), - ))); - } - _ => return Err(ServerErrorResponses::IO(e)), - }, - } - } -} - pub fn query_request( buf: &mut [u8; UDP_BUFFER_SIZE], dst: &SocketAddr, @@ -249,7 +147,7 @@ pub fn get_request( .unwrap(); let mut offset: usize = 0; - let mut peers: Vec = Vec::with_capacity(1); // at least one client + let mut peers: Vec = Vec::with_capacity(1); // at least one client let key: [u8; 32] = match password { Some(p) => shared::crypto::derive_key_from_password(p.as_bytes(), &salt), @@ -413,7 +311,7 @@ pub fn P2P_query( dst: &SocketAddr, socket: &UdpSocket, encrypted: bool, // avoid deadlock - key: [u8; 32] + key: [u8; 32], ) -> Result> { #[cfg(debug_assertions)] println!("P2P QUERY method"); @@ -426,43 +324,39 @@ pub fn P2P_query( STANDARD_RETRY_MAX, )?; - let iv: [u8; BLOCK_SIZE] = buf[P2PStandardDataPositions::IV as usize - ..P2PStandardDataPositions::IV as usize + BLOCK_SIZE] + let iv: [u8; BLOCK_SIZE] = buf + [P2PStandardDataPositions::IV as usize..P2PStandardDataPositions::IV as usize + BLOCK_SIZE] .try_into() .expect("this should never happen"); let tmp_decrypted: Vec; - return Ok(std::net::Ipv4Addr::from_str( - if encrypted { - match shared::crypto::decrypt( - &key, - &iv, - &buf[P2PStandardDataPositions::DATA as usize..data_lenght - 1], - ) { - Ok(decrypted) => { - tmp_decrypted = decrypted; - match std::str::from_utf8(&tmp_decrypted) { - Ok(s) => s, - Err(e) => return Err(Box::new(e)), - } - } - Err(e) => { - return Err(Box::new(ServerErrorResponses::GENERAL_ERROR(format!( - "{}", - e - )))); + return Ok(std::net::Ipv4Addr::from_str(if encrypted { + match shared::crypto::decrypt( + &key, + &iv, + &buf[P2PStandardDataPositions::DATA as usize..data_lenght - 1], + ) { + Ok(decrypted) => { + tmp_decrypted = decrypted; + match std::str::from_utf8(&tmp_decrypted) { + Ok(s) => s, + Err(e) => return Err(Box::new(e)), } } - } else { - match std::str::from_utf8( - &buf[P2PStandardDataPositions::DATA as usize..data_lenght - 1], - ) { - Ok(s) => s, - Err(e) => return Err(Box::new(e)), + Err(e) => { + return Err(Box::new(ServerErrorResponses::GENERAL_ERROR(format!( + "{}", + e + )))); } - }, - )?); + } + } else { + match std::str::from_utf8(&buf[P2PStandardDataPositions::DATA as usize..data_lenght - 1]) { + Ok(s) => s, + Err(e) => return Err(Box::new(e)), + } + })?); } #[allow(non_snake_case)] @@ -480,13 +374,9 @@ pub fn P2P_hello( let mut iv: [u8; BLOCK_SIZE] = [0u8; BLOCK_SIZE]; rng.fill_bytes(&mut iv); ( - shared::crypto::encrypt( - &key, - &iv, - &private_ip_str.as_bytes(), - ) - .unwrap() - .into_boxed_slice(), + shared::crypto::encrypt(&key, &iv, &private_ip_str.as_bytes()) + .unwrap() + .into_boxed_slice(), iv, ) } else { @@ -510,8 +400,8 @@ pub fn P2P_hello( ); send_buf[0] = P2PMethods::PEER_HELLO as u8; - send_buf[P2PStandardDataPositions::IV as usize - ..P2PStandardDataPositions::IV as usize + BLOCK_SIZE] + send_buf + [P2PStandardDataPositions::IV as usize..P2PStandardDataPositions::IV as usize + BLOCK_SIZE] .copy_from_slice(&iv); send_buf[P2PStandardDataPositions::DATA as usize..].copy_from_slice(&private_ip_final); @@ -523,7 +413,7 @@ pub fn P2P_hello( } pub async fn handle_incoming_connection( - buf: [u8; UDP_BUFFER_SIZE], + mut buf: [u8; UDP_BUFFER_SIZE], src: SocketAddr, network: Arc>, tun_iface: Arc, @@ -533,7 +423,6 @@ pub async fn handle_incoming_connection( #[cfg(debug_assertions)] eprintln!("recived method 0x{:02x}", buf[0]); match buf[0] { - x if x == P2PMethods::PACKET as u8 => { #[cfg(debug_assertions)] println!("PACKET from difernt peer receved"); @@ -571,7 +460,14 @@ pub async fn handle_incoming_connection( let private_ip = network.read().unwrap().private_ip; let private_ip_str = private_ip.to_string(); let mut send_buf: Box<[u8]> = if encrypted { - vec![0; P2PStandardDataPositions::DATA as usize + 1 + (private_ip_str.len() + (BLOCK_SIZE - (private_ip_str.len() % BLOCK_SIZE)))].into() // calculate lenght of data with block alligment + vec![ + 0; + P2PStandardDataPositions::DATA as usize + + 1 + + (private_ip_str.len() + + (BLOCK_SIZE - (private_ip_str.len() % BLOCK_SIZE))) + ] + .into() // calculate lenght of data with block alligment } else { vec![0; P2PStandardDataPositions::DATA as usize + 1 + private_ip_str.len()].into() }; @@ -581,10 +477,26 @@ pub async fn handle_incoming_connection( if encrypted { let mut rng = rng(); rng.fill_bytes(&mut iv); - send_buf[P2PStandardDataPositions::IV as usize..P2PStandardDataPositions::IV as usize+BLOCK_SIZE].copy_from_slice(&iv); - send_buf[P2PStandardDataPositions::DATA as usize..P2PStandardDataPositions::DATA as usize + (private_ip_str.len() + (BLOCK_SIZE - (private_ip_str.len() % BLOCK_SIZE)))].copy_from_slice(shared::crypto::encrypt(&network.read().unwrap().key, &iv, private_ip_str.as_bytes()).unwrap().as_slice()); + send_buf[P2PStandardDataPositions::IV as usize + ..P2PStandardDataPositions::IV as usize + BLOCK_SIZE] + .copy_from_slice(&iv); + send_buf[P2PStandardDataPositions::DATA as usize + ..P2PStandardDataPositions::DATA as usize + + (private_ip_str.len() + + (BLOCK_SIZE - (private_ip_str.len() % BLOCK_SIZE)))] + .copy_from_slice( + shared::crypto::encrypt( + &network.read().unwrap().key, + &iv, + private_ip_str.as_bytes(), + ) + .unwrap() + .as_slice(), + ); } else { - send_buf[P2PStandardDataPositions::DATA as usize..P2PStandardDataPositions::DATA as usize + private_ip_str.len()].copy_from_slice(private_ip_str.as_bytes()); + send_buf[P2PStandardDataPositions::DATA as usize + ..P2PStandardDataPositions::DATA as usize + private_ip_str.len()] + .copy_from_slice(private_ip_str.as_bytes()); } match socket.send_to(&send_buf, &src) { Ok(s) => { @@ -595,7 +507,7 @@ pub async fn handle_incoming_connection( eprintln!("Error sending data: {}", e); } } - }, + } x if x == P2PMethods::PEER_HELLO as u8 => { println!("{} peer hello receved from: {}", "[LOG]".blue(), src); @@ -605,7 +517,7 @@ pub async fn handle_incoming_connection( let key: [u8; 32] = network_write_lock.key; let encrypted: bool = network_write_lock.encrypted; #[cfg(debug_assertions)] - eprintln!( + eprintln!( "registering network:\niv: {}\nIP: {}", &buf[P2PStandardDataPositions::IV as usize ..P2PStandardDataPositions::IV as usize + BLOCK_SIZE].iter().map(|x| format!("{:02X} ", x)).collect::(), @@ -614,7 +526,7 @@ pub async fn handle_incoming_connection( .map(|x| format!("{:02X} ", x)) .collect::(), ); - network_write_lock.peers.push(Peer::new( + network_write_lock.peers.push(types::Peer::new( src, Some( match std::net::Ipv4Addr::from_str( @@ -691,7 +603,7 @@ pub async fn handle_incoming_connection( Ok(ip) => ip, Err(e) => {eprintln!("{} error parsing ip, Error: {}", "[ERROR]".red(), e); return false;}, } && peer.sock_addr == src}); - match socket.send_to(&[P2PMethods::PEER_GOODBYE as u8], &src) { + match socket.send_to(&[P2PMethods::PEER_GOODBYE as u8], &src) { Ok(s) => { #[cfg(debug_assertions)] eprintln!("send {} bytes", s); @@ -701,7 +613,74 @@ pub async fn handle_incoming_connection( } } } + x if x == P2PMethods::NEW_CLIENT_NOTIFY as u8 => { + println!( + "{} Notified about new client, creating NAT mapping", + "[LOG]".blue() + ); + let data_tmp: Box<[u8]>; + let peer_addr: std::net::SocketAddr = match std::net::SocketAddr::from_str( + match std::str::from_utf8(if network.read().unwrap().encrypted { + match shared::crypto::decrypt( + &network.read().unwrap().key, + &buf[P2PStandardDataPositions::IV as usize + ..P2PStandardDataPositions::IV as usize + BLOCK_SIZE], + &buf[P2PStandardDataPositions::DATA as usize..], + ) { + Ok(v) => { + data_tmp = v.into_boxed_slice(); + &data_tmp + } + Err(e) => { + eprintln!( + "{} failed to decrypt sock addr of new client connection not posible Error: {}", + "[ERROR]".red(), + e + ); + return; + } + } + } else { + &buf[P2PStandardDataPositions::DATA as usize..] + }) { + Ok(s) => s, + Err(e) => { + eprintln!( + "{} failed to decode sock addr of new client connection not posible Error: {}", + "[ERROR]".red(), + e + ); + return; + } + }, + ) { + Ok(sa) => sa, + Err(e) => { + eprintln!( + "{} failed to parse sock addr of new client connection not posible Error: {}", + "[ERROR]".red(), + e + ); + return; + } + }; + match P2P_query( + // create NAT mapping + &mut buf, + &peer_addr, + &socket, + network.read().unwrap().encrypted, + network.read().unwrap().key, + ) { + Ok(_) => {} + Err(e) => eprintln!( + "{} failed to create NAT mapping to peer connection may not work Error: {}", + "[ERROR]".red(), + e + ), + }; + } _ => { eprintln!( "{} unknown method ID: 0x{:02x}, Droping!", diff --git a/src/client/net_utils.rs b/src/client/net_utils.rs deleted file mode 100644 index 72ae560..0000000 --- a/src/client/net_utils.rs +++ /dev/null @@ -1,112 +0,0 @@ -use std::net::UdpSocket; - -#[cfg(target_os = "windows")] -use std::os::windows::io::AsRawSocket; -#[cfg(target_os = "windows")] -use winapi::shared::minwindef::{BOOL, DWORD, FALSE}; -#[cfg(target_os = "windows")] -use winapi::um::mswsock::SIO_UDP_CONNRESET; -#[cfg(target_os = "windows")] -use winapi::um::winsock2::{SOCKET_ERROR, WSAIoctl}; - -#[cfg(target_os = "linux")] -use std::os::unix::io::AsRawFd; - -#[cfg(target_os = "windows")] -pub fn enable_icmp_errors(socket: &UdpSocket) -> std::io::Result<()> { - let socket_handle = socket.as_raw_socket(); - let mut bytes_returned: DWORD = 0; - let enable: BOOL = FALSE; - - let result = unsafe { - WSAIoctl( - socket_handle as usize, - SIO_UDP_CONNRESET, - &enable as *const _ as *mut _, - std::mem::size_of::() as DWORD, - std::ptr::null_mut(), - 0, - &mut bytes_returned, - std::ptr::null_mut(), - None, - ) - }; - - if result == SOCKET_ERROR { - Err(std::io::Error::last_os_error()) - } else { - Ok(()) - } -} - -#[cfg(target_os = "linux")] -pub fn enable_icmp_errors(socket: &UdpSocket) -> std::io::Result<()> { - let fd = socket.as_raw_fd(); - let optval: libc::c_int = 1; - - let ret = unsafe { - libc::setsockopt( - fd, - libc::SOL_IP, - libc::IP_RECVERR, - &optval as *const _ as *const libc::c_void, - std::mem::size_of::() as libc::socklen_t, - ) - }; - - if ret < 0 { - Err(std::io::Error::last_os_error()) - } else { - Ok(()) - } -} - -#[cfg(target_os = "linux")] -pub fn check_icmp_error_queue(socket: &UdpSocket) -> std::io::Result<()> { - use libc::{MSG_ERRQUEUE, iovec, msghdr, recvmsg}; - - let fd = socket.as_raw_fd(); - let mut buf = [0u8; 1024]; - let mut control_buf = [0u8; 1024]; - - let mut iov = iovec { - iov_base: buf.as_mut_ptr() as *mut libc::c_void, - iov_len: buf.len(), - }; - - let mut msg: msghdr = unsafe { std::mem::zeroed() }; - msg.msg_iov = &mut iov; - msg.msg_iovlen = 1; - msg.msg_control = control_buf.as_mut_ptr() as *mut libc::c_void; - msg.msg_controllen = control_buf.len(); - - let result = unsafe { recvmsg(fd, &mut msg, MSG_ERRQUEUE) }; - - if result < 0 { - let error = std::io::Error::last_os_error(); - if error.kind() == std::io::ErrorKind::WouldBlock { - return Ok(()); - } - return Err(error); - } - - Err(std::io::Error::new( - std::io::ErrorKind::NetworkUnreachable, - "ICMP destination unreachable received", - )) -} - -#[cfg(target_os = "windows")] -fn check_icmp_error_queue(_socket: &UdpSocket) -> std::io::Result<()> { - Ok(()) -} - -#[cfg(not(any(target_os = "linux", target_os = "windows")))] -fn enable_icmp_errors(_socket: &UdpSocket) -> std::io::Result<()> { - Ok(()) -} - -#[cfg(not(any(target_os = "linux", target_os = "windows")))] -fn check_icmp_error_queue(_socket: &UdpSocket) -> std::io::Result<()> { - Ok(()) -} diff --git a/src/lib.rs b/src/lib.rs index b5de5e5..c418d9e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ pub const IPV4_SIZE: usize = 4; pub const DEFAULT_NETWORK_PREFIX: [u8; 3] = [172, 22, 44]; #[repr(u8)] +#[allow(non_camel_case_types)] pub enum ServerMethods { QUERY = 0, // return IP and port of the client REGISTER = 1, @@ -98,8 +99,7 @@ pub enum GetResponseDataPositions { ENCRYPTED = 1, // this feeld should be 0 if not encrypted NUM_OF_CLIENTS = 2, SALT = 3, - CLIENTS = - (BLOCK_SIZE as usize + RegisterRequestDataPositions::SALT as usize) - 1 as usize, + CLIENTS = (BLOCK_SIZE as usize + RegisterRequestDataPositions::SALT as usize) - 1 as usize, // after this there will be blocks of this sturcture: one byte size of sockaddr than there will be IV that is SALT_AND_IV_SIZE long and after that there will be sockaddr this repeats until the end of packet } @@ -119,6 +119,7 @@ pub enum P2PMethods { PEER_HELLO = 21, // sends private ip encrypted if on PEER_GOODBYE = 22, // sends private ip encrypted if on PACKET = 23, // sends IP packet encrypted if on + NEW_CLIENT_NOTIFY = 24, } #[repr(usize)] pub enum P2PStandardDataPositions { diff --git a/src/server/net.rs b/src/server/net.rs index ccffc0a..f8e255f 100644 --- a/src/server/net.rs +++ b/src/server/net.rs @@ -2,6 +2,7 @@ use crate::utils::send_general_error_to_client; use super::types; use super::utils; +use colored::Colorize; use orx_concurrent_vec::ConcurrentVec; use pea_2_pea::*; use rayon::prelude::*; @@ -197,15 +198,13 @@ pub async fn handle_request( if encrypted { salt = Some( buf[(RegisterRequestDataPositions::SALT as usize) - ..(RegisterRequestDataPositions::SALT as usize) - + (BLOCK_SIZE as usize)] + ..(RegisterRequestDataPositions::SALT as usize) + (BLOCK_SIZE as usize)] .try_into() .expect("this should never happen"), ); iv = Some( buf[(RegisterRequestDataPositions::IV as usize) - ..(RegisterRequestDataPositions::IV as usize) - + (BLOCK_SIZE as usize)] + ..(RegisterRequestDataPositions::IV as usize) + (BLOCK_SIZE as usize)] .try_into() .expect("this should never happen"), ) @@ -244,6 +243,7 @@ pub async fn handle_request( chrono::Utc::now().timestamp(), salt, iv, + src )); match socket.send_to(&[ServerMethods::REGISTER as u8], src) { Ok(s) => { @@ -297,11 +297,10 @@ pub async fn handle_request( } }; - let iv: [u8; BLOCK_SIZE as usize] = - buf[HeartBeatRequestDataPositions::IV as usize - ..HeartBeatRequestDataPositions::IV as usize + BLOCK_SIZE as usize] - .try_into() - .unwrap(); + let iv: [u8; BLOCK_SIZE as usize] = buf[HeartBeatRequestDataPositions::IV as usize + ..HeartBeatRequestDataPositions::IV as usize + BLOCK_SIZE as usize] + .try_into() + .unwrap(); let sock_addr: Vec = buf[HeartBeatRequestDataPositions::DATA as usize + id_len as usize @@ -330,7 +329,21 @@ pub async fn handle_request( match r.clients.par_iter_mut().find_any(|c| *c.client_sock_addr == *sock_addr && c.iv == iv) { Some(c) => c.last_heart_beat = current_time, None => {// add new client if it isn't found - r.clients.push(types::Client::new(sock_addr.clone(), current_time, iv)); + r.clients.par_iter().for_each(|c| {let mut send_buf: Box<[u8]> = vec![0; P2PStandardDataPositions::DATA as usize + sock_addr_len as usize].into(); + send_buf[0] = P2PMethods::NEW_CLIENT_NOTIFY as u8; + send_buf[P2PStandardDataPositions::IV as usize..P2PStandardDataPositions::IV as usize+ BLOCK_SIZE].copy_from_slice(&iv); + send_buf[P2PStandardDataPositions::DATA as usize..P2PStandardDataPositions::DATA as usize + sock_addr_len as usize].copy_from_slice(&sock_addr); + let mut resp_buf: [u8; UDP_BUFFER_SIZE] = [0u8; UDP_BUFFER_SIZE]; + match shared::net::send_and_recv_with_retry(&mut resp_buf, &send_buf, &c.src, &socket, STANDARD_RETRY_MAX) { + Ok((data_lenght, _)) => { + #[cfg(debug_assertions)] + eprintln!("send {} bytes", data_lenght); + }, + Err(e) => eprintln!("{} failed to send data to client Error: {}", "[ERROR]".red(), e), + }; + }); + + r.clients.push(types::Client::new(sock_addr.clone(), current_time, iv, src)); } }; }); diff --git a/src/server/types.rs b/src/server/types.rs index 4d89c48..25b9ad9 100644 --- a/src/server/types.rs +++ b/src/server/types.rs @@ -9,14 +9,22 @@ pub struct Client { pub last_heart_beat: i64, #[readonly] pub iv: [u8; BLOCK_SIZE as usize], + #[readonly] + pub src: std::net::SocketAddr, } impl Client { - pub fn new(client_addr: Vec, heart_beat: i64, iv: [u8; BLOCK_SIZE as usize]) -> Self { + pub fn new( + client_addr: Vec, + heart_beat: i64, + iv: [u8; BLOCK_SIZE as usize], + src: std::net::SocketAddr, + ) -> Self { Client { client_sock_addr: client_addr, last_heart_beat: heart_beat, iv, + src, } } } @@ -43,6 +51,7 @@ impl Registration { heart_beat: i64, salt: Option<[u8; BLOCK_SIZE as usize]>, iv: Option<[u8; BLOCK_SIZE as usize]>, + src: std::net::SocketAddr, ) -> Self { Registration { net_id, @@ -50,6 +59,7 @@ impl Registration { client_addr, heart_beat, iv.unwrap_or([0; BLOCK_SIZE as usize]), + src, )], encrypted, last_heart_beat: heart_beat, diff --git a/src/shared/mod.rs b/src/shared/mod.rs index 274f0ed..11ab1ac 100644 --- a/src/shared/mod.rs +++ b/src/shared/mod.rs @@ -1 +1,2 @@ pub mod crypto; +pub mod net; diff --git a/src/shared/net.rs b/src/shared/net.rs new file mode 100644 index 0000000..ef28bd6 --- /dev/null +++ b/src/shared/net.rs @@ -0,0 +1,213 @@ +use std::io::ErrorKind; +use std::net::{SocketAddr, UdpSocket}; + +use crate::*; + +#[cfg(target_os = "windows")] +use std::os::windows::io::AsRawSocket; +#[cfg(target_os = "windows")] +use winapi::shared::minwindef::{BOOL, DWORD, FALSE}; +#[cfg(target_os = "windows")] +use winapi::um::mswsock::SIO_UDP_CONNRESET; +#[cfg(target_os = "windows")] +use winapi::um::winsock2::{SOCKET_ERROR, WSAIoctl}; + +#[cfg(target_os = "linux")] +use std::os::unix::io::AsRawFd; + +#[cfg(target_os = "windows")] +fn enable_icmp_errors(socket: &UdpSocket) -> std::io::Result<()> { + let socket_handle = socket.as_raw_socket(); + let mut bytes_returned: DWORD = 0; + let enable: BOOL = FALSE; + + let result = unsafe { + WSAIoctl( + socket_handle as usize, + SIO_UDP_CONNRESET, + &enable as *const _ as *mut _, + std::mem::size_of::() as DWORD, + std::ptr::null_mut(), + 0, + &mut bytes_returned, + std::ptr::null_mut(), + None, + ) + }; + + if result == SOCKET_ERROR { + Err(std::io::Error::last_os_error()) + } else { + Ok(()) + } +} + +#[cfg(target_os = "linux")] +fn enable_icmp_errors(socket: &UdpSocket) -> std::io::Result<()> { + let fd = socket.as_raw_fd(); + let optval: libc::c_int = 1; + + let ret = unsafe { + libc::setsockopt( + fd, + libc::SOL_IP, + libc::IP_RECVERR, + &optval as *const _ as *const libc::c_void, + std::mem::size_of::() as libc::socklen_t, + ) + }; + + if ret < 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(()) + } +} + +#[cfg(target_os = "linux")] +fn check_icmp_error_queue(socket: &UdpSocket) -> std::io::Result<()> { + use libc::{MSG_ERRQUEUE, iovec, msghdr, recvmsg}; + + let fd = socket.as_raw_fd(); + let mut buf = [0u8; 1024]; + let mut control_buf = [0u8; 1024]; + + let mut iov = iovec { + iov_base: buf.as_mut_ptr() as *mut libc::c_void, + iov_len: buf.len(), + }; + + let mut msg: msghdr = unsafe { std::mem::zeroed() }; + msg.msg_iov = &mut iov; + msg.msg_iovlen = 1; + msg.msg_control = control_buf.as_mut_ptr() as *mut libc::c_void; + msg.msg_controllen = control_buf.len(); + + let result = unsafe { recvmsg(fd, &mut msg, MSG_ERRQUEUE) }; + + if result < 0 { + let error = std::io::Error::last_os_error(); + if error.kind() == std::io::ErrorKind::WouldBlock { + return Ok(()); + } + return Err(error); + } + + Err(std::io::Error::new( + std::io::ErrorKind::NetworkUnreachable, + "ICMP destination unreachable received", + )) +} + +#[cfg(target_os = "windows")] +fn check_icmp_error_queue(_socket: &UdpSocket) -> std::io::Result<()> { + Ok(()) +} + +#[cfg(not(any(target_os = "linux", target_os = "windows")))] +fn enable_icmp_errors(_socket: &UdpSocket) -> std::io::Result<()> { + Ok(()) +} + +#[cfg(not(any(target_os = "linux", target_os = "windows")))] +fn check_icmp_error_queue(_socket: &UdpSocket) -> std::io::Result<()> { + Ok(()) +} + +// return data_lenght and number of retryes +pub fn send_and_recv_with_retry( + buf: &mut [u8; UDP_BUFFER_SIZE], + send_buf: &[u8], + dst: &SocketAddr, + socket: &UdpSocket, + retry_max: usize, +) -> Result<(usize, usize), ServerErrorResponses> { + #[cfg(any(target_os = "linux", target_os = "windows"))] + enable_icmp_errors(socket)?; + + let mut retry_count: usize = 0; + + loop { + match socket.send_to(send_buf, dst) { + Ok(s) => { + #[cfg(debug_assertions)] + eprintln!("send {} bytes", s); + } + Err(e) => match e.kind() { + ErrorKind::ConnectionReset + | ErrorKind::ConnectionRefused + | ErrorKind::NetworkUnreachable + | ErrorKind::HostUnreachable => { + return Err(ServerErrorResponses::IO(std::io::Error::new( + e.kind(), + format!("Destination unreachable: {}", e), + ))); + } + _ => return Err(ServerErrorResponses::IO(e)), + }, + } + + #[cfg(target_os = "linux")] + if let Err(icmp_error) = check_icmp_error_queue(socket) { + return Err(ServerErrorResponses::IO(icmp_error)); + } + + match socket.recv_from(buf) { + Ok((data_length, src)) => { + if src != *dst { + continue; + } + match buf[0] { + x if x == send_buf[0] as u8 => { + return Ok((data_length, retry_count)); + } + x if x == ServerResponse::GENERAL_ERROR as u8 => { + return Err(ServerErrorResponses::IO(std::io::Error::new( + std::io::ErrorKind::InvalidData, + match std::str::from_utf8(&buf[1..data_length]) { + Ok(s) => s.to_string(), + Err(e) => format!("invalid error string: {}", e), + }, + ))); + } + x if x == ServerResponse::ID_DOESNT_EXIST as u8 => { + return Err(ServerErrorResponses::ID_DOESNT_EXIST); + } + x if x == ServerResponse::ID_EXISTS as u8 => { + return Err(ServerErrorResponses::ID_EXISTS); + } + _ => { + continue; + } + } + } + Err(e) if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut => { + #[cfg(target_os = "linux")] + if let Err(icmp_error) = check_icmp_error_queue(socket) { + return Err(ServerErrorResponses::IO(icmp_error)); + } + + if retry_count >= retry_max { + return Err(ServerErrorResponses::IO(std::io::Error::new( + ErrorKind::TimedOut, + "Max retry count reached - destination may be unreachable", + ))); + } + retry_count += 1; + continue; + } + Err(e) => match e.kind() { + ErrorKind::ConnectionReset + | ErrorKind::ConnectionRefused + | ErrorKind::NetworkUnreachable + | ErrorKind::HostUnreachable => { + return Err(ServerErrorResponses::IO(std::io::Error::new( + e.kind(), + format!("Destination unreachable during receive: {}", e), + ))); + } + _ => return Err(ServerErrorResponses::IO(e)), + }, + } + } +}