-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
326 additions
and
1 deletion.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import functions_framework | ||
|
||
from dataclasses import dataclass | ||
import datetime | ||
import re | ||
|
||
from document_extract import async_document_extract | ||
from storage import upload_to_gcs | ||
from vertex_llm import predict_large_language_model_hack | ||
|
||
|
||
def coerce_datetime_zulu(input_datetime: datetime.datetime): | ||
"""Force datetime into specific format. | ||
Args: | ||
input_datetime (datetime.datetime): the datetime to coerce | ||
""" | ||
regex = re.compile(r"(.*)(Z$)") | ||
regex_match = regex.search(input_datetime) | ||
if regex_match: | ||
assert input_datetime.startswith(regex_match.group(1)) | ||
assert input_datetime.endswith(regex_match.group(2)) | ||
return datetime.datetime.fromisoformat(f'{input_datetime[:-1]}+00:00') | ||
raise RuntimeError( | ||
'The input datetime is not in the expected format. ' | ||
'Please check format of the input datetime. Expected "Z" at the end' | ||
) | ||
|
||
|
||
@dataclass | ||
class CloudEventData: | ||
event_id: str | ||
event_type: str | ||
bucket: str | ||
name: str | ||
metageneration: str | ||
timeCreated: str | ||
updated: str | ||
|
||
@classmethod | ||
def read_datetimes(cls, kwargs): | ||
for key in ['timeCreated', 'updated']: | ||
kwargs[key] = coerce_datetime_zulu(kwargs.pop(key)) | ||
|
||
return cls(**kwargs) | ||
|
||
|
||
# WEBHOOK FUNCTION | ||
@functions_framework.cloud_event | ||
def entrypoint(cloud_event): | ||
event_id = cloud_event["id"] | ||
event_type = cloud_event["type"] | ||
bucket = cloud_event.data["bucket"] | ||
name = cloud_event.data["name"] | ||
metageneration = cloud_event.data["metageneration"] | ||
timeCreated = coerce_datetime_zulu(cloud_event.data["timeCreated"]) | ||
updated = coerce_datetime_zulu(cloud_event.data["updated"]) | ||
|
||
extracted_text = async_document_extract(bucket, name) | ||
summary = predict_large_language_model_hack( | ||
project_id="velociraptor-16p1-src", | ||
model_name="text-bison-001", | ||
temperature=0.2, | ||
max_decode_steps=1024, | ||
top_p=0.8, | ||
top_k=40, | ||
content=f'Summarize:\n{extracted_text}', | ||
location="us-central1", | ||
) | ||
|
||
output_filename = f'{name.replace(".pdf", "")}_summary.txt' | ||
upload_to_gcs( | ||
bucket, | ||
output_filename, | ||
summary, | ||
) | ||
|
||
return { | ||
'summary': summary, | ||
'output_filename': output_filename, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
functions-framework | ||
google-auth | ||
google-cloud-aiplatform==1.25.0 | ||
google-cloud-storage | ||
google-cloud-vision |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from google.cloud import storage | ||
|
||
|
||
def upload_to_gcs(bucket: str, name: str, data: str): | ||
"""Upload a string to Google Cloud Storage bucket. | ||
Args: | ||
bucket (str): the name of the Storage bucket. Do not include "gs://" | ||
name (str): the name of the file to create in the bucket | ||
data (str): the data to store | ||
""" | ||
client = storage.Client() | ||
bucket = client.get_bucket(bucket) | ||
blob = bucket.blob(name) | ||
blob.upload_from_string(data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from google import auth | ||
from google.cloud import aiplatform | ||
from google.cloud.aiplatform import TextGenerationModel | ||
|
||
import datetime | ||
import requests | ||
from requests.adapters import HTTPAdapter | ||
import urllib3 | ||
|
||
|
||
def predict_large_language_model( | ||
project_id: str, | ||
model_name: str, | ||
temperature: float, | ||
max_decode_steps: int, | ||
top_p: float, | ||
top_k: int, | ||
content: str, | ||
location: str = "us-central1", | ||
tuned_model_name: str = "", | ||
) -> str: | ||
"""Predict using a Large Language Model. | ||
Args: | ||
project_id (str): the Google Cloud project ID | ||
model_name (str): the name of the LLM model to use | ||
temperature (float): TODO(nicain) | ||
max_decode_steps (int): TODO(nicain) | ||
top_p (float): TODO(nicain) | ||
top_k (int): TODO(nicain) | ||
content (str): the text to summarize | ||
location (str): the Google Cloud region to run in | ||
tuned_model_name (str): TODO(nicain) | ||
Returns: | ||
The summarization of the content | ||
""" | ||
aiplatform.init(project=project_id, location=location) | ||
model = TextGenerationModel.from_pretrained(model_name) | ||
if tuned_model_name: | ||
model = model.get_tuned_model(tuned_model_name) | ||
response = model.predict( | ||
content, | ||
temperature=temperature, | ||
max_output_tokens=max_decode_steps, | ||
top_k=top_k, | ||
top_p=top_p,) | ||
return response.text | ||
|
||
|
||
def predict_large_language_model_hack( | ||
project_id: str, | ||
model_name: str, | ||
temperature: float, | ||
max_decode_steps: int, | ||
top_p: float, | ||
top_k: int, | ||
content: str, | ||
location: str = "us-central1", | ||
tuned_model_name: str = "", | ||
) -> str: | ||
"""Predict using a Large Language Model. | ||
Args: | ||
project_id (str): the Google Cloud project ID | ||
model_name (str): the name of the LLM model to use | ||
temperature (float): TODO(nicain) | ||
max_decode_steps (int): TODO(nicain) | ||
top_p (float): TODO(nicain) | ||
top_k (int): TODO(nicain) | ||
content (str): the text to summarize | ||
location (str): the Google Cloud region to run in | ||
tuned_model_name (str): TODO(nicain) | ||
Returns: | ||
The summarization of the content | ||
""" | ||
credentials, project_id = auth.default() | ||
request = auth.transport.requests.Request() | ||
credentials.refresh(request) | ||
|
||
audience = f'https://us-central1-aiplatform.googleapis.com/v1/projects/cloud-large-language-models/locations/us-central1/endpoints/{model_name}:predict' | ||
s = requests.Session() | ||
retries = urllib3.util.Retry( | ||
connect=10, | ||
read=1, | ||
backoff_factor=0.1, | ||
status_forcelist=[429, 500], | ||
) | ||
|
||
headers = {} | ||
headers["Content-type"] = "application/json" | ||
headers["Authorization"] = f"Bearer {credentials.token}" | ||
|
||
json_data = { | ||
"instances": [ | ||
{"content": content}, | ||
], | ||
"parameters": { | ||
"temperature": temperature, | ||
"maxDecodeSteps": max_decode_steps, | ||
"topP": top_p, | ||
"topK": top_k, | ||
} | ||
} | ||
|
||
s.mount('https://', HTTPAdapter(max_retries=retries)) | ||
response = s.post( | ||
audience, | ||
headers=headers, | ||
timeout=datetime.timedelta(minutes=15).total_seconds(), | ||
json=json_data, | ||
) | ||
|
||
return response.json()['predictions'][0]['content'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
@dataclass | ||
class CloudEventDataMock: | ||
bucket: str | ||
name: str | ||
metageneration: str | ||
timeCreated: str | ||
updated: str | ||
|
||
def __getitem__(self, key): | ||
return self.__getattribute__(key) | ||
|
||
|
||
@dataclass | ||
class CloudEventMock: | ||
data: str | ||
id: str | ||
type: str | ||
|
||
def __getitem__(self, key): | ||
if key == 'id': | ||
return self.id | ||
elif key == 'type': | ||
return self.type | ||
else: | ||
raise RuntimeError(f'Unknown key: {key}') | ||
|
||
MOCK_CLOUD_EVENT = CloudEventMock( | ||
id='7631145714375969', | ||
type='google.cloud.storage.object.v1.finalized', | ||
data=CloudEventDataMock( | ||
bucket='velociraptor-16p1-mock-users-bucket', | ||
name='9404001v1.pdf', | ||
metageneration='1', | ||
timeCreated='2023-05-08T19:28:55.255Z', | ||
updated='2023-05-08T19:28:55.255Z', | ||
) | ||
) |