diff --git a/.github/workflows/swift.yml b/.github/workflows/swift.yml index f8b34ed..bdbfb02 100644 --- a/.github/workflows/swift.yml +++ b/.github/workflows/swift.yml @@ -5,9 +5,9 @@ name: Swift on: push: - branches: [ "main" ] + branches: [ "main", "local" ] pull_request: - branches: [ "main" ] + branches: [ "main", "local" ] jobs: build: diff --git a/Package.resolved b/Package.resolved index 47d5427..2215a3a 100644 --- a/Package.resolved +++ b/Package.resolved @@ -72,6 +72,15 @@ "version" : "4.2.2" } }, + { + "identity" : "llmfarm_core.swift", + "kind" : "remoteSourceControl", + "location" : "https://github.com/buhe/llmfarm_core.swift", + "state" : { + "branch" : "langchain", + "revision" : "927d670751bc8aebbc5eb845afd36fe1eeef4f5a" + } + }, { "identity" : "openai-kit", "kind" : "remoteSourceControl", diff --git a/Package.swift b/Package.swift index f90bac0..26fa2e3 100644 --- a/Package.swift +++ b/Package.swift @@ -26,6 +26,7 @@ let package = Package( .package(url: "https://github.com/juyan/swift-filestore", .upToNextMajor(from: "0.5.0")), .package(url: "https://github.com/buhe/similarity-search-kit", from: "0.0.16"), .package(url: "https://github.com/google/generative-ai-swift", .upToNextMajor(from: "0.4.4")), + .package(url: "https://github.com/buhe/llmfarm_core.swift", .branch("langchain")), .package(url: "https://github.com/buhe/SwiftyNotion", .upToNextMajor(from: "0.1.3")), .package(url: "https://github.com/nmdias/FeedKit", .upToNextMajor(from: "9.1.2")), ], @@ -45,6 +46,7 @@ let package = Package( // .product(name: "SimilaritySearchKitDistilbert", package: "similarity-search-kit", condition: .when(platforms: [.macOS, .iOS, .visionOS])), .product(name: "GoogleGenerativeAI", package: "generative-ai-swift"), .product(name: "SwiftyNotion", package: "SwiftyNotion"), + .product(name: "llmfarm_core", package: "llmfarm_core.swift"), .product(name: "FeedKit", package: "FeedKit"), ] diff --git a/Sources/LangChain/llms/Local.swift b/Sources/LangChain/llms/Local.swift index 8ef6fab..7cec3e1 100644 --- a/Sources/LangChain/llms/Local.swift +++ b/Sources/LangChain/llms/Local.swift @@ -1,46 +1,42 @@ -//// -//// File.swift -//// -//// -//// Created by ι‘Ύθ‰³εŽ on 1/22/24. -//// -//import llmfarm_core -//import Foundation -// -//public class Local: LLM { -// let modelPath: String -// let useMetal: Bool -// let inference: ModelInference -// -// public init(inference: ModelInference, modelPath: String, useMetal: Bool = false, callbacks: [BaseCallbackHandler] = [], cache: BaseCache? = nil) { -// self.inference = inference -// self.modelPath = modelPath -// self.useMetal = useMetal -// super.init(callbacks: callbacks, cache: cache) -// } -// public override func _send(text: String, stops: [String] = []) async throws -> LLMResult { -// let ai = AI(_modelPath: self.modelPath, _chatName: "chat") -// var params:ModelAndContextParams = .default -// params.use_metal = useMetal -// params.promptFormat = .Custom -// params.custom_prompt_format = "{{prompt}}" -// try? ai.loadModel(inference, contextParams: params) -// let output = try? ai.model.predict(text, mainCallback) -//// print("πŸš—\(output)") -// total_output = 0 -// return LLMResult(llm_output: output) -// } -// -// let maxOutputLength = 256 -// var total_output = 0 -// -// func mainCallback(_ str: String, _ time: Double) -> Bool { -// print("\(str)",terminator: "") -// total_output += str.count -// if(total_output>maxOutputLength){ -// return true -// } -// return false -// } -//} -// +import llmfarm_core +import Foundation + +public class Local: LLM { + let modelPath: String + let useMetal: Bool + let inference: ModelInference + var params: ModelAndContextParams + var sampleParams: ModelSampleParams + + public init(inference: ModelInference, modelPath: String, useMetal: Bool = false,params: ModelAndContextParams = .default, sampleParams: ModelSampleParams = .default, callbacks: [BaseCallbackHandler] = [], cache: BaseCache? = nil) { + self.inference = inference + self.modelPath = modelPath + self.useMetal = useMetal + self.params = params + self.sampleParams = sampleParams + super.init(callbacks: callbacks, cache: cache) + } + public override func _send(text: String, stops: [String] = []) async throws -> LLMResult { + let ai = AI(_modelPath: self.modelPath, _chatName: "chat") + self.params.use_metal = useMetal + + let _ = try? ai.loadModel(inference, contextParams: self.params) + ai.model.sampleParams = self.sampleParams + let output = try? ai.model.predict(text, mainCallback) +// print("πŸš—\(output)") + total_output = 0 + return LLMResult(llm_output: output) + } + + let maxOutputLength = 256 + var total_output = 0 + + func mainCallback(_ str: String, _ time: Double) -> Bool { + print("\(str)",terminator: "") + total_output += str.count + if(total_output>maxOutputLength){ + return true + } + return false + } +}