Skip to content

Commit

Permalink
blender rollout; changed configs a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
ledovsky committed Oct 8, 2024
1 parent 0b7dbcb commit d583f36
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 117 deletions.
98 changes: 7 additions & 91 deletions src/recommendations/meme_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,90 +79,7 @@ async def generate_cold_start_recommendations(user_id, limit=10):
await redis.add_memes_to_queue_by_key(queue_key, candidates)


async def generate_recommendations(user_id: int, limit: int):

if (user_id + 25) % 100 < 50:
await generate_with_blender(user_id, limit)
return

queue_key = redis.get_meme_queue_key(user_id)
memes_in_queue = await redis.get_all_memes_in_queue_by_key(queue_key)
meme_ids_in_queue = [meme["id"] for meme in memes_in_queue]

user_info = await get_user_info(user_id)

candidates = []

r = random.random()

if r < 0.2:
candidates = await get_fast_dopamine(
user_id, limit=limit, exclude_meme_ids=meme_ids_in_queue
)

elif user_info["nmemes_sent"] < 30:
candidates = await get_selected_sources(
user_id, limit=limit, exclude_meme_ids=meme_ids_in_queue
)

if len(candidates) == 0:
candidates = await get_best_memes_from_each_source(
user_id, limit=limit, exclude_meme_ids=meme_ids_in_queue
)

elif user_info["nmemes_sent"] < 100:
if r < 0.2:
candidates = await uploaded_memes(
user_id, limit=limit, exclude_meme_ids=meme_ids_in_queue
)
elif r < 0.4:
candidates = await get_fast_dopamine(
user_id, limit=limit, exclude_meme_ids=meme_ids_in_queue
)
elif r < 0.6:
candidates = await get_best_memes_from_each_source(
user_id, limit=limit, exclude_meme_ids=meme_ids_in_queue
)
else:
candidates = await get_lr_smoothed(
user_id, limit=limit, exclude_meme_ids=meme_ids_in_queue
)

else:
if r < 0.3:
candidates = await uploaded_memes(
user_id, limit=limit, exclude_meme_ids=meme_ids_in_queue
)
if r < 0.6:
candidates = await like_spread_and_recent_memes(
user_id, limit=limit, exclude_meme_ids=meme_ids_in_queue
)
else:
candidates = await get_lr_smoothed(
user_id, limit=limit, exclude_meme_ids=meme_ids_in_queue
)

if len(candidates) == 0:
candidates = await get_lr_smoothed(
user_id, limit=limit, exclude_meme_ids=meme_ids_in_queue
)

if len(candidates) == 0 and user_info["nmemes_sent"] > 1000:
candidates = await less_seen_meme_and_source(
user_id, limit=limit, exclude_meme_ids=meme_ids_in_queue)

if len(candidates) == 0:
# TODO: fallback to some algo which will always return something
return

# TODO:
# inference ML api
# select the best LIMIT memes -> save them to queue

await redis.add_memes_to_queue_by_key(queue_key, candidates)


async def generate_with_blender(
async def generate_recommendations(
user_id: int,
limit: int,
nmemes_sent: Optional[int] = None,
Expand Down Expand Up @@ -193,8 +110,10 @@ async def generate_with_blender(
async def get_candidates(user_id, limit):
"""A helper function to avoid copy-paste"""

print('get_candidates')
# <30 is treated as cold start. no blending
if nmemes_sent < 30:
print('less 30')
candidates = await retriever.get_candidates(
'fast_dopamine', user_id, limit, exclude_mem_ids=meme_ids_in_queue)

Expand All @@ -213,12 +132,10 @@ async def get_candidates(user_id, limit):
'lr_smoothed': 0.4,
}

engines = ['uploaded_memes', 'fast_dopamine',
'best_memes_from_each_source', 'lr_smoothed']
candidates_dict = await retriever.get_candidates_dict(
engines, user_id, limit, exclude_mem_ids=meme_ids_in_queue)
weights.keys(), user_id, limit, exclude_mem_ids=meme_ids_in_queue)

fixed_pos = {0: 'lr_smoothed', 1: 'lr_smoothed'}
fixed_pos = {0: 'lr_smoothed'}
return blend(candidates_dict, weights, fixed_pos, limit, random_seed)

# >=100
Expand All @@ -228,11 +145,10 @@ async def get_candidates(user_id, limit):
'lr_smoothed': 0.4,
}

engines = ['uploaded_memes', 'like_spread_and_recent_memes', 'lr_smoothed']
candidates_dict = await retriever.get_candidates_dict(
engines, user_id, limit, exclude_mem_ids=meme_ids_in_queue)
weights.keys(), user_id, limit, exclude_mem_ids=meme_ids_in_queue)

fixed_pos = {0: 'lr_smoothed', 1: 'lr_smoothed'}
fixed_pos = {0: 'lr_smoothed'}
candidates = blend(candidates_dict, weights, fixed_pos, limit, random_seed)

if len(candidates) == 0 and nmemes_sent > 1000:
Expand Down
3 changes: 2 additions & 1 deletion src/tgbot/senders/next_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def get_next_meme_for_user(user_id: int) -> MemeData | None:
while True:
meme = await meme_queue.get_next_meme_for_user(user_id)
if not meme: # no memes in queue
await meme_queue.generate_recommendations(user_id, limit=5)
await meme_queue.generate_recommendations(user_id, limit=7)
meme = await meme_queue.get_next_meme_for_user(user_id)
if not meme:
return None
Expand All @@ -76,6 +76,7 @@ async def next_message(
if popup:
return await send_popup(user_id, popup)

print("here")
meme = await get_next_meme_for_user(user_id)
if not meme:
asyncio.create_task(meme_queue.check_queue(user_id))
Expand Down
38 changes: 13 additions & 25 deletions tests/recommendations/test_meme_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import pytest

from src.recommendations.candidates import CandidatesRetriever
from src.recommendations.meme_queue import generate_with_blender
from src.recommendations.meme_queue import generate_recommendations


@pytest.mark.asyncio
async def test_generate_with_blender_below_30():
async def test_generate_below_30():
async def get_fast_dopamine(
self,
user_id: int,
Expand Down Expand Up @@ -44,7 +44,7 @@ class TestRetriever(CandidatesRetriever):
'best_meme_from_each_source': get_best_memes_from_each_source,
}

candidates = await generate_with_blender(1, 10, 10, TestRetriever())
candidates = await generate_recommendations(1, 10, 10, TestRetriever())
assert len(candidates) == 2
assert candidates[0]['id'] == 1
assert candidates[1]['id'] == 2
Expand All @@ -55,14 +55,14 @@ class TestRetriever(CandidatesRetriever):
'best_memes_from_each_source': get_best_memes_from_each_source,
}

candidates = await generate_with_blender(1, 10, 10, TestRetriever())
candidates = await generate_recommendations(1, 10, 10, TestRetriever())
assert len(candidates) == 2
assert candidates[0]['id'] == 3
assert candidates[1]['id'] == 4


@pytest.mark.asyncio
async def test_generate_with_blender_below_100():
async def test_generate_below_100():
async def uploaded_memes(
self,
user_id: int,
Expand Down Expand Up @@ -115,18 +115,12 @@ class TestRetriever(CandidatesRetriever):
'lr_smoothed': get_lr_smoothed,
}

candidates = await generate_with_blender(1, 10, 40, TestRetriever())
candidates = await generate_recommendations(1, 10, 40, TestRetriever())
assert len(candidates) == 10
# hardcoded values
assert candidates[0]['id'] == 7
assert candidates[1]['id'] == 8
assert candidates[2]['id'] == 9
assert candidates[3]['id'] == 1
assert candidates[4]['id'] == 3
assert candidates[5]['id'] == 4
assert candidates[0]['id'] in [7, 8, 9, 10]

@pytest.mark.asyncio
async def test_generate_with_blender_above_100():
async def test_generate_above_100():
async def uploaded_memes(
self,
user_id: int,
Expand Down Expand Up @@ -170,19 +164,13 @@ class TestRetriever(CandidatesRetriever):
'lr_smoothed': get_lr_smoothed,
}

candidates = await generate_with_blender(1, 10, 200, TestRetriever(), random_seed=102)
candidates = await generate_recommendations(1, 10, 200, TestRetriever(), random_seed=102)
assert len(candidates) == 10
# hardcoded values
assert candidates[0]['id'] == 7
assert candidates[1]['id'] == 8
assert candidates[2]['id'] == 1
assert candidates[3]['id'] == 9
assert candidates[4]['id'] == 2
assert candidates[5]['id'] == 10
assert candidates[0]['id'] in [7, 8, 9, 10]


@pytest.mark.asyncio
async def test_generate_with_blender_empty_above_100():
async def test_generate_empty_above_100():
async def uploaded_memes(
self,
user_id: int,
Expand Down Expand Up @@ -239,10 +227,10 @@ class TestRetriever(CandidatesRetriever):
'best_memes_from_each_source': get_best_memes_from_each_source,
}

candidates = await generate_with_blender(1, 10, 200, TestRetriever())
candidates = await generate_recommendations(1, 10, 200, TestRetriever())
assert len(candidates) == 2
assert candidates[0]['id'] == 3

candidates = await generate_with_blender(1, 10, 1200, TestRetriever())
candidates = await generate_recommendations(1, 10, 1200, TestRetriever())
assert len(candidates) == 2
assert candidates[0]['id'] == 1

0 comments on commit d583f36

Please sign in to comment.