Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Support Token Provider Mode #2160

Merged
merged 14 commits into from
Feb 28, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ package com.microsoft.azure.synapse.ml.services
import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.core.contracts.HasOutputCol
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions
import com.microsoft.azure.synapse.ml.fabric.{FabricClient, TokenLibrary}
import com.microsoft.azure.synapse.ml.io.http._
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.stages.{DropColumns, Lambda}
import org.apache.http.NameValuePair
Expand Down Expand Up @@ -218,18 +220,6 @@ trait HasCustomCogServiceDomain extends Wrappable with HasURL with HasUrlPath {
| return self
|
|def _transform(self, dataset: DataFrame) -> DataFrame:
| if running_on_synapse_internal():
| try:
| from synapse.ml.fabric.token_utils import TokenUtils
| from synapse.ml.fabric.service_discovery import get_fabric_env_config
| fabric_env_config = get_fabric_env_config().fabric_env_config
| if self._java_obj.getInternalServiceType() != "openai":
| self._java_obj.setDefaultAADToken(TokenUtils().get_aad_token())
| else:
| self._java_obj.setDefaultCustomAuthHeader(TokenUtils().get_openai_auth_header())
| self.setDefaultInternalEndpoint(fabric_env_config.get_mlflow_workload_endpoint())
| except ModuleNotFoundError as e:
| pass
| return super()._transform(dataset)
|""".stripMargin
}
Expand Down Expand Up @@ -327,6 +317,15 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA

protected def contentType: Row => String = { _ => "application/json" }

protected def getCustomAuthHeader(row: Row): Option[String] = {
var customHeader = getValueOpt(row, CustomAuthHeader)
if (customHeader.isEmpty && PlatformDetails.runningOnFabric()) {
customHeader = Option(TokenLibrary.getAuthHeader)
logInfo("Using Default AAD Token On Fabric")
}
customHeader
lhrotk marked this conversation as resolved.
Show resolved Hide resolved
}

protected def addHeaders(req: HttpRequestBase,
subscriptionKey: Option[String],
aadToken: Option[String],
Expand Down Expand Up @@ -364,7 +363,7 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
getValueOpt(row, subscriptionKey),
getValueOpt(row, AADToken),
contentType(row),
getValueOpt(row, CustomAuthHeader))
getCustomAuthHeader(row))

req match {
case er: HttpEntityEnclosingRequestBase =>
Expand Down Expand Up @@ -501,7 +500,12 @@ abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transform

setDefault(
outputCol -> (this.uid + "_output"),
errorCol -> (this.uid + "_error"))
errorCol -> (this.uid + "_error")
)

if(PlatformDetails.runningOnFabric()) {
setDefaultInternalEndpoint(FabricClient.MLWorkloadEndpointML)
}

protected def handlingFunc(client: CloseableHttpClient,
request: HTTPRequestData): HTTPResponseData
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.codegen.GenerationUtils
import com.microsoft.azure.synapse.ml.services.{CognitiveServicesBase, HasAPIVersion, HasServiceParams}
import com.microsoft.azure.synapse.ml.fabric.OpenAITokenLibrary
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.services.{CognitiveServicesBase, HasAPIVersion,
HasCognitiveServiceInput, HasServiceParams}
import com.microsoft.azure.synapse.ml.param.ServiceParam
import org.apache.spark.sql.Row
import spray.json.DefaultJsonProtocol._
Expand Down Expand Up @@ -244,6 +247,17 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
}
}

trait HasOpenAICognitiveServiceInput extends HasCognitiveServiceInput {
override protected def getCustomAuthHeader(row: Row): Option[String] = {
var customHeader = getValueOpt(row, CustomAuthHeader)
if (customHeader.isEmpty && PlatformDetails.runningOnFabric()) {
customHeader = Option(OpenAITokenLibrary.getAuthHeader)
logInfo("Using Default OpenAI Token On Fabric")
}
customHeader
lhrotk marked this conversation as resolved.
Show resolved Hide resolved
}
}

abstract class OpenAIServicesBase(override val uid: String) extends CognitiveServicesBase(uid: String) {
setDefault(timeout -> 360.0)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,24 @@

package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.services.{CognitiveServicesBase, HasCognitiveServiceInput,
HasInternalJsonOutputParser}
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
import com.microsoft.azure.synapse.ml.services.HasInternalJsonOutputParser
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import spray.json._
import spray.json.DefaultJsonProtocol._
import spray.json._

import scala.language.existentials

object OpenAIChatCompletion extends ComplexParamsReadable[OpenAIChatCompletion]

class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAITextParams with HasCognitiveServiceInput
with HasOpenAITextParams with HasOpenAICognitiveServiceInput
with HasInternalJsonOutputParser with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.services.{CognitiveServicesBase,
HasCognitiveServiceInput, HasInternalJsonOutputParser}
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
import com.microsoft.azure.synapse.ml.services.HasInternalJsonOutputParser
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
import org.apache.spark.ml.util._
Expand All @@ -20,7 +19,7 @@ import scala.language.existentials
object OpenAICompletion extends ComplexParamsReadable[OpenAICompletion]

class OpenAICompletion(override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAITextParams with HasPromptInputs with HasCognitiveServiceInput
with HasOpenAITextParams with HasPromptInputs with HasOpenAICognitiveServiceInput
with HasInternalJsonOutputParser with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,25 @@

package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.services.{CognitiveServicesBase, HasCognitiveServiceInput, HasServiceParams}
import com.microsoft.azure.synapse.ml.core.contracts.HasInputCol
import com.microsoft.azure.synapse.ml.io.http.JSONOutputParser
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.ServiceParam
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.util._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import spray.json.DefaultJsonProtocol._
import spray.json._
import org.apache.spark.ml.linalg.{Vector, Vectors}

import scala.language.existentials

object OpenAIEmbedding extends ComplexParamsReadable[OpenAIEmbedding]

class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAISharedParams with HasCognitiveServiceInput with SynapseMLLogging {
with HasOpenAISharedParams with HasOpenAICognitiveServiceInput with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

def this() = this(Identifiable.randomUID("OpenAIEmbedding"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,13 @@ private[ml] abstract class TextAnalyticsBaseNoBinding(uid: String)
} else {
import TAJSONFormat._
val post = new HttpPost(prepareUrl(row))
addHeaders(post, getValueOpt(row, subscriptionKey), getValueOpt(row, AADToken), contentType(row))
addHeaders(
post,
getValueOpt(row, subscriptionKey),
getValueOpt(row, AADToken),
contentType(row),
getCustomAuthHeader(row)
)
val json = TARequest(makeDocuments(row)).toJson.compactPrint
post.setEntity(new StringEntity(json, "UTF-8"))
Some(post)
Expand Down Expand Up @@ -648,7 +654,13 @@ class TextAnalyze(override val uid: String) extends TextAnalyticsBaseNoBinding(u
None
} else {
val post = new HttpPost(getUrl)
addHeaders(post, getValueOpt(row, subscriptionKey), getValueOpt(row, AADToken), contentType(row))
addHeaders(
post,
getValueOpt(row, subscriptionKey),
getValueOpt(row, AADToken),
contentType(row),
getCustomAuthHeader(row)
)
val tasks = TextAnalyzeTasks(
entityRecognitionTasks = getTaskHelper(getIncludeEntityRecognition, getEntityRecognitionParams),
entityLinkingTasks = getTaskHelper(getIncludeEntityLinking, getEntityLinkingParams),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,13 @@ trait TextAsOnlyEntity extends HasTextInput with HasCognitiveServiceInput with H
}

val post = new HttpPost(base + appended)
addHeaders(post, getValueOpt(row, subscriptionKey), getValueOpt(row, AADToken), contentType(row))
addHeaders(
post,
getValueOpt(row, subscriptionKey),
getValueOpt(row, AADToken),
contentType(row),
getCustomAuthHeader(row)
)
getValueOpt(row, subscriptionRegion).foreach(post.setHeader("Ocp-Apim-Subscription-Region", _))

val json = texts.map(s => Map("Text" -> s)).toJson.compactPrint
Expand Down Expand Up @@ -248,7 +254,13 @@ class Translate(override val uid: String) extends TextTranslatorBase(uid)
}

val post = new HttpPost(base + appended)
addHeaders(post, getValueOpt(row, subscriptionKey), getValueOpt(row, AADToken), contentType(row))
addHeaders(
post,
getValueOpt(row, subscriptionKey),
getValueOpt(row, AADToken),
contentType(row),
getCustomAuthHeader(row)
)
getValueOpt(row, subscriptionRegion).foreach(post.setHeader("Ocp-Apim-Subscription-Region", _))

val json = texts.map(s => Map("Text" -> s)).toJson.compactPrint
Expand Down Expand Up @@ -533,7 +545,13 @@ class DictionaryExamples(override val uid: String) extends TextTranslatorBase(ui
}

val post = new HttpPost(base + appended)
addHeaders(post, getValueOpt(row, subscriptionKey), getValueOpt(row, AADToken), contentType(row))
addHeaders(
post,
getValueOpt(row, subscriptionKey),
getValueOpt(row, AADToken),
contentType(row),
getCustomAuthHeader(row)
)
getValueOpt(row, subscriptionRegion).foreach(post.setHeader("Ocp-Apim-Subscription-Region", _))

val json = textAndTranslations.head.getClass.getTypeName match {
Expand Down
Loading
Loading