Skip to content

Commit

Permalink
add user id filters on post
Browse files Browse the repository at this point in the history
  • Loading branch information
yu23ki14 committed Nov 3, 2024
1 parent f01cd78 commit a412d89
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 1 deletion.
4 changes: 4 additions & 0 deletions api/birdxplorer_api/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
TopicId,
TwitterTimestamp,
UserEnrollment,
UserId,
)
from birdxplorer_common.storage import Storage

Expand Down Expand Up @@ -255,6 +256,7 @@ def get_posts(
request: Request,
post_ids: Union[List[PostId], None] = Query(default=None),
note_ids: Union[List[NoteId], None] = Query(default=None),
user_ids: Union[List[UserId], None] = Query(default=None),
created_at_from: Union[None, TwitterTimestamp, str] = Query(
default=None, **V1DataPostsDocs.params["created_at_from"]
),
Expand All @@ -275,6 +277,7 @@ def get_posts(
storage.get_posts(
post_ids=post_ids,
note_ids=note_ids,
user_ids=user_ids,
start=created_at_from,
end=created_at_to,
search_text=search_text,
Expand All @@ -288,6 +291,7 @@ def get_posts(
total_count = storage.get_number_of_posts(
post_ids=post_ids,
note_ids=note_ids,
user_ids=user_ids,
start=created_at_from,
end=created_at_to,
search_text=search_text,
Expand Down
6 changes: 5 additions & 1 deletion api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def _get_number_of_notes(
def _get_posts(
post_ids: Union[List[PostId], None] = None,
note_ids: Union[List[NoteId], None] = None,
user_ids: Union[List[str], None] = None,
start: Union[TwitterTimestamp, None] = None,
end: Union[TwitterTimestamp, None] = None,
search_text: Union[str, None] = None,
Expand All @@ -403,6 +404,8 @@ def _get_posts(
note.note_id in note_ids and note.post_id == post.post_id for note in note_samples
):
continue
if user_ids is not None and post.x_user_id not in user_ids:
continue
if start is not None and post.created_at < start:
continue
if end is not None and post.created_at >= end:
Expand All @@ -426,12 +429,13 @@ def _get_posts(
def _get_number_of_posts(
post_ids: Union[List[PostId], None] = None,
note_ids: Union[List[NoteId], None] = None,
user_ids: Union[List[str], None] = None,
start: Union[TwitterTimestamp, None] = None,
end: Union[TwitterTimestamp, None] = None,
search_text: Union[str, None] = None,
search_url: Union[HttpUrl, None] = None,
) -> int:
return len(list(_get_posts(post_ids, note_ids, start, end, search_text, search_url)))
return len(list(_get_posts(post_ids, note_ids, user_ids, start, end, search_text, search_url)))

mock.get_number_of_posts.side_effect = _get_number_of_posts

Expand Down
6 changes: 6 additions & 0 deletions common/birdxplorer_common/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def get_posts(
self,
post_ids: Union[List[PostId], None] = None,
note_ids: Union[List[NoteId], None] = None,
user_ids: Union[List[UserId], None] = None,
start: Union[TwitterTimestamp, None] = None,
end: Union[TwitterTimestamp, None] = None,
search_text: Union[str, None] = None,
Expand All @@ -452,6 +453,8 @@ def get_posts(
query = query.join(NoteRecord, NoteRecord.post_id == PostRecord.post_id).filter(
NoteRecord.note_id.in_(note_ids)
)
if user_ids is not None:
query = query.filter(PostRecord.user_id.in_(user_ids))
if start is not None:
query = query.filter(PostRecord.created_at >= start)
if end is not None:
Expand All @@ -474,6 +477,7 @@ def get_number_of_posts(
self,
post_ids: Union[List[PostId], None] = None,
note_ids: Union[List[NoteId], None] = None,
user_ids: Union[List[UserId], None] = None,
start: Union[TwitterTimestamp, None] = None,
end: Union[TwitterTimestamp, None] = None,
search_text: Union[str, None] = None,
Expand All @@ -487,6 +491,8 @@ def get_number_of_posts(
query = query.join(NoteRecord, NoteRecord.post_id == PostRecord.post_id).filter(
NoteRecord.note_id.in_(note_ids)
)
if user_ids is not None:
query = query.filter(PostRecord.user_id.in_(user_ids))
if start is not None:
query = query.filter(PostRecord.created_at >= start)
if end is not None:
Expand Down

0 comments on commit a412d89

Please sign in to comment.