Skip to content

Commit

Permalink
add fabric token provider in cog service
Browse files Browse the repository at this point in the history
  • Loading branch information
mslhrotk committed Jan 25, 2024
1 parent 8122f86 commit aeb302d
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 36 deletions.
1 change: 0 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ val extraDependencies = Seq(
"io.spray" %% "spray-json" % "1.3.5",
"com.jcraft" % "jsch" % "0.1.54",
"com.pauldijou" %% "jwt-core" % "3.0.0",
"org.json" % "json" % "20180130",
"org.apache.httpcomponents.client5" % "httpclient5" % "5.1.3",
"org.apache.httpcomponents" % "httpmime" % "4.5.13",
"com.linkedin.isolation-forest" %% "isolation-forest_3.4.1" % "3.0.3"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ package com.microsoft.azure.synapse.ml.services
import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.core.contracts.HasOutputCol
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions
import com.microsoft.azure.synapse.ml.fabric.{FabricClient, TokenLibrary}
import com.microsoft.azure.synapse.ml.io.http._
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.stages.{DropColumns, Lambda}
import org.apache.http.NameValuePair
Expand Down Expand Up @@ -218,18 +220,6 @@ trait HasCustomCogServiceDomain extends Wrappable with HasURL with HasUrlPath {
| return self
|
|def _transform(self, dataset: DataFrame) -> DataFrame:
| if running_on_synapse_internal():
| try:
| from synapse.ml.fabric.token_utils import TokenUtils
| from synapse.ml.fabric.service_discovery import get_fabric_env_config
| fabric_env_config = get_fabric_env_config().fabric_env_config
| if self._java_obj.getInternalServiceType() != "openai":
| self._java_obj.setDefaultAADToken(TokenUtils().get_aad_token())
| else:
| self._java_obj.setDefaultCustomAuthHeader(TokenUtils().get_openai_auth_header())
| self.setDefaultInternalEndpoint(fabric_env_config.get_mlflow_workload_endpoint())
| except ModuleNotFoundError as e:
| pass
| return super()._transform(dataset)
|""".stripMargin
}
Expand Down Expand Up @@ -327,6 +317,15 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA

protected def contentType: Row => String = { _ => "application/json" }

protected def getCustomAuthHeader(row: Row): Option[String] = {
var customHeader = getValueOpt(row, CustomAuthHeader)
if (customHeader.isEmpty && PlatformDetails.runningOnFabric()) {
customHeader = Option(TokenLibrary.getAuthHeader)
logInfo("Using Default AAD Token On Fabric")
}
customHeader
}

protected def addHeaders(req: HttpRequestBase,
subscriptionKey: Option[String],
aadToken: Option[String],
Expand Down Expand Up @@ -364,7 +363,7 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
getValueOpt(row, subscriptionKey),
getValueOpt(row, AADToken),
contentType(row),
getValueOpt(row, CustomAuthHeader))
getCustomAuthHeader(row))

req match {
case er: HttpEntityEnclosingRequestBase =>
Expand Down Expand Up @@ -501,7 +500,12 @@ abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transform

setDefault(
outputCol -> (this.uid + "_output"),
errorCol -> (this.uid + "_error"))
errorCol -> (this.uid + "_error")
)

if(PlatformDetails.runningOnFabric()) {
setDefaultInternalEndpoint(FabricClient.MLWorkloadEndpointML)
}

protected def handlingFunc(client: CloseableHttpClient,
request: HTTPRequestData): HTTPResponseData
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.codegen.GenerationUtils
import com.microsoft.azure.synapse.ml.services.{CognitiveServicesBase, HasAPIVersion, HasServiceParams}
import com.microsoft.azure.synapse.ml.fabric.OpenAITokenLibrary
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.services.{CognitiveServicesBase, HasAPIVersion,
HasCognitiveServiceInput, HasServiceParams}
import com.microsoft.azure.synapse.ml.param.ServiceParam
import org.apache.spark.sql.Row
import spray.json.DefaultJsonProtocol._
Expand Down Expand Up @@ -244,6 +247,17 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
}
}

trait HasOpenAICognitiveServiceInput extends HasCognitiveServiceInput {
override protected def getCustomAuthHeader(row: Row): Option[String] = {
var customHeader = getValueOpt(row, CustomAuthHeader)
if (customHeader.isEmpty && PlatformDetails.runningOnFabric()) {
customHeader = Option(OpenAITokenLibrary.getAuthHeader)
logInfo("Using Default OpenAI Token On Fabric")
}
customHeader
}
}

abstract class OpenAIServicesBase(override val uid: String) extends CognitiveServicesBase(uid: String) {
setDefault(timeout -> 360.0)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,24 @@

package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.services.{CognitiveServicesBase, HasCognitiveServiceInput,
HasInternalJsonOutputParser}
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
import com.microsoft.azure.synapse.ml.services.HasInternalJsonOutputParser
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import spray.json._
import spray.json.DefaultJsonProtocol._
import spray.json._

import scala.language.existentials

object OpenAIChatCompletion extends ComplexParamsReadable[OpenAIChatCompletion]

class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAITextParams with HasCognitiveServiceInput
with HasOpenAITextParams with HasOpenAICognitiveServiceInput
with HasInternalJsonOutputParser with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.services.{CognitiveServicesBase,
HasCognitiveServiceInput, HasInternalJsonOutputParser}
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
import com.microsoft.azure.synapse.ml.services.HasInternalJsonOutputParser
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
import org.apache.spark.ml.util._
Expand All @@ -20,7 +19,7 @@ import scala.language.existentials
object OpenAICompletion extends ComplexParamsReadable[OpenAICompletion]

class OpenAICompletion(override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAITextParams with HasPromptInputs with HasCognitiveServiceInput
with HasOpenAITextParams with HasPromptInputs with HasOpenAICognitiveServiceInput
with HasInternalJsonOutputParser with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,25 @@

package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.services.{CognitiveServicesBase, HasCognitiveServiceInput, HasServiceParams}
import com.microsoft.azure.synapse.ml.core.contracts.HasInputCol
import com.microsoft.azure.synapse.ml.io.http.JSONOutputParser
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.ServiceParam
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.util._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import spray.json.DefaultJsonProtocol._
import spray.json._
import org.apache.spark.ml.linalg.{Vector, Vectors}

import scala.language.existentials

object OpenAIEmbedding extends ComplexParamsReadable[OpenAIEmbedding]

class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAISharedParams with HasCognitiveServiceInput with SynapseMLLogging {
with HasOpenAISharedParams with HasOpenAICognitiveServiceInput with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

def this() = this(Identifiable.randomUID("OpenAIEmbedding"))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
package com.microsoft.azure.synapse.ml.fabric

import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import org.json.JSONObject
import spray.json._
import pdi.jwt.{Jwt, JwtOptions}
import spray.json.DefaultJsonProtocol.StringJsonFormat

import java.util.Date
import scala.util.{Failure, Success, Try}

object OpenAITokenLibrary extends SynapseMLLogging{
object OpenAITokenLibrary extends SynapseMLLogging with AuthHeaderProvider {
var MLMWCToken = "";
val BackgroundRefreshExpiryCushionInMillis: Long = 5 * 60 * 1000L
val OpenAIFeatureName = "SparkCodeFirst"

def getAccessToken: String = {
def getAuthHeader: String = {
if (MLMWCToken != "" && !isTokenExpired(MLMWCToken)) {
logInfo("using cached openai mwc token")
MLMWCToken
Expand All @@ -29,13 +29,13 @@ object OpenAITokenLibrary extends SynapseMLLogging{
val url: String = FabricClient.MLWorkloadEndpointML + "cognitive/openai/generatemwctoken";

try {
var token = FabricClient.usagePost(url, payload).asJsObject.fields("Token").convertTo[String];
val token = FabricClient.usagePost(url, payload).asJsObject.fields("Token").convertTo[String];
logInfo("successfully fetch openai mwc token")
token
"MwcToken " + token
} catch {
case e: Throwable =>
logInfo("openai mwc token not available, using aad token", e)
TokenLibrary.getAccessToken;
"Bearer" + TokenLibrary.getAccessToken;
}
}
}
Expand All @@ -46,8 +46,7 @@ object OpenAITokenLibrary extends SynapseMLLogging{
val jwtTokenDecoded: Try[(String, String, String)] = Jwt.decodeRawAll(accessToken, jwtOptions)
jwtTokenDecoded match {
case Success((_, payload, _)) =>
val jsonPayload: JSONObject = new JSONObject(payload)
val expiry = jsonPayload.get("exp").toString
val expiry = payload.parseJson.asJsObject().fields("exp").convertTo[String]
new Date(expiry.toLong * 1000)
case Failure(t) =>
throw t
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@ package com.microsoft.azure.synapse.ml.fabric

import scala.reflect.runtime.currentMirror
import scala.reflect.runtime.universe._
object TokenLibrary {

trait AuthHeaderProvider {
def getAuthHeader: String
}

object TokenLibrary extends AuthHeaderProvider {
def getAccessToken: String = {
val objectName = "com.microsoft.azure.trident.tokenlibrary.TokenLibrary"
val mirror = currentMirror
Expand All @@ -24,4 +29,7 @@ object TokenLibrary {
val methodMirror = mirror.reflect(obj).reflectMethod(selectedMethodSymbol.asMethod)
methodMirror("pbi").asInstanceOf[String]
}


def getAuthHeader: String = "Bearer " + getAccessToken
}

0 comments on commit aeb302d

Please sign in to comment.