Skip to content

Commit

Permalink
added xlm-roberta
Browse files Browse the repository at this point in the history
  • Loading branch information
jkrukowski committed Dec 8, 2024
1 parent 8855129 commit c46d0c2
Show file tree
Hide file tree
Showing 19 changed files with 894 additions and 102 deletions.
11 changes: 10 additions & 1 deletion Package.resolved
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"originHash" : "5501a395135de94e2c743aa6ee8e034ae6347e2a5c1f556d0c0e028bfceb9b7c",
"originHash" : "3173defd78a48faa60b1c56cfa74f15c0c2b63eee978ea01ea5eb21e0b8e5939",
"pins" : [
{
"identity" : "jinja",
Expand Down Expand Up @@ -37,6 +37,15 @@
"version" : "0.0.6"
}
},
{
"identity" : "swift-sentencepiece",
"kind" : "remoteSourceControl",
"location" : "https://github.com/jkrukowski/swift-sentencepiece",
"state" : {
"revision" : "75d725019ff0b75fbbd7128314fe6710c5a86df0",
"version" : "0.0.5"
}
},
{
"identity" : "swift-transformers",
"kind" : "remoteSourceControl",
Expand Down
5 changes: 5 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ let package = Package(
url: "https://github.com/apple/swift-argument-parser.git",
from: "1.5.0"
),
.package(
url: "https://github.com/jkrukowski/swift-sentencepiece",
from: "0.0.5"
),
],
targets: [
.executableTarget(
Expand All @@ -57,6 +61,7 @@ let package = Package(
"MLTensorUtils",
.product(name: "Safetensors", package: "swift-safetensors"),
.product(name: "Transformers", package: "swift-transformers"),
.product(name: "SentencepieceTokenizer", package: "swift-sentencepiece"),
]
),
.target(
Expand Down
38 changes: 17 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ Some of the supported models on `Hugging Face`:
- [sentence-transformers/msmarco-bert-base-dot-v5](https://huggingface.co/sentence-transformers/msmarco-bert-base-dot-v5)
- [thenlper/gte-base](https://huggingface.co/thenlper/gte-base)

### XLM-RoBERTa (Cross-lingual Language Model - Robustly Optimized BERT Approach)

Some of the supported models on `Hugging Face`:

- [sentence-transformers/paraphrase-multilingual-mpnet-base-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2)
- [tomaarsen/xlm-roberta-base-multilingual-en-ar-fr-de-es-tr-it](https://huggingface.co/tomaarsen/xlm-roberta-base-multilingual-en-ar-fr-de-es-tr-it)

### CLIP (Contrastive Language–Image Pre-training)

NOTE: only text encoding is supported for now.
Expand All @@ -31,7 +38,7 @@ Add the following to your `Package.swift` file. In the package dependencies add:

```swift
dependencies: [
.package(url: "https://github.com/jkrukowski/swift-embeddings", from: "0.0.4")
.package(url: "https://github.com/jkrukowski/swift-embeddings", from: "0.0.5")
]
```

Expand Down Expand Up @@ -85,37 +92,26 @@ print(result)

## Command Line Demo

### BERT

To run the `BERT` command line demo, use the following command:

```bash
swift run embeddings-cli bert [--model-id <model-id>] [--text <text>] [--max-length <max-length>]
```

Command line options:
To run the command line demo, use the following command:

```bash
--model-id <model-id> (default: sentence-transformers/all-MiniLM-L6-v2)
--text <text> (default: a photo of a dog)
--max-length <max-length> (default: 512)
-h, --help Show help information.
swift run embeddings-cli <subcommand> [--model-id <model-id>] [--text <text>] [--max-length <max-length>]
```

### CLIP

To run the `CLIP` command line demo, use the following command:
Subcommands:

```bash
swift run embeddings-cli clip [--model-id <model-id>] [--text <text>] [--max-length <max-length>]
bert Encode text using BERT model
clip Encode text using CLIP model
xlm-roberta Encode text using XLMRoberta model
```

Command line options:

```bash
--model-id <model-id> (default: jkrukowski/clip-vit-base-patch16)
--text <text> (default: a photo of a dog)
--max-length <max-length> (default: 77)
--model-id <model-id> Id of the model to use
--text <text> Text to encode
--max-length <max-length> Maximum length of the input
-h, --help Show help information.
```

Expand Down
8 changes: 4 additions & 4 deletions Sources/Embeddings/Bert/BertModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,8 @@ extension Bert {
public func encode(
_ text: String,
maxLength: Int = 512
) -> MLTensor {
let tokens = tokenizer.tokenize(text, maxLength: maxLength)
) throws -> MLTensor {
let tokens = try tokenizer.tokenizeText(text, maxLength: maxLength)
let inputIds = MLTensor(shape: [1, tokens.count], scalars: tokens)
let result = model(inputIds: inputIds)
return result.sequenceOutput[0..., 0, 0...]
Expand All @@ -382,8 +382,8 @@ extension Bert {
_ texts: [String],
padTokenId: Int = 0,
maxLength: Int = 512
) -> MLTensor {
let encodedTexts = tokenizer.tokenizePaddingToLongest(
) throws -> MLTensor {
let encodedTexts = try tokenizer.tokenizeTextsPaddingToLongest(
texts, padTokenId: padTokenId, maxLength: maxLength)
let inputIds = MLTensor(
shape: [encodedTexts.count, encodedTexts[0].count],
Expand Down
14 changes: 9 additions & 5 deletions Sources/Embeddings/Bert/BertUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,22 @@ extension Bert {
downloadBase: URL? = nil,
useBackgroundSession: Bool = false
) async throws -> Bert.ModelBundle {
let modelUrl = try await downloadModelFromHub(
let modelFolder = try await downloadModelFromHub(
from: hubRepoId,
downloadBase: downloadBase,
useBackgroundSession: useBackgroundSession
)
let tokenizer = try await AutoTokenizer.from(modelFolder: modelUrl)
return try await loadModelBundle(from: modelFolder)
}

public static func loadModelBundle(from modelFolder: URL) async throws -> Bert.ModelBundle {
let tokenizer = try await AutoTokenizer.from(modelFolder: modelFolder)
// NOTE: just `safetensors` support for now
let weightsUrl = modelUrl.appendingPathComponent("model.safetensors")
let configUrl = modelUrl.appendingPathComponent("config.json")
let weightsUrl = modelFolder.appendingPathComponent("model.safetensors")
let configUrl = modelFolder.appendingPathComponent("config.json")
let config = try Bert.loadConfig(at: configUrl)
let model = try Bert.loadModel(weightsUrl: weightsUrl, config: config)
return Bert.ModelBundle(model: model, tokenizer: TextTokenizerType.transformers(tokenizer))
return Bert.ModelBundle(model: model, tokenizer: TokenizerWrapper(tokenizer))
}
}

Expand Down
8 changes: 4 additions & 4 deletions Sources/Embeddings/Clip/ClipModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,8 @@ extension Clip {
self.tokenizer = tokenizer
}

public func encode(_ text: String, maxLength: Int = 77) -> MLTensor {
let tokens = tokenizer.tokenize(text, maxLength: maxLength)
public func encode(_ text: String, maxLength: Int = 77) throws -> MLTensor {
let tokens = try tokenizer.tokenizeText(text, maxLength: maxLength)
let inputIds = MLTensor(shape: [1, tokens.count], scalars: tokens)
let modelOutput = textModel(inputIds: inputIds)
let textEmbeddings = textModel.textProjection(modelOutput.poolerOutput)
Expand All @@ -293,8 +293,8 @@ extension Clip {
_ texts: [String],
padTokenId: Int = 0,
maxLength: Int = 77
) -> MLTensor {
let encodedTexts = tokenizer.tokenizePaddingToLongest(
) throws -> MLTensor {
let encodedTexts = try tokenizer.tokenizeTextsPaddingToLongest(
texts, padTokenId: padTokenId, maxLength: maxLength)
let inputIds = MLTensor(
shape: [encodedTexts.count, encodedTexts[0].count],
Expand Down
14 changes: 9 additions & 5 deletions Sources/Embeddings/Clip/ClipUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,22 @@ extension Clip {
downloadBase: URL? = nil,
useBackgroundSession: Bool = false
) async throws -> Clip.ModelBundle {
let modelUrl = try await downloadModelFromHub(
let modelFolder = try await downloadModelFromHub(
from: hubRepoId,
downloadBase: downloadBase,
useBackgroundSession: useBackgroundSession
)
let tokenizer = try loadClipTokenizer(at: modelUrl)
let weightsUrl = modelUrl.appendingPathComponent("model.safetensors")
let configUrl = modelUrl.appendingPathComponent("config.json")
return try await loadModelBundle(from: modelFolder)
}

public static func loadModelBundle(from modelFolder: URL) async throws -> Clip.ModelBundle {
let tokenizer = try loadClipTokenizer(at: modelFolder)
let weightsUrl = modelFolder.appendingPathComponent("model.safetensors")
let configUrl = modelFolder.appendingPathComponent("config.json")
let config = try Clip.loadConfig(at: configUrl)
// TODO: implement vision model loading
let textModel = try Clip.loadModel(weightsUrl: weightsUrl, config: config)
return Clip.ModelBundle(textModel: textModel, tokenizer: TextTokenizerType.clip(tokenizer))
return Clip.ModelBundle(textModel: textModel, tokenizer: tokenizer)
}
}

Expand Down
5 changes: 5 additions & 0 deletions Sources/Embeddings/EmbeddingsUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@ func downloadModelFromHub(
"*.safetensors",
"*.py",
"tokenizer.model",
"sentencepiece*.model",
"*.tiktoken",
"*.txt",
]
)
}

enum EmbeddingsError: Error {
case fileNotFound
}

func loadConfigFromFile<Config: Codable>(at url: URL) throws -> Config {
let configData = try Data(contentsOf: url)
let decoder = JSONDecoder()
Expand Down
51 changes: 0 additions & 51 deletions Sources/Embeddings/Tokenizer.swift

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ final class ClipTokenizer: Sendable {
self.cache = Mutex([:])
}

func tokenize(_ text: String, maxLength: Int = 77, padToLength: Int? = nil) -> [Int] {
func tokenize(_ text: String, maxLength: Int, padToLength: Int? = nil) -> [Int] {
precondition(
maxLength >= 2, "maxLength must be at least 2 to accommodate BOS and EOS tokens")
let cleanText = text.lowercased().replacing(emptyStringPattern, with: " ")
Expand Down Expand Up @@ -96,6 +96,12 @@ final class ClipTokenizer: Sendable {
}
}

extension ClipTokenizer: TextTokenizer {
func tokenizeText(_ text: String, maxLength: Int) throws -> [Int32] {
tokenize(text, maxLength: maxLength, padToLength: nil).map { Int32($0) }
}
}

func loadClipTokenizer(at url: URL) throws -> ClipTokenizer {
let mergesData = try String(
contentsOf: url.appendingPathComponent("merges.txt"),
Expand Down
52 changes: 52 additions & 0 deletions Sources/Embeddings/Tokenizer/TextTokenizer.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import Foundation
import SentencepieceTokenizer
@preconcurrency import Tokenizers

public protocol TextTokenizer: Sendable {
func tokenizeText(_ text: String, maxLength: Int) throws -> [Int32]
func tokenizeTextsPaddingToLongest(
_ texts: [String], padTokenId: Int, maxLength: Int
) throws -> [[Int32]]
}

extension TextTokenizer {
public func tokenizeTextsPaddingToLongest(
_ texts: [String],
padTokenId: Int,
maxLength: Int
) throws -> [[Int32]] {
var longest = 0
var result = [[Int32]]()
result.reserveCapacity(texts.count)
for text in texts {
let encoded = try tokenizeText(text, maxLength: maxLength)
longest = max(longest, encoded.count)
result.append(encoded)
}
return result.map {
if $0.count < longest {
return $0 + Array(repeating: Int32(padTokenId), count: longest - $0.count)
} else {
return $0
}
}
}
}

public struct TokenizerWrapper {
private let tokenizer: any Tokenizers.Tokenizer

public init(_ tokenizer: any Tokenizers.Tokenizer) {
self.tokenizer = tokenizer
}
}

extension TokenizerWrapper: TextTokenizer {
public func tokenizeText(_ text: String, maxLength: Int) throws -> [Int32] {
var encoded = tokenizer.encode(text: text)
if encoded.count > maxLength {
encoded.removeLast(encoded.count - maxLength)
}
return encoded.map { Int32($0) }
}
}
Loading

0 comments on commit c46d0c2

Please sign in to comment.