diff --git a/src/oatpp-websocket/AsyncWebSocket.cpp b/src/oatpp-websocket/AsyncWebSocket.cpp index c5ecfeb..5db9770 100644 --- a/src/oatpp-websocket/AsyncWebSocket.cpp +++ b/src/oatpp-websocket/AsyncWebSocket.cpp @@ -56,15 +56,13 @@ void AsyncWebSocket::setConfig(const Config& config) { } bool AsyncWebSocket::checkForContinuation(const Frame::Header& frameHeader) { - if(m_lastOpcode == Frame::OPCODE_TEXT || m_lastOpcode == Frame::OPCODE_BINARY) { - return false; - } + bool flag = m_lastOpcode == -1; if(frameHeader.fin) { m_lastOpcode = -1; - } else { + } else if(frameHeader.opcode != Frame::OPCODE_CONTINUATION && flag) { m_lastOpcode = frameHeader.opcode; } - return true; + return flag; } oatpp::async::CoroutineStarter AsyncWebSocket::readFrameHeaderAsync(const std::shared_ptr& frameHeader) { @@ -353,10 +351,12 @@ oatpp::async::CoroutineStarter AsyncWebSocket::handleFrameAsync(const std::share switch (m_frameHeader->opcode) { case Frame::OPCODE_CONTINUATION: - if(m_socket->m_lastOpcode < 0) { - throw std::runtime_error("[oatpp::web::protocol::websocket::AsyncWebSocket::handleFrameAsync(){HandleFrameCoroutine}]: Invalid communication state."); + if(m_socket->checkForContinuation(*m_frameHeader)) { + throw std::runtime_error("[oatpp::web::protocol::websocket::AsyncWebSocket::handleFrameAsync(){HandleFrameCoroutine}]: Invalid communication state. OPCODE_CONTINUATION unexpected"); + } else { + return m_socket->readPayloadAsync(m_frameHeader, nullptr).next(finish()); } - return m_socket->readPayloadAsync(m_frameHeader, nullptr).next(finish()); + break; case Frame::OPCODE_TEXT: if(m_socket->checkForContinuation(*m_frameHeader)) { @@ -364,6 +364,7 @@ oatpp::async::CoroutineStarter AsyncWebSocket::handleFrameAsync(const std::share } else { throw std::runtime_error("[oatpp::web::protocol::websocket::AsyncWebSocket::handleFrameAsync(){HandleFrameCoroutine}]: Invalid communication state. OPCODE_CONTINUATION expected"); } + break; case Frame::OPCODE_BINARY: if(m_socket->checkForContinuation(*m_frameHeader)) { @@ -371,6 +372,7 @@ oatpp::async::CoroutineStarter AsyncWebSocket::handleFrameAsync(const std::share } else { throw std::runtime_error("[oatpp::web::protocol::websocket::AsyncWebSocket::handleFrameAsync(){HandleFrameCoroutine}]: Invalid communication state. OPCODE_CONTINUATION expected"); } + break; case Frame::OPCODE_CLOSE: m_shortMessageStream = std::make_shared(); diff --git a/src/oatpp-websocket/WebSocket.cpp b/src/oatpp-websocket/WebSocket.cpp index 0235684..d893dd8 100644 --- a/src/oatpp-websocket/WebSocket.cpp +++ b/src/oatpp-websocket/WebSocket.cpp @@ -56,15 +56,13 @@ void WebSocket::setConfig(const Config& config) { } bool WebSocket::checkForContinuation(const Frame::Header& frameHeader) { - if(m_lastOpcode == Frame::OPCODE_TEXT || m_lastOpcode == Frame::OPCODE_BINARY) { - return false; - } + bool flag = m_lastOpcode == -1; if(frameHeader.fin) { m_lastOpcode = -1; - } else { + } else if(flag && frameHeader.opcode != Frame::OPCODE_CONTINUATION) { m_lastOpcode = frameHeader.opcode; } - return true; + return flag; } void WebSocket::readFrameHeader(Frame::Header& frameHeader) const { @@ -203,10 +201,11 @@ void WebSocket::handleFrame(const Frame::Header& frameHeader) { switch (frameHeader.opcode) { case Frame::OPCODE_CONTINUATION: - if(m_lastOpcode < 0) { - throw std::runtime_error("[oatpp::web::protocol::websocket::WebSocket::handleFrame()]: Invalid communication state."); + if(checkForContinuation(frameHeader)) { + throw std::runtime_error("[oatpp::web::protocol::websocket::WebSocket::handleFrame()]: Invalid communication state. OPCODE_CONTINUATION unexpected"); + } else { + readPayload(frameHeader, nullptr); } - readPayload(frameHeader, nullptr); break; case Frame::OPCODE_TEXT: