From 86c2c82a9c9cac8888913d99c7cc52638fa3667a Mon Sep 17 00:00:00 2001 From: yu23ki14 Date: Mon, 23 Dec 2024 12:20:18 +0900 Subject: [PATCH] async function for estimation --- compose.yml | 2 +- etl/src/birdxplorer_etl/extract.py | 29 ++--- .../lib/openapi/open_ai_service.py | 4 +- etl/src/birdxplorer_etl/lib/x/postlookup.py | 25 +++++ etl/src/birdxplorer_etl/main.py | 4 +- etl/src/birdxplorer_etl/transform.py | 100 ++++++++++++------ 6 files changed, 117 insertions(+), 47 deletions(-) diff --git a/compose.yml b/compose.yml index df2676f..f50fd75 100644 --- a/compose.yml +++ b/compose.yml @@ -14,7 +14,7 @@ services: timeout: 5s retries: 5 ports: - - "5432:5432" + - "5434:5432" volumes: - postgres_data:/var/lib/postgresql/data app: diff --git a/etl/src/birdxplorer_etl/extract.py b/etl/src/birdxplorer_etl/extract.py index 4aa2719..90c7b5e 100644 --- a/etl/src/birdxplorer_etl/extract.py +++ b/etl/src/birdxplorer_etl/extract.py @@ -83,9 +83,9 @@ def extract_data(sqlite: Session, postgresql: Session): status = ( sqlite.query(RowNoteStatusRecord).filter(RowNoteStatusRecord.note_id == row["note_id"]).first() ) - if status is None or status.created_at_millis > int(datetime.now().timestamp() * 1000): + if status is not None: sqlite.query(RowNoteStatusRecord).filter(RowNoteStatusRecord.note_id == row["note_id"]).delete() - rows_to_add.append(RowNoteStatusRecord(**row)) + rows_to_add.append(RowNoteStatusRecord(**row)) if index % 1000 == 0: sqlite.bulk_save_objects(rows_to_add) rows_to_add = [] @@ -184,18 +184,21 @@ def extract_data(sqlite: Session, postgresql: Session): logging.error(f"Error: {e}") postgresql.rollback() - media_recs = [ - RowPostMediaRecord( - media_key=f"{m['media_key']}-{post['data']['id']}", - type=m["type"], - url=m.get("url") or (m["variants"][0]["url"] if "variants" in m and m["variants"] else ""), - width=m["width"], - height=m["height"], - post_id=post["data"]["id"], + for m in media_data: + media_key = f"{m['media_key']}-{post['data']['id']}" + is_media_exist = ( + postgresql.query(RowPostMediaRecord).filter(RowPostMediaRecord.media_key == media_key).first() ) - for m in media_data - ] - postgresql.add_all(media_recs) + if is_media_exist is None: + media_rec = RowPostMediaRecord( + media_key=media_key, + type=m["type"], + url=m.get("url") or (m["variants"][0]["url"] if "variants" in m and m["variants"] else ""), + width=m["width"], + height=m["height"], + post_id=post["data"]["id"], + ) + postgresql.add(media_rec) if "entities" in post["data"] and "urls" in post["data"]["entities"]: for url in post["data"]["entities"]["urls"]: diff --git a/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py b/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py index 048a4f5..115ac17 100644 --- a/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py +++ b/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py @@ -29,14 +29,14 @@ def load_topics(self, topic_csv_file_path: str) -> Dict[str, int]: # topics[label] = topic_id return topics - def detect_language(self, text: str) -> str: + async def detect_language(self, text: str) -> str: prompt = ( "Detect the language of the following text and return only the language code " f"from this list: en, es, ja, pt, de, fr. Text: {text}. " "Respond with only the language code, nothing else." ) - response = self.client.chat.completions.create( + response = await self.client.chat.completions.create( model="gpt-4o-mini", messages=[ {"role": "system", "content": "You are a helpful assistant."}, diff --git a/etl/src/birdxplorer_etl/lib/x/postlookup.py b/etl/src/birdxplorer_etl/lib/x/postlookup.py index 343c9f2..f6df635 100644 --- a/etl/src/birdxplorer_etl/lib/x/postlookup.py +++ b/etl/src/birdxplorer_etl/lib/x/postlookup.py @@ -29,6 +29,31 @@ def bearer_oauth(r): return r +# def connect_to_endpoint(url, current_token_index=0, wait_until=0): +# tokens = settings.X_BEARER_TOKEN.split(",") +# token_max_index = len(tokens) - 1 + +# logging.info(f"token_max_index: {token_max_index}") +# logging.info(f"current_token_index: {current_token_index}") + +# response = requests.request("GET", url, auth=lambda r: bearer_oauth(r, tokens[current_token_index])) +# if response.status_code == 429: +# if current_token_index == token_max_index: +# logging.warning(f"Rate limit exceeded. Waiting until: {wait_until - int(time.time()) + 1}") +# time.sleep(wait_until - int(time.time()) + 1) +# data = connect_to_endpoint(url, 0, 0) +# return data +# else: +# reset_time = int(response.headers["x-rate-limit-reset"]) +# fastestReset = wait_until == 0 and reset_time or min(wait_until, reset_time) +# logging.warning("Rate limit exceeded. Waiting until: {}".format(fastestReset)) +# data = connect_to_endpoint(url, current_token_index + 1, fastestReset) +# return data +# elif response.status_code != 200: +# raise Exception("Request returned an error: {} {}".format(response.status_code, response.text)) +# return response.json() + + def connect_to_endpoint(url): response = requests.request("GET", url, auth=bearer_oauth) if response.status_code == 429: diff --git a/etl/src/birdxplorer_etl/main.py b/etl/src/birdxplorer_etl/main.py index b484e39..0ff1fa8 100644 --- a/etl/src/birdxplorer_etl/main.py +++ b/etl/src/birdxplorer_etl/main.py @@ -2,6 +2,7 @@ from extract import extract_data from load import load_data from transform import transform_data +import asyncio import logging logging.basicConfig(level=logging.INFO) @@ -10,5 +11,6 @@ sqlite = init_sqlite() postgresql = init_postgresql() extract_data(sqlite, postgresql) - transform_data(sqlite, postgresql) + asyncio.run(transform_data(sqlite, postgresql)) load_data() + logging.info("ETL process completed") diff --git a/etl/src/birdxplorer_etl/transform.py b/etl/src/birdxplorer_etl/transform.py index e9ee5aa..2889fb8 100644 --- a/etl/src/birdxplorer_etl/transform.py +++ b/etl/src/birdxplorer_etl/transform.py @@ -9,6 +9,8 @@ from sqlalchemy import Integer, Numeric, and_, func, select from sqlalchemy.orm import Session +import asyncio + from birdxplorer_common.storage import ( RowNoteRecord, RowNoteStatusRecord, @@ -17,14 +19,14 @@ RowPostRecord, RowUserRecord, ) -from lib.ai_model.ai_model_interface import get_ai_service +from lib.ai_model.ai_model_interface import get_ai_service, AIModelInterface from settings import ( TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND, TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND, ) -def transform_data(sqlite: Session, postgresql: Session): +async def transform_data(sqlite: Session, postgresql: Session): logging.info("Transforming data") @@ -36,7 +38,7 @@ def transform_data(sqlite: Session, postgresql: Session): os.remove("./data/transformed/note.csv") with open("./data/transformed/note.csv", "a") as file: writer = csv.writer(file) - writer.writerow(["note_id", "post_id", "summary", "current_status", "created_at", "language"]) + writer.writerow(["note_id", "post_id", "summary", "current_status", "locked_status", "created_at", "language"]) offset = 0 limit = 1000 @@ -63,6 +65,7 @@ def transform_data(sqlite: Session, postgresql: Session): RowNoteRecord.row_post_id, RowNoteRecord.summary, RowNoteStatusRecord.current_status, + RowNoteStatusRecord.locked_status, func.cast(RowNoteRecord.created_at_millis, Integer).label("created_at"), ) .filter( @@ -76,12 +79,18 @@ def transform_data(sqlite: Session, postgresql: Session): .offset(offset) ) - for note in notes: - note_as_list = list(note) - note_as_list.append(ai_service.detect_language(note[2])) - note_as_list.append("ja") - writer = csv.writer(file) - writer.writerow(note_as_list) + notes_list = list(notes) + note_chunks = [notes_list[i : i + 10] for i in range(0, len(notes_list), 10)] + writer = csv.writer(file) + for chunk in note_chunks: + estimated_notes_list = await asyncio.gather( + *[estimate_language_of_note(ai_service, note) for note in chunk] + ) + + for note_as_list in estimated_notes_list: + writer.writerow(note_as_list) + + await asyncio.sleep(1) offset += limit # Transform row post data and generate post.csv @@ -165,11 +174,18 @@ def transform_data(sqlite: Session, postgresql: Session): # Transform row post embed url data and generate post_embed_url.csv generate_topic() - generate_note_topic(sqlite) + await generate_note_topic(sqlite) return +async def estimate_language_of_note(ai_service: AIModelInterface, note: RowNoteRecord) -> list: + note_list = list(note) + language = await asyncio.to_thread(ai_service.detect_language, note[2]) + note_list.append(language) + return note_list + + def write_media_csv(postgresql: Session) -> None: media_csv_path = Path("./data/transformed/media.csv") post_media_association_csv_path = Path("./data/transformed/post_media_association.csv") @@ -283,7 +299,7 @@ def generate_topic(): writer.writerow({"topic_id": record["topic_id"], "label": {k: v for k, v in record["label"].items()}}) -def generate_note_topic(sqlite: Session): +async def generate_note_topic(sqlite: Session): output_csv_file_path = "./data/transformed/note_topic_association.csv" ai_service = get_ai_service() @@ -299,7 +315,18 @@ def generate_note_topic(sqlite: Session): offset = 0 limit = 1000 - num_of_notes = sqlite.query(func.count(RowNoteRecord.row_post_id)).scalar() + num_of_notes = ( + sqlite.query(func.count(RowNoteRecord.note_id)) + .filter( + and_( + RowNoteRecord.created_at_millis <= TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND, + RowNoteRecord.created_at_millis >= TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND, + ) + ) + .scalar() + ) + + logging.info(f"Transforming note data: {num_of_notes}") while offset < num_of_notes: topicEstimationTargetNotes = sqlite.execute( @@ -315,24 +342,32 @@ def generate_note_topic(sqlite: Session): .offset(offset) ) - for index, note in enumerate(topicEstimationTargetNotes): - note_id = note.note_id - summary = note.summary - topics_info = ai_service.detect_topic(note_id, summary) - if topics_info: - for topic in topics_info.get("topics", []): - record = {"note_id": note_id, "topic_id": topic} - records.append(record) - if index % 100 == 0: - for record in records: - writer.writerow( - { - "note_id": record["note_id"], - "topic_id": record["topic_id"], - } - ) - records = [] - print(index) + topicEstimationTargetNotes_list = list(topicEstimationTargetNotes) + note_chunks = [ + topicEstimationTargetNotes_list[i : i + 10] for i in range(0, len(topicEstimationTargetNotes_list), 10) + ] + + for index, chunk in enumerate(note_chunks): + logging.info(f"Processing chunk {index}") + topic_info_list = await asyncio.gather(*[estimate_topic_of_note(ai_service, note) for note in chunk]) + + for topic_info in topic_info_list: + if topic_info: + for topic in topic_info.get("topics", []): + record = {"note_id": topic_info["note_id"], "topic_id": topic} + records.append(record) + + if index % 10 == 0: + for record in records: + writer.writerow( + { + "note_id": record["note_id"], + "topic_id": record["topic_id"], + } + ) + records = [] + await asyncio.sleep(1) + offset += limit for record in records: @@ -346,5 +381,10 @@ def generate_note_topic(sqlite: Session): print(f"New CSV file has been created at {output_csv_file_path}") +async def estimate_topic_of_note(ai_service: AIModelInterface, note: RowNoteRecord) -> dict: + res = await asyncio.to_thread(ai_service.detect_topic, note[0], note[2]) + return {"note_id": note[0], "topics": res["topics"]} + + if __name__ == "__main__": generate_note_topic()