Skip to content

Commit

Permalink
Merge pull request #82 from nathanborror/main
Browse files Browse the repository at this point in the history
Add functionCall to ChatQuery
  • Loading branch information
Krivoblotsky authored Jul 10, 2023
2 parents d7f478e + c3f1c96 commit a9965f0
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
32 changes: 31 additions & 1 deletion Sources/OpenAI/Public/Models/ChatQuery.swift
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ public struct ChatQuery: Equatable, Codable, Streamable {
public let messages: [Chat]
/// A list of functions the model may generate JSON inputs for.
public let functions: [ChatFunctionDeclaration]?
/// Controls how the model responds to function calls. "none" means the model does not call a function, and responds to the end-user. "auto" means the model can pick between and end-user or calling a function. Specifying a particular function via `{"name": "my_function"}` forces the model to call that function. "none" is the default when no functions are present. "auto" is the default if functions are present.
public let functionCall: FunctionCall?
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and We generally recommend altering this or top_p but not both.
public let temperature: Double?
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
Expand All @@ -211,11 +213,38 @@ public struct ChatQuery: Equatable, Codable, Streamable {
public let user: String?

var stream: Bool = false

public enum FunctionCall: Codable, Equatable {
case none
case auto
case function(String)

enum CodingKeys: String, CodingKey {
case none = "none"
case auto = "auto"
case function = "name"
}

public func encode(to encoder: Encoder) throws {
switch self {
case .none:
var container = encoder.singleValueContainer()
try container.encode(CodingKeys.none.rawValue)
case .auto:
var container = encoder.singleValueContainer()
try container.encode(CodingKeys.auto.rawValue)
case .function(let name):
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(name, forKey: .function)
}
}
}

enum CodingKeys: String, CodingKey {
case model
case messages
case functions
case functionCall = "function_call"
case temperature
case topP = "top_p"
case n
Expand All @@ -228,10 +257,11 @@ public struct ChatQuery: Equatable, Codable, Streamable {
case user
}

public init(model: Model, messages: [Chat], functions: [ChatFunctionDeclaration]? = nil, temperature: Double? = nil, topP: Double? = nil, n: Int? = nil, stop: [String]? = nil, maxTokens: Int? = nil, presencePenalty: Double? = nil, frequencyPenalty: Double? = nil, logitBias: [String : Int]? = nil, user: String? = nil, stream: Bool = false) {
public init(model: Model, messages: [Chat], functions: [ChatFunctionDeclaration]? = nil, functionCall: FunctionCall? = nil, temperature: Double? = nil, topP: Double? = nil, n: Int? = nil, stop: [String]? = nil, maxTokens: Int? = nil, presencePenalty: Double? = nil, frequencyPenalty: Double? = nil, logitBias: [String : Int]? = nil, user: String? = nil, stream: Bool = false) {
self.model = model
self.messages = messages
self.functions = functions
self.functionCall = functionCall
self.temperature = temperature
self.topP = topP
self.n = n
Expand Down
24 changes: 23 additions & 1 deletion Tests/OpenAITests/OpenAITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,29 @@ class OpenAITests: XCTestCase {

let result = try await openAI.chats(query: query)
XCTAssertEqual(result, chatResult)
}
}

func testChatsFunction() async throws {
let query = ChatQuery(model: .gpt3_5Turbo0613, messages: [
.init(role: .system, content: "You are Weather-GPT. You know everything about the weather."),
.init(role: .user, content: "What's the weather like in Boston?"),
], functions: [
.init(name: "get_current_weather", description: "Get the current weather in a given location", parameters: .init(type: .object, properties: [
"location": .init(type: .string, description: "The city and state, e.g. San Francisco, CA"),
"unit": .init(type: .string, enumValues: ["celsius", "fahrenheit"])
], required: ["location"]))
], functionCall: .auto)

let chatResult = ChatResult(id: "id-12312", object: "foo", created: 100, model: .gpt3_5Turbo, choices: [
.init(index: 0, message: .init(role: .system, content: "bar"), finishReason: "baz"),
.init(index: 0, message: .init(role: .user, content: "bar1"), finishReason: "baz1"),
.init(index: 0, message: .init(role: .assistant, content: "bar2"), finishReason: "baz2")
], usage: .init(promptTokens: 100, completionTokens: 200, totalTokens: 300))
try self.stub(result: chatResult)

let result = try await openAI.chats(query: query)
XCTAssertEqual(result, chatResult)
}

func testChatsError() async throws {
let query = ChatQuery(model: .gpt4, messages: [
Expand Down

0 comments on commit a9965f0

Please sign in to comment.