Skip to content

Commit

Permalink
Merge pull request #7 from AYLIEN/implement-semantic-filtering
Browse files Browse the repository at this point in the history
adds semantic filters base implementation
  • Loading branch information
chrishokamp authored Oct 29, 2024
2 parents 2dc3d67 + 1d56dc0 commit f597ed1
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 1 deletion.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.7.0
0.7.1
20 changes: 20 additions & 0 deletions news_signals/semantic_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Any


class SemanticFilter:

def __call__(self, item: Any) -> bool:
raise NotImplementedError("Subclasses must implement this method")

@property
def name(self) -> str:
return self.__class__.__name__


class StoryKeywordMatchFilter(SemanticFilter):

def __init__(self, keywords: list[str]):
self.keywords = keywords

def __call__(self, item: dict) -> bool:
return any(kw in item['title'] for kw in self.keywords)
17 changes: 17 additions & 0 deletions news_signals/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .anomaly_detection import SigmaAnomalyDetector
from .aql_builder import params_to_aql
from .summarization import Summarizer
from .semantic_filters import SemanticFilter
from .exogenous_signals import (
wikidata_id_to_wikimedia_pageviews_timeseries,
wikidata_id_to_current_events
Expand Down Expand Up @@ -769,7 +770,23 @@ def sample_stories(self, num_stories=10, **kwargs):
**kwargs
)
return self

def filter_stories(self, filter_model: SemanticFilter, delete_filtered: bool = True, **kwargs):
"""
Filter stories in the signal using a semantic model, adding a column `matching_scores` to the feeds_df
"""
for index, tick_stories in self.feeds_df['stories'].items():
filtered_stories = []
for story in tick_stories:
keep = filter_model(story)
if story.get('filter_model_scores') is None:
story['filter_model_outputs'] = [(filter_model.name, keep)]
if keep or not delete_filtered:
filtered_stories.append(story)
self.feeds_df.at[index, 'stories'] = filtered_stories

return self

@staticmethod
def normalize_aylien_story(story):
"""
Expand Down
40 changes: 40 additions & 0 deletions news_signals/test_semantic_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
from pathlib import Path

from news_signals.test_signals import SignalTest
from news_signals.semantic_filters import StoryKeywordMatchFilter
from news_signals.log import create_logger


logger = create_logger(__name__)

path_to_file = Path(os.path.dirname(os.path.abspath(__file__)))
resources = Path(os.environ.get(
'RESOURCES', path_to_file / '../resources/test'))


class TestFilterSignal(SignalTest):

def test_filter_signal(self):
example_signal = self.aylien_signals()[0]
orig_stories_per_tick = [len(tick) for tick in example_signal['stories']]

keywords = ['Million']
filter_model = StoryKeywordMatchFilter(keywords=keywords)
filtered_signal = example_signal.filter_stories(filter_model=filter_model)
filtered_stories_per_tick = [len(tick) for tick in filtered_signal['stories']]
assert(sum(orig_stories_per_tick) > sum(filtered_stories_per_tick))
for tick_stories in filtered_signal['stories']:
for s in tick_stories:
assert any(kw in s['title'] for kw in keywords)

# test don't delete filtered
example_signal = self.aylien_signals()[0]
filtered_signal = example_signal.filter_stories(filter_model=filter_model, delete_filtered=False)
filtered_stories_per_tick = [len(tick) for tick in filtered_signal['stories']]
assert(sum(orig_stories_per_tick) == sum(filtered_stories_per_tick))





0 comments on commit f597ed1

Please sign in to comment.