Skip to content

New connections should fail if there's an already open connection #40

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 132 additions & 140 deletions FullStackTests/Tests/ConnectionFullStackTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,174 +13,166 @@
// limitations under the License.

import CryptoTokenKit
import XCTest
import Testing

@testable import FullStackTests
@testable import YubiKit

class ConnectionFullStackTests: XCTestCase {
@Suite("Connection Full Stack Tests", .serialized)
struct ConnectionFullStackTests {

typealias Connection = SmartCardConnection

func testSingleConnection() throws {
runAsyncTest {
do {
let connection = try await Connection.connection()
print("✅ Got connection \(connection)")
XCTAssertNotNil(connection)
} catch {
XCTFail("🚨 Failed with: \(error)")
}
}
@Test("Single Connection", .timeLimit(.minutes(1)))
func singleConnection() async throws {
let connection = try await Connection.connection()
#expect(true, "✅ Got connection \(connection)")
await connection.close(error: nil)
}

func testSerialConnections() throws {
runAsyncTest {
do {
let firstConnection = try await Connection.connection()
print("✅ Got first connection \(firstConnection)")
let task = Task {
let result = await firstConnection.connectionDidClose()
print("✅ First connection did close")
return result
}
try? await Task.sleep(for: .seconds(1))
let secondConnection = try await Connection.connection()
print("✅ Got second connection \(secondConnection)")
XCTAssertNotNil(secondConnection)
let closingError = await task.value
XCTAssertNil(closingError)
print("✅ connectionDidClose() returned: \(String(describing: closingError))")
} catch {
XCTFail("🚨 Failed with: \(error)")
}
@Test("Serial Connections", .timeLimit(.minutes(1)))
func serialConnections() async throws {
let firstConnection = try await Connection.connection()
#expect(true, "✅ Got first connection \(firstConnection)")
let task = Task {
let result = await firstConnection.connectionDidClose()
#expect(true, "✅ First connection did close")
return result
}

// attempt to create a second connection (should fail!)
try? await Task.sleep(for: .seconds(1))
let new = try? await Connection.connection()
#expect(new == nil, "✅ Second connection failed as expected")

// close the first connection
_ = await firstConnection.close(error: nil)
let closingError = await task.value
#expect(closingError == nil, "✅ connectionDidClose() returned: \(String(describing: closingError))")

// attempt to create a second connection (now it should succed!)
try? await Task.sleep(for: .seconds(1))
let secondConnection = try await Connection.connection()
#expect(true, "✅ Got second connection \(secondConnection)")

// close the second connection
await secondConnection.close(error: nil)
}

func testConnectionCancellation() {
runAsyncTest {
let task1 = Task {
try await Connection.connection()
}
let task2 = Task {
try await Connection.connection()
}
let task3 = Task {
try await Connection.connection()
}
let task4 = Task {
try await Connection.connection()
}

let result1 = try? await task1.value
print("✅ Result 1: \(String(describing: result1))")
let result2 = try? await task2.value
print("✅ Result 2: \(String(describing: result2))")
let result3 = try? await task3.value
print("✅ Result 3: \(String(describing: result3))")
let result4 = try? await task4.value
print("✅ Result 4: \(String(describing: result4))")

XCTAssert([result1, result2, result3, result4].compactMap { $0 }.count == 1)
@Test("Connection Cancellation", .timeLimit(.minutes(1)))
func connectionCancellation() async {
let task1 = Task {
try await Connection.connection()
}
let task2 = Task {
try await Connection.connection()
}
let task3 = Task {
try await Connection.connection()
}
let task4 = Task {
try await Connection.connection()
}

let result1 = try? await task1.value
print("✅ Result 1: \(String(describing: result1))")
let result2 = try? await task2.value
print("✅ Result 2: \(String(describing: result2))")
let result3 = try? await task3.value
print("✅ Result 3: \(String(describing: result3))")
let result4 = try? await task4.value
print("✅ Result 4: \(String(describing: result4))")

let connections = [result1, result2, result3, result4].compactMap { $0 }
#expect(connections.count == 1)

// close the only established connection
await connections.first?.close(error: nil)
}

func testSendManually() {
runAsyncTest {
let connection = try await Connection.connection()
// Select Management application
let apdu = APDU(
cla: 0x00,
ins: 0xa4,
p1: 0x04,
p2: 0x00,
command: Data([0xA0, 0x00, 0x00, 0x05, 0x27, 0x47, 0x11, 0x17])
)
let resultData = try await connection.send(data: apdu.data)
let result = Response(rawData: resultData)
XCTAssertEqual(result.responseStatus.status, .ok)
/// Get version number
let deviceInfoApdu = APDU(cla: 0, ins: 0x1d, p1: 0, p2: 0)
let deviceInfoResultData = try await connection.send(data: deviceInfoApdu.data)
let deviceInfoResult = Response(rawData: deviceInfoResultData)
XCTAssertEqual(deviceInfoResult.responseStatus.status, .ok)
let records = TKBERTLVRecord.sequenceOfRecords(
from: deviceInfoResult.data.subdata(in: 1..<deviceInfoResult.data.count)
)
guard let versionData = records?.filter({ $0.tag == 0x05 }).first?.value else {
XCTFail("No YubiKey version record in result.")
return
}
guard versionData.count == 3 else {
XCTFail("Wrong sized return data. Got \(versionData.hexEncodedString)")
return
}
let bytes = [UInt8](versionData)
let major = bytes[0]
let minor = bytes[1]
let micro = bytes[2]
print("✅ Got version: \(major).\(minor).\(micro)")
XCTAssertEqual(major, 5)
// Try to select non existing application
let notFoundApdu = APDU(cla: 0x00, ins: 0xa4, p1: 0x04, p2: 0x00, command: Data([0x01, 0x02, 0x03]))
let notFoundResultData = try await connection.send(data: notFoundApdu.data)
let notFoundResult = Response(rawData: notFoundResultData)
if !(notFoundResult.responseStatus.status == .fileNotFound
@Test("Send Manually", .timeLimit(.minutes(1)))
func sendManually() async throws {
let connection = try await Connection.connection()
// Select Management application
let apdu = APDU(
cla: 0x00,
ins: 0xa4,
p1: 0x04,
p2: 0x00,
command: Data([0xA0, 0x00, 0x00, 0x05, 0x27, 0x47, 0x11, 0x17])
)
let resultData = try await connection.send(data: apdu.data)
let result = Response(rawData: resultData)
#expect(result.responseStatus.status == .ok)
/// Get version number
let deviceInfoApdu = APDU(cla: 0, ins: 0x1d, p1: 0, p2: 0)
let deviceInfoResultData = try await connection.send(data: deviceInfoApdu.data)
let deviceInfoResult = Response(rawData: deviceInfoResultData)
#expect(deviceInfoResult.responseStatus.status == .ok)
let records = TKBERTLVRecord.sequenceOfRecords(
from: deviceInfoResult.data.subdata(in: 1..<deviceInfoResult.data.count)
)
let versionData = try #require(
records?.filter({ $0.tag == 0x05 }).first?.value,
"No YubiKey version record in result."
)
#expect(versionData.count == 3, "Wrong sized return data. Got \(versionData.hexEncodedString)")
let bytes = [UInt8](versionData)
let major = bytes[0]
let minor = bytes[1]
let micro = bytes[2]
print("✅ Got version: \(major).\(minor).\(micro)")
#expect(major == 5)
// Try to select non existing application
let notFoundApdu = APDU(cla: 0x00, ins: 0xa4, p1: 0x04, p2: 0x00, command: Data([0x01, 0x02, 0x03]))
let notFoundResultData = try await connection.send(data: notFoundApdu.data)
let notFoundResult = Response(rawData: notFoundResultData)
#expect(
notFoundResult.responseStatus.status == .fileNotFound
|| notFoundResult.responseStatus.status == .incorrectParameters
|| notFoundResult.responseStatus.status == .invalidInstruction)
{
XCTFail("Unexpected result: \(notFoundResult.responseStatus)")
}
}
|| notFoundResult.responseStatus.status == .invalidInstruction,
"Unexpected result: \(notFoundResult.responseStatus)"
)

await connection.close(error: nil)
}
}

#if os(iOS)
class NFCFullStackTests: XCTestCase {

func testNFCAlertMessage() throws {
runAsyncTest {
do {
let connection = try await TestableConnections.create(with: .nfc(alertMessage: "Test Alert Message"))
await connection.nfcConnection?.setAlertMessage("Updated Alert Message")
try? await Task.sleep(for: .seconds(1))
await connection.nfcConnection?.close(message: "Closing Alert Message")
} catch {
XCTFail("🚨 Failed with: \(error)")
}
}
@Suite("NFC Full Stack Tests", .serialized)
struct NFCFullStackTests {

@Test("NFC Alert Message", .timeLimit(.minutes(1)))
func nfcAlertMessage() async throws {
let connection = try await TestableConnections.create(with: .nfc(alertMessage: "Test Alert Message"))
await connection.nfcConnection?.setAlertMessage("Updated Alert Message")
try? await Task.sleep(for: .seconds(1))
await connection.nfcConnection?.close(message: "Closing Alert Message")
}

func testNFCClosingErrorMessage() throws {
runAsyncTest {
do {
let connection = try await TestableConnections.create(with: .nfc(alertMessage: "Test Alert Message"))
await connection.close(error: nil)
} catch {
XCTFail("🚨 Failed with: \(error)")
}
}
@Test("NFC Closing Error Message", .timeLimit(.minutes(1)))
func nfcClosingErrorMessage() async throws {
let connection = try await TestableConnections.create(with: .nfc(alertMessage: "Test Alert Message"))
await connection.close(error: nil)
}

}
#endif

class SmartCardConnectionFullStackTests: XCTestCase {

func testSmartCardConnectionWithSlot() throws {
runAsyncTest {
let allSlots = try await SmartCardConnection.availableSlots
allSlots.enumerated().forEach { index, slot in
print("\(index): \(slot.name)")
}
let random = allSlots.randomElement()
// we need at least one YubiKey connected
XCTAssertNotNil(random)
guard let random else { return }
let connection = try await SmartCardConnection.connection(slot: random)
print("✅ Got connection \(connection)")
XCTAssertNotNil(connection)
@Suite("SmartCard Connection Full Stack Tests", .serialized)
struct SmartCardConnectionFullStackTests {

@Test("SmartCard Connection With Slot", .timeLimit(.minutes(1)))
func smartCardConnectionWithSlot() async throws {
let allSlots = try await SmartCardConnection.availableSlots
allSlots.enumerated().forEach { index, slot in
print("\(index): \(slot.name)")
}
let random = allSlots.randomElement()
// we need at least one YubiKey connected
let slot = try #require(random, "No YubiKey slots available")
let connection = try await SmartCardConnection.connection(slot: slot)
#expect(true, "✅ Got connection \(connection)")
}

}
4 changes: 3 additions & 1 deletion YubiKit/YubiKit/Connection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ public protocol Connection: Sendable {
}

/// Connection Errors.
public enum ConnectionError: Error {
public enum ConnectionError: Error, Sendable {
/// There is an active connection.
case busy
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm open to some other naming here!

/// No current connection.
case noConnection
/// Unexpected result returned from YubiKey.
Expand Down
16 changes: 4 additions & 12 deletions YubiKit/YubiKit/LightningConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,27 +72,19 @@ private actor LightningConnectionManager {

static let shared = LightningConnectionManager()

private var connectionTask: Task<LightningConnection, Error>?
private var pendingConnectionPromise: Promise<LightningConnection>?
private var connectionState: (connectionID: ConnectionID, didCloseConnection: (Promise<Error?>))?

private init() {}

func connect() async throws -> LightningConnection {
// If a connection task is already running, await its result
if let connectionTask {
trace(message: "awaiting existing connection task")
_ = try await connectionTask.value
// we cancel this task because only one of multiple
// concurrent connections can succed
throw ConnectionError.cancelled
// If there is already a connection the caller must close the connection first.
if connectionState != nil || pendingConnectionPromise != nil {
throw ConnectionError.busy
}

// Otherwise, create and store a new connection task.
let task = Task { () -> LightningConnection in
// When the task finishes (on any path), clear it to allow a new connection.
defer { self.connectionTask = nil }

trace(message: "begin new connection task")

do {
Expand Down Expand Up @@ -121,12 +113,12 @@ private actor LightningConnectionManager {
trace(message: "connection failed: \(error.localizedDescription)")
// Cleanup on failure
self.pendingConnectionPromise = nil
self.connectionState = nil
await EAAccessoryWrapper.shared.stopMonitoring()
throw error
}
}

self.connectionTask = task
return try await task.value
}

Expand Down
17 changes: 9 additions & 8 deletions YubiKit/YubiKit/NFCConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -222,21 +222,22 @@ private final actor NFCConnectionManager: NSObject {
throw NFCConnectionError.unsupported
}

// To proceed with a new connection we need to acquire a lock
guard !isEstablishing else { throw ConnectionError.cancelled }
defer { isEstablishing = false }
isEstablishing = true

// Close the previous connection before establishing a new one
// if there is already a connection for this slot we throw `ConnectionError.busy`.
// The caller must close the connection first.
switch currentState {
case .inactive:
// lets continue
break
case .scanning, .connected:
// invalidate and continue
await invalidate()
// throw
throw ConnectionError.busy
}

// To proceed with a new connection we need to acquire a lock
guard !isEstablishing else { throw ConnectionError.cancelled }
defer { isEstablishing = false }
isEstablishing = true

// Start polling
guard let session = NFCTagReaderSession(pollingOption: [.iso14443], delegate: self, queue: nil) else {
throw NFCConnectionError.failedToPoll
Expand Down
Loading
Loading