Skip to content

Issue with DTLS: Unable to Achieve Handshake Between Client and Serve #1323

Open
@hamma96

Description

@hamma96

I have been working for a week on creating a DTLS (Datagram Transport Layer Security) client-server setup, but I am consistently failing to achieve a successful handshake. Despite multiple attempts and configurations, the handshake process does not complete as expected.

`import socket
import logging
from OpenSSL import SSL
from openssl_psk import patch_context
import time
import threading
import hashlib

patch_context()

logging.basicConfig(level=logging.INFO)

def psk_client_callback(connection, hint):
logging.info(f"[TLSClient] PSK client callback called with hint: {hint}")
identity = b'client-identity'
key = b'1a2b3c4d5e6f'
logging.info(f"[TLSClient] Returning identity: {identity}, key: {key}")
return (identity, key)

class TLSClient:
def init(self, config):
self.context = SSL.Context(SSL.DTLS_METHOD)
self.context.set_cipher_list(b'PSK-AES256-CBC-SHA')
self.context.set_psk_client_callback(psk_client_callback)
self.context.set_options(SSL.OP_NO_RENEGOTIATION)
self.context.set_info_callback(lambda conn, where, ret: print(f"[TLSClient] Info: where={where}, ret={ret}, state={conn.get_state_string()}"))
self.client_socket = None
self.config = config
self.ssl_conn = None
self.callback_running = False
self._running = False

def log_handshake_progress(self, conn):
    state = conn.get_state_string()
    pending = conn.pending()
    cipher_name = conn.get_cipher_name()
    version = conn.get_protocol_version_name()
    logging.info(f"[TLSClient] Handshake state: {state}, Pending: {pending}, Cipher: {cipher_name}, Version: {version}")


def start_client(self):
    try:
        self._running = True
        self.client_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        #self.client_socket.setblocking(False)
        self.client_socket.connect(self.config['address'])
        
        self.context.set_timeout(30)
        self.ssl_conn = SSL.Connection(self.context, self.client_socket)            
        self.ssl_conn.set_connect_state()

        logging.info("[TLSClient] Starting DTLS handshake...")
        while self._running:
            try:   

                self.log_handshake_progress(self.ssl_conn)
                self.ssl_conn.do_handshake()
            
            except SSL.WantReadError:
                self.log_handshake_progress(self.ssl_conn)
                pass
            else:
                logging.info("[TLSClient] else handshake.")
                self._running = False
            
        self.log_handshake_progress(self.ssl_conn)
        logging.info("[TLSClient] DTLS handshake completed.")

        # Send a message to the server
        message = b"Hello from Client!"
        self.ssl_conn.send(message)
        logging.info(f"[TLSClient] Sent to server: {message}")

        # Receive a response from the server
        data = self.ssl_conn.recv(self.config['buffer_size'])
        logging.info(f"[TLSClient] Received from server: {data.decode()}")

        self.ssl_conn.shutdown()
        self.ssl_conn.close()

    except SSL.Error as e:
        logging.error(f"[TLSClient] SSL error: {e}")
    except Exception as e:
        logging.error(f"[TLSClient] Error: {e}")
    finally:
        self.callback_running = False  # Stop callback thread
        if self.client_socket:
            self.client_socket.close()
        logging.info("[TLSClient] Client stopped")

def psk_server_callback(connection, identity):
logging.info(f"[TLSServer] PSK server callback called with identity: {identity}")
if identity == b'client-identity':
key = b'1a2b3c4d5e6f'
logging.info(f"[TLSServer] Returning key: {key}")
return key
return None

class TLSServer:
def init(self, config):
self.context = SSL.Context(SSL.DTLS_METHOD)
self.context.set_cipher_list(b'PSK-AES256-CBC-SHA')
self.context.set_psk_server_callback(psk_server_callback)
self.context.set_options(SSL.OP_NO_QUERY_MTU)
self.context.set_info_callback(lambda conn, where, ret: print(f"[TLSServer] Info: where={where}, ret={ret}, state={conn.get_state_string()}"))
# Setup cookie generation and verification
self.context.set_cookie_generate_callback(self.generate_cookie)
self.context.set_cookie_verify_callback(self.verify_cookie)

    self.server_socket = None
    self._running = False
    self.config = config
    self.ssl_conn = None

def generate_cookie(self, ssl):
        logging.info("[TLSServer] generate_cookie")
        return b"xyzzy"

def verify_cookie(self, ssl, cookie):
        logging.info("[TLSServer] verify_cookie")
        return cookie == b"xyzzy"

def log_handshake_progress(self, conn: SSL.Connection):
    state = conn.get_state_string()
    pending = conn.pending()
    cipher_name = conn.get_cipher_name()
    version = conn.get_protocol_version_name()
    logging.info(f"[TLSServer] Handshake state: {state}, Pending: {pending}, Cipher: {cipher_name}, Version: {version}")

def start_server(self):
    try:
        self._running = True
        self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        #self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        #self.server_socket.setblocking(False)
        self.server_socket.bind(self.config['address'])

        logging.info("[TLSServer] Server is running and waiting for connections...")
        s_handshaking = False
        self.context.set_timeout(30)
        s_listening = True

        import select
        while self._running:
            try:
                #ready_sockets, _, _ = select.select([self.server_socket], [], [])
                #sock = self.server_socket
                #for sock in ready_sockets:
                data, addr = self.server_socket.recvfrom(self.config['buffer_size'])
                ssl_conn = SSL.Connection(self.context, self.server_socket)          
                ssl_conn.set_accept_state()
                ssl_conn.set_tlsext_host_name(self.config['address'][0].encode())
                ssl_conn.set_ciphertext_mtu(1500)
                #self.invoke_client_callback(data, addr)
                self.log_handshake_progress(ssl_conn)
                if len(data) > 0 and data[0] == 22 and data[13] == 1:
                    logging.info("[TLSServer] Received ClientHello from client")
                    logging.info(f"[TLSServer] Received initial data from {addr}: {data}")
                    if s_listening:
                        try:
                            ssl_conn.DTLSv1_listen()
                            logging.info("[TLSServer] After DTLSv1_listen")
                        except SSL.WantReadError:
                            logging.info("[TLSServer] WantReadError during DTLSv1_listen")
                            continue
                        else:
                            s_listening = False
                            s_handshaking = True
                            logging.info("[TLSServer] s_listening=False")
                            ssl_conn.bio_write(data)
                        

                    logging.info(f"[TLSServer] Starting DTLS handshake with {addr}...")
                    while s_handshaking:
                        try:
                            self.log_handshake_progress(ssl_conn)
                            ssl_conn.do_handshake()
                            break
                        except SSL.WantReadError:
                            self.log_handshake_progress(ssl_conn)
                            self._running = False
                            s_handshaking = False
                            pass
                        except SSL.Error as e:
                            logging.error(f"[TLSServer] SSL error occurred during handshake: {e}")
                            self.log_handshake_progress(ssl_conn)
                            self._running = False
                            s_handshaking = False
                            break
                    self.log_handshake_progress(ssl_conn)
                    logging.info(f"[TLSServer] DTLS handshake with {addr} completed.")


            except SSL.Error as e:
                logging.error(f"[TLSServer] SSL error occurred: {e}")
            except Exception as e:
                logging.error(f"[TLSServer] An error occurred: {e}")

    except socket.error as e:
        logging.error(f"[TLSServer] Socket error: {e}")
    finally:
        self.cleanup()



def cleanup(self):
    self._running = False
    if self.server_socket:
        self.server_socket.close()
    logging.info("[TLSServer] Server cleaned up and stopped.")

if name == "main":
server_config = {
'address': ('localhost', 4433),
'buffer_size': 4096
}

client_config = {
    'address': ('localhost', 4433),
    'buffer_size': 4096
}

server = TLSServer(server_config)
server_thread = threading.Thread(target=server.start_server)
server_thread.start()

time.sleep(1)

""" client = TLSClient(client_config)
client_thread = threading.Thread(target=client.start_client)
client_thread.start() """

time.sleep(120)

server._running = False
#client_thread.join()
server_thread.join()`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions