diff --git a/src/llmperf/common.py b/src/llmperf/common.py index 3efefa1..c7f908a 100644 --- a/src/llmperf/common.py +++ b/src/llmperf/common.py @@ -5,6 +5,7 @@ ) from llmperf.ray_clients.sagemaker_client import SageMakerClient from llmperf.ray_clients.vertexai_client import VertexAIClient +from llmperf.ray_clients.bedrock_client import BedrockClient from llmperf.ray_llm_client import LLMClient @@ -28,6 +29,8 @@ def construct_clients(llm_api: str, num_clients: int) -> List[LLMClient]: clients = [SageMakerClient.remote() for _ in range(num_clients)] elif llm_api == "vertexai": clients = [VertexAIClient.remote() for _ in range(num_clients)] + elif llm_api == "bedrock": + clients = [BedrockClient.remote() for _ in range(num_clients)] elif llm_api in SUPPORTED_APIS: clients = [LiteLLMClient.remote() for _ in range(num_clients)] else: