2025-05-02 13:24:10 +02:00

156 lines
3.9 KiB
C++

#include "../common/defines.h"
#include <arpa/inet.h>
#include <netinet/in.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
void error(const char *msg) {
perror(msg);
exit(EXIT_FAILURE);
}
int main(int argc, char *argv[]) {
if (argc < 3) {
fprintf(stderr, "Usage: %s <server_ip> <port>\n", argv[0]);
exit(EXIT_FAILURE);
}
// Initialize OpenSSL (for OpenSSL 1.1.0+ this is optional as it's automatic)
OPENSSL_init_ssl(0, NULL);
// Create a TLS 1.3 client context
SSL_CTX *ctx = SSL_CTX_new(TLS_client_method());
if (ctx == NULL) {
ERR_print_errors_fp(stderr);
error("Failed to create SSL_CTX");
}
// Set minimum TLS version to 1.3
if (SSL_CTX_set_min_proto_version(ctx, TLS1_3_VERSION) != 1) {
SSL_CTX_free(ctx);
ERR_print_errors_fp(stderr);
error("Error setting minimum TLS version");
}
// Load CA certificate for server verification
if (SSL_CTX_load_verify_locations(ctx, "certs/mldsa87_root_cert.pem", NULL) !=
1) {
SSL_CTX_free(ctx);
ERR_print_errors_fp(stderr);
error("Error loading CA certificate");
}
// Set cipher suite for TLS 1.3
if (SSL_CTX_set_ciphersuites(ctx, "TLS_AES_256_GCM_SHA384") != 1) {
SSL_CTX_free(ctx);
ERR_print_errors_fp(stderr);
error("Error setting cipher list");
}
// Create a TCP socket
int sockfd = socket(AF_INET, SOCK_STREAM, 0);
if (sockfd < 0)
error("Socket creation failed");
// Set up server address
struct sockaddr_in serv_addr;
memset(&serv_addr, 0, sizeof(serv_addr));
serv_addr.sin_family = AF_INET;
serv_addr.sin_port = htons(atoi(argv[2]));
if (inet_pton(AF_INET, argv[1], &serv_addr.sin_addr) <= 0)
error("Invalid address/address not supported");
// Connect to server
if (connect(sockfd, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) < 0)
error("Connection failed");
printf("Connected to %s:%s\n", argv[1], argv[2]);
// Create an SSL object
SSL *ssl = SSL_new(ctx);
if (ssl == NULL) {
close(sockfd);
SSL_CTX_free(ctx);
ERR_print_errors_fp(stderr);
error("Failed to create SSL object");
}
// Associate the socket with SSL
if (SSL_set_fd(ssl, sockfd) != 1) {
SSL_free(ssl);
close(sockfd);
SSL_CTX_free(ctx);
ERR_print_errors_fp(stderr);
error("Failed to set SSL file descriptor");
}
// Connect using TLS
int ret = SSL_connect(ssl);
if (ret != 1) {
fprintf(stderr, "TLS handshake failed: %d\n", SSL_get_error(ssl, ret));
ERR_print_errors_fp(stderr);
SSL_free(ssl);
close(sockfd);
SSL_CTX_free(ctx);
exit(EXIT_FAILURE);
}
printf("TLS 1.3 handshake successful using %s\n", SSL_get_cipher(ssl));
// Fork to handle sending and receiving
pid_t pid = fork();
if (pid < 0)
error("Fork failed");
if (pid == 0) { // Child process: receive from server
char buffer[BUFFER_SIZE];
int bytes_read;
while (1) {
bytes_read = SSL_read(ssl, buffer, BUFFER_SIZE);
if (bytes_read > 0) {
fwrite(buffer, 1, bytes_read, stdout);
fflush(stdout);
memset(buffer, 0, BUFFER_SIZE);
} else {
int err = SSL_get_error(ssl, bytes_read);
if (err == SSL_ERROR_WANT_READ)
continue;
else
break;
}
}
kill(getppid(), SIGTERM); // Signal parent to terminate
exit(0);
} else { // Parent process: send to server
char buffer[BUFFER_SIZE];
int bytes_read;
while ((bytes_read = read(STDIN_FILENO, buffer, BUFFER_SIZE)) > 0) {
if (SSL_write(ssl, buffer, bytes_read) < 0) {
fprintf(stderr, "SSL_write error: %d\n", SSL_get_error(ssl, 0));
break;
}
}
kill(pid, SIGTERM); // Signal child to terminate
}
// Clean up
SSL_shutdown(ssl);
SSL_free(ssl);
close(sockfd);
SSL_CTX_free(ctx);
return 0;
}