Skip to content

Commit

Permalink
feat: upload and refactor of code
Browse files Browse the repository at this point in the history
  • Loading branch information
telpirion committed May 11, 2023
1 parent 7f22f33 commit 19530e3
Show file tree
Hide file tree
Showing 8 changed files with 326 additions and 1 deletion.
Empty file added src/bigquery.py
Empty file.
2 changes: 1 addition & 1 deletion src/document_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def async_document_extract(
bucket: str,
name: str,
timeout: int = 420,
):
) -> tuple(str, str):
"""Perform OCR with PDF/TIFF as source files on GCS.
Original sample is here:
Expand Down
96 changes: 96 additions & 0 deletions src/function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functions_framework

from dataclasses import dataclass
import datetime
import re

from document_extract import async_document_extract
from storage import upload_to_gcs
from vertex_llm import predict_large_language_model_hack


def coerce_datetime_zulu(input_datetime: datetime.datetime):
"""Force datetime into specific format.
Args:
input_datetime (datetime.datetime): the datetime to coerce
"""
regex = re.compile(r"(.*)(Z$)")
regex_match = regex.search(input_datetime)
if regex_match:
assert input_datetime.startswith(regex_match.group(1))
assert input_datetime.endswith(regex_match.group(2))
return datetime.datetime.fromisoformat(f'{input_datetime[:-1]}+00:00')
raise RuntimeError(
'The input datetime is not in the expected format. '
'Please check format of the input datetime. Expected "Z" at the end'
)


@dataclass
class CloudEventData:
event_id: str
event_type: str
bucket: str
name: str
metageneration: str
timeCreated: str
updated: str

@classmethod
def read_datetimes(cls, kwargs):
for key in ['timeCreated', 'updated']:
kwargs[key] = coerce_datetime_zulu(kwargs.pop(key))

return cls(**kwargs)


# WEBHOOK FUNCTION
@functions_framework.cloud_event
def entrypoint(cloud_event):
event_id = cloud_event["id"]
event_type = cloud_event["type"]
bucket = cloud_event.data["bucket"]
name = cloud_event.data["name"]
metageneration = cloud_event.data["metageneration"]
timeCreated = coerce_datetime_zulu(cloud_event.data["timeCreated"])
updated = coerce_datetime_zulu(cloud_event.data["updated"])

extracted_text = async_document_extract(bucket, name)
summary = predict_large_language_model_hack(
project_id="velociraptor-16p1-src",
model_name="text-bison-001",
temperature=0.2,
max_decode_steps=1024,
top_p=0.8,
top_k=40,
content=f'Summarize:\n{extracted_text}',
location="us-central1",
)

output_filename = f'{name.replace(".pdf", "")}_summary.txt'
upload_to_gcs(
bucket,
output_filename,
summary,
)

return {
'summary': summary,
'output_filename': output_filename,
}
5 changes: 5 additions & 0 deletions src/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
functions-framework
google-auth
google-cloud-aiplatform==1.25.0
google-cloud-storage
google-cloud-vision
30 changes: 30 additions & 0 deletions src/storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud import storage


def upload_to_gcs(bucket: str, name: str, data: str):
"""Upload a string to Google Cloud Storage bucket.
Args:
bucket (str): the name of the Storage bucket. Do not include "gs://"
name (str): the name of the file to create in the bucket
data (str): the data to store
"""
client = storage.Client()
bucket = client.get_bucket(bucket)
blob = bucket.blob(name)
blob.upload_from_string(data)
129 changes: 129 additions & 0 deletions src/vertex_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google import auth
from google.cloud import aiplatform
from google.cloud.aiplatform import TextGenerationModel

import datetime
import requests
from requests.adapters import HTTPAdapter
import urllib3


def predict_large_language_model(
project_id: str,
model_name: str,
temperature: float,
max_decode_steps: int,
top_p: float,
top_k: int,
content: str,
location: str = "us-central1",
tuned_model_name: str = "",
) -> str:
"""Predict using a Large Language Model.
Args:
project_id (str): the Google Cloud project ID
model_name (str): the name of the LLM model to use
temperature (float): TODO(nicain)
max_decode_steps (int): TODO(nicain)
top_p (float): TODO(nicain)
top_k (int): TODO(nicain)
content (str): the text to summarize
location (str): the Google Cloud region to run in
tuned_model_name (str): TODO(nicain)
Returns:
The summarization of the content
"""
aiplatform.init(project=project_id, location=location)
model = TextGenerationModel.from_pretrained(model_name)
if tuned_model_name:
model = model.get_tuned_model(tuned_model_name)
response = model.predict(
content,
temperature=temperature,
max_output_tokens=max_decode_steps,
top_k=top_k,
top_p=top_p,)
return response.text


def predict_large_language_model_hack(
project_id: str,
model_name: str,
temperature: float,
max_decode_steps: int,
top_p: float,
top_k: int,
content: str,
location: str = "us-central1",
tuned_model_name: str = "",
) -> str:
"""Predict using a Large Language Model.
Args:
project_id (str): the Google Cloud project ID
model_name (str): the name of the LLM model to use
temperature (float): TODO(nicain)
max_decode_steps (int): TODO(nicain)
top_p (float): TODO(nicain)
top_k (int): TODO(nicain)
content (str): the text to summarize
location (str): the Google Cloud region to run in
tuned_model_name (str): TODO(nicain)
Returns:
The summarization of the content
"""
credentials, project_id = auth.default()
request = auth.transport.requests.Request()
credentials.refresh(request)

audience = f'https://us-central1-aiplatform.googleapis.com/v1/projects/cloud-large-language-models/locations/us-central1/endpoints/{model_name}:predict'
s = requests.Session()
retries = urllib3.util.Retry(
connect=10,
read=1,
backoff_factor=0.1,
status_forcelist=[429, 500],
)

headers = {}
headers["Content-type"] = "application/json"
headers["Authorization"] = f"Bearer {credentials.token}"

json_data = {
"instances": [
{"content": content},
],
"parameters": {
"temperature": temperature,
"maxDecodeSteps": max_decode_steps,
"topP": top_p,
"topK": top_k,
}
}

s.mount('https://', HTTPAdapter(max_retries=retries))
response = s.post(
audience,
headers=headers,
timeout=datetime.timedelta(minutes=15).total_seconds(),
json=json_data,
)

return response.json()['predictions'][0]['content']
14 changes: 14 additions & 0 deletions test/document_extract_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

51 changes: 51 additions & 0 deletions test/function_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@dataclass
class CloudEventDataMock:
bucket: str
name: str
metageneration: str
timeCreated: str
updated: str

def __getitem__(self, key):
return self.__getattribute__(key)


@dataclass
class CloudEventMock:
data: str
id: str
type: str

def __getitem__(self, key):
if key == 'id':
return self.id
elif key == 'type':
return self.type
else:
raise RuntimeError(f'Unknown key: {key}')

MOCK_CLOUD_EVENT = CloudEventMock(
id='7631145714375969',
type='google.cloud.storage.object.v1.finalized',
data=CloudEventDataMock(
bucket='velociraptor-16p1-mock-users-bucket',
name='9404001v1.pdf',
metageneration='1',
timeCreated='2023-05-08T19:28:55.255Z',
updated='2023-05-08T19:28:55.255Z',
)
)

0 comments on commit 19530e3

Please sign in to comment.