Skip to content

Commit

Permalink
async function for estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
yu23ki14 committed Dec 23, 2024
1 parent 7e0e6ad commit 86c2c82
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 47 deletions.
2 changes: 1 addition & 1 deletion compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ services:
timeout: 5s
retries: 5
ports:
- "5432:5432"
- "5434:5432"
volumes:
- postgres_data:/var/lib/postgresql/data
app:
Expand Down
29 changes: 16 additions & 13 deletions etl/src/birdxplorer_etl/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def extract_data(sqlite: Session, postgresql: Session):
status = (
sqlite.query(RowNoteStatusRecord).filter(RowNoteStatusRecord.note_id == row["note_id"]).first()
)
if status is None or status.created_at_millis > int(datetime.now().timestamp() * 1000):
if status is not None:
sqlite.query(RowNoteStatusRecord).filter(RowNoteStatusRecord.note_id == row["note_id"]).delete()
rows_to_add.append(RowNoteStatusRecord(**row))
rows_to_add.append(RowNoteStatusRecord(**row))
if index % 1000 == 0:
sqlite.bulk_save_objects(rows_to_add)
rows_to_add = []
Expand Down Expand Up @@ -184,18 +184,21 @@ def extract_data(sqlite: Session, postgresql: Session):
logging.error(f"Error: {e}")
postgresql.rollback()

media_recs = [
RowPostMediaRecord(
media_key=f"{m['media_key']}-{post['data']['id']}",
type=m["type"],
url=m.get("url") or (m["variants"][0]["url"] if "variants" in m and m["variants"] else ""),
width=m["width"],
height=m["height"],
post_id=post["data"]["id"],
for m in media_data:
media_key = f"{m['media_key']}-{post['data']['id']}"
is_media_exist = (
postgresql.query(RowPostMediaRecord).filter(RowPostMediaRecord.media_key == media_key).first()
)
for m in media_data
]
postgresql.add_all(media_recs)
if is_media_exist is None:
media_rec = RowPostMediaRecord(
media_key=media_key,
type=m["type"],
url=m.get("url") or (m["variants"][0]["url"] if "variants" in m and m["variants"] else ""),
width=m["width"],
height=m["height"],
post_id=post["data"]["id"],
)
postgresql.add(media_rec)

if "entities" in post["data"] and "urls" in post["data"]["entities"]:
for url in post["data"]["entities"]["urls"]:
Expand Down
4 changes: 2 additions & 2 deletions etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ def load_topics(self, topic_csv_file_path: str) -> Dict[str, int]:
# topics[label] = topic_id
return topics

def detect_language(self, text: str) -> str:
async def detect_language(self, text: str) -> str:
prompt = (
"Detect the language of the following text and return only the language code "
f"from this list: en, es, ja, pt, de, fr. Text: {text}. "
"Respond with only the language code, nothing else."
)

response = self.client.chat.completions.create(
response = await self.client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
Expand Down
25 changes: 25 additions & 0 deletions etl/src/birdxplorer_etl/lib/x/postlookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,31 @@ def bearer_oauth(r):
return r


# def connect_to_endpoint(url, current_token_index=0, wait_until=0):
# tokens = settings.X_BEARER_TOKEN.split(",")
# token_max_index = len(tokens) - 1

# logging.info(f"token_max_index: {token_max_index}")
# logging.info(f"current_token_index: {current_token_index}")

# response = requests.request("GET", url, auth=lambda r: bearer_oauth(r, tokens[current_token_index]))
# if response.status_code == 429:
# if current_token_index == token_max_index:
# logging.warning(f"Rate limit exceeded. Waiting until: {wait_until - int(time.time()) + 1}")
# time.sleep(wait_until - int(time.time()) + 1)
# data = connect_to_endpoint(url, 0, 0)
# return data
# else:
# reset_time = int(response.headers["x-rate-limit-reset"])
# fastestReset = wait_until == 0 and reset_time or min(wait_until, reset_time)
# logging.warning("Rate limit exceeded. Waiting until: {}".format(fastestReset))
# data = connect_to_endpoint(url, current_token_index + 1, fastestReset)
# return data
# elif response.status_code != 200:
# raise Exception("Request returned an error: {} {}".format(response.status_code, response.text))
# return response.json()


def connect_to_endpoint(url):
response = requests.request("GET", url, auth=bearer_oauth)
if response.status_code == 429:
Expand Down
4 changes: 3 additions & 1 deletion etl/src/birdxplorer_etl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from extract import extract_data
from load import load_data
from transform import transform_data
import asyncio
import logging

logging.basicConfig(level=logging.INFO)
Expand All @@ -10,5 +11,6 @@
sqlite = init_sqlite()
postgresql = init_postgresql()
extract_data(sqlite, postgresql)
transform_data(sqlite, postgresql)
asyncio.run(transform_data(sqlite, postgresql))
load_data()
logging.info("ETL process completed")
100 changes: 70 additions & 30 deletions etl/src/birdxplorer_etl/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from sqlalchemy import Integer, Numeric, and_, func, select
from sqlalchemy.orm import Session

import asyncio

from birdxplorer_common.storage import (
RowNoteRecord,
RowNoteStatusRecord,
Expand All @@ -17,14 +19,14 @@
RowPostRecord,
RowUserRecord,
)
from lib.ai_model.ai_model_interface import get_ai_service
from lib.ai_model.ai_model_interface import get_ai_service, AIModelInterface
from settings import (
TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND,
TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND,
)


def transform_data(sqlite: Session, postgresql: Session):
async def transform_data(sqlite: Session, postgresql: Session):

logging.info("Transforming data")

Expand All @@ -36,7 +38,7 @@ def transform_data(sqlite: Session, postgresql: Session):
os.remove("./data/transformed/note.csv")
with open("./data/transformed/note.csv", "a") as file:
writer = csv.writer(file)
writer.writerow(["note_id", "post_id", "summary", "current_status", "created_at", "language"])
writer.writerow(["note_id", "post_id", "summary", "current_status", "locked_status", "created_at", "language"])

offset = 0
limit = 1000
Expand All @@ -63,6 +65,7 @@ def transform_data(sqlite: Session, postgresql: Session):
RowNoteRecord.row_post_id,
RowNoteRecord.summary,
RowNoteStatusRecord.current_status,
RowNoteStatusRecord.locked_status,
func.cast(RowNoteRecord.created_at_millis, Integer).label("created_at"),
)
.filter(
Expand All @@ -76,12 +79,18 @@ def transform_data(sqlite: Session, postgresql: Session):
.offset(offset)
)

for note in notes:
note_as_list = list(note)
note_as_list.append(ai_service.detect_language(note[2]))
note_as_list.append("ja")
writer = csv.writer(file)
writer.writerow(note_as_list)
notes_list = list(notes)
note_chunks = [notes_list[i : i + 10] for i in range(0, len(notes_list), 10)]
writer = csv.writer(file)
for chunk in note_chunks:
estimated_notes_list = await asyncio.gather(
*[estimate_language_of_note(ai_service, note) for note in chunk]
)

for note_as_list in estimated_notes_list:
writer.writerow(note_as_list)

await asyncio.sleep(1)
offset += limit

# Transform row post data and generate post.csv
Expand Down Expand Up @@ -165,11 +174,18 @@ def transform_data(sqlite: Session, postgresql: Session):
# Transform row post embed url data and generate post_embed_url.csv
generate_topic()

generate_note_topic(sqlite)
await generate_note_topic(sqlite)

return


async def estimate_language_of_note(ai_service: AIModelInterface, note: RowNoteRecord) -> list:
note_list = list(note)
language = await asyncio.to_thread(ai_service.detect_language, note[2])
note_list.append(language)
return note_list


def write_media_csv(postgresql: Session) -> None:
media_csv_path = Path("./data/transformed/media.csv")
post_media_association_csv_path = Path("./data/transformed/post_media_association.csv")
Expand Down Expand Up @@ -283,7 +299,7 @@ def generate_topic():
writer.writerow({"topic_id": record["topic_id"], "label": {k: v for k, v in record["label"].items()}})


def generate_note_topic(sqlite: Session):
async def generate_note_topic(sqlite: Session):
output_csv_file_path = "./data/transformed/note_topic_association.csv"
ai_service = get_ai_service()

Expand All @@ -299,7 +315,18 @@ def generate_note_topic(sqlite: Session):
offset = 0
limit = 1000

num_of_notes = sqlite.query(func.count(RowNoteRecord.row_post_id)).scalar()
num_of_notes = (
sqlite.query(func.count(RowNoteRecord.note_id))
.filter(
and_(
RowNoteRecord.created_at_millis <= TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND,
RowNoteRecord.created_at_millis >= TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND,
)
)
.scalar()
)

logging.info(f"Transforming note data: {num_of_notes}")

while offset < num_of_notes:
topicEstimationTargetNotes = sqlite.execute(
Expand All @@ -315,24 +342,32 @@ def generate_note_topic(sqlite: Session):
.offset(offset)
)

for index, note in enumerate(topicEstimationTargetNotes):
note_id = note.note_id
summary = note.summary
topics_info = ai_service.detect_topic(note_id, summary)
if topics_info:
for topic in topics_info.get("topics", []):
record = {"note_id": note_id, "topic_id": topic}
records.append(record)
if index % 100 == 0:
for record in records:
writer.writerow(
{
"note_id": record["note_id"],
"topic_id": record["topic_id"],
}
)
records = []
print(index)
topicEstimationTargetNotes_list = list(topicEstimationTargetNotes)
note_chunks = [
topicEstimationTargetNotes_list[i : i + 10] for i in range(0, len(topicEstimationTargetNotes_list), 10)
]

for index, chunk in enumerate(note_chunks):
logging.info(f"Processing chunk {index}")
topic_info_list = await asyncio.gather(*[estimate_topic_of_note(ai_service, note) for note in chunk])

for topic_info in topic_info_list:
if topic_info:
for topic in topic_info.get("topics", []):
record = {"note_id": topic_info["note_id"], "topic_id": topic}
records.append(record)

if index % 10 == 0:
for record in records:
writer.writerow(
{
"note_id": record["note_id"],
"topic_id": record["topic_id"],
}
)
records = []
await asyncio.sleep(1)

offset += limit

for record in records:
Expand All @@ -346,5 +381,10 @@ def generate_note_topic(sqlite: Session):
print(f"New CSV file has been created at {output_csv_file_path}")


async def estimate_topic_of_note(ai_service: AIModelInterface, note: RowNoteRecord) -> dict:
res = await asyncio.to_thread(ai_service.detect_topic, note[0], note[2])
return {"note_id": note[0], "topics": res["topics"]}


if __name__ == "__main__":
generate_note_topic()

0 comments on commit 86c2c82

Please sign in to comment.