Description
I have been working for a week on creating a DTLS (Datagram Transport Layer Security) client-server setup, but I am consistently failing to achieve a successful handshake. Despite multiple attempts and configurations, the handshake process does not complete as expected.
`import socket
import logging
from OpenSSL import SSL
from openssl_psk import patch_context
import time
import threading
import hashlib
patch_context()
logging.basicConfig(level=logging.INFO)
def psk_client_callback(connection, hint):
logging.info(f"[TLSClient] PSK client callback called with hint: {hint}")
identity = b'client-identity'
key = b'1a2b3c4d5e6f'
logging.info(f"[TLSClient] Returning identity: {identity}, key: {key}")
return (identity, key)
class TLSClient:
def init(self, config):
self.context = SSL.Context(SSL.DTLS_METHOD)
self.context.set_cipher_list(b'PSK-AES256-CBC-SHA')
self.context.set_psk_client_callback(psk_client_callback)
self.context.set_options(SSL.OP_NO_RENEGOTIATION)
self.context.set_info_callback(lambda conn, where, ret: print(f"[TLSClient] Info: where={where}, ret={ret}, state={conn.get_state_string()}"))
self.client_socket = None
self.config = config
self.ssl_conn = None
self.callback_running = False
self._running = False
def log_handshake_progress(self, conn):
state = conn.get_state_string()
pending = conn.pending()
cipher_name = conn.get_cipher_name()
version = conn.get_protocol_version_name()
logging.info(f"[TLSClient] Handshake state: {state}, Pending: {pending}, Cipher: {cipher_name}, Version: {version}")
def start_client(self):
try:
self._running = True
self.client_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
#self.client_socket.setblocking(False)
self.client_socket.connect(self.config['address'])
self.context.set_timeout(30)
self.ssl_conn = SSL.Connection(self.context, self.client_socket)
self.ssl_conn.set_connect_state()
logging.info("[TLSClient] Starting DTLS handshake...")
while self._running:
try:
self.log_handshake_progress(self.ssl_conn)
self.ssl_conn.do_handshake()
except SSL.WantReadError:
self.log_handshake_progress(self.ssl_conn)
pass
else:
logging.info("[TLSClient] else handshake.")
self._running = False
self.log_handshake_progress(self.ssl_conn)
logging.info("[TLSClient] DTLS handshake completed.")
# Send a message to the server
message = b"Hello from Client!"
self.ssl_conn.send(message)
logging.info(f"[TLSClient] Sent to server: {message}")
# Receive a response from the server
data = self.ssl_conn.recv(self.config['buffer_size'])
logging.info(f"[TLSClient] Received from server: {data.decode()}")
self.ssl_conn.shutdown()
self.ssl_conn.close()
except SSL.Error as e:
logging.error(f"[TLSClient] SSL error: {e}")
except Exception as e:
logging.error(f"[TLSClient] Error: {e}")
finally:
self.callback_running = False # Stop callback thread
if self.client_socket:
self.client_socket.close()
logging.info("[TLSClient] Client stopped")
def psk_server_callback(connection, identity):
logging.info(f"[TLSServer] PSK server callback called with identity: {identity}")
if identity == b'client-identity':
key = b'1a2b3c4d5e6f'
logging.info(f"[TLSServer] Returning key: {key}")
return key
return None
class TLSServer:
def init(self, config):
self.context = SSL.Context(SSL.DTLS_METHOD)
self.context.set_cipher_list(b'PSK-AES256-CBC-SHA')
self.context.set_psk_server_callback(psk_server_callback)
self.context.set_options(SSL.OP_NO_QUERY_MTU)
self.context.set_info_callback(lambda conn, where, ret: print(f"[TLSServer] Info: where={where}, ret={ret}, state={conn.get_state_string()}"))
# Setup cookie generation and verification
self.context.set_cookie_generate_callback(self.generate_cookie)
self.context.set_cookie_verify_callback(self.verify_cookie)
self.server_socket = None
self._running = False
self.config = config
self.ssl_conn = None
def generate_cookie(self, ssl):
logging.info("[TLSServer] generate_cookie")
return b"xyzzy"
def verify_cookie(self, ssl, cookie):
logging.info("[TLSServer] verify_cookie")
return cookie == b"xyzzy"
def log_handshake_progress(self, conn: SSL.Connection):
state = conn.get_state_string()
pending = conn.pending()
cipher_name = conn.get_cipher_name()
version = conn.get_protocol_version_name()
logging.info(f"[TLSServer] Handshake state: {state}, Pending: {pending}, Cipher: {cipher_name}, Version: {version}")
def start_server(self):
try:
self._running = True
self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
#self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
#self.server_socket.setblocking(False)
self.server_socket.bind(self.config['address'])
logging.info("[TLSServer] Server is running and waiting for connections...")
s_handshaking = False
self.context.set_timeout(30)
s_listening = True
import select
while self._running:
try:
#ready_sockets, _, _ = select.select([self.server_socket], [], [])
#sock = self.server_socket
#for sock in ready_sockets:
data, addr = self.server_socket.recvfrom(self.config['buffer_size'])
ssl_conn = SSL.Connection(self.context, self.server_socket)
ssl_conn.set_accept_state()
ssl_conn.set_tlsext_host_name(self.config['address'][0].encode())
ssl_conn.set_ciphertext_mtu(1500)
#self.invoke_client_callback(data, addr)
self.log_handshake_progress(ssl_conn)
if len(data) > 0 and data[0] == 22 and data[13] == 1:
logging.info("[TLSServer] Received ClientHello from client")
logging.info(f"[TLSServer] Received initial data from {addr}: {data}")
if s_listening:
try:
ssl_conn.DTLSv1_listen()
logging.info("[TLSServer] After DTLSv1_listen")
except SSL.WantReadError:
logging.info("[TLSServer] WantReadError during DTLSv1_listen")
continue
else:
s_listening = False
s_handshaking = True
logging.info("[TLSServer] s_listening=False")
ssl_conn.bio_write(data)
logging.info(f"[TLSServer] Starting DTLS handshake with {addr}...")
while s_handshaking:
try:
self.log_handshake_progress(ssl_conn)
ssl_conn.do_handshake()
break
except SSL.WantReadError:
self.log_handshake_progress(ssl_conn)
self._running = False
s_handshaking = False
pass
except SSL.Error as e:
logging.error(f"[TLSServer] SSL error occurred during handshake: {e}")
self.log_handshake_progress(ssl_conn)
self._running = False
s_handshaking = False
break
self.log_handshake_progress(ssl_conn)
logging.info(f"[TLSServer] DTLS handshake with {addr} completed.")
except SSL.Error as e:
logging.error(f"[TLSServer] SSL error occurred: {e}")
except Exception as e:
logging.error(f"[TLSServer] An error occurred: {e}")
except socket.error as e:
logging.error(f"[TLSServer] Socket error: {e}")
finally:
self.cleanup()
def cleanup(self):
self._running = False
if self.server_socket:
self.server_socket.close()
logging.info("[TLSServer] Server cleaned up and stopped.")
if name == "main":
server_config = {
'address': ('localhost', 4433),
'buffer_size': 4096
}
client_config = {
'address': ('localhost', 4433),
'buffer_size': 4096
}
server = TLSServer(server_config)
server_thread = threading.Thread(target=server.start_server)
server_thread.start()
time.sleep(1)
""" client = TLSClient(client_config)
client_thread = threading.Thread(target=client.start_client)
client_thread.start() """
time.sleep(120)
server._running = False
#client_thread.join()
server_thread.join()`