diff --git a/src/bigquery.py b/src/bigquery.py new file mode 100644 index 0000000..e69de29 diff --git a/src/document_extract.py b/src/document_extract.py index 1285ac2..d3dd105 100644 --- a/src/document_extract.py +++ b/src/document_extract.py @@ -22,7 +22,7 @@ def async_document_extract( bucket: str, name: str, timeout: int = 420, -): +) -> tuple(str, str): """Perform OCR with PDF/TIFF as source files on GCS. Original sample is here: diff --git a/src/function.py b/src/function.py new file mode 100644 index 0000000..d81af62 --- /dev/null +++ b/src/function.py @@ -0,0 +1,96 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functions_framework + +from dataclasses import dataclass +import datetime +import re + +from document_extract import async_document_extract +from storage import upload_to_gcs +from vertex_llm import predict_large_language_model_hack + + +def coerce_datetime_zulu(input_datetime: datetime.datetime): + """Force datetime into specific format. + + Args: + input_datetime (datetime.datetime): the datetime to coerce + + """ + regex = re.compile(r"(.*)(Z$)") + regex_match = regex.search(input_datetime) + if regex_match: + assert input_datetime.startswith(regex_match.group(1)) + assert input_datetime.endswith(regex_match.group(2)) + return datetime.datetime.fromisoformat(f'{input_datetime[:-1]}+00:00') + raise RuntimeError( + 'The input datetime is not in the expected format. ' + 'Please check format of the input datetime. Expected "Z" at the end' + ) + + +@dataclass +class CloudEventData: + event_id: str + event_type: str + bucket: str + name: str + metageneration: str + timeCreated: str + updated: str + + @classmethod + def read_datetimes(cls, kwargs): + for key in ['timeCreated', 'updated']: + kwargs[key] = coerce_datetime_zulu(kwargs.pop(key)) + + return cls(**kwargs) + + +# WEBHOOK FUNCTION +@functions_framework.cloud_event +def entrypoint(cloud_event): + event_id = cloud_event["id"] + event_type = cloud_event["type"] + bucket = cloud_event.data["bucket"] + name = cloud_event.data["name"] + metageneration = cloud_event.data["metageneration"] + timeCreated = coerce_datetime_zulu(cloud_event.data["timeCreated"]) + updated = coerce_datetime_zulu(cloud_event.data["updated"]) + + extracted_text = async_document_extract(bucket, name) + summary = predict_large_language_model_hack( + project_id="velociraptor-16p1-src", + model_name="text-bison-001", + temperature=0.2, + max_decode_steps=1024, + top_p=0.8, + top_k=40, + content=f'Summarize:\n{extracted_text}', + location="us-central1", + ) + + output_filename = f'{name.replace(".pdf", "")}_summary.txt' + upload_to_gcs( + bucket, + output_filename, + summary, + ) + + return { + 'summary': summary, + 'output_filename': output_filename, + } diff --git a/src/requirements.txt b/src/requirements.txt new file mode 100644 index 0000000..aa36645 --- /dev/null +++ b/src/requirements.txt @@ -0,0 +1,5 @@ +functions-framework +google-auth +google-cloud-aiplatform==1.25.0 +google-cloud-storage +google-cloud-vision \ No newline at end of file diff --git a/src/storage.py b/src/storage.py new file mode 100644 index 0000000..b8bd910 --- /dev/null +++ b/src/storage.py @@ -0,0 +1,30 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud import storage + + +def upload_to_gcs(bucket: str, name: str, data: str): + """Upload a string to Google Cloud Storage bucket. + + Args: + bucket (str): the name of the Storage bucket. Do not include "gs://" + name (str): the name of the file to create in the bucket + data (str): the data to store + + """ + client = storage.Client() + bucket = client.get_bucket(bucket) + blob = bucket.blob(name) + blob.upload_from_string(data) diff --git a/src/vertex_llm.py b/src/vertex_llm.py new file mode 100644 index 0000000..bc8dc9f --- /dev/null +++ b/src/vertex_llm.py @@ -0,0 +1,129 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google import auth +from google.cloud import aiplatform +from google.cloud.aiplatform import TextGenerationModel + +import datetime +import requests +from requests.adapters import HTTPAdapter +import urllib3 + + +def predict_large_language_model( + project_id: str, + model_name: str, + temperature: float, + max_decode_steps: int, + top_p: float, + top_k: int, + content: str, + location: str = "us-central1", + tuned_model_name: str = "", +) -> str: + """Predict using a Large Language Model. + + Args: + project_id (str): the Google Cloud project ID + model_name (str): the name of the LLM model to use + temperature (float): TODO(nicain) + max_decode_steps (int): TODO(nicain) + top_p (float): TODO(nicain) + top_k (int): TODO(nicain) + content (str): the text to summarize + location (str): the Google Cloud region to run in + tuned_model_name (str): TODO(nicain) + + Returns: + The summarization of the content + """ + aiplatform.init(project=project_id, location=location) + model = TextGenerationModel.from_pretrained(model_name) + if tuned_model_name: + model = model.get_tuned_model(tuned_model_name) + response = model.predict( + content, + temperature=temperature, + max_output_tokens=max_decode_steps, + top_k=top_k, + top_p=top_p,) + return response.text + + +def predict_large_language_model_hack( + project_id: str, + model_name: str, + temperature: float, + max_decode_steps: int, + top_p: float, + top_k: int, + content: str, + location: str = "us-central1", + tuned_model_name: str = "", +) -> str: + """Predict using a Large Language Model. + + Args: + project_id (str): the Google Cloud project ID + model_name (str): the name of the LLM model to use + temperature (float): TODO(nicain) + max_decode_steps (int): TODO(nicain) + top_p (float): TODO(nicain) + top_k (int): TODO(nicain) + content (str): the text to summarize + location (str): the Google Cloud region to run in + tuned_model_name (str): TODO(nicain) + + Returns: + The summarization of the content + """ + credentials, project_id = auth.default() + request = auth.transport.requests.Request() + credentials.refresh(request) + + audience = f'https://us-central1-aiplatform.googleapis.com/v1/projects/cloud-large-language-models/locations/us-central1/endpoints/{model_name}:predict' + s = requests.Session() + retries = urllib3.util.Retry( + connect=10, + read=1, + backoff_factor=0.1, + status_forcelist=[429, 500], + ) + + headers = {} + headers["Content-type"] = "application/json" + headers["Authorization"] = f"Bearer {credentials.token}" + + json_data = { + "instances": [ + {"content": content}, + ], + "parameters": { + "temperature": temperature, + "maxDecodeSteps": max_decode_steps, + "topP": top_p, + "topK": top_k, + } + } + + s.mount('https://', HTTPAdapter(max_retries=retries)) + response = s.post( + audience, + headers=headers, + timeout=datetime.timedelta(minutes=15).total_seconds(), + json=json_data, + ) + + return response.json()['predictions'][0]['content'] \ No newline at end of file diff --git a/test/document_extract_text.py b/test/document_extract_text.py new file mode 100644 index 0000000..071dfc5 --- /dev/null +++ b/test/document_extract_text.py @@ -0,0 +1,14 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/test/function_test.py b/test/function_test.py new file mode 100644 index 0000000..fc4ff5b --- /dev/null +++ b/test/function_test.py @@ -0,0 +1,51 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +@dataclass +class CloudEventDataMock: + bucket: str + name: str + metageneration: str + timeCreated: str + updated: str + + def __getitem__(self, key): + return self.__getattribute__(key) + + +@dataclass +class CloudEventMock: + data: str + id: str + type: str + + def __getitem__(self, key): + if key == 'id': + return self.id + elif key == 'type': + return self.type + else: + raise RuntimeError(f'Unknown key: {key}') + +MOCK_CLOUD_EVENT = CloudEventMock( + id='7631145714375969', + type='google.cloud.storage.object.v1.finalized', + data=CloudEventDataMock( + bucket='velociraptor-16p1-mock-users-bucket', + name='9404001v1.pdf', + metageneration='1', + timeCreated='2023-05-08T19:28:55.255Z', + updated='2023-05-08T19:28:55.255Z', + ) +)