Skip to content

Commit

Permalink
feat: Add support for embed file with AWS Bedrock Titan Model (langge…
Browse files Browse the repository at this point in the history
…nius#3377)

Co-authored-by: crazywoola <[email protected]>
  • Loading branch information
longzhihun and crazywoola authored Apr 11, 2024
1 parent 8714520 commit 4007868
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ help:
en_US: https://console.aws.amazon.com/
supported_model_types:
- llm
- text-embedding
configurate_methods:
- predefined-model
provider_credential_schema:
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- amazon.titan-embed-text-v1
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model: amazon.titan-embed-text-v1
model_type: text-embedding
model_properties:
context_size: 8192
pricing:
input: '0.0001'
unit: '0.001'
currency: USD
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import json
import time
from typing import Optional

import boto3
from botocore.config import Config
from botocore.exceptions import (
ClientError,
EndpointConnectionError,
NoRegionError,
ServiceNotInRegionError,
UnknownServiceError,
)

from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel


class BedrockTextEmbeddingModel(TextEmbeddingModel):


def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
client_config = Config(
region_name=credentials["aws_region"]
)

bedrock_runtime = boto3.client(
service_name='bedrock-runtime',
config=client_config,
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"]
)

embeddings = []
token_usage = 0

model_prefix = model.split('.')[0]
if model_prefix == "amazon":
for text in texts:
body = {
"inputText": text,
}
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
embeddings.extend([response_body.get('embedding')])
token_usage += response_body.get('inputTextTokenCount')
result = TextEmbeddingResult(
model=model,
embeddings=embeddings,
usage=self._calc_response_usage(
model=model,
credentials=credentials,
tokens=token_usage
)
)
else:
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")

return result


def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
num_tokens = 0
for text in texts:
num_tokens += self._get_num_tokens_by_gpt2(text)
return num_tokens

def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
The value is the md = genai.GenerativeModel(model)error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke emd = genai.GenerativeModel(model)rror mapping
"""
return {
InvokeConnectionError: [],
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: []
}

def _create_payload(self, model_prefix: str, texts: list[str], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
"""
Create payload for bedrock api call depending on model provider
"""
payload = dict()

if model_prefix == "amazon":
payload['inputText'] = texts


def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)

# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)

return usage

def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]:
"""
Map client error to invoke error
:param error_code: error code
:param error_msg: error message
:return: invoke error
"""

if error_code == "AccessDeniedException":
return InvokeAuthorizationError(error_msg)
elif error_code in ["ResourceNotFoundException", "ValidationException"]:
return InvokeBadRequestError(error_msg)
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
return InvokeRateLimitError(error_msg)
elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]:
return InvokeServerUnavailableError(error_msg)
elif error_code == "ModelStreamErrorException":
return InvokeConnectionError(error_msg)

return InvokeError(error_msg)


def _invoke_bedrock_embedding(self, model: str, bedrock_runtime, body: dict, ):
accept = 'application/json'
content_type = 'application/json'
try:
response = bedrock_runtime.invoke_model(
body=json.dumps(body),
modelId=model,
accept=accept,
contentType=content_type
)
response_body = json.loads(response.get('body').read().decode('utf-8'))
return response_body
except ClientError as ex:
error_code = ex.response['Error']['Code']
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
raise self._map_client_to_invoke_error(error_code, full_error_msg)

except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
raise InvokeConnectionError(str(ex))

except UnknownServiceError as ex:
raise InvokeServerUnavailableError(str(ex))

except Exception as ex:
raise InvokeError(str(ex))

0 comments on commit 4007868

Please sign in to comment.