Skip to content

Commit

Permalink
add test to ensure only calling async when supported for the push mode
Browse files Browse the repository at this point in the history
Signed-off-by: Rob Howley <[email protected]>
  • Loading branch information
robhowley committed Oct 18, 2024
1 parent c53664a commit 344171c
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
1 change: 1 addition & 0 deletions sdk/python/feast/feature_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ async def push(body=Depends(get_body)):
and to in [PushMode.ONLINE, PushMode.ONLINE_AND_OFFLINE]
)
if should_push_async:
print("im in async?")
await store.push_async(**push_params)
else:
store.push(**push_params)
Expand Down
18 changes: 18 additions & 0 deletions sdk/python/tests/foo_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
from feast.infra.offline_stores.offline_store import RetrievalJob
from feast.infra.provider import Provider
from feast.infra.registry.base_registry import BaseRegistry
from feast.infra.supported_async_methods import (
ProviderAsyncMethods,
SupportedAsyncMethods,
)
from feast.online_response import OnlineResponse
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import RepeatedValue
Expand All @@ -30,6 +34,20 @@


class FooProvider(Provider):
@staticmethod
def with_async_support(online_read=False, online_write=False):
class _FooProvider(FooProvider):
@property
def async_supported(self):
return ProviderAsyncMethods(
online=SupportedAsyncMethods(
read=online_read,
write=online_write,
)
)

return _FooProvider(None)

def __init__(self, config: RepoConfig):
pass

Expand Down
48 changes: 48 additions & 0 deletions sdk/python/tests/unit/test_feature_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import json
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from fastapi.testclient import TestClient

from feast import FeatureStore
from feast.data_source import PushMode
from feast.feature_server import get_app
from feast.utils import _utc_now
from tests.foo_provider import FooProvider


@pytest.mark.parametrize(
"online_write,push_mode,async_count",
[
(True, PushMode.ONLINE_AND_OFFLINE, 1),
(True, PushMode.OFFLINE, 0),
(True, PushMode.ONLINE, 1),
(False, PushMode.ONLINE_AND_OFFLINE, 0),
(False, PushMode.OFFLINE, 0),
(False, PushMode.ONLINE, 0),
],
)
def test_push_online_async_supported(online_write, push_mode, async_count, environment):
push_payload = json.dumps(
{
"push_source_name": "location_stats_push_source",
"df": {
"location_id": [1],
"temperature": [100],
"event_timestamp": [str(_utc_now())],
"created": [str(_utc_now())],
},
"to": push_mode.name.lower(),
}
)

provider = FooProvider.with_async_support(online_write=online_write)
print(provider.async_supported.online.write)
with patch.object(FeatureStore, "_get_provider", return_value=provider):
fs = environment.feature_store
fs.push = MagicMock()
fs.push_async = AsyncMock()
client = TestClient(get_app(fs))
client.post("/push", data=push_payload)
assert fs.push.call_count == 1 - async_count
assert fs.push_async.await_count == async_count

0 comments on commit 344171c

Please sign in to comment.