Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic Formatting with apple/swift-format #90

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .swift-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"version": 1,
"lineLength": 120,
"indentation": {
"spaces": 4
},
"maximumBlankLines": 1,
"respectsExistingLineBreaks": true,
"lineBreakBeforeControlFlowKeywords": true,
"lineBreakBeforeEachArgument": true,
"multiElementCollectionTrailingCommas": true,
"spacesAroundRangeFormationOperators": true
}
14 changes: 14 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Contributing to Swift Transformers

## Code Styling and Linting

Code formatting is enforced with `swift-format` default utility from Apple.
To install and run it on all the files in the project, use the following command:

```bash
brew install swift-format
swift-format . -i -r
```

The style is controlled by the `.swift-format` JSON file in the root of the repository.
As there is no standard for Swift formatting, even Apple's own `swift-format` tool and Xcode differ in their formatting rules, and available settings.
17 changes: 13 additions & 4 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,27 @@ let package = Package(
name: "TransformersCLI",
dependencies: [
"Models", "Generation", "Tokenizers",
.product(name: "ArgumentParser", package: "swift-argument-parser")]),
.executableTarget(name: "HubCLI", dependencies: ["Hub", .product(name: "ArgumentParser", package: "swift-argument-parser")]),
.product(name: "ArgumentParser", package: "swift-argument-parser"),
]
),
.executableTarget(
name: "HubCLI",
dependencies: ["Hub", .product(name: "ArgumentParser", package: "swift-argument-parser")]
),
.target(name: "Hub", resources: [.process("FallbackConfigs")]),
.target(name: "Tokenizers", dependencies: ["Hub"]),
.target(name: "TensorUtils"),
.target(name: "Generation", dependencies: ["Tokenizers", "TensorUtils"]),
.target(name: "Models", dependencies: ["Tokenizers", "Generation", "TensorUtils"]),
.testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources"), .process("Vocabs")]),
.testTarget(
name: "TokenizersTests",
dependencies: ["Tokenizers", "Models", "Hub"],
resources: [.process("Resources"), .process("Vocabs")]
),
.testTarget(name: "HubTests", dependencies: ["Hub"]),
.testTarget(name: "PreTokenizerTests", dependencies: ["Tokenizers", "Hub"]),
.testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils"]),
.testTarget(name: "NormalizerTests", dependencies: ["Tokenizers", "Hub"]),
.testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"])
.testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"]),
]
)
49 changes: 38 additions & 11 deletions Sources/Generation/Generation.swift
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
//
// Generation.swift
//
//
//
// Created by Pedro Cuenca on 7/5/23.
//

import Tokenizers
import CoreML
import TensorUtils
import Tokenizers

public enum GenerationMode {
case contrastiveSearch
Expand All @@ -29,13 +29,29 @@ public typealias PredictionStringCallback = (String) -> Void

// TODO: callbacks (for streaming)
public protocol Generation {
func greedySearch(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback?) async -> GenerationOutput

func generate(config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer, callback: PredictionStringCallback?) async -> String
func greedySearch(
config: GenerationConfig,
tokens: InputTokens,
model: NextTokenModel,
callback: PredictionTokensCallback?
) async -> GenerationOutput

func generate(
config: GenerationConfig,
prompt: String,
model: NextTokenModel,
tokenizer: Tokenizer,
callback: PredictionStringCallback?
) async -> String
}

public extension Generation {
func greedySearch(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? = nil) async -> GenerationOutput {
extension Generation {
public func greedySearch(
config: GenerationConfig,
tokens: InputTokens,
model: NextTokenModel,
callback: PredictionTokensCallback? = nil
) async -> GenerationOutput {
// Iterate until we find the eos token or reach the max length
// TODO: additional stopping criteria
var outputTokens = tokens
Expand All @@ -48,9 +64,14 @@ public extension Generation {
}
return outputTokens
}

/// https://github.com/huggingface/transformers/blob/42017d82baa083da2bee3055fdac80c81ee97b8a/src/transformers/generation/utils.py#L1552
func sample(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? = nil) async -> GenerationOutput {
public func sample(
config: GenerationConfig,
tokens: InputTokens,
model: NextTokenModel,
callback: PredictionTokensCallback? = nil
) async -> GenerationOutput {
// Iterate until we find the eos token or reach the max length
// TODO: additional stopping criteria
var outputTokens = tokens
Expand All @@ -68,7 +89,13 @@ public extension Generation {
return outputTokens
}

func generate(config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer, callback: PredictionStringCallback? = nil) async -> String {
public func generate(
config: GenerationConfig,
prompt: String,
model: NextTokenModel,
tokenizer: Tokenizer,
callback: PredictionStringCallback? = nil
) async -> String {
let tokens = tokenizer.encode(text: prompt)
var generationConfig = config
generationConfig.maxLength = config.maxNewTokens + tokens.count
Expand All @@ -86,7 +113,7 @@ public extension Generation {
default:
fatalError("Generation mode \(generationConfig.generationMode) not implemented yet")
}

return tokenizer.decode(tokens: output)
}

Expand Down
33 changes: 22 additions & 11 deletions Sources/Generation/GenerationConfig.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//
// GenerationConfig.swift
//
//
//
// Created by Pedro Cuenca on 7/5/23.
//
Expand All @@ -19,12 +19,23 @@ public struct GenerationConfig {
public var topK = 50
public var topP = 1.0
public var repetitionPenalty = 1.0

public var padTokenId: Int? = nil
public var bosTokenId: Int? = nil
public var eosTokenId: Int? = nil

public init(maxLength: Int = 20, maxNewTokens: Int, doSample: Bool = false, numBeams: Int = 1, numBeamGroups: Int = 1, penaltyAlpha: Double? = nil, temperature: Double = 1.0, topK: Int = 50, topP: Double = 1.0, repetitionPenalty: Double = 1.0) {

public init(
maxLength: Int = 20,
maxNewTokens: Int,
doSample: Bool = false,
numBeams: Int = 1,
numBeamGroups: Int = 1,
penaltyAlpha: Double? = nil,
temperature: Double = 1.0,
topK: Int = 50,
topP: Double = 1.0,
repetitionPenalty: Double = 1.0
) {
self.maxLength = maxLength
self.maxNewTokens = maxNewTokens
self.doSample = doSample
Expand All @@ -38,19 +49,19 @@ public struct GenerationConfig {
}
}

public extension GenerationConfig {
var generationMode: GenerationMode {
extension GenerationConfig {
public var generationMode: GenerationMode {
// Exclude this case from the pattern matching below
if topK > 1 && !doSample && penaltyAlpha != nil && penaltyAlpha! > 0 {
return .contrastiveSearch
}

switch (numBeams, numBeamGroups, doSample) {
case (1, 1, false) : return .greedy
case (1, 1, true) : return .sample
case (1, 1, false): return .greedy
case (1, 1, true): return .sample
case (2..., 1, false): return .beam
case (2..., 2..., _) : return .groupBeam
default : return .unsupported
case (2..., 2..., _): return .groupBeam
default: return .unsupported
}
}
}
Expand Down
29 changes: 18 additions & 11 deletions Sources/Hub/Downloader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
//

import Foundation
import Combine
import Foundation

class Downloader: NSObject, ObservableObject {
private(set) var destination: URL
Expand Down Expand Up @@ -86,16 +86,16 @@ class Downloader: NSObject, ObservableObject {
stateSubscriber = downloadState.sink { state in
switch state {
case .completed: semaphore.signal()
case .failed: semaphore.signal()
default: break
case .failed: semaphore.signal()
default: break
}
}
semaphore.wait()

switch downloadState.value {
case .completed(let url): return url
case .failed(let error): throw error
default: throw DownloadError.unexpectedError
case .failed(let error): throw error
default: throw DownloadError.unexpectedError
}
}

Expand All @@ -105,7 +105,13 @@ class Downloader: NSObject, ObservableObject {
}

extension Downloader: URLSessionDownloadDelegate {
func urlSession(_: URLSession, downloadTask: URLSessionDownloadTask, didWriteData _: Int64, totalBytesWritten: Int64, totalBytesExpectedToWrite: Int64) {
func urlSession(
_: URLSession,
downloadTask: URLSessionDownloadTask,
didWriteData _: Int64,
totalBytesWritten: Int64,
totalBytesExpectedToWrite: Int64
) {
downloadState.value = .downloading(Double(totalBytesWritten) / Double(totalBytesExpectedToWrite))
}

Expand All @@ -114,18 +120,19 @@ extension Downloader: URLSessionDownloadDelegate {
// If the downloaded file already exists on the filesystem, overwrite it
try FileManager.default.moveDownloadedFile(from: location, to: self.destination)
downloadState.value = .completed(destination)
} catch {
}
catch {
downloadState.value = .failed(error)
}
}

func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
if let error = error {
downloadState.value = .failed(error)
// } else if let response = task.response as? HTTPURLResponse {
// print("HTTP response status code: \(response.statusCode)")
// let headers = response.allHeaderFields
// print("HTTP response headers: \(headers)")
// } else if let response = task.response as? HTTPURLResponse {
// print("HTTP response status code: \(response.statusCode)")
// let headers = response.allHeaderFields
// print("HTTP response headers: \(headers)")
}
}
}
Expand Down
Loading