diff --git a/R/create_chat_completion.R b/R/create_chat_completion.R index ba06571..0fc1eee 100644 --- a/R/create_chat_completion.R +++ b/R/create_chat_completion.R @@ -10,6 +10,7 @@ #' @param model required; a length one character vector. #' @param messages required; defaults to `NULL`; a list in the following #' format: `list(list("role" = "user", "content" = "Hey! How old are you?")` +#' @param functions optional; defaults to `NULL`; OpenAI function definitions #' @param temperature required; defaults to `1`; a length one numeric vector #' with the value between `0` and `2`. #' @param top_p required; defaults to `1`; a length one numeric vector with the @@ -57,181 +58,126 @@ #' ) #' ) #' ) +#' +#' #OpenAI function example: +#' messages <- +#' list( +#' list( +#' "role" = "system", +#' "content" = "You are a helpful assistant." +#' ), +#' list( +#' "role" = "user", +#' "content" = "What is the weather like in San Francisco?" +#' ) +#' ) +#' +#' functions<- +#' list( +#' list(name = "get_current_weather", +#' description = "Get the current weather", +#' parameters = +#' list( +#' type = "object", +#' properties = +#' list( +#' location = +#' list(type = "string", +#' description = "The city and state, +#' e.g. San Francisco, CA"), +#' format = list(type = "string", +#' enum = c("celsius", "fahrenheit"), +#' description = "The temperature unit to use. +#' Infer this from the users location.")), +#' required = c("location", "format")) +#' ) +#' ) +#' +#' create_chat_completion(model="gpt-3.5-turbo-0613", +#' messages = messages, functions = functions) +#' #' } #' @export -create_chat_completion<- function( - model, - messages = NULL, - temperature = 1, - top_p = 1, - n = 1, - stream = FALSE, - stop = NULL, - max_tokens = NULL, - presence_penalty = 0, - frequency_penalty = 0, - logit_bias = NULL, - user = NULL, - openai_api_key = Sys.getenv("OPENAI_API_KEY"), - openai_organization = NULL -) { - - #--------------------------------------------------------------------------- - # Validate arguments - - assertthat::assert_that( - assertthat::is.string(model), - assertthat::noNA(model) - ) - - if (!is.null(messages)) { - assertthat::assert_that( - is.list(messages) - ) - } - - assertthat::assert_that( - assertthat::is.number(temperature), - assertthat::noNA(temperature), - value_between(temperature, 0, 2) - ) - - assertthat::assert_that( - assertthat::is.number(top_p), - assertthat::noNA(top_p), - value_between(top_p, 0, 1) - ) - - if (both_specified(temperature, top_p)) { - warning( - "It is recommended NOT to specify temperature and top_p at a time." - ) - } - - assertthat::assert_that( - assertthat::is.count(n) - ) - - assertthat::assert_that( - assertthat::is.flag(stream), - assertthat::noNA(stream), - is_false(stream) - ) - - if (!is.null(stop)) { - assertthat::assert_that( - is.character(stop), - assertthat::noNA(stop), - length_between(stop, 1, 4) - ) - } - - - if (!is.null(max_tokens)) { - assertthat::assert_that( - assertthat::is.count(max_tokens) - ) - } - - assertthat::assert_that( - assertthat::is.number(presence_penalty), - assertthat::noNA(presence_penalty), - value_between(presence_penalty, -2, 2) - ) - - assertthat::assert_that( - assertthat::is.number(frequency_penalty), - assertthat::noNA(frequency_penalty), - value_between(frequency_penalty, -2, 2) - ) - if (!is.null(logit_bias)) { - assertthat::assert_that( - is.list(logit_bias) - ) +create_chat_completion<- + function (model, messages = NULL, functions = NULL, + temperature = 1, top_p = 1, + n = 1, + stream = FALSE, stop = NULL, max_tokens = NULL, presence_penalty = 0, + frequency_penalty = 0, logit_bias = NULL, user = NULL, + openai_api_key = Sys.getenv("OPENAI_API_KEY"), + openai_organization = NULL) + { + assertthat::assert_that(assertthat::is.string(model), assertthat::noNA(model)) + if (!is.null(messages)) { + assertthat::assert_that(is.list(messages)) + } + assertthat::assert_that(assertthat::is.number(temperature), + assertthat::noNA(temperature), value_between(temperature, + 0, 2)) + assertthat::assert_that(assertthat::is.number(top_p), assertthat::noNA(top_p), + value_between(top_p, 0, 1)) + if (both_specified(temperature, top_p)) { + warning("It is recommended NOT to specify temperature and top_p at a time.") + } + assertthat::assert_that(assertthat::is.count(n)) + assertthat::assert_that(assertthat::is.flag(stream), assertthat::noNA(stream), + is_false(stream)) + if (!is.null(stop)) { + assertthat::assert_that(is.character(stop), assertthat::noNA(stop), + length_between(stop, 1, 4)) + } + if (!is.null(max_tokens)) { + assertthat::assert_that(assertthat::is.count(max_tokens)) + } + assertthat::assert_that(assertthat::is.number(presence_penalty), + assertthat::noNA(presence_penalty), value_between(presence_penalty, + -2, 2)) + assertthat::assert_that(assertthat::is.number(frequency_penalty), + assertthat::noNA(frequency_penalty), value_between(frequency_penalty, + -2, 2)) + if (!is.null(logit_bias)) { + assertthat::assert_that(is.list(logit_bias)) + } + if (!is.null(user)) { + assertthat::assert_that(assertthat::is.string(user), + assertthat::noNA(user)) + } + assertthat::assert_that(assertthat::is.string(openai_api_key), + assertthat::noNA(openai_api_key)) + if (!is.null(openai_organization)) { + assertthat::assert_that(assertthat::is.string(openai_organization), + assertthat::noNA(openai_organization)) + } + task <- "chat/completions" + base_url <- glue::glue("https://api.openai.com/v1/{task}") + headers <- c(Authorization = paste("Bearer", openai_api_key), + `Content-Type` = "application/json") + if (!is.null(openai_organization)) { + headers["OpenAI-Organization"] <- openai_organization + } + body <- list() + body[["model"]] <- model + body[["messages"]] <- messages + body[["functions"]] <- functions + body[["temperature"]] <- temperature + body[["top_p"]] <- top_p + body[["n"]] <- n + body[["stream"]] <- stream + body[["stop"]] <- stop + body[["max_tokens"]] <- max_tokens + body[["presence_penalty"]] <- presence_penalty + body[["frequency_penalty"]] <- frequency_penalty + body[["logit_bias"]] <- logit_bias + body[["user"]] <- user + response <- httr::POST(url = base_url, httr::add_headers(.headers = headers), + body = body, encode = "json") + verify_mime_type(response) + parsed <- response %>% httr::content(as = "text", encoding = "UTF-8") %>% + jsonlite::fromJSON(flatten = TRUE) + if (httr::http_error(response)) { + paste0("OpenAI API request failed [", httr::status_code(response), + "]:\n\n", parsed$error$message) %>% stop(call. = FALSE) + } + parsed } - - if (!is.null(user)) { - assertthat::assert_that( - assertthat::is.string(user), - assertthat::noNA(user) - ) - } - - assertthat::assert_that( - assertthat::is.string(openai_api_key), - assertthat::noNA(openai_api_key) - ) - - if (!is.null(openai_organization)) { - assertthat::assert_that( - assertthat::is.string(openai_organization), - assertthat::noNA(openai_organization) - ) - } - - #--------------------------------------------------------------------------- - # Build path parameters - - task <- "chat/completions" - - base_url <- glue::glue("https://api.openai.com/v1/{task}") - - headers <- c( - "Authorization" = paste("Bearer", openai_api_key), - "Content-Type" = "application/json" - ) - - if (!is.null(openai_organization)) { - headers["OpenAI-Organization"] <- openai_organization - } - - #--------------------------------------------------------------------------- - # Build request body - - body <- list() - body[["model"]] <- model - body[["messages"]] <- messages - body[["temperature"]] <- temperature - body[["top_p"]] <- top_p - body[["n"]] <- n - body[["stream"]] <- stream - body[["stop"]] <- stop - body[["max_tokens"]] <- max_tokens - body[["presence_penalty"]] <- presence_penalty - body[["frequency_penalty"]] <- frequency_penalty - body[["logit_bias"]] <- logit_bias - body[["user"]] <- user - - #--------------------------------------------------------------------------- - # Make a request and parse it - - response <- httr::POST( - url = base_url, - httr::add_headers(.headers = headers), - body = body, - encode = "json" - ) - - verify_mime_type(response) - - parsed <- response %>% - httr::content(as = "text", encoding = "UTF-8") %>% - jsonlite::fromJSON(flatten = TRUE) - - #--------------------------------------------------------------------------- - # Check whether request failed and return parsed - - if (httr::http_error(response)) { - paste0( - "OpenAI API request failed [", - httr::status_code(response), - "]:\n\n", - parsed$error$message - ) %>% - stop(call. = FALSE) - } - - parsed - -} diff --git a/man/create_chat_completion.Rd b/man/create_chat_completion.Rd index a13b81c..817bf4d 100644 --- a/man/create_chat_completion.Rd +++ b/man/create_chat_completion.Rd @@ -7,6 +7,7 @@ create_chat_completion( model, messages = NULL, + functions = NULL, temperature = 1, top_p = 1, n = 1, @@ -27,6 +28,8 @@ create_chat_completion( \item{messages}{required; defaults to \code{NULL}; a list in the following format: \verb{list(list("role" = "user", "content" = "Hey! How old are you?")}} +\item{functions}{optional; defaults to \code{NULL}; OpenAI function definitions} + \item{temperature}{required; defaults to \code{1}; a length one numeric vector with the value between \code{0} and \code{2}.} @@ -96,5 +99,43 @@ create_chat_completion( ) ) ) + +#OpenAI function example: +messages <- +list( + list( + "role" = "system", + "content" = "You are a helpful assistant." + ), + list( + "role" = "user", + "content" = "What is the weather like in San Francisco?" + ) +) + +functions<- + list( + list(name = "get_current_weather", + description = "Get the current weather", + parameters = + list( + type = "object", + properties = + list( + location = + list(type = "string", + description = "The city and state, + e.g. San Francisco, CA"), + format = list(type = "string", + enum = c("celsius", "fahrenheit"), + description = "The temperature unit to use. + Infer this from the users location.")), + required = c("location", "format")) + ) + ) + +create_chat_completion(model="gpt-3.5-turbo-0613", + messages = messages, functions = functions) + } }