Skip to content

Support Codable Messages #103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ let package = Package(
dependencies: [
.package(url: "https://github.com/apple/swift-nio.git", from: "2.33.0"),
.package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.14.0"),
.package(url: "https://github.com/apple/swift-log.git", from: "1.0.0"),
],
targets: [
.target(name: "WebSocketKit", dependencies: [
Expand All @@ -22,6 +23,7 @@ let package = Package(
.product(name: "NIOHTTP1", package: "swift-nio"),
.product(name: "NIOSSL", package: "swift-nio-ssl"),
.product(name: "NIOWebSocket", package: "swift-nio"),
.product(name: "Logging", package: "swift-log")
]),
.testTarget(name: "WebSocketKitTests", dependencies: [
.target(name: "WebSocketKit"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ extension WebSocket {
return try await promise.futureResult.get()
}

public func send<T>(_ data: T, type: WebSocketSendType = .text) async throws where T: Codable {
let promise = eventLoop.makePromise(of: Void.self)
send(data, type: type, promise: promise)
return try await promise.futureResult.get()
}

public func sendPing() async throws {
let promise = eventLoop.makePromise(of: Void.self)
sendPing(promise: promise)
Expand Down
106 changes: 102 additions & 4 deletions Sources/WebSocketKit/WebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,19 @@ import NIOHTTP1
import NIOSSL
import Foundation
import NIOFoundationCompat
import Logging

public final class WebSocket {
enum PeerType {
case server
case client
}

public enum WebSocketSendType {
case text
case binary
}

public var eventLoop: EventLoop {
return channel.eventLoop
}
Expand All @@ -24,7 +30,11 @@ public final class WebSocket {
self.channel.closeFuture
}

private let jsonDecoder = JSONDecoder()
private let jsonEncoder = JSONEncoder()

private let channel: Channel
private let logger: Logger
private var onTextCallback: (WebSocket, String) -> ()
private var onBinaryCallback: (WebSocket, ByteBuffer) -> ()
private var onPongCallback: (WebSocket) -> ()
Expand All @@ -34,8 +44,9 @@ public final class WebSocket {
private var waitingForPong: Bool
private var waitingForClose: Bool
private var scheduledTimeoutTask: Scheduled<Void>?
private var events: [String : (WebSocket, Data) -> Void] = [:]

init(channel: Channel, type: PeerType) {
init(channel: Channel, type: PeerType, logger: Logger) {
self.channel = channel
self.type = type
self.onTextCallback = { _, _ in }
Expand All @@ -45,6 +56,7 @@ public final class WebSocket {
self.waitingForPong = false
self.waitingForClose = false
self.scheduledTimeoutTask = nil
self.logger = logger
}

public func onText(_ callback: @escaping (WebSocket, String) -> ()) {
Expand All @@ -54,7 +66,7 @@ public final class WebSocket {
public func onBinary(_ callback: @escaping (WebSocket, ByteBuffer) -> ()) {
self.onBinaryCallback = callback
}

public func onPong(_ callback: @escaping (WebSocket) -> ()) {
self.onPongCallback = callback
}
Expand All @@ -63,6 +75,31 @@ public final class WebSocket {
self.onPingCallback = callback
}

public func onEvent(_ identifier: String, _ handler: @escaping (WebSocket) -> Void) {
events[identifier] = { ws, data in
handler(ws)
}
}

public func onEvent<T>(_ identifier: String, _ type: T.Type, _ handler: @escaping (WebSocket, T) -> Void) where T: Codable {
onEvent(identifier, handler)
}

public func onEvent<T>(_ identifier: String, _ handler: @escaping (WebSocket, T) -> Void) where T: Codable {
events[identifier] = { ws, data in
do {
let res = try self.jsonDecoder.decode(WebSocketEvent<T>.self, from: data)
if let data = res.data {
handler(ws, data)
} else {
self.logger.trace("Unable to unwrap data for event `\(identifier)`, because it is unexpectedly nil. Please use another `bind` method which support optional payload to avoid this message.")
}
} catch {
self.logger.debug("Unable to decode incoming event `\(identifier)`: \(error)")
}
}
}

/// If set, this will trigger automatic pings on the connection. If ping is not answered before
/// the next ping is sent, then the WebSocket will be presumed innactive and will be closed
/// automatically.
Expand Down Expand Up @@ -96,6 +133,18 @@ public final class WebSocket {
self.send(raw: binary, opcode: .binary, fin: true, promise: promise)
}

public func send<T>(_ data: T, type: WebSocketSendType = .text, promise: EventLoopPromise<Void>? = nil) where T: Codable {
guard let data = try? jsonEncoder.encode(data), let dataString = String(data: data, encoding: .utf8) else {
return
}
switch type {
case .text:
self.send(dataString, promise: promise)
case .binary:
self.send(String(dataString).utf8.map{ UInt8($0) }, promise: promise)
}
}

public func sendPing(promise: EventLoopPromise<Void>? = nil) {
self.send(
raw: Data(),
Expand Down Expand Up @@ -242,9 +291,13 @@ public final class WebSocket {
if let frameSequence = self.frameSequence, frame.fin {
switch frameSequence.type {
case .binary:
self.onBinaryCallback(self, frameSequence.binaryBuffer)
if !proceedEventData(self, frameSequence.binaryBuffer) {
self.onBinaryCallback(self, frameSequence.binaryBuffer)
}
case .text:
self.onTextCallback(self, frameSequence.textBuffer)
if !proceedEventData(self, frameSequence.textBuffer) {
self.onTextCallback(self, frameSequence.textBuffer)
}
case .pong:
self.waitingForPong = false
self.onPongCallback(self)
Expand All @@ -256,6 +309,37 @@ public final class WebSocket {
}
}

private func proceedEventData(_ socket: WebSocket, _ text: String) -> Bool {
guard !events.isEmpty, let data = text.data(using: .utf8) else { return false }
do {
let prototype = try jsonDecoder.decode(WebSocketEventPrototype.self, from: data)
if let bind = events.first(where: { $0.0 == prototype.event }) {
bind.value(socket, data)
return true
}
} catch {
logger.trace("Unable to decode incoming event cause it doesn't conform to `WebSocketEventPrototype` model: \(error)")
}
return false
}

private func proceedEventData(_ socket: WebSocket, _ byteBuffer: ByteBuffer) -> Bool {
guard !events.isEmpty, byteBuffer.readableBytes > 0 else { return false }
do {
var bytes: [UInt8] = byteBuffer.getBytes(at: byteBuffer.readerIndex, length: byteBuffer.readableBytes) ?? []
let data = Data(bytes: &bytes, count: byteBuffer.readableBytes)

let prototype = try jsonDecoder.decode(WebSocketEventPrototype.self, from: data)
if let bind = events.first(where: { $0.0 == prototype.event }) {
bind.value(socket, data)
return true
}
} catch {
logger.trace("Unable to decode incoming event because it doesn't conform to `WebSocketEventPrototype` model: \(error)")
}
return false
}

private func pingAndScheduleNextTimeoutTask() {
guard channel.isActive, let pingInterval = pingInterval else {
return
Expand Down Expand Up @@ -311,3 +395,17 @@ private struct WebSocketFrameSequence {
}
}
}

private struct WebSocketEvent<T: Codable>: Codable {
public let event: String
public let data: T?
public init (event: String, data: T? = nil) {
self.event = event
self.data = data
}
}

private struct WebSocketEventPrototype: Codable {
var event: String
}

7 changes: 5 additions & 2 deletions Sources/WebSocketKit/WebSocketClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import NIOConcurrencyHelpers
import NIOHTTP1
import NIOWebSocket
import NIOSSL
import Logging

public final class WebSocketClient {
public enum Error: Swift.Error, LocalizedError {
Expand Down Expand Up @@ -37,8 +38,9 @@ public final class WebSocketClient {
let group: EventLoopGroup
let configuration: Configuration
let isShutdown = NIOAtomic.makeAtomic(value: false)
let logger: Logger

public init(eventLoopGroupProvider: EventLoopGroupProvider, configuration: Configuration = .init()) {
public init(eventLoopGroupProvider: EventLoopGroupProvider, configuration: Configuration = .init(), logger: Logger = Logger(label: "codes.vapor.websocket")) {
self.eventLoopGroupProvider = eventLoopGroupProvider
switch self.eventLoopGroupProvider {
case .shared(let group):
Expand All @@ -47,6 +49,7 @@ public final class WebSocketClient {
self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
}
self.configuration = configuration
self.logger = logger
}

public func connect(
Expand Down Expand Up @@ -80,7 +83,7 @@ public final class WebSocketClient {
maxFrameSize: self.configuration.maxFrameSize,
automaticErrorHandling: true,
upgradePipelineHandler: { channel, req in
return WebSocket.client(on: channel, onUpgrade: onUpgrade)
return WebSocket.client(on: channel, logger: self.logger, onUpgrade: onUpgrade)
}
)

Expand Down
29 changes: 26 additions & 3 deletions Sources/WebSocketKit/WebSocketHandler.swift
Original file line number Diff line number Diff line change
@@ -1,27 +1,50 @@
import NIO
import NIOWebSocket
import Logging

extension WebSocket {

@available(*, deprecated, message: "use Websocket.client(on:logger:onUpgrade:)")
public static func client(
on channel: Channel,
onUpgrade: @escaping (WebSocket) -> ()
) -> EventLoopFuture<Void> {
let logger = Logger(label: "codes.vapor.websocket")
return self.handle(on: channel, as: .client, logger: logger, onUpgrade: onUpgrade)
}

@available(*, deprecated, message: "use Websocket.server(on:logger:onUpgrade:)")
public static func server(
on channel: Channel,
onUpgrade: @escaping (WebSocket) -> ()
) -> EventLoopFuture<Void> {
let logger = Logger(label: "codes.vapor.websocket")
return self.handle(on: channel, as: .server, logger: logger, onUpgrade: onUpgrade)
}

public static func client(
on channel: Channel,
logger: Logger,
onUpgrade: @escaping (WebSocket) -> ()
) -> EventLoopFuture<Void> {
return self.handle(on: channel, as: .client, onUpgrade: onUpgrade)
return self.handle(on: channel, as: .client, logger: logger, onUpgrade: onUpgrade)
}

public static func server(
on channel: Channel,
logger: Logger,
onUpgrade: @escaping (WebSocket) -> ()
) -> EventLoopFuture<Void> {
return self.handle(on: channel, as: .server, onUpgrade: onUpgrade)
return self.handle(on: channel, as: .server, logger: logger, onUpgrade: onUpgrade)
}

private static func handle(
on channel: Channel,
as type: PeerType,
logger: Logger,
onUpgrade: @escaping (WebSocket) -> ()
) -> EventLoopFuture<Void> {
let webSocket = WebSocket(channel: channel, type: type)
let webSocket = WebSocket(channel: channel, type: type, logger: logger)
return channel.pipeline.addHandler(WebSocketHandler(webSocket: webSocket)).map { _ in
onUpgrade(webSocket)
}
Expand Down
45 changes: 44 additions & 1 deletion Tests/WebSocketKitTests/AsyncWebSocketKitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ import XCTest
import NIO
import NIOHTTP1
import NIOWebSocket
import Logging
@testable import WebSocketKit

@available(macOS 12, iOS 15, watchOS 8, tvOS 15, *)
final class AsyncWebSocketKitTests: XCTestCase {
func testWebSocketEcho() async throws {
let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in
let server = try ServerBootstrap.webSocket(on: self.elg, logger: self.logger) { req, ws in
ws.onText { ws, text in
ws.send(text)
}
Expand Down Expand Up @@ -42,10 +43,52 @@ final class AsyncWebSocketKitTests: XCTestCase {
try await server.close(mode: .all)
}

func testWebSocketSendCodable() async throws {
let server = try ServerBootstrap.webSocket(on: self.elg, logger: self.logger) { req, ws in
ws.onEvent("hello", User.self) { ws, user in
ws.send("Hello \(user.firstName) \(user.lastName)")
}
}.bind(host: "localhost", port: 0).wait()

guard let port = server.localAddress?.port else {
XCTFail("couldn't get port from \(server.localAddress.debugDescription)")
return
}

try await WebSocket.connect(to: "ws://localhost:\(port)", on: elg) { ws in
do {
try await ws.send(Response(event: "hello", data: User(firstName: "Vapor", lastName: "WebSocket")))
ws.onText { ws, string in
XCTAssertEqual(string, "Hello Vapor WebSocket")
do {
try await ws.close()
} catch {
XCTFail("Failed to close websocket, error: \(error)")
}
}
} catch {
XCTFail("Failed to connect, error: \(error)")
}
}

try await server.close(mode: .all)

struct Response: Codable {
let event: String
let data: User
}

struct User: Codable {
let firstName, lastName: String
}
}

var elg: EventLoopGroup!
var logger: Logger!
override func setUp() {
// needs to be at least two to avoid client / server on same EL timing issues
self.elg = MultiThreadedEventLoopGroup(numberOfThreads: 2)
self.logger = Logger(label: "com.vapor.websocketkit.tests")
}
override func tearDown() {
try! self.elg.syncShutdownGracefully()
Expand Down
Loading