forked from langgenius/dify
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add support for embed file with AWS Bedrock Titan Model (langge…
…nius#3377) Co-authored-by: crazywoola <[email protected]>
- Loading branch information
1 parent
8714520
commit 4007868
Showing
5 changed files
with
219 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
1 change: 1 addition & 0 deletions
1
api/core/model_runtime/model_providers/bedrock/text_embedding/_position.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
- amazon.titan-embed-text-v1 |
8 changes: 8 additions & 0 deletions
8
...core/model_runtime/model_providers/bedrock/text_embedding/amazon.titan-embed-text-v1.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
model: amazon.titan-embed-text-v1 | ||
model_type: text-embedding | ||
model_properties: | ||
context_size: 8192 | ||
pricing: | ||
input: '0.0001' | ||
unit: '0.001' | ||
currency: USD |
209 changes: 209 additions & 0 deletions
209
api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
import json | ||
import time | ||
from typing import Optional | ||
|
||
import boto3 | ||
from botocore.config import Config | ||
from botocore.exceptions import ( | ||
ClientError, | ||
EndpointConnectionError, | ||
NoRegionError, | ||
ServiceNotInRegionError, | ||
UnknownServiceError, | ||
) | ||
|
||
from core.model_runtime.entities.model_entities import PriceType | ||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult | ||
from core.model_runtime.errors.invoke import ( | ||
InvokeAuthorizationError, | ||
InvokeBadRequestError, | ||
InvokeConnectionError, | ||
InvokeError, | ||
InvokeRateLimitError, | ||
InvokeServerUnavailableError, | ||
) | ||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | ||
|
||
|
||
class BedrockTextEmbeddingModel(TextEmbeddingModel): | ||
|
||
|
||
def _invoke(self, model: str, credentials: dict, | ||
texts: list[str], user: Optional[str] = None) \ | ||
-> TextEmbeddingResult: | ||
""" | ||
Invoke text embedding model | ||
:param model: model name | ||
:param credentials: model credentials | ||
:param texts: texts to embed | ||
:param user: unique user id | ||
:return: embeddings result | ||
""" | ||
client_config = Config( | ||
region_name=credentials["aws_region"] | ||
) | ||
|
||
bedrock_runtime = boto3.client( | ||
service_name='bedrock-runtime', | ||
config=client_config, | ||
aws_access_key_id=credentials["aws_access_key_id"], | ||
aws_secret_access_key=credentials["aws_secret_access_key"] | ||
) | ||
|
||
embeddings = [] | ||
token_usage = 0 | ||
|
||
model_prefix = model.split('.')[0] | ||
if model_prefix == "amazon": | ||
for text in texts: | ||
body = { | ||
"inputText": text, | ||
} | ||
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) | ||
embeddings.extend([response_body.get('embedding')]) | ||
token_usage += response_body.get('inputTextTokenCount') | ||
result = TextEmbeddingResult( | ||
model=model, | ||
embeddings=embeddings, | ||
usage=self._calc_response_usage( | ||
model=model, | ||
credentials=credentials, | ||
tokens=token_usage | ||
) | ||
) | ||
else: | ||
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") | ||
|
||
return result | ||
|
||
|
||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: | ||
""" | ||
Get number of tokens for given prompt messages | ||
:param model: model name | ||
:param credentials: model credentials | ||
:param texts: texts to embed | ||
:return: | ||
""" | ||
num_tokens = 0 | ||
for text in texts: | ||
num_tokens += self._get_num_tokens_by_gpt2(text) | ||
return num_tokens | ||
|
||
def validate_credentials(self, model: str, credentials: dict) -> None: | ||
""" | ||
Validate model credentials | ||
:param model: model name | ||
:param credentials: model credentials | ||
:return: | ||
""" | ||
|
||
@property | ||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | ||
""" | ||
Map model invoke error to unified error | ||
The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller | ||
The value is the md = genai.GenerativeModel(model)error type thrown by the model, | ||
which needs to be converted into a unified error type for the caller. | ||
:return: Invoke emd = genai.GenerativeModel(model)rror mapping | ||
""" | ||
return { | ||
InvokeConnectionError: [], | ||
InvokeServerUnavailableError: [], | ||
InvokeRateLimitError: [], | ||
InvokeAuthorizationError: [], | ||
InvokeBadRequestError: [] | ||
} | ||
|
||
def _create_payload(self, model_prefix: str, texts: list[str], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): | ||
""" | ||
Create payload for bedrock api call depending on model provider | ||
""" | ||
payload = dict() | ||
|
||
if model_prefix == "amazon": | ||
payload['inputText'] = texts | ||
|
||
|
||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: | ||
""" | ||
Calculate response usage | ||
:param model: model name | ||
:param credentials: model credentials | ||
:param tokens: input tokens | ||
:return: usage | ||
""" | ||
# get input price info | ||
input_price_info = self.get_price( | ||
model=model, | ||
credentials=credentials, | ||
price_type=PriceType.INPUT, | ||
tokens=tokens | ||
) | ||
|
||
# transform usage | ||
usage = EmbeddingUsage( | ||
tokens=tokens, | ||
total_tokens=tokens, | ||
unit_price=input_price_info.unit_price, | ||
price_unit=input_price_info.unit, | ||
total_price=input_price_info.total_amount, | ||
currency=input_price_info.currency, | ||
latency=time.perf_counter() - self.started_at | ||
) | ||
|
||
return usage | ||
|
||
def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]: | ||
""" | ||
Map client error to invoke error | ||
:param error_code: error code | ||
:param error_msg: error message | ||
:return: invoke error | ||
""" | ||
|
||
if error_code == "AccessDeniedException": | ||
return InvokeAuthorizationError(error_msg) | ||
elif error_code in ["ResourceNotFoundException", "ValidationException"]: | ||
return InvokeBadRequestError(error_msg) | ||
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: | ||
return InvokeRateLimitError(error_msg) | ||
elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]: | ||
return InvokeServerUnavailableError(error_msg) | ||
elif error_code == "ModelStreamErrorException": | ||
return InvokeConnectionError(error_msg) | ||
|
||
return InvokeError(error_msg) | ||
|
||
|
||
def _invoke_bedrock_embedding(self, model: str, bedrock_runtime, body: dict, ): | ||
accept = 'application/json' | ||
content_type = 'application/json' | ||
try: | ||
response = bedrock_runtime.invoke_model( | ||
body=json.dumps(body), | ||
modelId=model, | ||
accept=accept, | ||
contentType=content_type | ||
) | ||
response_body = json.loads(response.get('body').read().decode('utf-8')) | ||
return response_body | ||
except ClientError as ex: | ||
error_code = ex.response['Error']['Code'] | ||
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" | ||
raise self._map_client_to_invoke_error(error_code, full_error_msg) | ||
|
||
except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: | ||
raise InvokeConnectionError(str(ex)) | ||
|
||
except UnknownServiceError as ex: | ||
raise InvokeServerUnavailableError(str(ex)) | ||
|
||
except Exception as ex: | ||
raise InvokeError(str(ex)) |