diff --git a/Package.swift b/Package.swift index f3ee95372..fb7e87748 100644 --- a/Package.swift +++ b/Package.swift @@ -23,6 +23,10 @@ let package = Package( name: "ImageSerializationPlugin", targets: ["ImageSerializationPlugin"] ), + .library( + name: "FileSerializationPlugin", + targets: ["FileSerializationPlugin"] + ), .library( name: "InlineSnapshotTesting", targets: ["InlineSnapshotTesting"] @@ -35,11 +39,16 @@ let package = Package( .target( name: "SnapshotTesting", dependencies: [ + "FileSerializationPlugin", "ImageSerializationPlugin", "SnapshotTestingPlugin" ] ), .target(name: "SnapshotTestingPlugin"), + .target( + name: "FileSerializationPlugin", + dependencies: ["SnapshotTestingPlugin"] + ), .target( name: "ImageSerializationPlugin", dependencies: ["SnapshotTestingPlugin"] diff --git a/Sources/FileSerializationPlugin/FileSerializationPlugin.swift b/Sources/FileSerializationPlugin/FileSerializationPlugin.swift new file mode 100644 index 000000000..9477bf1b8 --- /dev/null +++ b/Sources/FileSerializationPlugin/FileSerializationPlugin.swift @@ -0,0 +1,35 @@ +import SnapshotTestingPlugin +import Foundation + +public typealias FileSerializationPlugin = FileSerialization & SnapshotTestingPlugin + +@preconcurrency public protocol FileSerialization { + associatedtype Configuration + static var location: FileSerializationLocation { get } + func write(_ data: Data, to url: URL, options: Data.WritingOptions) async throws + func read(_ url: URL) async throws -> Data? + + // This should not be called ofter + func start(_ configuration: Configuration) async throws + func stop() async throws +} + +public enum FileSerializationLocation: RawRepresentable, Sendable, Equatable { + + public static let defaultValue: FileSerializationLocation = .local + + case local + + case plugins(String) + + public init?(rawValue: String) { + self = rawValue == "local" ? .local : .plugins(rawValue) + } + + public var rawValue: String { + switch self { + case .local: return "local" + case let .plugins(value): return value + } + } +} diff --git a/Sources/SnapshotTesting/AssertSnapshot.swift b/Sources/SnapshotTesting/AssertSnapshot.swift index 55ebc0e94..d90cfc866 100644 --- a/Sources/SnapshotTesting/AssertSnapshot.swift +++ b/Sources/SnapshotTesting/AssertSnapshot.swift @@ -1,5 +1,6 @@ import XCTest import ImageSerializationPlugin +import FileSerializationPlugin #if canImport(Testing) // NB: We are importing only the implementation of Testing because that framework is not available @@ -7,6 +8,8 @@ import ImageSerializationPlugin @_implementationOnly import Testing #endif +public var fileLocation: FileSerializationLocation = .defaultValue + /// Whether or not to change the default output image format to something else. public var imageFormat: ImageSerializationFormat { get { @@ -387,7 +390,7 @@ public func verifySnapshot( } func recordSnapshot() throws { - try snapshotting.diffing.toData(diffable).write(to: snapshotFileUrl) + try FileSerializer().write(snapshotting.diffing.toData(diffable), to: snapshotFileUrl, location: fileLocation) #if !os(Linux) && !os(Windows) if !isSwiftTesting, ProcessInfo.processInfo.environment.keys.contains("__XCODE_BUILT_PRODUCTS_DIR_PATHS") @@ -424,7 +427,8 @@ public func verifySnapshot( """ } - let data = try Data(contentsOf: snapshotFileUrl) + guard let data = try FileSerializer().read(snapshotFileUrl, location: fileLocation) + else { return nil } let reference = snapshotting.diffing.fromData(data) #if os(iOS) || os(tvOS) diff --git a/Sources/SnapshotTesting/Plugins/FileSerializer.swift b/Sources/SnapshotTesting/Plugins/FileSerializer.swift new file mode 100644 index 000000000..50ebfb06f --- /dev/null +++ b/Sources/SnapshotTesting/Plugins/FileSerializer.swift @@ -0,0 +1,57 @@ +#if canImport(SwiftUI) +import Foundation +import FileSerializationPlugin + +final class FileSerializer { + + /// A collection of plugins that conform to the `FileSerialization` protocol. + private let plugins: [any FileSerialization] + + init() { + self.plugins = PluginRegistry.allPlugins() + } + + func write(_ data: Data, to url: URL, options: Data.WritingOptions = [], location: FileSerializationLocation = .defaultValue) throws { + if let plugin = self.plugins.first(where: { type(of: $0).location == location }) { + Task { + try await plugin.write(data, to: url, options: options) + } + return + } + + try data.write(to: url) + } + + + func read(_ url: URL, location: FileSerializationLocation = .defaultValue) throws -> Data? { + if let plugin = self.plugins.first(where: { type(of: $0).location == location }) { + let semaphore = DispatchSemaphore(value: 0) + var result: Result? + + Task { + do { + let data = try await plugin.read(url) + result = .success(data) + } catch { + result = .failure(error) + } + semaphore.signal() // Release the semaphore once async task is done + } + + semaphore.wait() // Wait for async task to complete + + switch result { + case .success(let data): + return data + case .failure(let error): + throw error + case .none: + fatalError("Unexpected error occurred") + } + } + + // Synchronous path for fallback + return try Data(contentsOf: url) + } +} +#endif diff --git a/Tests/SnapshotTestingTests/FileSerializationPluginTests.swift b/Tests/SnapshotTestingTests/FileSerializationPluginTests.swift new file mode 100644 index 000000000..979475713 --- /dev/null +++ b/Tests/SnapshotTestingTests/FileSerializationPluginTests.swift @@ -0,0 +1,83 @@ +#if canImport(SwiftUI) && canImport(ObjectiveC) +import XCTest +import SnapshotTestingPlugin +@testable import SnapshotTesting +import FileSerializationPlugin + +class InMemoryFileSerializationPlugin: FileSerializationPlugin { + static var location: FileSerializationLocation = .plugins("inMemory") + var inMemory: [String: Data] = [:] + + func write(_ data: Data, to url: URL, options: Data.WritingOptions) async throws { + inMemory[url.absoluteString] = data + } + + func read(_ url: URL) async throws -> Data? { + return inMemory[url.absoluteString] + } + + // MARK: - SnapshotTestingPlugin + static var identifier: String = "FileSerializationPlugin.InMemoryFileSerializationPlugin.mock" + required init() {} +} + +class FileSerializerTests: XCTestCase { + + var fileSerializer: FileSerializer! + let testData = "Test Data".data(using: .utf8)! + let testURL = URL(string: "file:///test.txt")! + + override func setUp() { + super.setUp() + PluginRegistry.reset() // Reset state before each test + + // Register the mock plugin in the PluginRegistry + PluginRegistry.registerPlugin(InMemoryFileSerializationPlugin() as SnapshotTestingPlugin) + + fileSerializer = FileSerializer() + } + + override func tearDown() { + fileSerializer = nil + PluginRegistry.reset() // Reset state after each test + super.tearDown() + } + + func testReadAndWriteUsingMockPlugin() async throws { + try fileSerializer.write( + testData, + to: testURL, + location: InMemoryFileSerializationPlugin.location + ) + + let storedData = try fileSerializer.read(testURL, location: InMemoryFileSerializationPlugin.location) + XCTAssertNotNil(storedData, "Data should be stored in the in-memory plugin.") + XCTAssertEqual(storedData, testData, "Stored data should match the test data.") + } + + func testReadAndWriteDefaultPlugin() async throws { + let tmpURL = FileManager.default.temporaryDirectory.appending(path: UUID().uuidString) + try fileSerializer.write( + testData, + to: tmpURL + ) + + let storedData = try fileSerializer.read(tmpURL) + XCTAssertNotNil(storedData, "Data should be stored in the in-memory plugin.") + XCTAssertEqual(storedData, testData, "Stored data should match the test data.") + } + + func testReadNonExistantFileUsingMockPlugin() async throws { + let data = try fileSerializer.read(URL(string: "https://www.pointfree.co")!, location: InMemoryFileSerializationPlugin.location) + XCTAssertNil(data, "This should be empty.") + } + + func testPluginRegistryShouldContainRegisteredPlugins() { + let plugins = PluginRegistry.allPlugins() as [FileSerialization] + + XCTAssertEqual(plugins.count, 1, "There should be one registered plugin.") + XCTAssertEqual(type(of: plugins[0]).location.rawValue, "inMemory", "The plugin should support the 'inMemory' location.") + } +} + +#endif