From 1d60269112f7125234ba5c0c98da92d10a56d7cf Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Wed, 18 Dec 2024 19:39:05 +0100 Subject: [PATCH 01/11] add: init vertex --- gemini-vertex-model-provider/README.md | 27 ++ gemini-vertex-model-provider/main.py | 373 ++++++++++++++++++ gemini-vertex-model-provider/requirements.txt | 4 + gemini-vertex-model-provider/tool.gpt | 12 + 4 files changed, 416 insertions(+) create mode 100644 gemini-vertex-model-provider/README.md create mode 100644 gemini-vertex-model-provider/main.py create mode 100644 gemini-vertex-model-provider/requirements.txt create mode 100644 gemini-vertex-model-provider/tool.gpt diff --git a/gemini-vertex-model-provider/README.md b/gemini-vertex-model-provider/README.md new file mode 100644 index 00000000..81166048 --- /dev/null +++ b/gemini-vertex-model-provider/README.md @@ -0,0 +1,27 @@ +Expects you to be authenticated with Google Cloud + +example: +``` +gcloud auth application-default login +``` + +## Usage Example + +``` +gptscript --default-model='gemini-1.0-pro from github.com/gptscript-ai/gemini-vertexai-provider' examples/bob.gpt +``` + +## Development + +Run using the following commands + +``` +python -m venv .venv +source ./.venv/bin/activate +pip install -r requirements.txt +DEBUG=true ./run.sh +``` + +``` +gptscript --default-model='gemini-1.0-pro from http://127.0.0.1:8000/v1' examples/bob.gpt +``` diff --git a/gemini-vertex-model-provider/main.py b/gemini-vertex-model-provider/main.py new file mode 100644 index 00000000..38318f1b --- /dev/null +++ b/gemini-vertex-model-provider/main.py @@ -0,0 +1,373 @@ +import json +import os +from typing import AsyncIterable, List + +import google.auth.exceptions +import vertexai.preview.generative_models as generative_models +from fastapi import FastAPI, HTTPException, Request +from fastapi.encoders import jsonable_encoder +from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.routing import APIRouter +from openai.types.chat import ChatCompletionChunk +from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta +from vertexai.preview.generative_models import Content, FunctionDeclaration, GenerationConfig, GenerationResponse, \ + GenerativeModel, Part, Tool + +debug = os.environ.get("GPTSCRIPT_DEBUG", "false") == "true" +app = FastAPI() +router = APIRouter() + +uri = "http://127.0.0.1:" + os.environ.get("PORT", "8000") + + + +def log(*args): + if debug: + print(*args) + + +@app.middleware("http") +async def log_body(request: Request, call_next): + body = await request.body() + log("REQUEST BODY: ", body) + return await call_next(request) + + +@app.post("/") +@app.get("/") +async def get_root(): + return uri + + +@app.get("/v1/models") +def list_models() -> JSONResponse: + content = { + "data": [ + { + "id": "gemini-1.0-pro", + "name": "Gemini 1.0 Pro", + }, + { + "id": "gemini-1.5-pro", + "name": "Gemini 1.5 Pro", + }, + { + "id": "gemini-1.5-pro-preview-0409", + "name": "Gemini 1.5 Pro Preview 0409" + } + ] + } + return JSONResponse(content=content) + + +async def map_tools(req_tools: List | None = None) -> List[Tool] | None: + if req_tools is None or len(req_tools) < 1: + return None + + function_declarations = [] + for tool in req_tools: + parameters = tool['function'].get('parameters', { + "properties": {}, + "type": "object" + }) + + function_declarations.append( + FunctionDeclaration( + name=tool["function"]["name"], + description=tool["function"]["description"], + parameters=parameters, + ) + ) + + tools: list["Tool"] = [Tool.from_function_declarations(function_declarations)] + + return tools + + +def merge_consecutive_dicts_with_same_value(list_of_dicts, key) -> list[dict]: + merged_list = [] + index = 0 + while index < len(list_of_dicts): + current_dict = list_of_dicts[index] + value_to_match = current_dict.get(key) + compared_index = index + 1 + while compared_index < len(list_of_dicts) and list_of_dicts[compared_index].get(key) == value_to_match: + list_of_dicts[compared_index]["content"] = current_dict["content"] + "\n" + list_of_dicts[compared_index][ + "content"] + current_dict.update(list_of_dicts[compared_index]) + compared_index += 1 + merged_list.append(current_dict) + index = compared_index + return merged_list + + +async def map_messages(req_messages: list) -> list[Content] | None: + messages: list[Content] = [] + log(req_messages) + + if req_messages is not None: + system: str = """ +You are a task oriented system. +Be as brief as possible when answering the user. +Only give the required answer. +Do not give your thought process. +Use functions or tools as needed to complete the tasks given to you. +You are referred to as a tool. +Do not call functions or tools unless you need to. +Ensure you are passing the correct arguments to the functions or tools you call. +Do not move on to the next task until the current task is completed. +Do not make up arguments for tools. +Call functions one at a time to make sure you have the correct inputs. +""" + req_messages = [ + {"role": "system", "content": system}, + {"role": "model", "content": "Understood."} + ] + req_messages + + for message in req_messages: + match message["role"]: + case "system": + message['role'] = "user" + case "user": + message['role'] = "user" + case "assistant": + message['role'] = "model" + case "model": + message['role'] = "model" + case "tool": + message['role'] = "function" + case _: + message['role'] = "user" + req_messages = merge_consecutive_dicts_with_same_value(req_messages, "role") + + for message in req_messages: + if 'tool_call_id' in message.keys(): + convert_message = Content( + role=message['role'], + parts=[Part.from_function_response( + name=message.get('name', ''), + response={ + 'name': message.get('name', ''), + 'content': message['content'] + }, + )] + ) + elif 'tool_calls' in message.keys(): + tool_call_parts: list[Part] = [] + for tool_call in message['tool_calls']: + function_call = { + "functionCall": { + "name": tool_call['function']['name'], + "args": json.loads(tool_call['function']['arguments']) + } + } + tool_call_parts.append(Part.from_dict(function_call)) + convert_message = Content( + role=message['role'], + parts=tool_call_parts + ) + elif 'content' in message.keys(): + convert_message = Content( + role=message['role'], + parts=[Part.from_text(message["content"])] + ) + messages.append(convert_message) + + return messages + + return None + + +@app.post("/v1/chat/completions") +async def chat_completion(request: Request): + data = await request.body() + data = json.loads(data) + + req_tools = data.get("tools", None) + tools: list[Tool] | None = None + if req_tools is not None: + tools = await map_tools(req_tools) + + req_messages = data["messages"] + messages = await map_messages(req_messages) + + temperature = data.get("temperature", None) + if temperature is not None: + temperature = float(temperature) + + stream = data.get("stream", False) + + top_k = data.get("top_k", None) + if top_k is not None: + top_k = float(top_k) + + top_p = data.get("top_p", None) + if top_p is not None: + top_p = float(top_p) + + max_output_tokens = data.get("max_tokens", None) + if max_output_tokens is not None: + max_output_tokens = float(max_output_tokens) + + log() + log("GEMINI TOOLS: ", tools) + try: + model = GenerativeModel(data["model"]) + except google.auth.exceptions.GoogleAuthError as e: + log("AUTH ERROR: ", e) + raise HTTPException(status_code=401, + detail="Authentication error. Please ensure you are properly authenticated with GCP and have the correct project configured.") + except Exception as e: + log("ERROR: ", e) + log(type(e)) + raise HTTPException(status_code=500, detail=str(e)) + try: + response = model.generate_content( + contents=messages, + tools=tools, + generation_config=GenerationConfig( + temperature=temperature, + top_k=top_k, + top_p=top_p, + candidate_count=1, + max_output_tokens=max_output_tokens, + ), + safety_settings={ + generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, + generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, + generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, + generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, + } + ) + if not stream: + return JSONResponse(content=jsonable_encoder(response)) + + return StreamingResponse(to_chunk(data['model'], response), media_type="application/x-ndjson") + + except Exception as e: + log("ERROR: ", e) + raise HTTPException(status_code=500, detail=str(e)) + + +async def to_chunk(model: str, response: GenerationResponse) -> AsyncIterable[str]: + mapped_chunk = map_resp(model, response) + if mapped_chunk is None: + yield "data: " + json.dumps({}) + "\n\n" + else: + log("RESPONSE CHUNK: ", mapped_chunk.model_dump_json()) + yield "data: " + mapped_chunk.model_dump_json() + "\n\n" + + +def map_resp(model: str, chunk: GenerationResponse) -> ChatCompletionChunk | None: + tool_calls = [] + if len(chunk.candidates) > 0: + if len(chunk.candidates[0].function_calls) > 0: + parts = chunk.candidates[0].to_dict().get('content', {}).get('parts', []) + + for idx, part in enumerate(parts): + call = part.get('function_call', None) + if not call: + continue + + tool_calls.append({ + "index": idx, + "id": call['name'] + "_" + str(idx), + "function": { + "name": call['name'], + "arguments": json.dumps(call['args']) + }, + "type": "function" + }) + + try: + content = chunk.candidates[0].content.text + except: + content = None + + match chunk.candidates[0].content.role: + case "system": + role = "user" + case "user": + role = "user" + case "assistant": + role = "model" + case "model": + role = "assistant" + case "function": + role = "tool" + case _: + role = "user" + + try: + if len(tool_calls) > 0: + finish_reason = "tool_calls" + else: + finish_reason = map_finish_reason(str(chunk.candidates[0].finish_reason)) + except KeyError: + finish_reason = None + + log("FINISH_REASON: ", finish_reason) + + resp = ChatCompletionChunk( + id="0", + choices=[ + Choice( + delta=ChoiceDelta( + content=content, + tool_calls=tool_calls, + role=role + ), + finish_reason=finish_reason, + index=0, + ) + ], + created=0, + model=model, + object="chat.completion.chunk", + ) + return resp + + return None + + +def map_finish_reason(finish_reason: str) -> str: + if (finish_reason == "ERROR"): + return "stop" + elif (finish_reason == "FINISH_REASON_UNSPECIFIED" or finish_reason == "STOP"): + return "stop" + elif finish_reason == "SAFETY": + return "content_filter" + elif finish_reason == "STOP": + return "stop" + elif finish_reason == "0": + return "stop" + elif finish_reason == "1": + return "stop" + elif finish_reason == "2": + return "length" + elif finish_reason == "3": + return "content_filter" + elif finish_reason == "4": + return "content_filter" + elif finish_reason == "5": + return "stop" + elif finish_reason == "6": + return "content_filter" + elif finish_reason == "7": + return "content_filter" + elif finish_reason == "8": + return "content_filter" + # elif finish_reason == None: + # return "tool_calls" + return finish_reason + + +if __name__ == "__main__": + import uvicorn + import asyncio + + try: + uvicorn.run("main:app", host="127.0.0.1", port=int(os.environ.get("PORT", "8000")), + log_level="debug" if debug else "critical", access_log=debug) + except (KeyboardInterrupt, asyncio.CancelledError): + pass diff --git a/gemini-vertex-model-provider/requirements.txt b/gemini-vertex-model-provider/requirements.txt new file mode 100644 index 00000000..04bf2fb3 --- /dev/null +++ b/gemini-vertex-model-provider/requirements.txt @@ -0,0 +1,4 @@ +fastapi +uvicorn[standard] +openai +google-cloud-aiplatform \ No newline at end of file diff --git a/gemini-vertex-model-provider/tool.gpt b/gemini-vertex-model-provider/tool.gpt new file mode 100644 index 00000000..9128ab51 --- /dev/null +++ b/gemini-vertex-model-provider/tool.gpt @@ -0,0 +1,12 @@ +Name: Anthropic on AWS Bedrock +Description: Model provider for AWS Bedrock hosted Anthropic models +Metadata: envVars: OBOT_ANTHROPIC_BEDROCK_MODEL_PROVIDER_ACCESS_KEY_ID,OBOT_ANTHROPIC_BEDROCK_MODEL_PROVIDER_SECRET_ACCESS_KEY,OBOT_ANTHROPIC_BEDROCK_MODEL_PROVIDER_REGION,OBOT_ANTHROPIC_BEDROCK_MODEL_PROVIDER_SESSION_TOKEN +Model Provider: true +Credential: ../model-provider-credential as anthropic-bedrock-model-provider + + +#!sys.daemon /usr/bin/env python3 ${GPTSCRIPT_TOOL_DIR}/main.py + +--- +!metadata:*:icon +/admin/assets/anthropic_bedrock_icon.svg \ No newline at end of file From 6a04cb98adf6dc64cea228a483634c47dc6ba1d3 Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Thu, 19 Dec 2024 09:53:17 +0100 Subject: [PATCH 02/11] chore: formatting --- gemini-vertex-model-provider/README.md | 27 ---- gemini-vertex-model-provider/main.py | 193 ++++++++++++++++--------- gemini-vertex-model-provider/tool.gpt | 10 +- index.yaml | 2 + 4 files changed, 135 insertions(+), 97 deletions(-) delete mode 100644 gemini-vertex-model-provider/README.md diff --git a/gemini-vertex-model-provider/README.md b/gemini-vertex-model-provider/README.md deleted file mode 100644 index 81166048..00000000 --- a/gemini-vertex-model-provider/README.md +++ /dev/null @@ -1,27 +0,0 @@ -Expects you to be authenticated with Google Cloud - -example: -``` -gcloud auth application-default login -``` - -## Usage Example - -``` -gptscript --default-model='gemini-1.0-pro from github.com/gptscript-ai/gemini-vertexai-provider' examples/bob.gpt -``` - -## Development - -Run using the following commands - -``` -python -m venv .venv -source ./.venv/bin/activate -pip install -r requirements.txt -DEBUG=true ./run.sh -``` - -``` -gptscript --default-model='gemini-1.0-pro from http://127.0.0.1:8000/v1' examples/bob.gpt -``` diff --git a/gemini-vertex-model-provider/main.py b/gemini-vertex-model-provider/main.py index 38318f1b..0121cf90 100644 --- a/gemini-vertex-model-provider/main.py +++ b/gemini-vertex-model-provider/main.py @@ -1,7 +1,10 @@ import json import os +import tempfile from typing import AsyncIterable, List +from contextlib import asynccontextmanager + import google.auth.exceptions import vertexai.preview.generative_models as generative_models from fastapi import FastAPI, HTTPException, Request @@ -10,8 +13,54 @@ from fastapi.routing import APIRouter from openai.types.chat import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta -from vertexai.preview.generative_models import Content, FunctionDeclaration, GenerationConfig, GenerationResponse, \ - GenerativeModel, Part, Tool +from vertexai.preview.generative_models import ( + Content, + FunctionDeclaration, + GenerationConfig, + GenerationResponse, + GenerativeModel, + Part, + Tool, +) + + +@asynccontextmanager +async def lifespan(a: FastAPI): + project = os.environ.get( + "OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CLOUD_PROJECT", None + ) + if not project: + raise Exception("Google Cloud project is required.") + + creds_json = os.environ.get( + "OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CREDENTIALS_JSON", None + ) + if not creds_json: + raise Exception("Google application credentials content is required.") + + try: + _ = json.loads(creds_json) + except json.JSONDecodeError as e: + raise Exception("Invalid JSON in Google application credentials: ", e) + + creds_file: str + with tempfile.NamedTemporaryFile( + suffix="-google-credentials.json", delete=False, delete_on_close=False + ) as f: + creds_file = f.name + f.write(creds_json.encode("utf-8")) + f.close() + + # secure access by setting file permissions + os.chmod(creds_file, 0o600) + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = creds_file + + os.environ["GOOGLE_CLOUD_PROJECT"] = project + + yield # App shutdown + + os.remove(creds_file) + debug = os.environ.get("GPTSCRIPT_DEBUG", "false") == "true" app = FastAPI() @@ -20,7 +69,6 @@ uri = "http://127.0.0.1:" + os.environ.get("PORT", "8000") - def log(*args): if debug: print(*args) @@ -53,8 +101,8 @@ def list_models() -> JSONResponse: }, { "id": "gemini-1.5-pro-preview-0409", - "name": "Gemini 1.5 Pro Preview 0409" - } + "name": "Gemini 1.5 Pro Preview 0409", + }, ] } return JSONResponse(content=content) @@ -66,10 +114,9 @@ async def map_tools(req_tools: List | None = None) -> List[Tool] | None: function_declarations = [] for tool in req_tools: - parameters = tool['function'].get('parameters', { - "properties": {}, - "type": "object" - }) + parameters = tool["function"].get( + "parameters", {"properties": {}, "type": "object"} + ) function_declarations.append( FunctionDeclaration( @@ -91,9 +138,15 @@ def merge_consecutive_dicts_with_same_value(list_of_dicts, key) -> list[dict]: current_dict = list_of_dicts[index] value_to_match = current_dict.get(key) compared_index = index + 1 - while compared_index < len(list_of_dicts) and list_of_dicts[compared_index].get(key) == value_to_match: - list_of_dicts[compared_index]["content"] = current_dict["content"] + "\n" + list_of_dicts[compared_index][ - "content"] + while ( + compared_index < len(list_of_dicts) + and list_of_dicts[compared_index].get(key) == value_to_match + ): + list_of_dicts[compared_index]["content"] = ( + current_dict["content"] + + "\n" + + list_of_dicts[compared_index]["content"] + ) current_dict.update(list_of_dicts[compared_index]) compared_index += 1 merged_list.append(current_dict) @@ -120,56 +173,54 @@ async def map_messages(req_messages: list) -> list[Content] | None: Call functions one at a time to make sure you have the correct inputs. """ req_messages = [ - {"role": "system", "content": system}, - {"role": "model", "content": "Understood."} - ] + req_messages + {"role": "system", "content": system}, + {"role": "model", "content": "Understood."}, + ] + req_messages for message in req_messages: match message["role"]: case "system": - message['role'] = "user" + message["role"] = "user" case "user": - message['role'] = "user" + message["role"] = "user" case "assistant": - message['role'] = "model" + message["role"] = "model" case "model": - message['role'] = "model" + message["role"] = "model" case "tool": - message['role'] = "function" + message["role"] = "function" case _: - message['role'] = "user" + message["role"] = "user" req_messages = merge_consecutive_dicts_with_same_value(req_messages, "role") for message in req_messages: - if 'tool_call_id' in message.keys(): + if "tool_call_id" in message.keys(): convert_message = Content( - role=message['role'], - parts=[Part.from_function_response( - name=message.get('name', ''), - response={ - 'name': message.get('name', ''), - 'content': message['content'] - }, - )] + role=message["role"], + parts=[ + Part.from_function_response( + name=message.get("name", ""), + response={ + "name": message.get("name", ""), + "content": message["content"], + }, + ) + ], ) - elif 'tool_calls' in message.keys(): + elif "tool_calls" in message.keys(): tool_call_parts: list[Part] = [] - for tool_call in message['tool_calls']: + for tool_call in message["tool_calls"]: function_call = { "functionCall": { - "name": tool_call['function']['name'], - "args": json.loads(tool_call['function']['arguments']) + "name": tool_call["function"]["name"], + "args": json.loads(tool_call["function"]["arguments"]), } } tool_call_parts.append(Part.from_dict(function_call)) + convert_message = Content(role=message["role"], parts=tool_call_parts) + elif "content" in message.keys(): convert_message = Content( - role=message['role'], - parts=tool_call_parts - ) - elif 'content' in message.keys(): - convert_message = Content( - role=message['role'], - parts=[Part.from_text(message["content"])] + role=message["role"], parts=[Part.from_text(message["content"])] ) messages.append(convert_message) @@ -215,8 +266,10 @@ async def chat_completion(request: Request): model = GenerativeModel(data["model"]) except google.auth.exceptions.GoogleAuthError as e: log("AUTH ERROR: ", e) - raise HTTPException(status_code=401, - detail="Authentication error. Please ensure you are properly authenticated with GCP and have the correct project configured.") + raise HTTPException( + status_code=401, + detail="Authentication error. Please ensure you are properly authenticated with GCP and have the correct project configured.", + ) except Exception as e: log("ERROR: ", e) log(type(e)) @@ -237,12 +290,14 @@ async def chat_completion(request: Request): generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, - } + }, ) if not stream: return JSONResponse(content=jsonable_encoder(response)) - return StreamingResponse(to_chunk(data['model'], response), media_type="application/x-ndjson") + return StreamingResponse( + to_chunk(data["model"], response), media_type="application/x-ndjson" + ) except Exception as e: log("ERROR: ", e) @@ -262,26 +317,28 @@ def map_resp(model: str, chunk: GenerationResponse) -> ChatCompletionChunk | Non tool_calls = [] if len(chunk.candidates) > 0: if len(chunk.candidates[0].function_calls) > 0: - parts = chunk.candidates[0].to_dict().get('content', {}).get('parts', []) + parts = chunk.candidates[0].to_dict().get("content", {}).get("parts", []) for idx, part in enumerate(parts): - call = part.get('function_call', None) + call = part.get("function_call", None) if not call: continue - tool_calls.append({ - "index": idx, - "id": call['name'] + "_" + str(idx), - "function": { - "name": call['name'], - "arguments": json.dumps(call['args']) - }, - "type": "function" - }) + tool_calls.append( + { + "index": idx, + "id": call["name"] + "_" + str(idx), + "function": { + "name": call["name"], + "arguments": json.dumps(call["args"]), + }, + "type": "function", + } + ) try: content = chunk.candidates[0].content.text - except: + except KeyError: content = None match chunk.candidates[0].content.role: @@ -302,7 +359,9 @@ def map_resp(model: str, chunk: GenerationResponse) -> ChatCompletionChunk | Non if len(tool_calls) > 0: finish_reason = "tool_calls" else: - finish_reason = map_finish_reason(str(chunk.candidates[0].finish_reason)) + finish_reason = map_finish_reason( + str(chunk.candidates[0].finish_reason) + ) except KeyError: finish_reason = None @@ -313,9 +372,7 @@ def map_resp(model: str, chunk: GenerationResponse) -> ChatCompletionChunk | Non choices=[ Choice( delta=ChoiceDelta( - content=content, - tool_calls=tool_calls, - role=role + content=content, tool_calls=tool_calls, role=role ), finish_reason=finish_reason, index=0, @@ -331,9 +388,9 @@ def map_resp(model: str, chunk: GenerationResponse) -> ChatCompletionChunk | Non def map_finish_reason(finish_reason: str) -> str: - if (finish_reason == "ERROR"): + if finish_reason == "ERROR": return "stop" - elif (finish_reason == "FINISH_REASON_UNSPECIFIED" or finish_reason == "STOP"): + elif finish_reason == "FINISH_REASON_UNSPECIFIED" or finish_reason == "STOP": return "stop" elif finish_reason == "SAFETY": return "content_filter" @@ -367,7 +424,13 @@ def map_finish_reason(finish_reason: str) -> str: import asyncio try: - uvicorn.run("main:app", host="127.0.0.1", port=int(os.environ.get("PORT", "8000")), - log_level="debug" if debug else "critical", access_log=debug) + uvicorn.run( + "main:app", + workers=4, + host="127.0.0.1", + port=int(os.environ.get("PORT", "8000")), + log_level="debug" if debug else "critical", + access_log=debug, + ) except (KeyboardInterrupt, asyncio.CancelledError): pass diff --git a/gemini-vertex-model-provider/tool.gpt b/gemini-vertex-model-provider/tool.gpt index 9128ab51..e98bf199 100644 --- a/gemini-vertex-model-provider/tool.gpt +++ b/gemini-vertex-model-provider/tool.gpt @@ -1,12 +1,12 @@ -Name: Anthropic on AWS Bedrock -Description: Model provider for AWS Bedrock hosted Anthropic models -Metadata: envVars: OBOT_ANTHROPIC_BEDROCK_MODEL_PROVIDER_ACCESS_KEY_ID,OBOT_ANTHROPIC_BEDROCK_MODEL_PROVIDER_SECRET_ACCESS_KEY,OBOT_ANTHROPIC_BEDROCK_MODEL_PROVIDER_REGION,OBOT_ANTHROPIC_BEDROCK_MODEL_PROVIDER_SESSION_TOKEN +Name: Google Gemini Vertex AI Model Provider +Description: Model provider for Google Gemini Vertex AI +Metadata: envVars: OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CREDENTIALS_JSON,OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CLOUD_PROJECT Model Provider: true -Credential: ../model-provider-credential as anthropic-bedrock-model-provider +Credential: ../model-provider-credential as gemini-vertex-model-provider #!sys.daemon /usr/bin/env python3 ${GPTSCRIPT_TOOL_DIR}/main.py --- !metadata:*:icon -/admin/assets/anthropic_bedrock_icon.svg \ No newline at end of file +/admin/assets/gemini_icon.svg \ No newline at end of file diff --git a/index.yaml b/index.yaml index a1f9dbfa..b551ef64 100644 --- a/index.yaml +++ b/index.yaml @@ -116,3 +116,5 @@ modelProviders: reference: ./xai-model-provider deepseek-model-provider: reference: ./deepseek-model-provider + gemini-vertex-model-provider: + reference: ./gemini-vertex-model-provider From 49fd76b50b2c06e33b6e39905fb464106fd506a8 Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Fri, 20 Dec 2024 09:52:49 +0100 Subject: [PATCH 03/11] change: optional project_id field + error handling --- gemini-vertex-model-provider/main.py | 68 ++++++++++++++++++--------- gemini-vertex-model-provider/tool.gpt | 2 +- 2 files changed, 46 insertions(+), 24 deletions(-) diff --git a/gemini-vertex-model-provider/main.py b/gemini-vertex-model-provider/main.py index 0121cf90..a9baf948 100644 --- a/gemini-vertex-model-provider/main.py +++ b/gemini-vertex-model-provider/main.py @@ -1,5 +1,6 @@ import json import os +import sys import tempfile from typing import AsyncIterable, List @@ -23,47 +24,62 @@ Tool, ) +startup_err: str | None = None @asynccontextmanager async def lifespan(a: FastAPI): - project = os.environ.get( - "OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CLOUD_PROJECT", None - ) - if not project: - raise Exception("Google Cloud project is required.") + global startup_err creds_json = os.environ.get( "OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CREDENTIALS_JSON", None ) if not creds_json: - raise Exception("Google application credentials content is required.") + startup_err = "Google application credentials content is required." + else: - try: - _ = json.loads(creds_json) - except json.JSONDecodeError as e: - raise Exception("Invalid JSON in Google application credentials: ", e) + c: dict | None = None + try: + c = json.loads(creds_json) + except json.JSONDecodeError as e: + startup_err = f"Invalid JSON in Google application credentials: {e}" + + + if c is not None and "project_id" in c: + os.environ["GOOGLE_CLOUD_PROJECT"] = c["project_id"] + else: + p = os.getenv("OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CLOUD_PROJECT", None) + if not p: + startup_err = "Google Cloud project ID is required." + else: + os.environ["GOOGLE_CLOUD_PROJECT"] = p - creds_file: str - with tempfile.NamedTemporaryFile( - suffix="-google-credentials.json", delete=False, delete_on_close=False - ) as f: - creds_file = f.name - f.write(creds_json.encode("utf-8")) - f.close() - # secure access by setting file permissions - os.chmod(creds_file, 0o600) - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = creds_file + if creds_json: + creds_file: str + with tempfile.NamedTemporaryFile( + suffix="-google-credentials.json", delete=False, delete_on_close=False + ) as f: + creds_file = f.name + f.write(creds_json.encode("utf-8")) + f.close() + + # secure access by setting file permissions + os.chmod(creds_file, 0o600) + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = creds_file - os.environ["GOOGLE_CLOUD_PROJECT"] = project yield # App shutdown - os.remove(creds_file) + try: + os.remove(creds_file) + except Exception: + pass debug = os.environ.get("GPTSCRIPT_DEBUG", "false") == "true" -app = FastAPI() +app = FastAPI( + lifespan=lifespan, +) router = APIRouter() uri = "http://127.0.0.1:" + os.environ.get("PORT", "8000") @@ -89,6 +105,10 @@ async def get_root(): @app.get("/v1/models") def list_models() -> JSONResponse: + if startup_err: + return JSONResponse( + content={"error": startup_err}, status_code=500 + ) content = { "data": [ { @@ -234,6 +254,8 @@ async def chat_completion(request: Request): data = await request.body() data = json.loads(data) + log("Env: ", os.environ) + req_tools = data.get("tools", None) tools: list[Tool] | None = None if req_tools is not None: diff --git a/gemini-vertex-model-provider/tool.gpt b/gemini-vertex-model-provider/tool.gpt index e98bf199..15f118f8 100644 --- a/gemini-vertex-model-provider/tool.gpt +++ b/gemini-vertex-model-provider/tool.gpt @@ -1,6 +1,6 @@ Name: Google Gemini Vertex AI Model Provider Description: Model provider for Google Gemini Vertex AI -Metadata: envVars: OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CREDENTIALS_JSON,OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CLOUD_PROJECT +Metadata: envVars: OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CREDENTIALS_JSON Model Provider: true Credential: ../model-provider-credential as gemini-vertex-model-provider From 6e8413f877aab2b5460b8c78e1b895b05315548d Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Mon, 23 Dec 2024 16:19:25 +0100 Subject: [PATCH 04/11] add: credential validation --- gemini-vertex-model-provider/main.py | 33 ++++++++++++++---------- gemini-vertex-model-provider/tool.gpt | 12 ++++++++- gemini-vertex-model-provider/validate.py | 31 ++++++++++++++++++++++ 3 files changed, 62 insertions(+), 14 deletions(-) create mode 100644 gemini-vertex-model-provider/validate.py diff --git a/gemini-vertex-model-provider/main.py b/gemini-vertex-model-provider/main.py index a9baf948..a93b4b33 100644 --- a/gemini-vertex-model-provider/main.py +++ b/gemini-vertex-model-provider/main.py @@ -26,38 +26,33 @@ startup_err: str | None = None -@asynccontextmanager -async def lifespan(a: FastAPI): - global startup_err - +def configure(): creds_json = os.environ.get( "OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CREDENTIALS_JSON", None ) if not creds_json: - startup_err = "Google application credentials content is required." + raise KeyError("Google application credentials content is required.") else: c: dict | None = None try: c = json.loads(creds_json) except json.JSONDecodeError as e: - startup_err = f"Invalid JSON in Google application credentials: {e}" - + raise ValueError(f"Invalid JSON in Google application credentials: {e}") if c is not None and "project_id" in c: os.environ["GOOGLE_CLOUD_PROJECT"] = c["project_id"] else: p = os.getenv("OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CLOUD_PROJECT", None) if not p: - startup_err = "Google Cloud project ID is required." + raise KeyError("Google Cloud project ID is required.") else: os.environ["GOOGLE_CLOUD_PROJECT"] = p - - if creds_json: + if creds_json: creds_file: str with tempfile.NamedTemporaryFile( - suffix="-google-credentials.json", delete=False, delete_on_close=False + suffix="-google-credentials.json", delete=False, delete_on_close=False ) as f: creds_file = f.name f.write(creds_json.encode("utf-8")) @@ -67,12 +62,24 @@ async def lifespan(a: FastAPI): os.chmod(creds_file, 0o600) os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = creds_file +def cleanup(): + os.remove(os.getenv("GOOGLE_APPLICATION_CREDENTIALS", "")) + +@asynccontextmanager +async def lifespan(a: FastAPI): + global startup_err + + try: + configure() + except Exception as e: + startup_err = e + yield # App shutdown try: - os.remove(creds_file) - except Exception: + cleanup() + except FileNotFoundError: pass diff --git a/gemini-vertex-model-provider/tool.gpt b/gemini-vertex-model-provider/tool.gpt index 15f118f8..b41adf0a 100644 --- a/gemini-vertex-model-provider/tool.gpt +++ b/gemini-vertex-model-provider/tool.gpt @@ -1,6 +1,7 @@ Name: Google Gemini Vertex AI Model Provider Description: Model provider for Google Gemini Vertex AI Metadata: envVars: OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CREDENTIALS_JSON +Metadata: optionalEnvVars: OBOT_GEMINI_VERTEX_MODEL_PROVIDER_PROJECT_ID Model Provider: true Credential: ../model-provider-credential as gemini-vertex-model-provider @@ -9,4 +10,13 @@ Credential: ../model-provider-credential as gemini-vertex-model-provider --- !metadata:*:icon -/admin/assets/gemini_icon.svg \ No newline at end of file +/admin/assets/gemini_icon.svg + +--- +Name: Validate Credentials +Description: Validate the credentials for the Google Gemini Vertex AI Model Provider +Metadata: envVars: OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CREDENTIALS_JSON +Metadata: optionalEnvVars: OBOT_GEMINI_VERTEX_MODEL_PROVIDER_PROJECT_ID +Credential: ../model-provider-credential as gemini-vertex-model-provider + +#!/usr/bin/env python3 ${GPTSCRIPT_TOOL_DIR}/validate.py \ No newline at end of file diff --git a/gemini-vertex-model-provider/validate.py b/gemini-vertex-model-provider/validate.py new file mode 100644 index 00000000..1a92c329 --- /dev/null +++ b/gemini-vertex-model-provider/validate.py @@ -0,0 +1,31 @@ +import json + +from main import configure, cleanup, log + +from vertexai.preview.generative_models import GenerativeModel + +from google.auth.exceptions import GoogleAuthError + +def validate_credentials(): + try: + configure() + test_model() + log("Credentials are valid") + except Exception as e: + print(json.dumps({"error": str(e)})) + exit(1) + finally: + cleanup() + +def test_model(): + try: + _ = GenerativeModel("gemini-1.5-pro") + except GoogleAuthError as e: + print(json.dumps({"error": f"Invalid Google Credentials: {str(e)}"})) + exit(1) + except Exception as e: + print(json.dumps({"error": f"Unknown Error: {str(e)}"})) + exit(1) + +if __name__ == "__main__": + validate_credentials() \ No newline at end of file From 2f1d7f93f2d941c8f07ce77250760a588cf80727 Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Fri, 3 Jan 2025 13:55:47 +0100 Subject: [PATCH 05/11] add: validate endpoint --- gemini-vertex-model-provider/main.py | 8 ++++++-- gemini-vertex-model-provider/tool.gpt | 3 +-- gemini-vertex-model-provider/validate.py | 6 +++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/gemini-vertex-model-provider/main.py b/gemini-vertex-model-provider/main.py index a93b4b33..e2ff73f4 100644 --- a/gemini-vertex-model-provider/main.py +++ b/gemini-vertex-model-provider/main.py @@ -63,7 +63,10 @@ def configure(): os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = creds_file def cleanup(): - os.remove(os.getenv("GOOGLE_APPLICATION_CREDENTIALS", "")) + try: + os.remove(os.getenv("GOOGLE_APPLICATION_CREDENTIALS", "")) + except FileNotFoundError: + pass @asynccontextmanager async def lifespan(a: FastAPI): @@ -112,9 +115,10 @@ async def get_root(): @app.get("/v1/models") def list_models() -> JSONResponse: + global startup_err if startup_err: return JSONResponse( - content={"error": startup_err}, status_code=500 + content={"error": str(startup_err)}, status_code=500 ) content = { "data": [ diff --git a/gemini-vertex-model-provider/tool.gpt b/gemini-vertex-model-provider/tool.gpt index b41adf0a..393fe4e6 100644 --- a/gemini-vertex-model-provider/tool.gpt +++ b/gemini-vertex-model-provider/tool.gpt @@ -13,10 +13,9 @@ Credential: ../model-provider-credential as gemini-vertex-model-provider /admin/assets/gemini_icon.svg --- -Name: Validate Credentials +Name: validate Description: Validate the credentials for the Google Gemini Vertex AI Model Provider Metadata: envVars: OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CREDENTIALS_JSON Metadata: optionalEnvVars: OBOT_GEMINI_VERTEX_MODEL_PROVIDER_PROJECT_ID -Credential: ../model-provider-credential as gemini-vertex-model-provider #!/usr/bin/env python3 ${GPTSCRIPT_TOOL_DIR}/validate.py \ No newline at end of file diff --git a/gemini-vertex-model-provider/validate.py b/gemini-vertex-model-provider/validate.py index 1a92c329..bd621c1f 100644 --- a/gemini-vertex-model-provider/validate.py +++ b/gemini-vertex-model-provider/validate.py @@ -13,7 +13,7 @@ def validate_credentials(): log("Credentials are valid") except Exception as e: print(json.dumps({"error": str(e)})) - exit(1) + exit(0) finally: cleanup() @@ -22,10 +22,10 @@ def test_model(): _ = GenerativeModel("gemini-1.5-pro") except GoogleAuthError as e: print(json.dumps({"error": f"Invalid Google Credentials: {str(e)}"})) - exit(1) + exit(0) except Exception as e: print(json.dumps({"error": f"Unknown Error: {str(e)}"})) - exit(1) + exit(0) if __name__ == "__main__": validate_credentials() \ No newline at end of file From ff5c772575d70dfb254f841b10b69b03b8e38085 Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Thu, 9 Jan 2025 19:56:19 +0100 Subject: [PATCH 06/11] change: rewrite gemini-vertex-model-provider in Go --- gemini-vertex-model-provider/go.mod | 16 + gemini-vertex-model-provider/go.sum | 14 + gemini-vertex-model-provider/main.go | 99 ++++ gemini-vertex-model-provider/main.py | 469 --------------- gemini-vertex-model-provider/requirements.txt | 4 - gemini-vertex-model-provider/server/server.go | 536 ++++++++++++++++++ gemini-vertex-model-provider/tool.gpt | 4 +- gemini-vertex-model-provider/validate.py | 31 - 8 files changed, 667 insertions(+), 506 deletions(-) create mode 100644 gemini-vertex-model-provider/go.mod create mode 100644 gemini-vertex-model-provider/go.sum create mode 100644 gemini-vertex-model-provider/main.go delete mode 100644 gemini-vertex-model-provider/main.py delete mode 100644 gemini-vertex-model-provider/requirements.txt create mode 100644 gemini-vertex-model-provider/server/server.go delete mode 100644 gemini-vertex-model-provider/validate.py diff --git a/gemini-vertex-model-provider/go.mod b/gemini-vertex-model-provider/go.mod new file mode 100644 index 00000000..673746e1 --- /dev/null +++ b/gemini-vertex-model-provider/go.mod @@ -0,0 +1,16 @@ +module github.com/obot-platform/tools/gemini-vertex-model-provider + +go 1.23.4 + +require ( + github.com/gptscript-ai/chat-completion-client v0.0.0-20241219123536-85c44096bc10 + golang.org/x/oauth2 v0.23.0 + google.golang.org/genai v0.0.0-20250107232730-7bfd6e5a3ff7 +) + +require ( + cloud.google.com/go v0.116.0 // indirect + cloud.google.com/go/compute/metadata v0.5.0 // indirect + github.com/google/go-cmp v0.6.0 // indirect + golang.org/x/sys v0.25.0 // indirect +) diff --git a/gemini-vertex-model-provider/go.sum b/gemini-vertex-model-provider/go.sum new file mode 100644 index 00000000..5d418fef --- /dev/null +++ b/gemini-vertex-model-provider/go.sum @@ -0,0 +1,14 @@ +cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE= +cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U= +cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY= +cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/gptscript-ai/chat-completion-client v0.0.0-20241219123536-85c44096bc10 h1:v251qdhjAE+mCi3s+ekmGbqV9BurrMTl0Vd8/0MvsTY= +github.com/gptscript-ai/chat-completion-client v0.0.0-20241219123536-85c44096bc10/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo= +golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= +golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +google.golang.org/genai v0.0.0-20250107232730-7bfd6e5a3ff7 h1:sGAkRQ7mZxfaV5S6gK/E8UU0ULMKGUuYkzlGcSXjGl4= +google.golang.org/genai v0.0.0-20250107232730-7bfd6e5a3ff7/go.mod h1:oOXmTgRmvfizGLLCWeqvGyKJjDluaibHnZdFIZEob0k= diff --git a/gemini-vertex-model-provider/main.go b/gemini-vertex-model-provider/main.go new file mode 100644 index 00000000..1955ccc1 --- /dev/null +++ b/gemini-vertex-model-provider/main.go @@ -0,0 +1,99 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + + "github.com/obot-platform/tools/gemini-vertex-model-provider/server" + "golang.org/x/oauth2/google" + "google.golang.org/genai" +) + +func main() { + ctx := context.Background() + + args := os.Args[1:] + if len(args) == 1 && args[0] == "validate" { + if err := validate(ctx); err != nil { + fmt.Printf("{\"error\": \"%s\"}\n", err) + } + os.Exit(0) + } + + c, err := configure(ctx) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + + port := os.Getenv("PORT") + if port == "" { + port = "8000" + } + + if err := server.Run(c, port); err != nil { + panic(err) + } +} + +func validate(ctx context.Context) error { + _, err := configure(ctx) + return err +} + +func configure(ctx context.Context) (*genai.Client, error) { + + // Ensure that we have some valid credentials JSON data + credsJSON := os.Getenv("OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CREDENTIALS_JSON") + if credsJSON == "" { + return nil, fmt.Errorf("google application credentials content is required") + } + + var creds map[string]any + if err := json.Unmarshal([]byte(credsJSON), &creds); err != nil { + return nil, fmt.Errorf("failed to parse google application credentials json: %w", err) + } + + gcreds, err := google.CredentialsFromJSON(ctx, []byte(credsJSON)) + if err != nil { + return nil, fmt.Errorf("failed to parse google credentials JSON: %w", err) + } + + // Ensure that we have a Project ID set + var pid string + if p, ok := creds["project_id"]; ok { + pid = p.(string) + } else { + pid = os.Getenv("OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CLOUD_PROJECT") + } + if pid == "" { + return nil, fmt.Errorf("google cloud project id is required") + } + + // Ensure that we have a Location set + var loc string + if l, ok := creds["location"]; ok { + loc = l.(string) + } else { + pid = os.Getenv("OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CLOUD_LOCATION") + } + if loc == "" { + return nil, fmt.Errorf("google cloud location is required") + } + + cc := &genai.ClientConfig{ + Backend: genai.BackendVertexAI, + Credentials: gcreds, + Project: pid, + Location: loc, + } + + client, err := genai.NewClient(ctx, cc) + if err != nil { + return nil, fmt.Errorf("failed to create genai client: %w", err) + } + + return client, nil +} diff --git a/gemini-vertex-model-provider/main.py b/gemini-vertex-model-provider/main.py deleted file mode 100644 index e2ff73f4..00000000 --- a/gemini-vertex-model-provider/main.py +++ /dev/null @@ -1,469 +0,0 @@ -import json -import os -import sys -import tempfile -from typing import AsyncIterable, List - -from contextlib import asynccontextmanager - -import google.auth.exceptions -import vertexai.preview.generative_models as generative_models -from fastapi import FastAPI, HTTPException, Request -from fastapi.encoders import jsonable_encoder -from fastapi.responses import JSONResponse, StreamingResponse -from fastapi.routing import APIRouter -from openai.types.chat import ChatCompletionChunk -from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta -from vertexai.preview.generative_models import ( - Content, - FunctionDeclaration, - GenerationConfig, - GenerationResponse, - GenerativeModel, - Part, - Tool, -) - -startup_err: str | None = None - -def configure(): - creds_json = os.environ.get( - "OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CREDENTIALS_JSON", None - ) - if not creds_json: - raise KeyError("Google application credentials content is required.") - else: - - c: dict | None = None - try: - c = json.loads(creds_json) - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON in Google application credentials: {e}") - - if c is not None and "project_id" in c: - os.environ["GOOGLE_CLOUD_PROJECT"] = c["project_id"] - else: - p = os.getenv("OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CLOUD_PROJECT", None) - if not p: - raise KeyError("Google Cloud project ID is required.") - else: - os.environ["GOOGLE_CLOUD_PROJECT"] = p - - if creds_json: - creds_file: str - with tempfile.NamedTemporaryFile( - suffix="-google-credentials.json", delete=False, delete_on_close=False - ) as f: - creds_file = f.name - f.write(creds_json.encode("utf-8")) - f.close() - - # secure access by setting file permissions - os.chmod(creds_file, 0o600) - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = creds_file - -def cleanup(): - try: - os.remove(os.getenv("GOOGLE_APPLICATION_CREDENTIALS", "")) - except FileNotFoundError: - pass - -@asynccontextmanager -async def lifespan(a: FastAPI): - global startup_err - - try: - configure() - except Exception as e: - startup_err = e - - - yield # App shutdown - - try: - cleanup() - except FileNotFoundError: - pass - - -debug = os.environ.get("GPTSCRIPT_DEBUG", "false") == "true" -app = FastAPI( - lifespan=lifespan, -) -router = APIRouter() - -uri = "http://127.0.0.1:" + os.environ.get("PORT", "8000") - - -def log(*args): - if debug: - print(*args) - - -@app.middleware("http") -async def log_body(request: Request, call_next): - body = await request.body() - log("REQUEST BODY: ", body) - return await call_next(request) - - -@app.post("/") -@app.get("/") -async def get_root(): - return uri - - -@app.get("/v1/models") -def list_models() -> JSONResponse: - global startup_err - if startup_err: - return JSONResponse( - content={"error": str(startup_err)}, status_code=500 - ) - content = { - "data": [ - { - "id": "gemini-1.0-pro", - "name": "Gemini 1.0 Pro", - }, - { - "id": "gemini-1.5-pro", - "name": "Gemini 1.5 Pro", - }, - { - "id": "gemini-1.5-pro-preview-0409", - "name": "Gemini 1.5 Pro Preview 0409", - }, - ] - } - return JSONResponse(content=content) - - -async def map_tools(req_tools: List | None = None) -> List[Tool] | None: - if req_tools is None or len(req_tools) < 1: - return None - - function_declarations = [] - for tool in req_tools: - parameters = tool["function"].get( - "parameters", {"properties": {}, "type": "object"} - ) - - function_declarations.append( - FunctionDeclaration( - name=tool["function"]["name"], - description=tool["function"]["description"], - parameters=parameters, - ) - ) - - tools: list["Tool"] = [Tool.from_function_declarations(function_declarations)] - - return tools - - -def merge_consecutive_dicts_with_same_value(list_of_dicts, key) -> list[dict]: - merged_list = [] - index = 0 - while index < len(list_of_dicts): - current_dict = list_of_dicts[index] - value_to_match = current_dict.get(key) - compared_index = index + 1 - while ( - compared_index < len(list_of_dicts) - and list_of_dicts[compared_index].get(key) == value_to_match - ): - list_of_dicts[compared_index]["content"] = ( - current_dict["content"] - + "\n" - + list_of_dicts[compared_index]["content"] - ) - current_dict.update(list_of_dicts[compared_index]) - compared_index += 1 - merged_list.append(current_dict) - index = compared_index - return merged_list - - -async def map_messages(req_messages: list) -> list[Content] | None: - messages: list[Content] = [] - log(req_messages) - - if req_messages is not None: - system: str = """ -You are a task oriented system. -Be as brief as possible when answering the user. -Only give the required answer. -Do not give your thought process. -Use functions or tools as needed to complete the tasks given to you. -You are referred to as a tool. -Do not call functions or tools unless you need to. -Ensure you are passing the correct arguments to the functions or tools you call. -Do not move on to the next task until the current task is completed. -Do not make up arguments for tools. -Call functions one at a time to make sure you have the correct inputs. -""" - req_messages = [ - {"role": "system", "content": system}, - {"role": "model", "content": "Understood."}, - ] + req_messages - - for message in req_messages: - match message["role"]: - case "system": - message["role"] = "user" - case "user": - message["role"] = "user" - case "assistant": - message["role"] = "model" - case "model": - message["role"] = "model" - case "tool": - message["role"] = "function" - case _: - message["role"] = "user" - req_messages = merge_consecutive_dicts_with_same_value(req_messages, "role") - - for message in req_messages: - if "tool_call_id" in message.keys(): - convert_message = Content( - role=message["role"], - parts=[ - Part.from_function_response( - name=message.get("name", ""), - response={ - "name": message.get("name", ""), - "content": message["content"], - }, - ) - ], - ) - elif "tool_calls" in message.keys(): - tool_call_parts: list[Part] = [] - for tool_call in message["tool_calls"]: - function_call = { - "functionCall": { - "name": tool_call["function"]["name"], - "args": json.loads(tool_call["function"]["arguments"]), - } - } - tool_call_parts.append(Part.from_dict(function_call)) - convert_message = Content(role=message["role"], parts=tool_call_parts) - elif "content" in message.keys(): - convert_message = Content( - role=message["role"], parts=[Part.from_text(message["content"])] - ) - messages.append(convert_message) - - return messages - - return None - - -@app.post("/v1/chat/completions") -async def chat_completion(request: Request): - data = await request.body() - data = json.loads(data) - - log("Env: ", os.environ) - - req_tools = data.get("tools", None) - tools: list[Tool] | None = None - if req_tools is not None: - tools = await map_tools(req_tools) - - req_messages = data["messages"] - messages = await map_messages(req_messages) - - temperature = data.get("temperature", None) - if temperature is not None: - temperature = float(temperature) - - stream = data.get("stream", False) - - top_k = data.get("top_k", None) - if top_k is not None: - top_k = float(top_k) - - top_p = data.get("top_p", None) - if top_p is not None: - top_p = float(top_p) - - max_output_tokens = data.get("max_tokens", None) - if max_output_tokens is not None: - max_output_tokens = float(max_output_tokens) - - log() - log("GEMINI TOOLS: ", tools) - try: - model = GenerativeModel(data["model"]) - except google.auth.exceptions.GoogleAuthError as e: - log("AUTH ERROR: ", e) - raise HTTPException( - status_code=401, - detail="Authentication error. Please ensure you are properly authenticated with GCP and have the correct project configured.", - ) - except Exception as e: - log("ERROR: ", e) - log(type(e)) - raise HTTPException(status_code=500, detail=str(e)) - try: - response = model.generate_content( - contents=messages, - tools=tools, - generation_config=GenerationConfig( - temperature=temperature, - top_k=top_k, - top_p=top_p, - candidate_count=1, - max_output_tokens=max_output_tokens, - ), - safety_settings={ - generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, - generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, - generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, - generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, - }, - ) - if not stream: - return JSONResponse(content=jsonable_encoder(response)) - - return StreamingResponse( - to_chunk(data["model"], response), media_type="application/x-ndjson" - ) - - except Exception as e: - log("ERROR: ", e) - raise HTTPException(status_code=500, detail=str(e)) - - -async def to_chunk(model: str, response: GenerationResponse) -> AsyncIterable[str]: - mapped_chunk = map_resp(model, response) - if mapped_chunk is None: - yield "data: " + json.dumps({}) + "\n\n" - else: - log("RESPONSE CHUNK: ", mapped_chunk.model_dump_json()) - yield "data: " + mapped_chunk.model_dump_json() + "\n\n" - - -def map_resp(model: str, chunk: GenerationResponse) -> ChatCompletionChunk | None: - tool_calls = [] - if len(chunk.candidates) > 0: - if len(chunk.candidates[0].function_calls) > 0: - parts = chunk.candidates[0].to_dict().get("content", {}).get("parts", []) - - for idx, part in enumerate(parts): - call = part.get("function_call", None) - if not call: - continue - - tool_calls.append( - { - "index": idx, - "id": call["name"] + "_" + str(idx), - "function": { - "name": call["name"], - "arguments": json.dumps(call["args"]), - }, - "type": "function", - } - ) - - try: - content = chunk.candidates[0].content.text - except KeyError: - content = None - - match chunk.candidates[0].content.role: - case "system": - role = "user" - case "user": - role = "user" - case "assistant": - role = "model" - case "model": - role = "assistant" - case "function": - role = "tool" - case _: - role = "user" - - try: - if len(tool_calls) > 0: - finish_reason = "tool_calls" - else: - finish_reason = map_finish_reason( - str(chunk.candidates[0].finish_reason) - ) - except KeyError: - finish_reason = None - - log("FINISH_REASON: ", finish_reason) - - resp = ChatCompletionChunk( - id="0", - choices=[ - Choice( - delta=ChoiceDelta( - content=content, tool_calls=tool_calls, role=role - ), - finish_reason=finish_reason, - index=0, - ) - ], - created=0, - model=model, - object="chat.completion.chunk", - ) - return resp - - return None - - -def map_finish_reason(finish_reason: str) -> str: - if finish_reason == "ERROR": - return "stop" - elif finish_reason == "FINISH_REASON_UNSPECIFIED" or finish_reason == "STOP": - return "stop" - elif finish_reason == "SAFETY": - return "content_filter" - elif finish_reason == "STOP": - return "stop" - elif finish_reason == "0": - return "stop" - elif finish_reason == "1": - return "stop" - elif finish_reason == "2": - return "length" - elif finish_reason == "3": - return "content_filter" - elif finish_reason == "4": - return "content_filter" - elif finish_reason == "5": - return "stop" - elif finish_reason == "6": - return "content_filter" - elif finish_reason == "7": - return "content_filter" - elif finish_reason == "8": - return "content_filter" - # elif finish_reason == None: - # return "tool_calls" - return finish_reason - - -if __name__ == "__main__": - import uvicorn - import asyncio - - try: - uvicorn.run( - "main:app", - workers=4, - host="127.0.0.1", - port=int(os.environ.get("PORT", "8000")), - log_level="debug" if debug else "critical", - access_log=debug, - ) - except (KeyboardInterrupt, asyncio.CancelledError): - pass diff --git a/gemini-vertex-model-provider/requirements.txt b/gemini-vertex-model-provider/requirements.txt deleted file mode 100644 index 04bf2fb3..00000000 --- a/gemini-vertex-model-provider/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -fastapi -uvicorn[standard] -openai -google-cloud-aiplatform \ No newline at end of file diff --git a/gemini-vertex-model-provider/server/server.go b/gemini-vertex-model-provider/server/server.go new file mode 100644 index 00000000..9214d521 --- /dev/null +++ b/gemini-vertex-model-provider/server/server.go @@ -0,0 +1,536 @@ +package server + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + + openai "github.com/gptscript-ai/chat-completion-client" + "google.golang.org/genai" +) + +const systemPrompt = `You are a task oriented system. +Be as brief as possible when answering the user. +Only give the required answer. +Do not give your thought process. +Use functions or tools as needed to complete the tasks given to you. +You are referred to as a tool. +Do not call functions or tools unless you need to. +Ensure you are passing the correct arguments to the functions or tools you call. +Do not move on to the next task until the current task is completed. +Do not make up arguments for tools. +Call functions one at a time to make sure you have the correct inputs.` + +type server struct { + port string + client *genai.Client +} + +func Run(client *genai.Client, port string) error { + + mux := http.NewServeMux() + + s := &server{ + client: client, + port: port, + } + + mux.HandleFunc("/{$}", s.healthz) + mux.HandleFunc("/v1/models", s.listModels) + mux.HandleFunc("/v1/chat/completions", s.chatCompletions) + mux.HandleFunc("/v1/embeddings", s.embeddings) + + httpServer := &http.Server{ + Addr: "127.0.0.1:" + port, + Handler: mux, + } + + if err := httpServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { + return err + } + + return nil +} + +func (s *server) healthz(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("http://127.0.0.1:" + s.port)) +} + +func (s *server) listModels(w http.ResponseWriter, r *http.Request) { + content := map[string]any{ + "data": []map[string]any{ + // LLMs: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#supported-models + { + "id": "gemini-1.5-flash-001", + "name": "Gemini 1.5 Flash (001)", + }, + { + "id": "gemini-1.5-flash-002", + "name": "Gemini 1.5 Flash (002)", + }, + { + "id": "gemini-1.5-pro-001", + "name": "Gemini 1.5 Pro (001)", + }, + { + "id": "gemini-1.5-pro-002", + "name": "Gemini 1.5 Pro (002)", + }, + { + "id": "gemini-1.0-pro-vision-001", + "name": "Gemini 1.0 Pro Vision (001)", + }, + { + "id": "gemini-1.0-pro", + "name": "Gemini 1.0 Pro", + }, + { + "id": "gemini-1.0-pro-001", + "name": "Gemini 1.0 Pro (001)", + }, + { + "id": "gemini-1.0-pro-002", + "name": "Gemini 1.0 Pro (002)", + }, + // Embedding Models: https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings#supported-models + { + "id": "textembedding-gecko@001", + "name": "Text Embedding Gecko (001) [EN]", + }, + { + "id": "textembedding-gecko@003", + "name": "Text Embedding Gecko (003) [EN]", + }, + { + "id": "text-embedding-004", + "name": "Text Embedding 004 [EN]", + }, + { + "id": "text-embedding-005", + "name": "Text Embedding 005 [EN]", + }, + { + "id": "textembedding-gecko-multilingual@001", + "name": "Text Embedding Gecko Multilingual (001)", + }, + { + "id": "text-multilingual-embedding-002", + "name": "Text Multilingual Embedding 002", + }, + }, + } + if err := json.NewEncoder(w).Encode(content); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func (s *server) chatCompletions(w http.ResponseWriter, r *http.Request) { + var cr openai.ChatCompletionRequest + if err := json.NewDecoder(r.Body).Decode(&cr); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + // Tools + tools, err := mapToolsFromOpenAI(cr.Tools) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Messages + contents, err := mapMessagesFromOpenAI(cr.Messages) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Temperature + var temperature *float64 + if cr.Temperature != nil { + t := float64(*cr.Temperature) + temperature = &t + } + + // TopP + var topP *float64 + if cr.TopP > 0 { + t := float64(cr.TopP) + topP = &t + } + + // MaxTokens + var maxTokens *int64 + if cr.MaxTokens > 0 { + m := int64(cr.MaxTokens) + maxTokens = &m + } + + // Options + config := &genai.GenerateContentConfig{ + SystemInstruction: &genai.Content{ + Parts: []*genai.Part{ + { + Text: systemPrompt, + }, + }, + Role: "user", + }, + Tools: tools, + Temperature: temperature, + TopP: topP, + MaxOutputTokens: maxTokens, + CandidateCount: int64(cr.N), + SafetySettings: []*genai.SafetySetting{ + { + Method: genai.HarmBlockMethodSeverity, + Category: genai.HarmCategoryHateSpeech, + Threshold: genai.HarmBlockThresholdBlockOnlyHigh, + }, + { + Method: genai.HarmBlockMethodSeverity, + Category: genai.HarmCategoryDangerousContent, + Threshold: genai.HarmBlockThresholdBlockOnlyHigh, + }, + { + Method: genai.HarmBlockMethodSeverity, + Category: genai.HarmCategorySexuallyExplicit, + Threshold: genai.HarmBlockThresholdBlockOnlyHigh, + }, + { + Method: genai.HarmBlockMethodSeverity, + Category: genai.HarmCategoryHarassment, + Threshold: genai.HarmBlockThresholdBlockOnlyHigh, + }, + }, + } + + if cr.Stream { + for result, err := range s.client.Models.GenerateContentStream(r.Context(), cr.Model, contents, config) { + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + choices, err := mapToOpenAIStreamChoice(result.Candidates) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + resp := openai.ChatCompletionStreamResponse{ + ID: "0", + Choices: choices, + Created: 0, + Model: cr.Model, + Object: "chat.completion.chunk", + Usage: mapUsageToOpenAI(result.UsageMetadata), + } + + if err := json.NewEncoder(w).Encode(resp); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + } + + } else { + result, err := s.client.Models.GenerateContent(r.Context(), cr.Model, contents, config) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + choices, err := mapToOpenAIChoice(result.Candidates) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + resp := openai.ChatCompletionResponse{ + ID: "0", + Object: "chat.completion", + Created: 0, + Model: cr.Model, + Choices: choices, + Usage: mapUsageToOpenAI(result.UsageMetadata), + } + + if err := json.NewEncoder(w).Encode(resp); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } + + return +} + +func mapUsageToOpenAI(usage *genai.GenerateContentResponseUsageMetadata) openai.Usage { + if usage == nil { + return openai.Usage{} + } + return openai.Usage{ + PromptTokens: int(usage.PromptTokenCount), + CompletionTokens: int(usage.CandidatesTokenCount), + TotalTokens: int(usage.TotalTokenCount), + } +} + +func mapToOpenAIContentAndToolCalls(parts []*genai.Part) (string, []openai.ToolCall, error) { + var toolCalls []openai.ToolCall + content := "" + for idx, p := range parts { + tidx := idx + if p.Text != "" { + content += "\n" + p.Text + } + if p.FunctionCall != nil { + args, err := json.Marshal(p.FunctionCall.Args) + if err != nil { + return "", nil, fmt.Errorf("failed to marshal function arguments: %w", err) + } + toolCalls = append(toolCalls, openai.ToolCall{ + Index: &tidx, + ID: p.FunctionCall.ID, + Type: openai.ToolTypeFunction, + Function: openai.FunctionCall{ + Name: p.FunctionCall.Name, + Arguments: string(args), + }, + }) + } + } + return content, toolCalls, nil +} + +func mapToOpenAIStreamChoice(candidates []*genai.Candidate) ([]openai.ChatCompletionStreamChoice, error) { + if len(candidates) == 0 { + return nil, nil + } + + var choices []openai.ChatCompletionStreamChoice + for i, c := range candidates { + + content, toolCalls, err := mapToOpenAIContentAndToolCalls(c.Content.Parts) + if err != nil { + return nil, fmt.Errorf("failed to map content and tool calls: %w", err) + } + + var finishReason openai.FinishReason + if len(toolCalls) > 0 { + finishReason = openai.FinishReasonFunctionCall + } else { + finishReason = mapFinishReasonToOpenAI(c.FinishReason) + } + + choice := openai.ChatCompletionStreamChoice{ + Index: i, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: strings.TrimSpace(content), + ToolCalls: toolCalls, + Role: mapRoleToOpenAI(c.Content.Role), + }, + FinishReason: finishReason, + ContentFilterResults: openai.ContentFilterResults{}, // TODO: fill based on Google's finish_reason? + } + choices = append(choices, choice) + } + + return choices, nil +} + +func mapToOpenAIChoice(candidates []*genai.Candidate) ([]openai.ChatCompletionChoice, error) { + if len(candidates) == 0 { + return nil, nil + } + + var choices []openai.ChatCompletionChoice + for i, c := range candidates { + + content, toolCalls, err := mapToOpenAIContentAndToolCalls(c.Content.Parts) + if err != nil { + return nil, fmt.Errorf("failed to map content and tool calls: %w", err) + } + + var finishReason openai.FinishReason + if len(toolCalls) > 0 { + finishReason = openai.FinishReasonFunctionCall + } else { + finishReason = mapFinishReasonToOpenAI(c.FinishReason) + } + + choice := openai.ChatCompletionChoice{ + Index: i, + FinishReason: finishReason, + Message: openai.ChatCompletionMessage{ + Role: mapRoleToOpenAI(c.Content.Role), + Content: content, + ToolCalls: toolCalls, + }, + LogProbs: nil, + } + choices = append(choices, choice) + } + + return choices, nil +} + +func mapFinishReasonToOpenAI(reason genai.FinishReason) openai.FinishReason { + switch reason { + case genai.FinishReasonStop, genai.FinishReasonUnspecified, genai.FinishReasonOther: + return openai.FinishReasonStop + case genai.FinishReasonMaxTokens: + return openai.FinishReasonLength + case genai.FinishReasonBlocklist, genai.FinishReasonRecitation, genai.FinishReasonSafety, genai.FinishReasonSPII, genai.FinishReasonProhibitedContent: + return openai.FinishReasonContentFilter + default: + return openai.FinishReasonStop + } +} + +var roleMapFromOpenAI = map[string]string{ + "system": "user", + "user": "user", + "assistant": "model", + "model": "model", + "tool": "function", +} + +func mapRoleFromOpenAI(role string) string { + if r, ok := roleMapFromOpenAI[role]; ok { + return r + } + return "user" +} + +var roleMapToOpenAI = map[string]string{ + "system": "user", + "user": "user", + "assistant": "model", + "model": "assistant", + "function": "tool", +} + +func mapRoleToOpenAI(role string) string { + if r, ok := roleMapToOpenAI[role]; ok { + return r + } + return "user" +} + +func mapMessagesFromOpenAI(messages []openai.ChatCompletionMessage) ([]*genai.Content, error) { + + var contents []*genai.Content + if len(messages) > 0 { + contents = append(contents, &genai.Content{ + Parts: []*genai.Part{ + { + Text: systemPrompt, + }, + }, + Role: "user", + }) + } + + for _, m := range messages { + content := &genai.Content{ + Parts: []*genai.Part{}, + Role: mapRoleFromOpenAI(m.Role), + } + + if m.ToolCallID != "" { + // Tool Call Response + content.Parts = append(content.Parts, &genai.Part{ + FunctionResponse: &genai.FunctionResponse{ + ID: m.ToolCallID, + Name: m.Name, + Response: map[string]any{ + "name": m.Name, + "content": m.Content, + }, + }, + }) + } else if len(m.ToolCalls) > 0 { + // Tool Calls + for _, tc := range m.ToolCalls { + var args map[string]any + if tc.Function.Arguments != "" { + err := json.Unmarshal([]byte(tc.Function.Arguments), &args) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal function arguments: %w", err) + } + } + content.Parts = append(content.Parts, &genai.Part{ + FunctionCall: &genai.FunctionCall{ + ID: tc.ID, + Name: tc.Function.Name, + Args: args, + }, + }) + } + + } else if m.Content != "" { + // Pure text content + content.Parts = append(content.Parts, &genai.Part{ + Text: m.Content, + }) + } + + contents = append(contents, content) + } + return contents, nil +} + +func mapToolsFromOpenAI(oaiTools []openai.Tool) ([]*genai.Tool, error) { + var tools []*genai.Tool + for _, t := range oaiTools { + f, err := mapFunctionDefinitionFromOpenAI(t.Function) + if err != nil { + return nil, fmt.Errorf("failed to map functions: %w", err) + } + if len(f) > 0 { + tools = append(tools, &genai.Tool{ + FunctionDeclarations: f, + }) + } + } + + return tools, nil +} + +func mapFunctionDefinitionFromOpenAI(funcDef *openai.FunctionDefinition) ([]*genai.FunctionDeclaration, error) { + if funcDef == nil { + return nil, nil + } + var functions []*genai.FunctionDeclaration + + var params *genai.Schema + if funcDef.Parameters != nil { + pb, err := json.Marshal(funcDef.Parameters) + if err != nil { + return nil, fmt.Errorf("failed to marshal function parameters: %w", err) + } + + if err := json.Unmarshal(pb, ¶ms); err != nil { + return nil, fmt.Errorf("failed to unmarshal function parameters: %w", err) + } + } else { + params = &genai.Schema{ + Properties: map[string]*genai.Schema{}, + Type: genai.TypeObject, + } + } + + functions = append(functions, &genai.FunctionDeclaration{ + Description: funcDef.Description, + Name: funcDef.Name, + Parameters: params, + }) + return functions, nil +} + +func (s *server) embeddings(w http.ResponseWriter, r *http.Request) { + return +} diff --git a/gemini-vertex-model-provider/tool.gpt b/gemini-vertex-model-provider/tool.gpt index 393fe4e6..e49b2f7c 100644 --- a/gemini-vertex-model-provider/tool.gpt +++ b/gemini-vertex-model-provider/tool.gpt @@ -6,7 +6,7 @@ Model Provider: true Credential: ../model-provider-credential as gemini-vertex-model-provider -#!sys.daemon /usr/bin/env python3 ${GPTSCRIPT_TOOL_DIR}/main.py +#!sys.daemon ${GPTSCRIPT_TOOL_DIR}/bin/gptscript-go-tool --- !metadata:*:icon @@ -18,4 +18,4 @@ Description: Validate the credentials for the Google Gemini Vertex AI Model Prov Metadata: envVars: OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CREDENTIALS_JSON Metadata: optionalEnvVars: OBOT_GEMINI_VERTEX_MODEL_PROVIDER_PROJECT_ID -#!/usr/bin/env python3 ${GPTSCRIPT_TOOL_DIR}/validate.py \ No newline at end of file +#!${GPTSCRIPT_TOOL_DIR}/bin/gptscript-go-tool validate \ No newline at end of file diff --git a/gemini-vertex-model-provider/validate.py b/gemini-vertex-model-provider/validate.py deleted file mode 100644 index bd621c1f..00000000 --- a/gemini-vertex-model-provider/validate.py +++ /dev/null @@ -1,31 +0,0 @@ -import json - -from main import configure, cleanup, log - -from vertexai.preview.generative_models import GenerativeModel - -from google.auth.exceptions import GoogleAuthError - -def validate_credentials(): - try: - configure() - test_model() - log("Credentials are valid") - except Exception as e: - print(json.dumps({"error": str(e)})) - exit(0) - finally: - cleanup() - -def test_model(): - try: - _ = GenerativeModel("gemini-1.5-pro") - except GoogleAuthError as e: - print(json.dumps({"error": f"Invalid Google Credentials: {str(e)}"})) - exit(0) - except Exception as e: - print(json.dumps({"error": f"Unknown Error: {str(e)}"})) - exit(0) - -if __name__ == "__main__": - validate_credentials() \ No newline at end of file From 0fa890a3fd83aa6b9da642f583d0879242c5efcf Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Fri, 10 Jan 2025 13:08:25 +0100 Subject: [PATCH 07/11] add: custom embeddings implementation --- gemini-vertex-model-provider/main.go | 6 +- gemini-vertex-model-provider/server/server.go | 119 ++++++++++++++++++ 2 files changed, 122 insertions(+), 3 deletions(-) diff --git a/gemini-vertex-model-provider/main.go b/gemini-vertex-model-provider/main.go index 1955ccc1..4f01d0cd 100644 --- a/gemini-vertex-model-provider/main.go +++ b/gemini-vertex-model-provider/main.go @@ -19,7 +19,7 @@ func main() { if err := validate(ctx); err != nil { fmt.Printf("{\"error\": \"%s\"}\n", err) } - os.Exit(0) + os.Exit(1) } c, err := configure(ctx) @@ -56,7 +56,7 @@ func configure(ctx context.Context) (*genai.Client, error) { return nil, fmt.Errorf("failed to parse google application credentials json: %w", err) } - gcreds, err := google.CredentialsFromJSON(ctx, []byte(credsJSON)) + gcreds, err := google.CredentialsFromJSON(ctx, []byte(credsJSON), "https://www.googleapis.com/auth/cloud-platform") if err != nil { return nil, fmt.Errorf("failed to parse google credentials JSON: %w", err) } @@ -77,7 +77,7 @@ func configure(ctx context.Context) (*genai.Client, error) { if l, ok := creds["location"]; ok { loc = l.(string) } else { - pid = os.Getenv("OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CLOUD_LOCATION") + loc = os.Getenv("OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CLOUD_LOCATION") } if loc == "" { return nil, fmt.Errorf("google cloud location is required") diff --git a/gemini-vertex-model-provider/server/server.go b/gemini-vertex-model-provider/server/server.go index 9214d521..90c26618 100644 --- a/gemini-vertex-model-provider/server/server.go +++ b/gemini-vertex-model-provider/server/server.go @@ -1,9 +1,11 @@ package server import ( + "bytes" "encoding/json" "errors" "fmt" + "io" "net/http" "strings" @@ -531,6 +533,123 @@ func mapFunctionDefinitionFromOpenAI(funcDef *openai.FunctionDefinition) ([]*gen return functions, nil } +// openAIEmbeddingRequest - not (yet) provided by the Chat Completion Client package +type openAIEmbeddingRequest struct { + Input string `json:"input"` + Model string `json:"model"` + EncodingFormat string `json:"encoding_format,omitempty"` + Dimensions *int `json:"dimensions,omitempty"` +} + +type openAIResponse struct { + Data []openAIResponseData `json:"data"` +} + +type openAIResponseData struct { + Embedding []float32 `json:"embedding"` +} + +type vertexEmbeddingResponse struct { + Predictions []vertexPrediction `json:"predictions"` +} + +type vertexPrediction struct { + Embeddings vertexEmbeddings `json:"embeddings"` +} + +type vertexEmbeddings struct { + Values []float32 `json:"values"` + // leaving out what we don't need just yet +} + +// embeddings - not (yet) provided by the Google GenAI package func (s *server) embeddings(w http.ResponseWriter, r *http.Request) { + + var er openAIEmbeddingRequest + if err := json.NewDecoder(r.Body).Decode(&er); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", s.client.ClientConfig().Location, s.client.ClientConfig().Project, s.client.ClientConfig().Location, er.Model) + + payload := map[string]any{ + "instances": []map[string]any{ + { + "tast_type": "QUESTION_ANSWERING", + "content": er.Input, + "parameters": map[string]any{}, + }, + }, + } + + if er.Dimensions != nil { + payload["parameters"] = map[string]any{ + "outputDimensionality": *er.Dimensions, + } + } + + reqBody, err := json.Marshal(payload) + if err != nil { + http.Error(w, fmt.Sprintf("couldn't marshal request body: %v", err), http.StatusInternalServerError) + return + } + + req, err := http.NewRequestWithContext(r.Context(), "POST", url, bytes.NewBuffer(reqBody)) + if err != nil { + http.Error(w, fmt.Sprintf("couldn't create request: %v", err), http.StatusInternalServerError) + return + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/json") + + resp, err := s.client.ClientConfig().HTTPClient.Do(req) + if err != nil { + http.Error(w, fmt.Sprintf("couldn't make request: %v", err), http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + http.Error(w, fmt.Sprintf("unexpected status code: %d", resp.StatusCode), http.StatusInternalServerError) + return + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + http.Error(w, fmt.Sprintf("couldn't read response body: %v", err), http.StatusInternalServerError) + return + } + + var embeddingResponse vertexEmbeddingResponse + err = json.Unmarshal(body, &embeddingResponse) + if err != nil { + http.Error(w, fmt.Sprintf("couldn't unmarshal response body: %v", err), http.StatusInternalServerError) + return + } + + if len(embeddingResponse.Predictions) == 0 || len(embeddingResponse.Predictions[0].Embeddings.Values) == 0 { + http.Error(w, "no embeddings found in the response", http.StatusInternalServerError) + return + } + + if len(embeddingResponse.Predictions) > 1 { + fmt.Println("Info: multiple predictions found in the response - using only the first one") + } + + oaiResp := openAIResponse{ + Data: []openAIResponseData{ + { + Embedding: embeddingResponse.Predictions[0].Embeddings.Values, + }, + }, + } + + if err := json.NewEncoder(w).Encode(oaiResp); err != nil { + http.Error(w, fmt.Sprintf("couldn't encode response: %v", err), http.StatusInternalServerError) + return + } + return } From 560e56542102726f88a38d6ca79ac280ea164fbd Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Mon, 13 Jan 2025 09:42:52 +0100 Subject: [PATCH 08/11] fix: golint whitespaces --- gemini-vertex-model-provider/main.go | 1 - gemini-vertex-model-provider/server/server.go | 8 -------- 2 files changed, 9 deletions(-) diff --git a/gemini-vertex-model-provider/main.go b/gemini-vertex-model-provider/main.go index 4f01d0cd..a326fc36 100644 --- a/gemini-vertex-model-provider/main.go +++ b/gemini-vertex-model-provider/main.go @@ -44,7 +44,6 @@ func validate(ctx context.Context) error { } func configure(ctx context.Context) (*genai.Client, error) { - // Ensure that we have some valid credentials JSON data credsJSON := os.Getenv("OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CREDENTIALS_JSON") if credsJSON == "" { diff --git a/gemini-vertex-model-provider/server/server.go b/gemini-vertex-model-provider/server/server.go index 90c26618..777a975e 100644 --- a/gemini-vertex-model-provider/server/server.go +++ b/gemini-vertex-model-provider/server/server.go @@ -31,7 +31,6 @@ type server struct { } func Run(client *genai.Client, port string) error { - mux := http.NewServeMux() s := &server{ @@ -235,9 +234,7 @@ func (s *server) chatCompletions(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - } - } else { result, err := s.client.Models.GenerateContent(r.Context(), cr.Model, contents, config) if err != nil { @@ -314,7 +311,6 @@ func mapToOpenAIStreamChoice(candidates []*genai.Candidate) ([]openai.ChatComple var choices []openai.ChatCompletionStreamChoice for i, c := range candidates { - content, toolCalls, err := mapToOpenAIContentAndToolCalls(c.Content.Parts) if err != nil { return nil, fmt.Errorf("failed to map content and tool calls: %w", err) @@ -350,7 +346,6 @@ func mapToOpenAIChoice(candidates []*genai.Candidate) ([]openai.ChatCompletionCh var choices []openai.ChatCompletionChoice for i, c := range candidates { - content, toolCalls, err := mapToOpenAIContentAndToolCalls(c.Content.Parts) if err != nil { return nil, fmt.Errorf("failed to map content and tool calls: %w", err) @@ -423,7 +418,6 @@ func mapRoleToOpenAI(role string) string { } func mapMessagesFromOpenAI(messages []openai.ChatCompletionMessage) ([]*genai.Content, error) { - var contents []*genai.Content if len(messages) > 0 { contents = append(contents, &genai.Content{ @@ -472,7 +466,6 @@ func mapMessagesFromOpenAI(messages []openai.ChatCompletionMessage) ([]*genai.Co }, }) } - } else if m.Content != "" { // Pure text content content.Parts = append(content.Parts, &genai.Part{ @@ -564,7 +557,6 @@ type vertexEmbeddings struct { // embeddings - not (yet) provided by the Google GenAI package func (s *server) embeddings(w http.ResponseWriter, r *http.Request) { - var er openAIEmbeddingRequest if err := json.NewDecoder(r.Body).Decode(&er); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) From 7c1f2249f760dd4d0c741227625ef551aab48974 Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Mon, 13 Jan 2025 09:51:49 +0100 Subject: [PATCH 09/11] fix: PR comments --- gemini-vertex-model-provider/main.go | 3 +- gemini-vertex-model-provider/server/server.go | 37 ++++++------------- 2 files changed, 13 insertions(+), 27 deletions(-) diff --git a/gemini-vertex-model-provider/main.go b/gemini-vertex-model-provider/main.go index a326fc36..88a23619 100644 --- a/gemini-vertex-model-provider/main.go +++ b/gemini-vertex-model-provider/main.go @@ -18,8 +18,9 @@ func main() { if len(args) == 1 && args[0] == "validate" { if err := validate(ctx); err != nil { fmt.Printf("{\"error\": \"%s\"}\n", err) + os.Exit(1) } - os.Exit(1) + os.Exit(0) } c, err := configure(ctx) diff --git a/gemini-vertex-model-provider/server/server.go b/gemini-vertex-model-provider/server/server.go index 777a975e..b8a7ccdc 100644 --- a/gemini-vertex-model-provider/server/server.go +++ b/gemini-vertex-model-provider/server/server.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "strings" @@ -39,9 +38,9 @@ func Run(client *genai.Client, port string) error { } mux.HandleFunc("/{$}", s.healthz) - mux.HandleFunc("/v1/models", s.listModels) - mux.HandleFunc("/v1/chat/completions", s.chatCompletions) - mux.HandleFunc("/v1/embeddings", s.embeddings) + mux.HandleFunc("GET /v1/models", s.listModels) + mux.HandleFunc("POST /v1/chat/completions", s.chatCompletions) + mux.HandleFunc("POST /v1/embeddings", s.embeddings) httpServer := &http.Server{ Addr: "127.0.0.1:" + port, @@ -278,10 +277,11 @@ func mapUsageToOpenAI(usage *genai.GenerateContentResponseUsageMetadata) openai. } func mapToOpenAIContentAndToolCalls(parts []*genai.Part) (string, []openai.ToolCall, error) { - var toolCalls []openai.ToolCall - content := "" + var ( + toolCalls []openai.ToolCall + content string + ) for idx, p := range parts { - tidx := idx if p.Text != "" { content += "\n" + p.Text } @@ -291,7 +291,7 @@ func mapToOpenAIContentAndToolCalls(parts []*genai.Part) (string, []openai.ToolC return "", nil, fmt.Errorf("failed to marshal function arguments: %w", err) } toolCalls = append(toolCalls, openai.ToolCall{ - Index: &tidx, + Index: &idx, ID: p.FunctionCall.ID, Type: openai.ToolTypeFunction, Function: openai.FunctionCall{ @@ -305,10 +305,6 @@ func mapToOpenAIContentAndToolCalls(parts []*genai.Part) (string, []openai.ToolC } func mapToOpenAIStreamChoice(candidates []*genai.Candidate) ([]openai.ChatCompletionStreamChoice, error) { - if len(candidates) == 0 { - return nil, nil - } - var choices []openai.ChatCompletionStreamChoice for i, c := range candidates { content, toolCalls, err := mapToOpenAIContentAndToolCalls(c.Content.Parts) @@ -340,10 +336,6 @@ func mapToOpenAIStreamChoice(candidates []*genai.Candidate) ([]openai.ChatComple } func mapToOpenAIChoice(candidates []*genai.Candidate) ([]openai.ChatCompletionChoice, error) { - if len(candidates) == 0 { - return nil, nil - } - var choices []openai.ChatCompletionChoice for i, c := range candidates { content, toolCalls, err := mapToOpenAIContentAndToolCalls(c.Content.Parts) @@ -604,20 +596,13 @@ func (s *server) embeddings(w http.ResponseWriter, r *http.Request) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - http.Error(w, fmt.Sprintf("unexpected status code: %d", resp.StatusCode), http.StatusInternalServerError) - return - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - http.Error(w, fmt.Sprintf("couldn't read response body: %v", err), http.StatusInternalServerError) + http.Error(w, fmt.Sprintf("unexpected status code: %d", resp.StatusCode), resp.StatusCode) return } var embeddingResponse vertexEmbeddingResponse - err = json.Unmarshal(body, &embeddingResponse) - if err != nil { - http.Error(w, fmt.Sprintf("couldn't unmarshal response body: %v", err), http.StatusInternalServerError) + if err := json.NewDecoder(resp.Body).Decode(&embeddingResponse); err != nil { + http.Error(w, fmt.Sprintf("couldn't decode response: %v", err), http.StatusInternalServerError) return } From 7d6c10f618d23c76d88e94167069fb2413f59a8c Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Wed, 15 Jan 2025 15:34:58 +0100 Subject: [PATCH 10/11] Update gemini-vertex-model-provider/server/server.go Co-authored-by: Donnie Adams --- gemini-vertex-model-provider/server/server.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/gemini-vertex-model-provider/server/server.go b/gemini-vertex-model-provider/server/server.go index b8a7ccdc..4604e28a 100644 --- a/gemini-vertex-model-provider/server/server.go +++ b/gemini-vertex-model-provider/server/server.go @@ -261,8 +261,6 @@ func (s *server) chatCompletions(w http.ResponseWriter, r *http.Request) { return } } - - return } func mapUsageToOpenAI(usage *genai.GenerateContentResponseUsageMetadata) openai.Usage { From 754ce15a0a57af77a3c635a1fdf8fa1e42325f8d Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Wed, 15 Jan 2025 15:35:07 +0100 Subject: [PATCH 11/11] Update gemini-vertex-model-provider/server/server.go Co-authored-by: Donnie Adams --- gemini-vertex-model-provider/server/server.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/gemini-vertex-model-provider/server/server.go b/gemini-vertex-model-provider/server/server.go index 4604e28a..c5698ff5 100644 --- a/gemini-vertex-model-provider/server/server.go +++ b/gemini-vertex-model-provider/server/server.go @@ -625,6 +625,4 @@ func (s *server) embeddings(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("couldn't encode response: %v", err), http.StatusInternalServerError) return } - - return }