Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable OpenAI functions in chat completion #47

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
294 changes: 120 additions & 174 deletions R/create_chat_completion.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#' @param model required; a length one character vector.
#' @param messages required; defaults to `NULL`; a list in the following
#' format: `list(list("role" = "user", "content" = "Hey! How old are you?")`
#' @param functions optional; defaults to `NULL`; OpenAI function definitions
#' @param temperature required; defaults to `1`; a length one numeric vector
#' with the value between `0` and `2`.
#' @param top_p required; defaults to `1`; a length one numeric vector with the
Expand Down Expand Up @@ -57,181 +58,126 @@
#' )
#' )
#' )
#'
#' #OpenAI function example:
#' messages <-
#' list(
#' list(
#' "role" = "system",
#' "content" = "You are a helpful assistant."
#' ),
#' list(
#' "role" = "user",
#' "content" = "What is the weather like in San Francisco?"
#' )
#' )
#'
#' functions<-
#' list(
#' list(name = "get_current_weather",
#' description = "Get the current weather",
#' parameters =
#' list(
#' type = "object",
#' properties =
#' list(
#' location =
#' list(type = "string",
#' description = "The city and state,
#' e.g. San Francisco, CA"),
#' format = list(type = "string",
#' enum = c("celsius", "fahrenheit"),
#' description = "The temperature unit to use.
#' Infer this from the users location.")),
#' required = c("location", "format"))
#' )
#' )
#'
#' create_chat_completion(model="gpt-3.5-turbo-0613",
#' messages = messages, functions = functions)
#'
#' }
#' @export
create_chat_completion<- function(
model,
messages = NULL,
temperature = 1,
top_p = 1,
n = 1,
stream = FALSE,
stop = NULL,
max_tokens = NULL,
presence_penalty = 0,
frequency_penalty = 0,
logit_bias = NULL,
user = NULL,
openai_api_key = Sys.getenv("OPENAI_API_KEY"),
openai_organization = NULL
) {

#---------------------------------------------------------------------------
# Validate arguments

assertthat::assert_that(
assertthat::is.string(model),
assertthat::noNA(model)
)

if (!is.null(messages)) {
assertthat::assert_that(
is.list(messages)
)
}

assertthat::assert_that(
assertthat::is.number(temperature),
assertthat::noNA(temperature),
value_between(temperature, 0, 2)
)

assertthat::assert_that(
assertthat::is.number(top_p),
assertthat::noNA(top_p),
value_between(top_p, 0, 1)
)

if (both_specified(temperature, top_p)) {
warning(
"It is recommended NOT to specify temperature and top_p at a time."
)
}

assertthat::assert_that(
assertthat::is.count(n)
)

assertthat::assert_that(
assertthat::is.flag(stream),
assertthat::noNA(stream),
is_false(stream)
)

if (!is.null(stop)) {
assertthat::assert_that(
is.character(stop),
assertthat::noNA(stop),
length_between(stop, 1, 4)
)
}


if (!is.null(max_tokens)) {
assertthat::assert_that(
assertthat::is.count(max_tokens)
)
}

assertthat::assert_that(
assertthat::is.number(presence_penalty),
assertthat::noNA(presence_penalty),
value_between(presence_penalty, -2, 2)
)

assertthat::assert_that(
assertthat::is.number(frequency_penalty),
assertthat::noNA(frequency_penalty),
value_between(frequency_penalty, -2, 2)
)

if (!is.null(logit_bias)) {
assertthat::assert_that(
is.list(logit_bias)
)
create_chat_completion<-
function (model, messages = NULL, functions = NULL,
temperature = 1, top_p = 1,
n = 1,
stream = FALSE, stop = NULL, max_tokens = NULL, presence_penalty = 0,
frequency_penalty = 0, logit_bias = NULL, user = NULL,
openai_api_key = Sys.getenv("OPENAI_API_KEY"),
openai_organization = NULL)
{
assertthat::assert_that(assertthat::is.string(model), assertthat::noNA(model))
if (!is.null(messages)) {
assertthat::assert_that(is.list(messages))
}
assertthat::assert_that(assertthat::is.number(temperature),
assertthat::noNA(temperature), value_between(temperature,
0, 2))
assertthat::assert_that(assertthat::is.number(top_p), assertthat::noNA(top_p),
value_between(top_p, 0, 1))
if (both_specified(temperature, top_p)) {
warning("It is recommended NOT to specify temperature and top_p at a time.")
}
assertthat::assert_that(assertthat::is.count(n))
assertthat::assert_that(assertthat::is.flag(stream), assertthat::noNA(stream),
is_false(stream))
if (!is.null(stop)) {
assertthat::assert_that(is.character(stop), assertthat::noNA(stop),
length_between(stop, 1, 4))
}
if (!is.null(max_tokens)) {
assertthat::assert_that(assertthat::is.count(max_tokens))
}
assertthat::assert_that(assertthat::is.number(presence_penalty),
assertthat::noNA(presence_penalty), value_between(presence_penalty,
-2, 2))
assertthat::assert_that(assertthat::is.number(frequency_penalty),
assertthat::noNA(frequency_penalty), value_between(frequency_penalty,
-2, 2))
if (!is.null(logit_bias)) {
assertthat::assert_that(is.list(logit_bias))
}
if (!is.null(user)) {
assertthat::assert_that(assertthat::is.string(user),
assertthat::noNA(user))
}
assertthat::assert_that(assertthat::is.string(openai_api_key),
assertthat::noNA(openai_api_key))
if (!is.null(openai_organization)) {
assertthat::assert_that(assertthat::is.string(openai_organization),
assertthat::noNA(openai_organization))
}
task <- "chat/completions"
base_url <- glue::glue("https://api.openai.com/v1/{task}")
headers <- c(Authorization = paste("Bearer", openai_api_key),
`Content-Type` = "application/json")
if (!is.null(openai_organization)) {
headers["OpenAI-Organization"] <- openai_organization
}
body <- list()
body[["model"]] <- model
body[["messages"]] <- messages
body[["functions"]] <- functions
body[["temperature"]] <- temperature
body[["top_p"]] <- top_p
body[["n"]] <- n
body[["stream"]] <- stream
body[["stop"]] <- stop
body[["max_tokens"]] <- max_tokens
body[["presence_penalty"]] <- presence_penalty
body[["frequency_penalty"]] <- frequency_penalty
body[["logit_bias"]] <- logit_bias
body[["user"]] <- user
response <- httr::POST(url = base_url, httr::add_headers(.headers = headers),
body = body, encode = "json")
verify_mime_type(response)
parsed <- response %>% httr::content(as = "text", encoding = "UTF-8") %>%
jsonlite::fromJSON(flatten = TRUE)
if (httr::http_error(response)) {
paste0("OpenAI API request failed [", httr::status_code(response),
"]:\n\n", parsed$error$message) %>% stop(call. = FALSE)
}
parsed
}

if (!is.null(user)) {
assertthat::assert_that(
assertthat::is.string(user),
assertthat::noNA(user)
)
}

assertthat::assert_that(
assertthat::is.string(openai_api_key),
assertthat::noNA(openai_api_key)
)

if (!is.null(openai_organization)) {
assertthat::assert_that(
assertthat::is.string(openai_organization),
assertthat::noNA(openai_organization)
)
}

#---------------------------------------------------------------------------
# Build path parameters

task <- "chat/completions"

base_url <- glue::glue("https://api.openai.com/v1/{task}")

headers <- c(
"Authorization" = paste("Bearer", openai_api_key),
"Content-Type" = "application/json"
)

if (!is.null(openai_organization)) {
headers["OpenAI-Organization"] <- openai_organization
}

#---------------------------------------------------------------------------
# Build request body

body <- list()
body[["model"]] <- model
body[["messages"]] <- messages
body[["temperature"]] <- temperature
body[["top_p"]] <- top_p
body[["n"]] <- n
body[["stream"]] <- stream
body[["stop"]] <- stop
body[["max_tokens"]] <- max_tokens
body[["presence_penalty"]] <- presence_penalty
body[["frequency_penalty"]] <- frequency_penalty
body[["logit_bias"]] <- logit_bias
body[["user"]] <- user

#---------------------------------------------------------------------------
# Make a request and parse it

response <- httr::POST(
url = base_url,
httr::add_headers(.headers = headers),
body = body,
encode = "json"
)

verify_mime_type(response)

parsed <- response %>%
httr::content(as = "text", encoding = "UTF-8") %>%
jsonlite::fromJSON(flatten = TRUE)

#---------------------------------------------------------------------------
# Check whether request failed and return parsed

if (httr::http_error(response)) {
paste0(
"OpenAI API request failed [",
httr::status_code(response),
"]:\n\n",
parsed$error$message
) %>%
stop(call. = FALSE)
}

parsed

}
41 changes: 41 additions & 0 deletions man/create_chat_completion.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.