diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..364bf9d --- /dev/null +++ b/.flake8 @@ -0,0 +1,2 @@ +[flake8] + max-line-length = 100 diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 6ba7cb0..0415928 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -43,4 +43,6 @@ jobs: REGISTRY=$(grep "REGISTRY_URL := " $MAKEFILE | cut -d\ -f3) echo dev-tools=${REGISTRY}/${IMAGE}:${VERSION} >> "$GITHUB_OUTPUT" fi - - run: docker run --rm -v ${{ github.workspace }}:/workspace ${{ steps.variables.outputs.dev-tools }} /usr/local/bin/test_lint.sh + - run: docker run --rm -e EXCLUDE_LINT_DIRS -v ${{ github.workspace }}:/workspace ${{ steps.variables.outputs.dev-tools }} /usr/local/bin/test_lint.sh + env: + EXCLUDE_LINT_DIRS: '\./assets|\./.github|\./docs|\./env' diff --git a/Makefile b/Makefile index fc2d758..be11259 100644 --- a/Makefile +++ b/Makefile @@ -25,7 +25,7 @@ REGISTRY_URL := gcr.io/cloud-foundation-cicd # Enter docker container for local development .PHONY: docker_run docker_run: - docker run --rm -it \ + docker run --rm -it --pull=always \ -e SERVICE_ACCOUNT_JSON \ -v "$(CURDIR)":/workspace \ $(REGISTRY_URL)/${DOCKER_IMAGE_DEVELOPER_TOOLS}:${DOCKER_TAG_VERSION_DEVELOPER_TOOLS} \ @@ -34,7 +34,7 @@ docker_run: # Execute prepare tests within the docker container .PHONY: docker_test_prepare docker_test_prepare: - docker run --rm -it \ + docker run --rm -it --pull=always \ -e SERVICE_ACCOUNT_JSON \ -e TF_VAR_org_id \ -e TF_VAR_folder_id \ @@ -46,7 +46,7 @@ docker_test_prepare: # Clean up test environment within the docker container .PHONY: docker_test_cleanup docker_test_cleanup: - docker run --rm -it \ + docker run --rm -it --pull=always \ -e SERVICE_ACCOUNT_JSON \ -e TF_VAR_org_id \ -e TF_VAR_folder_id \ @@ -58,7 +58,7 @@ docker_test_cleanup: # Execute integration tests within the docker container .PHONY: docker_test_integration docker_test_integration: - docker run --rm -it \ + docker run --rm -it --pull=always \ -e SERVICE_ACCOUNT_JSON \ -v "$(CURDIR)":/workspace \ $(REGISTRY_URL)/${DOCKER_IMAGE_DEVELOPER_TOOLS}:${DOCKER_TAG_VERSION_DEVELOPER_TOOLS} \ @@ -67,7 +67,7 @@ docker_test_integration: # Execute lint tests within the docker container .PHONY: docker_test_lint docker_test_lint: - docker run --rm -it \ + docker run --rm -it --pull=always \ -e EXCLUDE_LINT_DIRS \ -v "$(CURDIR)":/workspace \ $(REGISTRY_URL)/${DOCKER_IMAGE_DEVELOPER_TOOLS}:${DOCKER_TAG_VERSION_DEVELOPER_TOOLS} \ @@ -76,7 +76,7 @@ docker_test_lint: # Generate documentation .PHONY: docker_generate_docs docker_generate_docs: - docker run --rm -it \ + docker run --rm -it --pull=always \ -v "$(CURDIR)":/workspace \ $(REGISTRY_URL)/${DOCKER_IMAGE_DEVELOPER_TOOLS}:${DOCKER_TAG_VERSION_DEVELOPER_TOOLS} \ /bin/bash -c 'source /usr/local/bin/task_helper_functions.sh && generate_docs' @@ -84,7 +84,7 @@ docker_generate_docs: # Generate metadata .PHONY: docker_generate_metadata_w_display docker_generate_metadata: - docker run --rm -it \ + docker run --rm -it --pull=always \ -v "$(CURDIR)":/workspace \ $(REGISTRY_URL)/${DOCKER_IMAGE_DEVELOPER_TOOLS}:${DOCKER_TAG_VERSION_DEVELOPER_TOOLS} \ /bin/bash -c 'source /usr/local/bin/task_helper_functions.sh && generate_metadata display' diff --git a/README.md b/README.md index 2142aee..6d5d464 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,12 @@ Functional examples are included in the | Name | Description | Type | Default | Required | |------|-------------|------|---------|:--------:| | bucket\_name | The name of the bucket to create | `string` | n/a | yes | -| project\_id | The project ID to deploy to | `string` | n/a | yes | +| gcf\_timeout\_seconds | GCF execution timeout | `number` | `900` | no | +| project\_id | The Google Cloud project ID to deploy to | `string` | n/a | yes | +| region | Google Cloud region | `string` | `"us-central1"` | no | +| time\_to\_enable\_apis | Time taken to enable APIs in new projects | `string` | `"300s"` | no | +| webhook\_name | Name of the webhook | `string` | `"webhook"` | no | +| webhook\_path | Path to the webhook directory | `string` | `"webhook"` | no | ## Outputs diff --git a/examples/simple_example/README.md b/examples/simple_example/README.md new file mode 100644 index 0000000..7f9f66e --- /dev/null +++ b/examples/simple_example/README.md @@ -0,0 +1,14 @@ +# Simple Example + + +## Inputs + +| Name | Description | Type | Default | Required | +|------|-------------|------|---------|:--------:| +| project\_id | GCP project for provisioning cloud resources. | `any` | n/a | yes | + +## Outputs + +No outputs. + + diff --git a/examples/simple_example/main.tf b/examples/simple_example/main.tf new file mode 100644 index 0000000..803696a --- /dev/null +++ b/examples/simple_example/main.tf @@ -0,0 +1,26 @@ +/** + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +resource "random_id" "id" { + byte_length = 4 +} + +module "simple" { + source = "../../" + project_id = var.project_id + webhook_path = abspath("../../webhook") + bucket_name = "cft-test-${random_id.id.hex}" +} diff --git a/examples/simple_example/variables.tf b/examples/simple_example/variables.tf new file mode 100644 index 0000000..04b5602 --- /dev/null +++ b/examples/simple_example/variables.tf @@ -0,0 +1,19 @@ +/** + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +variable "project_id" { + description = "GCP project for provisioning cloud resources." +} diff --git a/main.tf b/main.tf index 1901453..9de1e90 100644 --- a/main.tf +++ b/main.tf @@ -14,13 +14,10 @@ * limitations under the License. */ -data "google_project" "project" { - project_id = var.project_id -} - module "project_services" { - source = "terraform-google-modules/project-factory/google//modules/project_services" - version = "~> 14.2" + source = "terraform-google-modules/project-factory/google//modules/project_services" + version = "~> 14.2" + disable_services_on_destroy = false project_id = var.project_id @@ -40,10 +37,28 @@ module "project_services" { ] } +data "google_project" "project" { + project_id = var.project_id + depends_on = [ + module.project_services, + ] +} + +# Gate resources till APIs are enabled +resource "null_resource" "previous_time" {} + +resource "time_sleep" "wait_for_apis" { + depends_on = [ + null_resource.previous_time, + module.project_services, + ] + + create_duration = var.time_to_enable_apis +} data "archive_file" "webhook" { type = "zip" - source_dir = "webhook" + source_dir = var.webhook_path output_path = abspath("./.tmp/${var.webhook_name}.zip") } @@ -54,7 +69,7 @@ resource "google_storage_bucket_object" "webhook" { } resource "google_service_account" "webhook" { - project = var.project_id + project = var.project_id account_id = "webhook-service-account" display_name = "Serverless Webhooks Service Account" depends_on = [ @@ -62,55 +77,24 @@ resource "google_service_account" "webhook" { ] } -resource "google_project_iam_member" "aiplatform_user" { - project = var.project_id - role = "roles/aiplatform.user" - member = "serviceAccount:${google_service_account.webhook.email}" - depends_on = [ - module.project_services, - ] -} - -resource "google_project_iam_member" "storage_admin" { - project = var.project_id - role = "roles/storage.admin" - member = "serviceAccount:${google_service_account.webhook.email}" - depends_on = [ - module.project_services, - ] -} - -resource "google_project_iam_member" "log_writer" { - project = var.project_id - role = "roles/logging.logWriter" - member = "serviceAccount:${google_service_account.webhook.email}" - depends_on = [ - module.project_services, - ] -} - -resource "google_project_iam_member" "data_editor" { - project = var.project_id - role = "roles/bigquery.dataEditor" - member = "serviceAccount:${google_service_account.webhook.email}" - depends_on = [ - module.project_services, - ] -} - -resource "google_project_iam_member" "artifactregistry_reader" { +resource "google_project_iam_member" "webhook_sa_roles" { project = var.project_id - role = "roles/artifactregistry.reader" + for_each = toset([ + "roles/cloudfunctions.invoker", + "roles/storage.admin", + "roles/logging.logWriter", + "roles/artifactregistry.reader", + "roles/bigquery.dataEditor", + "roles/aiplatform.user", + ]) + role = each.key member = "serviceAccount:${google_service_account.webhook.email}" - depends_on = [ - module.project_services, - ] } resource "google_cloudfunctions2_function" "webhook" { - project = var.project_id - name = var.webhook_name - location = var.region + project = var.project_id + name = var.webhook_name + location = var.region build_config { runtime = "python310" @@ -125,22 +109,25 @@ resource "google_cloudfunctions2_function" "webhook" { service_config { - service_account_email = google_service_account.webhook.email - max_instance_count = 100 - available_memory = "4G" - available_cpu = 2 + service_account_email = google_service_account.webhook.email + max_instance_count = 100 + available_memory = "4G" + available_cpu = 2 max_instance_request_concurrency = 16 - timeout_seconds = var.timeout_seconds + timeout_seconds = var.gcf_timeout_seconds environment_variables = { - PROJECT_ID = var.project_id - LOCATION = var.region + PROJECT_ID = var.project_id + LOCATION = var.region OUTPUT_BUCKET = google_storage_bucket.output.name - DATASET_ID = google_bigquery_dataset.default.dataset_id - TABLE_ID = google_bigquery_table.default.table_id + DATASET_ID = google_bigquery_dataset.default.dataset_id + TABLE_ID = google_bigquery_table.default.table_id } } depends_on = [ module.project_services, + time_sleep.wait_for_apis, + google_project_iam_member.webhook_sa_roles, + ] } @@ -153,10 +140,10 @@ resource "google_bigquery_dataset" "default" { } resource "google_bigquery_table" "default" { - dataset_id = google_bigquery_dataset.default.dataset_id - table_id = "summary_table" - project = var.project_id - deletion_protection=false + dataset_id = google_bigquery_dataset.default.dataset_id + table_id = "summary_table" + project = var.project_id + deletion_protection = false schema = < str: """Iterates over blobs in output bucket to get full OCR result. Arguments: gcs_destination_uri: the URI where the OCR output was saved. bucket_name: the name of the bucket where the output was saved. - + Returns the full text of the document. """ storage_client = storage.Client() - match = re.match(r'gs://([^/]+)/(.+)', gcs_destination_uri) + match = re.match(r"gs://([^/]+)/(.+)", gcs_destination_uri) prefix = match.group(2) bucket = storage_client.get_bucket(bucket_name) # List objects with the given prefix, filtering out folders. - blob_list = [blob for blob in list(bucket.list_blobs( - prefix=prefix)) if not blob.name.endswith('/')] + blob_list = [ + blob + for blob in list(bucket.list_blobs(prefix=prefix)) + if not blob.name.endswith("/") + ] # Concatenate all text from the blobs complete_text = "" for output in blob_list: - json_string = output.download_as_bytes().decode("utf-8") response = json.loads(json_string) # The actual response for the first page of the input file. - page_response = response['responses'][0] - annotation = page_response['fullTextAnnotation'] + page_response = response["responses"][0] + annotation = page_response["fullTextAnnotation"] - complete_text = complete_text + annotation['text'] + complete_text = complete_text + annotation["text"] return complete_text diff --git a/webhook/document_extract_test.py b/webhook/document_extract_test.py index c97d9f2..01da34e 100644 --- a/webhook/document_extract_test.py +++ b/webhook/document_extract_test.py @@ -25,6 +25,7 @@ _OUTPUT_BUCKET = f"{_PROJECT_ID}_output" _FILE_NAME = "9404001v1.pdf" + # System / integration test @backoff.on_exception(backoff.expo, Exception, max_tries=3) def test_async_document_extract_system(capsys): diff --git a/webhook/main.py b/webhook/main.py index 47322d9..7319eaf 100644 --- a/webhook/main.py +++ b/webhook/main.py @@ -18,27 +18,27 @@ import vertexai from vertexai.preview.language_models import TextGenerationModel -_FUNCTIONS_GCS_EVENT_LOGGER = 'function-triggered-by-storage' -_FUNCTIONS_VERTEX_EVENT_LOGGER = 'summarization-by-llm' - from bigquery import write_summarization_to_table from document_extract import async_document_extract from storage import upload_to_gcs from vertex_llm import predict_large_language_model from utils import coerce_datetime_zulu, truncate_complete_text -_PROJECT_ID = os.environ['PROJECT_ID'] -_OUTPUT_BUCKET = os.environ['OUTPUT_BUCKET'] -_LOCATION = os.environ['LOCATION'] -_MODEL_NAME = 'text-bison@001' +_FUNCTIONS_GCS_EVENT_LOGGER = "function-triggered-by-storage" +_FUNCTIONS_VERTEX_EVENT_LOGGER = "summarization-by-llm" + +_PROJECT_ID = os.environ["PROJECT_ID"] +_OUTPUT_BUCKET = os.environ["OUTPUT_BUCKET"] +_LOCATION = os.environ["LOCATION"] +_MODEL_NAME = "text-bison@001" _DEFAULT_PARAMETERS = { - "temperature": .2, + "temperature": 0.2, "max_output_tokens": 256, - "top_p": .95, + "top_p": 0.95, "top_k": 40, } -_DATASET_ID = os.environ['DATASET_ID'] -_TABLE_ID = os.environ['TABLE_ID'] +_DATASET_ID = os.environ["DATASET_ID"] +_TABLE_ID = os.environ["TABLE_ID"] def default_marshaller(o: object) -> str: @@ -60,8 +60,8 @@ def summarize_text(text: str, parameters: None | dict[str, int | float] = None) model = TextGenerationModel.from_pretrained("text-bison@001") response = model.predict( - f'Provide a summary with about two sentences for the following article: {text}\n' - 'Summary:', + f"Provide a summary with about two sentences for the following article: {text}\n" + "Summary:", **final_parameters, ) print(f"Response from Model: {response.text}") @@ -70,36 +70,34 @@ def summarize_text(text: str, parameters: None | dict[str, int | float] = None) def entrypoint(request: object) -> dict[str, str]: - data = request.get_json() - if data.get('kind', None) == 'storage#object': + if data.get("kind", None) == "storage#object": return cloud_event_entrypoint( - name = data['name'], - event_id = data["id"], - bucket = data["bucket"], - time_created = coerce_datetime_zulu(data["timeCreated"]), + name=data["name"], + event_id=data["id"], + bucket=data["bucket"], + time_created=coerce_datetime_zulu(data["timeCreated"]), ) else: return summarization_entrypoint( - name=data['name'], - extracted_text=data['text'], + name=data["name"], + extracted_text=data["text"], time_created=datetime.datetime.now(datetime.timezone.utc), - event_id='CURL_TRIGGER' + event_id="CURL_TRIGGER", ) def cloud_event_entrypoint(event_id, bucket, name, time_created): - orig_pdf_uri = f"gs://{bucket}/{name}" logging_client = logging.Client() logger = logging_client.logger(_FUNCTIONS_GCS_EVENT_LOGGER) - logger.log(f"cloud_event_id({event_id}): UPLOAD {orig_pdf_uri}", - severity="INFO") - + logger.log(f"cloud_event_id({event_id}): UPLOAD {orig_pdf_uri}", severity="INFO") + extracted_text = async_document_extract(bucket, name, output_bucket=_OUTPUT_BUCKET) - logger.log(f"cloud_event_id({event_id}): OCR gs://{bucket}/{name}", - severity="INFO") - + logger.log( + f"cloud_event_id({event_id}): OCR gs://{bucket}/{name}", severity="INFO" + ) + return summarization_entrypoint( name, extracted_text, @@ -110,12 +108,12 @@ def cloud_event_entrypoint(event_id, bucket, name, time_created): def summarization_entrypoint( - name, - extracted_text, - time_created, - bucket=None, - event_id=None, - ): + name, + extracted_text, + time_created, + bucket=None, + event_id=None, +): logging_client = logging.Client() logger = logging_client.logger(_FUNCTIONS_VERTEX_EVENT_LOGGER) @@ -125,9 +123,10 @@ def summarization_entrypoint( complete_text_filename, extracted_text, ) - logger.log(f"cloud_event_id({event_id}): FULLTEXT_UPLOAD {complete_text_filename}", - severity="INFO") - + logger.log( + f"cloud_event_id({event_id}): FULLTEXT_UPLOAD {complete_text_filename}", + severity="INFO", + ) extracted_text_trunc = truncate_complete_text(extracted_text) summary = predict_large_language_model( @@ -137,12 +136,10 @@ def summarization_entrypoint( max_decode_steps=1024, top_p=0.8, top_k=40, - content=f'Summarize:\n{extracted_text_trunc}', + content=f"Summarize:\n{extracted_text_trunc}", location="us-central1", ) - logger.log(f"cloud_event_id({event_id}): SUMMARY_COMPLETE", - severity="INFO") - + logger.log(f"cloud_event_id({event_id}): SUMMARY_COMPLETE", severity="INFO") output_filename = f'system-test/{name.replace(".pdf", "")}_summary.txt' upload_to_gcs( @@ -150,8 +147,9 @@ def summarization_entrypoint( output_filename, summary, ) - logger.log(f"cloud_event_id({event_id}): SUMMARY_UPLOAD {upload_to_gcs}", - severity="INFO") + logger.log( + f"cloud_event_id({event_id}): SUMMARY_UPLOAD {upload_to_gcs}", severity="INFO" + ) # If we have any errors, they'll be caught by the bigquery module errors = write_summarization_to_table( @@ -168,13 +166,13 @@ def summarization_entrypoint( ) if len(errors) > 0: - logger.log(f"cloud_event_id({event_id}): DB_WRITE_ERROR: {errors}", - severity="ERROR") + logger.log( + f"cloud_event_id({event_id}): DB_WRITE_ERROR: {errors}", severity="ERROR" + ) return errors - logger.log(f"cloud_event_id({event_id}): DB_WRITE", - severity="INFO") + logger.log(f"cloud_event_id({event_id}): DB_WRITE", severity="INFO") if errors: return errors - return {'summary': summary} + return {"summary": summary} diff --git a/webhook/main_test.py b/webhook/main_test.py index dc62ea7..8db81dd 100644 --- a/webhook/main_test.py +++ b/webhook/main_test.py @@ -19,19 +19,20 @@ from dataclasses import dataclass -_PROJECT_ID = os.environ["PROJECT_ID"] -_OUTPUT_BUCKET = f'{_PROJECT_ID}_output' -_LOCATION = os.environ["REGION"] +_PROJECT_ID = os.environ["PROJECT_ID"] +_OUTPUT_BUCKET = f"{_PROJECT_ID}_output" +_LOCATION = os.environ["REGION"] _DATASET_ID = "summary_dataset" _TABLE_ID = "summary_table" + @dataclass class CloudEventDataMock: - bucket: str - name: str - metageneration: str - timeCreated: str - updated: str + bucket: str + name: str + metageneration: str + timeCreated: str + updated: str def __getitem__(self, key): return self.__getattribute__(key) @@ -39,48 +40,48 @@ def __getitem__(self, key): @dataclass class CloudEventMock: - data: str - id: str - type: str + data: str + id: str + type: str def __getitem__(self, key): - if key == 'id': + if key == "id": return self.id - elif key == 'type': + elif key == "type": return self.type else: - raise RuntimeError(f'Unknown key: {key}') + raise RuntimeError(f"Unknown key: {key}") def get_json(self): return { - 'name': self.data.name, - 'kind': 'storage#object', - 'id': self.id, - 'bucket': self.data.bucket, - 'timeCreated': self.data.timeCreated, + "name": self.data.name, + "kind": "storage#object", + "id": self.id, + "bucket": self.data.bucket, + "timeCreated": self.data.timeCreated, } - + @pytest.fixture def cloud_event(): return CloudEventMock( - id='7631145714375969', - type='google.cloud.storage.object.v1.finalized', + id="7631145714375969", + type="google.cloud.storage.object.v1.finalized", data=CloudEventDataMock( - bucket='velociraptor-16p1-mock-users-bucket', - name='9404001v1.pdf', - metageneration='1', + bucket="velociraptor-16p1-mock-users-bucket", + name="9404001v1.pdf", + metageneration="1", timeCreated=f"{datetime.datetime.now().isoformat()}Z", updated=f"{datetime.datetime.now().isoformat()}Z", - ) + ), ) class RequestMock: def get_json(self): return { - 'name': 'MOCK_REQUEST_NAME', - 'text': 'abstract: mock text. conclusion: there is none', + "name": "MOCK_REQUEST_NAME", + "text": "abstract: mock text. conclusion: there is none", } @@ -89,27 +90,37 @@ def curl_request(): return RequestMock() -@mock.patch.dict(os.environ, { - "OUTPUT_BUCKET": _OUTPUT_BUCKET, - "PROJECT_ID": _PROJECT_ID, - "LOCATION": _LOCATION, - "DATASET_ID": _DATASET_ID, - "TABLE_ID": _TABLE_ID, -}, clear=True) +@mock.patch.dict( + os.environ, + { + "OUTPUT_BUCKET": _OUTPUT_BUCKET, + "PROJECT_ID": _PROJECT_ID, + "LOCATION": _LOCATION, + "DATASET_ID": _DATASET_ID, + "TABLE_ID": _TABLE_ID, + }, + clear=True, +) def test_function_entrypoint_cloud_event(cloud_event): from main import entrypoint - result = entrypoint(cloud_event) - assert 'summary' in result - -@mock.patch.dict(os.environ, { - "OUTPUT_BUCKET": _OUTPUT_BUCKET, - "PROJECT_ID": _PROJECT_ID, - "LOCATION": _LOCATION, - "DATASET_ID": _DATASET_ID, - "TABLE_ID": _TABLE_ID, -}, clear=True) + result = entrypoint(cloud_event) + assert "summary" in result + + +@mock.patch.dict( + os.environ, + { + "OUTPUT_BUCKET": _OUTPUT_BUCKET, + "PROJECT_ID": _PROJECT_ID, + "LOCATION": _LOCATION, + "DATASET_ID": _DATASET_ID, + "TABLE_ID": _TABLE_ID, + }, + clear=True, +) def test_function_entrypoint_curl(curl_request): from main import entrypoint + result = entrypoint(curl_request) - assert 'summary' in result \ No newline at end of file + assert "summary" in result diff --git a/webhook/requirements-test.txt b/webhook/requirements-test.txt index 66b8213..8b60cfc 100644 --- a/webhook/requirements-test.txt +++ b/webhook/requirements-test.txt @@ -1,4 +1,4 @@ backoff==2.2.1 mock pytest==7.3.1 -google-cloud-storage \ No newline at end of file +google-cloud-storage diff --git a/webhook/services_test.py b/webhook/services_test.py index 5fb09c1..f09d528 100644 --- a/webhook/services_test.py +++ b/webhook/services_test.py @@ -26,11 +26,11 @@ _PROJECT_ID = os.environ["PROJECT_ID"] _BUCKET_NAME = os.environ["BUCKET"] -_OUTPUT_BUCKET = f'{_PROJECT_ID}_output' +_OUTPUT_BUCKET = f"{_PROJECT_ID}_output" _DATASET_ID = "summary_dataset" _TABLE_ID = "summary_table" -_FILE_NAME = '9404001v1.pdf' -_MODEL_NAME = 'text-bison@001' +_FILE_NAME = "9404001v1.pdf" +_MODEL_NAME = "text-bison@001" def check_blob_exists(bucket, filename) -> bool: @@ -39,15 +39,18 @@ def check_blob_exists(bucket, filename) -> bool: blob = bucket.blob(filename) return blob.exists() + @backoff.on_exception(backoff.expo, Exception, max_tries=3) def test_up16_services(): - extracted_text = async_document_extract(_BUCKET_NAME, - _FILE_NAME, - output_bucket=_OUTPUT_BUCKET) + extracted_text = async_document_extract( + _BUCKET_NAME, _FILE_NAME, output_bucket=_OUTPUT_BUCKET + ) assert "Abstract" in extracted_text - complete_text_filename = f'system-test/{_FILE_NAME.replace(".pdf", "")}_fulltext.txt' + complete_text_filename = ( + f'system-test/{_FILE_NAME.replace(".pdf", "")}_fulltext.txt' + ) upload_to_gcs( _OUTPUT_BUCKET, complete_text_filename, @@ -65,7 +68,7 @@ def test_up16_services(): max_decode_steps=1024, top_p=0.8, top_k=40, - content=f'Summarize:\n{extracted_text_}', + content=f"Summarize:\n{extracted_text_}", location="us-central1", ) @@ -93,4 +96,4 @@ def test_up16_services(): timestamp=datetime.datetime.now(), ) - assert len(errors) == 0 \ No newline at end of file + assert len(errors) == 0 diff --git a/webhook/storage_test.py b/webhook/storage_test.py index 94a51bb..cfa26e0 100644 --- a/webhook/storage_test.py +++ b/webhook/storage_test.py @@ -24,6 +24,7 @@ _BUCKET_NAME = os.environ["BUCKET"] _FILE_NAME = "system-test/fake.text" + @backoff.on_exception(backoff.expo, Exception, max_tries=3) def test_upload_to_gcs(): want = datetime.datetime.now().isoformat() @@ -38,7 +39,7 @@ def test_upload_to_gcs(): @patch.object(storage.Client, "get_bucket") -def test_upload_to_gcs(mock_get_bucket): +def test_upload_to_gcs_mock(mock_get_bucket): mock_blob = MagicMock(spec=storage.Blob) mock_bucket = MagicMock(spec=storage.Bucket) mock_bucket.blob.return_value = mock_blob diff --git a/webhook/utils.py b/webhook/utils.py index d8c9f69..6d1d3ac 100644 --- a/webhook/utils.py +++ b/webhook/utils.py @@ -15,10 +15,11 @@ import datetime import re -ABSTRACT_LENGTH = 150 * 10 # Abstract recommended max word length * avg 10 letters long +ABSTRACT_LENGTH = 150 * 10 # Abstract recommended max word length * avg 10 letters long CONCLUSION_LENGTH = 200 * 10 # Conclusion max word legnth * avg 10 letters long -ABSTRACT_H1 = 'abstract' -CONCLUSION_H1 = 'conclusion' +ABSTRACT_H1 = "abstract" +CONCLUSION_H1 = "conclusion" + def coerce_datetime_zulu(input_datetime: datetime.datetime): """Force datetime into specific format. @@ -32,9 +33,9 @@ def coerce_datetime_zulu(input_datetime: datetime.datetime): if regex_match: assert input_datetime.startswith(regex_match.group(1)) assert input_datetime.endswith(regex_match.group(2)) - return datetime.datetime.fromisoformat(f'{input_datetime[:-1]}+00:00') + return datetime.datetime.fromisoformat(f"{input_datetime[:-1]}+00:00") raise RuntimeError( - 'The input datetime is not in the expected format. ' + "The input datetime is not in the expected format. " 'Please check format of the input datetime. Expected "Z" at the end' ) @@ -63,6 +64,6 @@ def truncate_complete_text(complete_text: str) -> str: return f""" Abstract: {abstract} - + Conclusion: {conclusion} """ diff --git a/webhook/vertex_llm.py b/webhook/vertex_llm.py index e9e332e..2d4ec88 100644 --- a/webhook/vertex_llm.py +++ b/webhook/vertex_llm.py @@ -48,7 +48,7 @@ def predict_large_language_model( project=project_id, location=location, ) - print('FOO', vertexai.init) + print("FOO", vertexai.init) model = TextGenerationModel.from_pretrained(model_name) if tuned_model_name: model = model.get_tuned_model(tuned_model_name) @@ -57,8 +57,6 @@ def predict_large_language_model( temperature=temperature, max_output_tokens=max_decode_steps, top_k=top_k, - top_p=top_p,) + top_p=top_p, + ) return response.text - - - diff --git a/webhook/vertex_llm_test.py b/webhook/vertex_llm_test.py index f999ddb..2b7c41a 100644 --- a/webhook/vertex_llm_test.py +++ b/webhook/vertex_llm_test.py @@ -17,14 +17,13 @@ from unittest.mock import MagicMock, PropertyMock, patch from vertexai.preview.language_models import TextGenerationModel -from google.cloud import aiplatform import vertexai from vertex_llm import predict_large_language_model -_MODEL_NAME = 'text-bison@001' -_PROJECT_ID = os.environ['PROJECT_ID'] +_MODEL_NAME = "text-bison@001" +_PROJECT_ID = os.environ["PROJECT_ID"] extracted_text = """ arXiv:cmp-lg/9404001v1 4 Apr 1994 @@ -86,25 +85,25 @@ def test_predict_large_language_model(): max_decode_steps=1024, top_p=0.8, top_k=40, - content=f'Summarize:\n{extracted_text}', + content=f"Summarize:\n{extracted_text}", location="us-central1", ) assert summary != "" -@patch.object(vertexai, 'init') -@patch.object(TextGenerationModel, 'from_pretrained') -def test_predict_large_language_model(mock_get_model, mock_init): - project_id = 'fake-project' - model_name = 'fake@fake-orca' - temperature=0.2 - max_decode_steps=1024 - top_p=0.8 - top_k=40 - content=f'Summarize:\nAbstract: fake\nConclusion: it is faked\n' - location='us-central1' - want = 'This is a fake summary' +@patch.object(vertexai, "init") +@patch.object(TextGenerationModel, "from_pretrained") +def test_predict_large_language_model_mock(mock_get_model, mock_init): + project_id = "fake-project" + model_name = "fake@fake-orca" + temperature = 0.2 + max_decode_steps = 1024 + top_p = 0.8 + top_k = 40 + content = "Summarize:\nAbstract: fake\nConclusion: it is faked\n" + location = "us-central1" + want = "This is a fake summary" mock_response = MagicMock() mock_prop = PropertyMock(return_value=want) @@ -114,20 +113,24 @@ def test_predict_large_language_model(mock_get_model, mock_init): mock_get_model.return_value = mock_model # Act - got = predict_large_language_model(project_id, - model_name, - temperature, - max_decode_steps, - top_p, - top_k, - content, - location) + got = predict_large_language_model( + project_id, + model_name, + temperature, + max_decode_steps, + top_p, + top_k, + content, + location, + ) # Assert assert want in got - mock_init.assert_called_with(project=project_id, location=location) - mock_model.predict.assert_called_with(content, - temperature=temperature, - max_output_tokens=max_decode_steps, - top_k=top_k, - top_p=top_p) + mock_init.assert_called_with(project=project_id, location=location) + mock_model.predict.assert_called_with( + content, + temperature=temperature, + max_output_tokens=max_decode_steps, + top_k=top_k, + top_p=top_p, + )