2025-05-03 16:23:19 +02:00

272 lines
7.2 KiB
C++

#include "../common/defines.h"
#include <algorithm>
#include <arpa/inet.h>
#include <cstring>
#include <iostream>
#include <mutex>
#include <netinet/in.h>
#include <nlohmann/json.hpp>
#include <openssl/err.h>
#include <openssl/ssl.h>
#include <thread>
#include <unistd.h>
#include <vector>
using json = nlohmann::json;
#define PORT 9009
struct Client {
SSL *ssl;
int socket;
};
std::vector<Client> clients;
std::mutex clients_mutex;
void handle_client(SSL *ssl, int client_socket) {
char buffer[BUFFER_SIZE];
while (true) {
int bytes_read = SSL_read(ssl, buffer, sizeof(buffer));
if (bytes_read <= 0) {
int err = SSL_get_error(ssl, bytes_read);
if (err == SSL_ERROR_WANT_READ)
continue;
break;
}
try {
json cldata =
json::parse(std::string(buffer, static_cast<size_t>(bytes_read)));
switch (cldata["type"].get<int>()) {
case CONNECT: {
json response;
std::vector<int> socket_values;
std::lock_guard<std::mutex> lock(clients_mutex);
for (const Client &client : clients) {
if (client.socket != client_socket) {
socket_values.push_back(client.socket);
}
}
response["pears"] = socket_values;
SSL_write(ssl, response.dump().c_str(),
static_cast<int>(response.dump().length()));
continue;
}
case DISCONNECT:
std::cout << "Client disconnected: " << cldata["client_id"].get<int>()
<< std::endl;
break;
case MESSAGE: {
std::lock_guard<std::mutex> lock(clients_mutex);
for (const int destination_socket : cldata["destinations"]) {
auto it = std::find_if(clients.begin(), clients.end(),
[destination_socket](const Client &client) {
return client.socket == destination_socket;
});
if (it != clients.end()) {
SSL_write(it->ssl, cldata["message"].get<std::string>().c_str(),
static_cast<int>(
cldata["message"].get<std::string>().length()));
} else {
std::cerr << "Client not found: " << destination_socket
<< std::endl;
}
}
continue;
}
default:
std::cerr << "Unknown message type" << std::endl;
break;
}
} catch (const json::parse_error &e) {
std::cerr << "JSON parse error: " << e.what() << std::endl;
}
// Broadcast to all other clients
std::lock_guard<std::mutex> lock(clients_mutex);
for (const Client &client : clients) {
if (client.socket != client_socket) {
SSL_write(client.ssl, buffer, bytes_read);
}
}
}
// Cleanup client
std::lock_guard<std::mutex> lock(clients_mutex);
clients.erase(std::remove_if(clients.begin(), clients.end(),
[client_socket](const Client &c) {
return c.socket == client_socket;
}),
clients.end());
SSL_free(ssl);
close(client_socket);
}
int main() {
// Initialize OpenSSL
SSL_library_init();
OpenSSL_add_all_algorithms();
SSL_load_error_strings();
// Create TLS 1.3 context
SSL_CTX *ctx = SSL_CTX_new(TLS_server_method());
if (!ctx) {
std::cerr << "Failed to create SSL_CTX" << std::endl;
ERR_print_errors_fp(stderr);
return 1;
}
// Set TLS 1.3 only
SSL_CTX_set_min_proto_version(ctx, TLS1_3_VERSION);
SSL_CTX_set_max_proto_version(ctx, TLS1_3_VERSION);
// Load post-quantum certificates
if (SSL_CTX_use_certificate_file(ctx, "certs/mldsa87_entity_cert.pem",
SSL_FILETYPE_PEM) != 1) {
std::cerr << "Error loading certificate" << std::endl;
ERR_print_errors_fp(stderr);
SSL_CTX_free(ctx);
EVP_cleanup();
return 1;
}
if (SSL_CTX_use_PrivateKey_file(ctx, "certs/mldsa87_entity_key.pem",
SSL_FILETYPE_PEM) != 1) {
std::cerr << "Error loading private key" << std::endl;
ERR_print_errors_fp(stderr);
SSL_CTX_free(ctx);
EVP_cleanup();
return 1;
}
// Configure cipher suites
if (SSL_CTX_set_ciphersuites(ctx, "TLS_AES_256_GCM_SHA384") != 1) {
std::cerr << "Error setting ciphers" << std::endl;
ERR_print_errors_fp(stderr);
SSL_CTX_free(ctx);
EVP_cleanup();
return 1;
}
// Create dual-stack IPv6 socket
int server_fd = socket(AF_INET6, SOCK_STREAM, 0);
if (server_fd < 0) {
std::cerr << "Socket creation failed" << std::endl;
SSL_CTX_free(ctx);
EVP_cleanup();
return 1;
}
// Enable dual-stack support
int opt = 0;
if (setsockopt(server_fd, IPPROTO_IPV6, IPV6_V6ONLY, &opt, sizeof(opt)) < 0) {
std::cerr << "Dual-stack configuration failed" << std::endl;
close(server_fd);
SSL_CTX_free(ctx);
EVP_cleanup();
return 1;
}
// Set socket options
int reuse = 1;
if (setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) <
0) {
std::cerr << "SO_REUSEADDR failed" << std::endl;
close(server_fd);
SSL_CTX_free(ctx);
EVP_cleanup();
return 1;
}
// Bind socket
sockaddr_in6 address{};
address.sin6_family = AF_INET6;
address.sin6_addr = in6addr_any;
address.sin6_port = htons(PORT);
if (bind(server_fd, (sockaddr *)&address, sizeof(address)) < 0) {
std::cerr << "Bind failed" << std::endl;
close(server_fd);
SSL_CTX_free(ctx);
EVP_cleanup();
return 1;
}
// Listen for connections
if (listen(server_fd, 5) < 0) {
std::cerr << "Listen failed" << std::endl;
close(server_fd);
SSL_CTX_free(ctx);
EVP_cleanup();
return 1;
}
std::cout << "Dual-stack server running on port " << PORT << std::endl;
std::cout << "Using TLS 1.3 with ML-DSA-87 post-quantum algorithm"
<< std::endl;
// Main accept loop
while (true) {
sockaddr_in6 client_addr{};
socklen_t client_len = sizeof(client_addr);
int client_socket =
accept(server_fd, (sockaddr *)&client_addr, &client_len);
if (client_socket < 0) {
std::cerr << "Accept failed" << std::endl;
continue;
}
// Get client IP
char client_ip[INET6_ADDRSTRLEN];
inet_ntop(AF_INET6, &client_addr.sin6_addr, client_ip, INET6_ADDRSTRLEN);
std::cout << "Connection from: " << client_ip << std::endl;
// SSL setup
SSL *ssl = SSL_new(ctx);
if (!ssl) {
std::cerr << "SSL_new failed" << std::endl;
close(client_socket);
continue;
}
if (SSL_set_fd(ssl, client_socket) != 1) {
std::cerr << "SSL_set_fd failed" << std::endl;
SSL_free(ssl);
close(client_socket);
continue;
}
// TLS handshake
if (SSL_accept(ssl) != 1) {
std::cerr << "TLS handshake failed: " << SSL_get_error(ssl, 0)
<< std::endl;
SSL_free(ssl);
close(client_socket);
continue;
}
// Add client to list
{
std::lock_guard<std::mutex> lock(clients_mutex);
clients.push_back({ssl, client_socket});
}
// Start client thread
std::thread(handle_client, ssl, client_socket).detach();
}
// Cleanup
close(server_fd);
SSL_CTX_free(ctx);
EVP_cleanup();
return 0;
}