Skip to content

Commit

Permalink
Enable custom headers in OpenAIPrompt (#2259)
Browse files Browse the repository at this point in the history
Co-authored-by: Shyam Sai <[email protected]>
  • Loading branch information
sss04 and Shyam Sai authored Aug 7, 2024
1 parent 6da5f57 commit 7c23d83
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(
s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/chat/completions"
}

override protected def prepareEntity: Row => Option[AbstractHttpEntity] = {
override protected[openai] def prepareEntity: Row => Option[AbstractHttpEntity] = {
r =>
lazy val optionalParams: Map[String, Any] = getOptionalParams(r)
val messages = r.getAs[Seq[Row]](getMessagesCol)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class OpenAICompletion(override val uid: String) extends OpenAIServicesBase(uid)
s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/completions"
}

override protected def prepareEntity: Row => Option[AbstractHttpEntity] = {
override protected[openai] def prepareEntity: Row => Option[AbstractHttpEntity] = {
r =>
lazy val optionalParams: Map[String, Any] = getOptionalParams(r)
getValueOpt(r, prompt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ import com.microsoft.azure.synapse.ml.core.spark.Functions
import com.microsoft.azure.synapse.ml.io.http.{ConcurrencyParams, HasErrorCol, HasURL}
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.StringStringMapParam
import org.apache.http.entity.AbstractHttpEntity
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.{Column, DataFrame, Dataset, functions => F, types => T}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, functions => F, types => T}

import scala.collection.JavaConverters._

Expand All @@ -25,6 +26,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
with HasErrorCol with HasOutputCol
with HasURL with HasCustomCogServiceDomain with ConcurrencyParams
with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader
with HasOpenAICognitiveServiceInput
with ComplexParamsWritable with SynapseMLLogging {

logClass(FeatureNames.AiServices.OpenAI)
Expand Down Expand Up @@ -174,6 +176,16 @@ class OpenAIPrompt(override val uid: String) extends Transformer
completion
}

override protected def prepareEntity: Row => Option[AbstractHttpEntity] = {
r =>
openAICompletion match {
case chatCompletion: OpenAIChatCompletion =>
chatCompletion.prepareEntity(r)
case completion: OpenAICompletion =>
completion.prepareEntity(r)
}
}

private def getParser: OutputParser = {
val opts = getPostProcessingOptions

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,33 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
.foreach(r => assert(r.get(0) != null))
}

ignore("Custom EndPoint") {
lazy val accessToken: String = sys.env.getOrElse("CUSTOM_ACCESS_TOKEN", "")
lazy val customRootUrlValue: String = sys.env.getOrElse("CUSTOM_ROOT_URL", "")
lazy val customHeadersValues: Map[String, String] = Map("X-ModelType" -> "gpt-4-turbo-chat-completions")

lazy val customPromptGpt4: OpenAIPrompt = new OpenAIPrompt()
.setCustomUrlRoot(customRootUrlValue)
.setOutputCol("outParsed")
.setTemperature(0)

if (accessToken.isEmpty) {
customPromptGpt4.setSubscriptionKey(openAIAPIKey)
.setDeploymentName(deploymentNameGpt4)
.setCustomServiceName(openAIServiceName)
} else {
customPromptGpt4.setAADToken(accessToken)
.setCustomHeaders(customHeadersValues)
}

customPromptGpt4.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
.setPostProcessing("csv")
.transform(df)
.select("outParsed")
.collect()
.count(r => Option(r.getSeq[String](0)).isDefined)
}

override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = {
super.assertDFEq(df1.drop("out", "outParsed"), df2.drop("out", "outParsed"))(eq)
}
Expand Down

0 comments on commit 7c23d83

Please sign in to comment.