diff --git a/Package.swift b/Package.swift index 95f6b77..43c5696 100644 --- a/Package.swift +++ b/Package.swift @@ -1,8 +1,13 @@ -// swift-tools-version: 5.8 +// swift-tools-version: 5.9 // The swift-tools-version declares the minimum version of Swift required to build this package. import PackageDescription +/// Define the strict concurrency settings to be applied to all targets. +let swiftSettings: [SwiftSetting] = [ + .enableExperimentalFeature("StrictConcurrency"), +] + let package = Package( name: "swift-transformers", platforms: [.iOS(.v16), .macOS(.v13)], @@ -24,13 +29,13 @@ let package = Package( ] ), .executableTarget(name: "HubCLI", dependencies: ["Hub", .product(name: "ArgumentParser", package: "swift-argument-parser")]), - .target(name: "Hub", resources: [.process("FallbackConfigs")]), + .target(name: "Hub", resources: [.process("FallbackConfigs")], swiftSettings: swiftSettings), .target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")]), .target(name: "TensorUtils"), .target(name: "Generation", dependencies: ["Tokenizers", "TensorUtils"]), .target(name: "Models", dependencies: ["Tokenizers", "Generation", "TensorUtils"]), .testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources"), .process("Vocabs")]), - .testTarget(name: "HubTests", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")]), + .testTarget(name: "HubTests", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")], swiftSettings: swiftSettings), .testTarget(name: "PreTokenizerTests", dependencies: ["Tokenizers", "Hub"]), .testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils", "Models", "Hub"], resources: [.process("Resources")]), .testTarget(name: "NormalizerTests", dependencies: ["Tokenizers", "Hub"]), diff --git a/Sources/Hub/Downloader.swift b/Sources/Hub/Downloader.swift index 0e16e7a..cc0171f 100644 --- a/Sources/Hub/Downloader.swift +++ b/Sources/Hub/Downloader.swift @@ -9,10 +9,11 @@ import Combine import Foundation -class Downloader: NSObject, ObservableObject { - private(set) var destination: URL - - private let chunkSize = 10 * 1024 * 1024 // 10MB +final class Downloader: NSObject, Sendable, ObservableObject { + private let destination: URL + private let incompleteDestination: URL + private let downloadResumeState: DownloadResumeState = .init() + private let chunkSize: Int enum DownloadState { case notStarted @@ -27,37 +28,25 @@ class Downloader: NSObject, ObservableObject { case tempFileNotFound } - private(set) lazy var downloadState: CurrentValueSubject = CurrentValueSubject(.notStarted) - private var stateSubscriber: Cancellable? - - private(set) var tempFilePath: URL - private(set) var expectedSize: Int? - private(set) var downloadedSize: Int = 0 + private let broadcaster: Broadcaster = Broadcaster { + DownloadState.notStarted + } - var session: URLSession? = nil - var downloadTask: Task? = nil + private let sessionConfig: URLSessionConfiguration + let session: SessionActor = .init() + private let task: TaskActor = .init() init( - from url: URL, to destination: URL, incompleteDestination: URL, - using authToken: String? = nil, inBackground: Bool = false, - headers: [String: String]? = nil, - expectedSize: Int? = nil, - timeout: TimeInterval = 10, - numRetries: Int = 5 + chunkSize: Int = 10 * 1024 * 1024 // 10MB ) { self.destination = destination - self.expectedSize = expectedSize - // Create incomplete file path based on destination - tempFilePath = incompleteDestination + self.incompleteDestination = incompleteDestination + self.chunkSize = chunkSize - // If resume size wasn't specified, check for an existing incomplete file - let resumeSize = Self.incompleteFileSize(at: incompleteDestination) - - super.init() let sessionIdentifier = "swift-transformers.hub.downloader" var config = URLSessionConfiguration.default @@ -66,23 +55,33 @@ class Downloader: NSObject, ObservableObject { config.isDiscretionary = false config.sessionSendsLaunchEvents = true } - - session = URLSession(configuration: config, delegate: self, delegateQueue: nil) - - setUpDownload(from: url, with: authToken, resumeSize: resumeSize, headers: headers, expectedSize: expectedSize, timeout: timeout, numRetries: numRetries) + sessionConfig = config } - /// Check if an incomplete file exists for the destination and returns its size - /// - Parameter destination: The destination URL for the download - /// - Returns: Size of the incomplete file if it exists, otherwise 0 - static func incompleteFileSize(at incompletePath: URL) -> Int { - if FileManager.default.fileExists(atPath: incompletePath.path) { - if let attributes = try? FileManager.default.attributesOfItem(atPath: incompletePath.path), let fileSize = attributes[.size] as? Int { - return fileSize - } + func download( + from url: URL, + using authToken: String? = nil, + headers: [String: String]? = nil, + expectedSize: Int? = nil, + timeout: TimeInterval = 10, + numRetries: Int = 5 + ) async -> AsyncStream { + if let task = await task.get() { + task.cancel() } - - return 0 + await downloadResumeState.setExpectedSize(expectedSize) + let resumeSize = Self.incompleteFileSize(at: incompleteDestination) + await session.set(URLSession(configuration: sessionConfig, delegate: self, delegateQueue: nil)) + await setUpDownload( + from: url, + with: authToken, + resumeSize: resumeSize, + headers: headers, + timeout: timeout, + numRetries: numRetries + ) + + return await broadcaster.subscribe() } /// Sets up and initiates a file download operation @@ -100,77 +99,92 @@ class Downloader: NSObject, ObservableObject { with authToken: String?, resumeSize: Int, headers: [String: String]?, - expectedSize: Int?, timeout: TimeInterval, numRetries: Int - ) { - session?.getAllTasks { tasks in - // If there's an existing pending background task with the same URL, let it proceed. - if let existing = tasks.filter({ $0.originalRequest?.url == url }).first { - switch existing.state { - case .running: - return - case .suspended: - existing.resume() - return - case .canceling, .completed: - existing.cancel() - @unknown default: - existing.cancel() - } + ) async { + let resumeSize = Self.incompleteFileSize(at: incompleteDestination) + guard let tasks = await session.get()?.allTasks else { + return + } + + // If there's an existing pending background task with the same URL, let it proceed. + if let existing = tasks.filter({ $0.originalRequest?.url == url }).first { + switch existing.state { + case .running: + return + case .suspended: + existing.resume() + return + case .canceling, .completed: + existing.cancel() + break + @unknown default: + existing.cancel() } + } - self.downloadTask = Task { + await task.set( + Task { do { - // Set up the request with appropriate headers var request = URLRequest(url: url) + + // Use headers from argument else create an empty header dictionary var requestHeaders = headers ?? [:] + // Populate header auth and range fields if let authToken { requestHeaders["Authorization"] = "Bearer \(authToken)" } - self.downloadedSize = resumeSize + await self.downloadResumeState.setDownloadedSize(resumeSize) + + if resumeSize > 0 { + requestHeaders["Range"] = "bytes=\(resumeSize)-" + } // Set Range header if we're resuming if resumeSize > 0 { requestHeaders["Range"] = "bytes=\(resumeSize)-" // Calculate and show initial progress - if let expectedSize, expectedSize > 0 { + if let expectedSize = await self.downloadResumeState.expectedSize, expectedSize > 0 { let initialProgress = Double(resumeSize) / Double(expectedSize) - self.downloadState.value = .downloading(initialProgress) + await self.broadcaster.broadcast(state: .downloading(initialProgress)) } else { - self.downloadState.value = .downloading(0) + await self.broadcaster.broadcast(state: .downloading(0)) } } else { - self.downloadState.value = .downloading(0) + await self.broadcaster.broadcast(state: .downloading(0)) } request.timeoutInterval = timeout request.allHTTPHeaderFields = requestHeaders // Open the incomplete file for writing - let tempFile = try FileHandle(forWritingTo: self.tempFilePath) + let tempFile = try FileHandle(forWritingTo: self.incompleteDestination) // If resuming, seek to end of file if resumeSize > 0 { try tempFile.seekToEnd() } - try await self.httpGet(request: request, tempFile: tempFile, resumeSize: self.downloadedSize, numRetries: numRetries, expectedSize: expectedSize) + defer { tempFile.closeFile() } - // Clean up and move the completed download to its final destination - tempFile.closeFile() + try await self.httpGet(request: request, tempFile: tempFile, numRetries: numRetries) try Task.checkCancellation() - try FileManager.default.moveDownloadedFile(from: self.tempFilePath, to: self.destination) - self.downloadState.value = .completed(self.destination) + try FileManager.default.moveDownloadedFile(from: self.incompleteDestination, to: self.destination) + + // // Clean up and move the completed download to its final destination + // tempFile.closeFile() + // try FileManager.default.moveDownloadedFile(from: tempURL, to: self.destination) + + await self.broadcaster.broadcast(state: .completed(self.destination)) } catch { - self.downloadState.value = .failed(error) + await self.broadcaster.broadcast(state: .failed(error)) } } - } + ) } /// Downloads a file from given URL using chunked transfer and handles retries. @@ -187,27 +201,26 @@ class Downloader: NSObject, ObservableObject { private func httpGet( request: URLRequest, tempFile: FileHandle, - resumeSize: Int, - numRetries: Int, - expectedSize: Int? + numRetries: Int ) async throws { - guard let session else { + guard let session = await session.get() else { throw DownloadError.unexpectedError } // Create a new request with Range header for resuming var newRequest = request - if resumeSize > 0 { - newRequest.setValue("bytes=\(resumeSize)-", forHTTPHeaderField: "Range") + if await downloadResumeState.downloadedSize > 0 { + await newRequest.setValue("bytes=\(downloadResumeState.downloadedSize)-", forHTTPHeaderField: "Range") } // Start the download and get the byte stream let (asyncBytes, response) = try await session.bytes(for: newRequest) - guard let httpResponse = response as? HTTPURLResponse else { + guard let response = response as? HTTPURLResponse else { throw DownloadError.unexpectedError } - guard (200..<300).contains(httpResponse.statusCode) else { + + guard (200..<300).contains(response.statusCode) else { throw DownloadError.unexpectedError } @@ -223,18 +236,19 @@ class Downloader: NSObject, ObservableObject { if !buffer.isEmpty { // Filter out keep-alive chunks try tempFile.write(contentsOf: buffer) buffer.removeAll(keepingCapacity: true) - downloadedSize += chunkSize + + await downloadResumeState.incDownloadedSize(chunkSize) newNumRetries = 5 - guard let expectedSize else { continue } - let progress = expectedSize != 0 ? Double(downloadedSize) / Double(expectedSize) : 0 - downloadState.value = .downloading(progress) + guard let expectedSize = await downloadResumeState.expectedSize else { continue } + let progress = await expectedSize != 0 ? Double(downloadResumeState.downloadedSize) / Double(expectedSize) : 0 + await broadcaster.broadcast(state: .downloading(progress)) } } } if !buffer.isEmpty { try tempFile.write(contentsOf: buffer) - downloadedSize += buffer.count + await downloadResumeState.incDownloadedSize(buffer.count) buffer.removeAll(keepingCapacity: true) newNumRetries = 5 } @@ -244,74 +258,73 @@ class Downloader: NSObject, ObservableObject { } try await Task.sleep(nanoseconds: 1_000_000_000) - let config = URLSessionConfiguration.default - self.session = URLSession(configuration: config, delegate: self, delegateQueue: nil) + await self.session.set(URLSession(configuration: self.sessionConfig, delegate: self, delegateQueue: nil)) try await httpGet( request: request, tempFile: tempFile, - resumeSize: self.downloadedSize, - numRetries: newNumRetries - 1, - expectedSize: expectedSize + numRetries: newNumRetries - 1 ) + return } // Verify the downloaded file size matches the expected size let actualSize = try tempFile.seekToEnd() - if let expectedSize, expectedSize != actualSize { + if let expectedSize = await downloadResumeState.expectedSize, expectedSize != actualSize { throw DownloadError.unexpectedError } } - @discardableResult - func waitUntilDone() throws -> URL { - // It's either this, or stream the bytes ourselves (add to a buffer, save to disk, etc; boring and finicky) - let semaphore = DispatchSemaphore(value: 0) - stateSubscriber = downloadState.sink { state in - switch state { - case .completed: semaphore.signal() - case .failed: semaphore.signal() - default: break - } - } - semaphore.wait() + func cancel() async { + await session.get()?.invalidateAndCancel() + await task.get()?.cancel() + await broadcaster.broadcast(state: .failed(URLError(.cancelled))) + } - switch downloadState.value { - case let .completed(url): return url - case let .failed(error): throw error - default: throw DownloadError.unexpectedError + /// Check if an incomplete file exists for the destination and returns its size + /// - Parameter destination: The destination URL for the download + /// - Returns: Size of the incomplete file if it exists, otherwise 0 + static func incompleteFileSize(at incompletePath: URL) -> Int { + if FileManager.default.fileExists(atPath: incompletePath.path) { + if let attributes = try? FileManager.default.attributesOfItem(atPath: incompletePath.path), let fileSize = attributes[.size] as? Int { + return fileSize + } } - } - func cancel() { - session?.invalidateAndCancel() - downloadTask?.cancel() - downloadState.value = .failed(URLError(.cancelled)) + return 0 } } extension Downloader: URLSessionDownloadDelegate { func urlSession(_: URLSession, downloadTask: URLSessionDownloadTask, didWriteData _: Int64, totalBytesWritten: Int64, totalBytesExpectedToWrite: Int64) { - downloadState.value = .downloading(Double(totalBytesWritten) / Double(totalBytesExpectedToWrite)) + Task { + await self.broadcaster.broadcast(state: .downloading(Double(totalBytesWritten) / Double(totalBytesExpectedToWrite))) + } } func urlSession(_: URLSession, downloadTask _: URLSessionDownloadTask, didFinishDownloadingTo location: URL) { do { // If the downloaded file already exists on the filesystem, overwrite it try FileManager.default.moveDownloadedFile(from: location, to: destination) - downloadState.value = .completed(destination) + Task { + await self.broadcaster.broadcast(state: .completed(destination)) + } } catch { - downloadState.value = .failed(error) + Task { + await self.broadcaster.broadcast(state: .failed(error)) + } } } func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { if let error { - downloadState.value = .failed(error) -// } else if let response = task.response as? HTTPURLResponse { -// print("HTTP response status code: \(response.statusCode)") -// let headers = response.allHeaderFields -// print("HTTP response headers: \(headers)") + Task { + await self.broadcaster.broadcast(state: .failed(error)) + } + // } else if let response = task.response as? HTTPURLResponse { + // print("HTTP response status code: \(response.statusCode)") + // let headers = response.allHeaderFields + // print("HTTP response headers: \(headers)") } } } @@ -328,3 +341,96 @@ extension FileManager { try moveItem(at: srcURL, to: dstURL) } } + +private actor DownloadResumeState { + var expectedSize: Int? + var downloadedSize: Int = 0 + + func setExpectedSize(_ size: Int?) { + expectedSize = size + } + + func setDownloadedSize(_ size: Int) { + downloadedSize = size + } + + func incDownloadedSize(_ size: Int) { + downloadedSize += size + } +} + +actor Broadcaster { + private let initialState: @Sendable () async -> E? + private var latestState: E? + private var continuations: [UUID: AsyncStream.Continuation] = [:] + + init(initialState: @Sendable @escaping () async -> E?) { + self.initialState = initialState + } + + deinit { + self.continuations.removeAll() + } + + func subscribe() -> AsyncStream { + AsyncStream { continuation in + let id = UUID() + self.continuations[id] = continuation + + continuation.onTermination = { @Sendable status in + Task { + await self.unsubscribe(id) + } + } + + Task { + if let state = self.latestState { + continuation.yield(state) + return + } + if let state = await self.initialState() { + continuation.yield(state) + } + } + } + } + + private func unsubscribe(_ id: UUID) { + continuations.removeValue(forKey: id) + } + + func broadcast(state: E) async { + latestState = state + await withTaskGroup(of: Void.self) { group in + for continuation in continuations.values { + group.addTask { + continuation.yield(state) + } + } + } + } +} + +actor SessionActor { + private var urlSession: URLSession? + + func set(_ urlSession: URLSession?) { + self.urlSession = urlSession + } + + func get() -> URLSession? { + urlSession + } +} + +actor TaskActor { + private var task: Task? + + func set(_ task: Task?) { + self.task = task + } + + func get() -> Task? { + task + } +} diff --git a/Sources/Hub/Hub.swift b/Sources/Hub/Hub.swift index 74ce0bf..00b7b2c 100644 --- a/Sources/Hub/Hub.swift +++ b/Sources/Hub/Hub.swift @@ -7,7 +7,7 @@ import Foundation -public struct Hub { } +public struct Hub: Sendable { } public extension Hub { enum HubClientError: LocalizedError { @@ -68,14 +68,14 @@ public extension Hub { } } -public class LanguageModelConfigurationFromHub { +public final class LanguageModelConfigurationFromHub: Sendable { struct Configurations { var modelConfig: Config var tokenizerConfig: Config? var tokenizerData: Config } - private var configPromise: Task? + private let configPromise: Task public init( modelName: String, @@ -83,7 +83,7 @@ public class LanguageModelConfigurationFromHub { hubApi: HubApi = .shared ) { configPromise = Task.init { - try await self.loadConfig(modelName: modelName, revision: revision, hubApi: hubApi) + try await Self.loadConfig(modelName: modelName, revision: revision, hubApi: hubApi) } } @@ -92,21 +92,21 @@ public class LanguageModelConfigurationFromHub { hubApi: HubApi = .shared ) { configPromise = Task { - try await self.loadConfig(modelFolder: modelFolder, hubApi: hubApi) + try await Self.loadConfig(modelFolder: modelFolder, hubApi: hubApi) } } public var modelConfig: Config { get async throws { - try await configPromise!.value.modelConfig + try await configPromise.value.modelConfig } } public var tokenizerConfig: Config? { get async throws { - if let hubConfig = try await configPromise!.value.tokenizerConfig { + if let hubConfig = try await configPromise.value.tokenizerConfig { // Try to guess the class if it's not present and the modelType is - if let _: String = hubConfig.tokenizerClass?.string() { return hubConfig } + if hubConfig.tokenizerClass?.string() != nil { return hubConfig } guard let modelType = try await modelType else { return hubConfig } // If the config exists but doesn't contain a tokenizerClass, use a fallback config if we have it @@ -129,7 +129,7 @@ public class LanguageModelConfigurationFromHub { public var tokenizerData: Config { get async throws { - try await configPromise!.value.tokenizerData + try await configPromise.value.tokenizerData } } @@ -139,7 +139,7 @@ public class LanguageModelConfigurationFromHub { } } - func loadConfig( + static func loadConfig( modelName: String, revision: String, hubApi: HubApi = .shared @@ -167,7 +167,7 @@ public class LanguageModelConfigurationFromHub { } } - func loadConfig( + static func loadConfig( modelFolder: URL, hubApi: HubApi = .shared ) async throws -> Configurations { diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift index adfbf4a..7c8f61e 100644 --- a/Sources/Hub/HubApi.swift +++ b/Sources/Hub/HubApi.swift @@ -10,18 +10,24 @@ import Foundation import Network import os -public struct HubApi { +public struct HubApi: Sendable { var downloadBase: URL var hfToken: String? var endpoint: String var useBackgroundSession: Bool - var useOfflineMode: Bool? + var useOfflineMode: Bool? = nil private let networkMonitor = NetworkMonitor() public typealias RepoType = Hub.RepoType public typealias Repo = Hub.Repo - public init(downloadBase: URL? = nil, hfToken: String? = nil, endpoint: String = "https://huggingface.co", useBackgroundSession: Bool = false, useOfflineMode: Bool? = nil) { + public init( + downloadBase: URL? = nil, + hfToken: String? = nil, + endpoint: String = "https://huggingface.co", + useBackgroundSession: Bool = false, + useOfflineMode: Bool? = nil + ) { self.hfToken = hfToken ?? Self.hfTokenFromEnv() if let downloadBase { self.downloadBase = downloadBase @@ -389,7 +395,9 @@ public extension HubApi { let remoteCommitHash = remoteMetadata.commitHash ?? "" // Local file exists + metadata exists + commit_hash matches => return file - if hub.isValidHash(hash: remoteCommitHash, pattern: hub.commitHashPattern), downloaded, localMetadata != nil, localCommitHash == remoteCommitHash { + if hub.isValidHash(hash: remoteCommitHash, pattern: hub.commitHashPattern), downloaded, localMetadata != nil, + localCommitHash == remoteCommitHash + { return destination } @@ -427,51 +435,46 @@ public extension HubApi { let incompleteDestination = repoMetadataDestination.appending(path: relativeFilename + ".\(remoteEtag).incomplete") try prepareCacheDestination(incompleteDestination) - let downloader = Downloader( - from: source, - to: destination, - incompleteDestination: incompleteDestination, - using: hfToken, - inBackground: backgroundSession, - expectedSize: remoteSize - ) + let downloader = Downloader(to: destination, incompleteDestination: incompleteDestination, inBackground: backgroundSession) - return try await withTaskCancellationHandler { - let downloadSubscriber = downloader.downloadState.sink { state in + try await withTaskCancellationHandler { + let sub = await downloader.download(from: source, using: hfToken, expectedSize: remoteSize) + listen: for await state in sub { switch state { + case .notStarted: + continue case let .downloading(progress): progressHandler(progress) - case .completed, .failed, .notStarted: - break - } - } - do { - _ = try withExtendedLifetime(downloadSubscriber) { - try downloader.waitUntilDone() + case let .failed(error): + throw error + case .completed: + break listen } - - try hub.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination) - - return destination - } catch { - // If download fails, leave the incomplete file in place for future resume - throw error } } onCancel: { - downloader.cancel() + Task { + await downloader.cancel() + } } + + try hub.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination) + + return destination } } @discardableResult - func snapshot(from repo: Repo, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + func snapshot(from repo: Repo, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) + async throws -> URL + { let repoDestination = localRepoLocation(repo) - let repoMetadataDestination = repoDestination - .appendingPathComponent(".cache") - .appendingPathComponent("huggingface") - .appendingPathComponent("download") + let repoMetadataDestination = + repoDestination + .appendingPathComponent(".cache") + .appendingPathComponent("huggingface") + .appendingPathComponent("download") - if useOfflineMode ?? NetworkMonitor.shared.shouldUseOfflineMode() { + if await NetworkMonitor.shared.state.shouldUseOfflineMode() || useOfflineMode == true { if !FileManager.default.fileExists(atPath: repoDestination.path) { throw EnvironmentError.offlineModeError(String(localized: "Repository not available locally")) } @@ -482,10 +485,12 @@ public extension HubApi { } for fileUrl in fileUrls { - let metadataPath = URL(fileURLWithPath: fileUrl.path.replacingOccurrences( - of: repoDestination.path, - with: repoMetadataDestination.path - ) + ".metadata") + let metadataPath = URL( + fileURLWithPath: fileUrl.path.replacingOccurrences( + of: repoDestination.path, + with: repoMetadataDestination.path + ) + ".metadata" + ) let localMetadata = try readDownloadMetadata(metadataPath: metadataPath) @@ -521,12 +526,18 @@ public extension HubApi { endpoint: endpoint, backgroundSession: useBackgroundSession ) + try await downloader.download { fractionDownloaded in fileProgress.completedUnitCount = Int64(100 * fractionDownloaded) progressHandler(progress) } + if Task.isCancelled { + return repoDestination + } + fileProgress.completedUnitCount = 100 } + progressHandler(progress) return repoDestination } @@ -624,14 +635,31 @@ public extension HubApi { } /// Network monitor helper class to help decide whether to use offline mode -private extension HubApi { - private final class NetworkMonitor { - private var monitor: NWPathMonitor - private var queue: DispatchQueue +extension HubApi { + private actor NetworkStateActor { + public var isConnected: Bool = false + public var isExpensive: Bool = false + public var isConstrained: Bool = false + + func update(path: NWPath) { + isConnected = path.status == .satisfied + isExpensive = path.isExpensive + isConstrained = path.isConstrained + } - private(set) var isConnected: Bool = false - private(set) var isExpensive: Bool = false - private(set) var isConstrained: Bool = false + func shouldUseOfflineMode() -> Bool { + if ProcessInfo.processInfo.environment["CI_DISABLE_NETWORK_MONITOR"] == "1" { + return false + } + return !isConnected || isExpensive || isConstrained + } + } + + private final class NetworkMonitor: Sendable { + private let monitor: NWPathMonitor + private let queue: DispatchQueue + + public let state: NetworkStateActor = .init() static let shared = NetworkMonitor() @@ -644,10 +672,9 @@ private extension HubApi { func startMonitoring() { monitor.pathUpdateHandler = { [weak self] path in guard let self else { return } - - isConnected = path.status == .satisfied - isExpensive = path.isExpensive - isConstrained = path.isConstrained + Task { + await self.state.update(path: path) + } } monitor.start(queue: queue) @@ -657,13 +684,6 @@ private extension HubApi { monitor.cancel() } - func shouldUseOfflineMode() -> Bool { - if ProcessInfo.processInfo.environment["CI_DISABLE_NETWORK_MONITOR"] == "1" { - return false - } - return !isConnected || isExpensive || isConstrained - } - deinit { stopMonitoring() } @@ -692,7 +712,9 @@ public extension Hub { try await HubApi.shared.snapshot(from: repo, matching: globs, progressHandler: progressHandler) } - static func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + static func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws + -> URL + { try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler) } @@ -740,11 +762,13 @@ public extension FileManager { var fileUrls = [URL]() // Get all contents including subdirectories - guard let enumerator = FileManager.default.enumerator( - at: directoryUrl, - includingPropertiesForKeys: [.isRegularFileKey, .isHiddenKey], - options: [.skipsHiddenFiles] - ) else { + guard + let enumerator = FileManager.default.enumerator( + at: directoryUrl, + includingPropertiesForKeys: [.isRegularFileKey, .isHiddenKey], + options: [.skipsHiddenFiles] + ) + else { return fileUrls } @@ -765,8 +789,14 @@ public extension FileManager { /// Only allow relative redirects and reject others /// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/file_download.py#L258 -private class RedirectDelegate: NSObject, URLSessionTaskDelegate { - func urlSession(_ session: URLSession, task: URLSessionTask, willPerformHTTPRedirection response: HTTPURLResponse, newRequest request: URLRequest, completionHandler: @escaping (URLRequest?) -> Void) { +private final class RedirectDelegate: NSObject, URLSessionTaskDelegate, Sendable { + func urlSession( + _ session: URLSession, + task: URLSessionTask, + willPerformHTTPRedirection response: HTTPURLResponse, + newRequest request: URLRequest, + completionHandler: @escaping (URLRequest?) -> Void + ) { // Check if it's a redirect status code (300-399) if (300...399).contains(response.statusCode) { // Get the Location header diff --git a/Tests/HubTests/DownloaderTests.swift b/Tests/HubTests/DownloaderTests.swift index d62d2e8..d1533d2 100644 --- a/Tests/HubTests/DownloaderTests.swift +++ b/Tests/HubTests/DownloaderTests.swift @@ -6,6 +6,8 @@ // import Combine +import XCTest + @testable import Hub import XCTest @@ -25,8 +27,8 @@ enum DownloadError: LocalizedError { } private extension Downloader { - func interruptDownload() { - session?.invalidateAndCancel() + func interruptDownload() async { + await session.get()?.invalidateAndCancel() } } @@ -72,33 +74,22 @@ final class DownloaderTests: XCTestCase { let incompleteDestination = cacheDir.appendingPathComponent("config.json.\(etag).incomplete") FileManager.default.createFile(atPath: incompleteDestination.path, contents: nil, attributes: nil) - let downloader = Downloader( - from: url, - to: destination, - incompleteDestination: incompleteDestination - ) - - // Store subscriber outside the continuation to maintain its lifecycle - var subscriber: AnyCancellable? - - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - subscriber = downloader.downloadState.sink { state in - switch state { - case .completed: - continuation.resume() - case let .failed(error): - continuation.resume(throwing: error) - case .downloading: - break - case .notStarted: - break - } + let downloader = Downloader(to: destination, incompleteDestination: incompleteDestination) + let sub = await downloader.download(from: url) + + listen: for await state in sub { + switch state { + case .notStarted: + continue + case .downloading: + continue + case let .failed(error): + throw error + case .completed: + break listen } } - // Cancel subscription after continuation completes - subscriber?.cancel() - // Verify download completed successfully XCTAssertTrue(FileManager.default.fileExists(atPath: destination.path)) XCTAssertEqual(try String(contentsOf: destination, encoding: .utf8), fileContent) @@ -116,18 +107,22 @@ final class DownloaderTests: XCTestCase { let incompleteDestination = cacheDir.appendingPathComponent("config.json.\(etag).incomplete") FileManager.default.createFile(atPath: incompleteDestination.path, contents: nil, attributes: nil) - // Create downloader with incorrect expected size - let downloader = Downloader( - from: url, - to: destination, - incompleteDestination: incompleteDestination, - expectedSize: 999999 // Incorrect size - ) - - do { - try downloader.waitUntilDone() - XCTFail("Download should have failed due to size mismatch") - } catch { } + let downloader = Downloader(to: destination, incompleteDestination: incompleteDestination) + // Download with incorrect expected size + let sub = await downloader.download(from: url, expectedSize: 999999) // Incorrect size + listen: for await state in sub { + switch state { + case .notStarted: + continue + case .downloading: + continue + case .failed: + break listen + case .completed: + XCTFail("Download should have failed due to size mismatch") + break listen + } + } // Verify no file was created at destination XCTAssertFalse(FileManager.default.fileExists(atPath: destination.path)) @@ -141,8 +136,10 @@ final class DownloaderTests: XCTestCase { let destination = tempDir.appendingPathComponent("SAM%202%20Studio%201.1.zip") // Create parent directory if it doesn't exist - try FileManager.default.createDirectory(at: destination.deletingLastPathComponent(), - withIntermediateDirectories: true) + try FileManager.default.createDirectory( + at: destination.deletingLastPathComponent(), + withIntermediateDirectories: true + ) let cacheDir = tempDir.appendingPathComponent("cache") try? FileManager.default.createDirectory(at: cacheDir, withIntermediateDirectories: true) @@ -150,42 +147,32 @@ final class DownloaderTests: XCTestCase { let incompleteDestination = cacheDir.appendingPathComponent("config.json.\(etag).incomplete") FileManager.default.createFile(atPath: incompleteDestination.path, contents: nil, attributes: nil) - let downloader = Downloader( - from: url, - to: destination, - incompleteDestination: incompleteDestination, - expectedSize: 73194001 // Correct size for verification - ) + let downloader = Downloader(to: destination, incompleteDestination: incompleteDestination) + let sub = await downloader.download(from: url, expectedSize: 73_194_001) // Correct size for verification // First interruption point at 50% var threshold = 0.5 - var subscriber: AnyCancellable? - do { // Monitor download progress and interrupt at thresholds to test if // download continues from where it left off - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - subscriber = downloader.downloadState.sink { state in - switch state { - case let .downloading(progress): - if threshold != 1.0, progress >= threshold { - // Move to next threshold and interrupt - threshold = threshold == 0.5 ? 0.75 : 1.0 - downloader.interruptDownload() - } - case .completed: - continuation.resume() - case let .failed(error): - continuation.resume(throwing: error) - case .notStarted: - break + listen: for await state in sub { + switch state { + case .notStarted: + continue + case let .downloading(progress): + if threshold != 1.0, progress >= threshold { + // Move to next threshold and interrupt + threshold = threshold == 0.5 ? 0.75 : 1.0 + await downloader.interruptDownload() } + case let .failed(error): + throw error + case .completed: + break listen } } - subscriber?.cancel() - // Verify the file exists and is complete if FileManager.default.fileExists(atPath: destination.path) { let attributes = try FileManager.default.attributesOfItem(atPath: destination.path) diff --git a/Tests/HubTests/HubApiTests.swift b/Tests/HubTests/HubApiTests.swift index b03716f..61d4968 100644 --- a/Tests/HubTests/HubApiTests.swift +++ b/Tests/HubTests/HubApiTests.swift @@ -4,9 +4,10 @@ // Created by Pedro Cuenca on 20231230. // -@testable import Hub import XCTest +@testable import Hub + class HubApiTests: XCTestCase { override func setUp() { // Put setup code here. This method is called before the invocation of each test method in the class. @@ -150,10 +151,14 @@ class HubApiTests: XCTestCase { do { let revision = "eaf97358a37d03fd48e5a87d15aff2e8423c1afb" let etag = "fc329090bfbb2570382c9af997cffd5f4b78b39b8aeca62076db69534e020107" - let location = "https://cdn-lfs.hf.co/repos/4a/4e/4a4e587f66a2979dcd75e1d7324df8ee9ef74be3582a05bea31c2c26d0d467d0/fc329090bfbb2570382c9af997cffd5f4b78b39b8aeca62076db69534e020107?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.mlmodel%3B+filename%3D%22model.mlmodel" + let location = + "https://cdn-lfs.hf.co/repos/4a/4e/4a4e587f66a2979dcd75e1d7324df8ee9ef74be3582a05bea31c2c26d0d467d0/fc329090bfbb2570382c9af997cffd5f4b78b39b8aeca62076db69534e020107?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.mlmodel%3B+filename%3D%22model.mlmodel" let size = 504766 - let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel") + let url = URL( + string: + "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel" + ) let metadata = try await Hub.getFileMetadata(fileURL: url!) XCTAssertEqual(metadata.commitHash, revision) @@ -188,7 +193,12 @@ class SnapshotDownloadTests: XCTestCase { var filenames: [String] = [] let prefix = downloadDestination.appending(path: "models/\(repo)").path.appending("/") - if let enumerator = FileManager.default.enumerator(at: url, includingPropertiesForKeys: [.isRegularFileKey], options: [.skipsHiddenFiles], errorHandler: nil) { + if let enumerator = FileManager.default.enumerator( + at: url, + includingPropertiesForKeys: [.isRegularFileKey], + options: [.skipsHiddenFiles], + errorHandler: nil + ) { for case let fileURL as URL in enumerator { do { let resourceValues = try fileURL.resourceValues(forKeys: [.isRegularFileKey]) @@ -915,7 +925,11 @@ class SnapshotDownloadTests: XCTestCase { let metadataDestination = downloadedTo.appendingPathComponent(".cache/huggingface/download").appendingPathComponent("x.bin.metadata") - try "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2ab4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4\n0\n".write(to: metadataDestination, atomically: true, encoding: .utf8) + try "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2ab4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4\n0\n".write( + to: metadataDestination, + atomically: true, + encoding: .utf8 + ) hubApi = HubApi(downloadBase: downloadDestination, useOfflineMode: true) @@ -972,7 +986,9 @@ class SnapshotDownloadTests: XCTestCase { func testResumeDownloadFromEmptyIncomplete() async throws { let hubApi = HubApi(downloadBase: downloadDestination) var lastProgress: Progress? = nil - var downloadedTo = FileManager.default.homeDirectoryForCurrentUser.appendingPathComponent("Library/Caches/huggingface-tests/models/coreml-projects/Llama-2-7b-chat-coreml") + var downloadedTo = FileManager.default.homeDirectoryForCurrentUser.appendingPathComponent( + "Library/Caches/huggingface-tests/models/coreml-projects/Llama-2-7b-chat-coreml" + ) let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") @@ -1070,6 +1086,7 @@ class SnapshotDownloadTests: XCTestCase { // Cancel the download once we've seen progress downloadTask.cancel() + try await Task.sleep(nanoseconds: 5_000_000_000) // Resume download with a new task @@ -1078,8 +1095,10 @@ class SnapshotDownloadTests: XCTestCase { } let filePath = downloadedTo.appendingPathComponent(targetFile) - XCTAssertTrue(FileManager.default.fileExists(atPath: filePath.path), - "Downloaded file should exist at \(filePath.path)") + XCTAssertTrue( + FileManager.default.fileExists(atPath: filePath.path), + "Downloaded file should exist at \(filePath.path)" + ) } func testDownloadWithRevision() async throws {