Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add people and parliament #137

Merged
merged 55 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
3846fcd
Add ner_script.py
carschno Oct 31, 2024
a46a5c6
Add map_places.py
carschno Oct 31, 2024
037ea46
Cache Spacy model.
carschno Oct 31, 2024
4bff264
Output HeatMapWithTime.
carschno Oct 31, 2024
682237e
Handle gecodong with caching in class.
carschno Oct 31, 2024
a79e27b
Move scripts to sub-directories.
carschno Oct 31, 2024
080fec3
Add docstrings, fix minor things.
carschno Oct 31, 2024
3cf2e11
Use date as index for HeatMapWithTime.
carschno Oct 31, 2024
8bbcf6e
Use rolling window for heatmap.
carschno Oct 31, 2024
0a6e655
Filter out place names with less than 3 letters.
carschno Oct 31, 2024
74adba4
Skip caching in case of exception.
carschno Nov 4, 2024
409895a
Rename heat map layer, hide pins by default.
carschno Nov 4, 2024
6f0e15d
Add title argument, add locations only once.
carschno Nov 4, 2024
5c7a4f6
Reduce title width to 50%.
carschno Nov 4, 2024
3fe155f
Add all spellings and dates to each location pin.
carschno Nov 4, 2024
a67d26f
Extract read_data_list()
carschno Nov 4, 2024
963ae02
Extract create_markers()
carschno Nov 4, 2024
62241ba
Extract create_smoothed_heat_data()
carschno Nov 4, 2024
fd68d9f
Refactor create_markers() to add_markers().
carschno Nov 4, 2024
65eb1a1
Rename to extract_places.py
carschno Nov 4, 2024
af792f8
Remove default limit.
carschno Nov 4, 2024
8dd088a
Fix: skip redundant opening of output file.
carschno Nov 4, 2024
231848f
Log name of the file that is being processed.
carschno Nov 4, 2024
85598fa
Fix command line argument name.
carschno Nov 4, 2024
8643e1c
Catch more generic GeocoderServiceError.
carschno Nov 4, 2024
3b5c74e
Handle csv parsing errors.
carschno Nov 5, 2024
f6357e8
Update scikit-learn dependency to ~1.5.2
carschno Nov 5, 2024
de47d88
Add --resume flag.
carschno Nov 6, 2024
897856a
Add TestGeocoder, rename internal Geocoder methods.
carschno Nov 6, 2024
511cf36
Move ner dependencies to install_requires, download Spacy models wher…
carschno Nov 6, 2024
c27dccb
Remove explicit matplotlib dependency, sort dependencies.
carschno Nov 6, 2024
4cc78e2
Update ChromaDB dependency for NumPy 2.0 compatibility.
carschno Nov 6, 2024
662bfe1
Refine load_spacy_model.
carschno Nov 6, 2024
e20acb6
Add year range option.
carschno Nov 7, 2024
23d1591
Add docstrings, type hints, TODOs.
carschno Nov 8, 2024
49adf6d
Exclude end year.
carschno Nov 9, 2024
f40b52b
Use nl_core_news_lg for NL.
carschno Nov 11, 2024
6d0ff1e
Add PeopleAndParliament corpus.
carschno Nov 12, 2024
66e9acc
Add max_corpus_size argument, make Segmenter and CorpusReader yield P…
carschno Nov 13, 2024
6d2fbbf
Set alias for id field.
carschno Nov 13, 2024
966150a
Pin SciPy dependency.
carschno Nov 13, 2024
496eec9
Revert "Pin SciPy dependency."
carschno Nov 13, 2024
92aaa7d
Use date field, datetime format for year filtering by default.
carschno Nov 14, 2024
4540eaf
Add doc_frequencies_per_year().
carschno Nov 14, 2024
2ec9644
Derive fields from metadata by default.
carschno Nov 15, 2024
3cd418e
Filter derived metadata fields.
carschno Nov 15, 2024
1ee7f72
Use date instead of year, add timezone to dates.
carschno Nov 18, 2024
bafaa5b
Adjust term_frequencies notebook.
carschno Nov 18, 2024
8433657
Do not include pins into map by default.
carschno Nov 18, 2024
e017ebf
Fix: replace global logger with __name__ (invalid in Python 3.9).
carschno Nov 18, 2024
30f43b0
Add Passage.Metadata.model_field_names() to get model field names, wo…
carschno Nov 18, 2024
f472ad4
Fix logger name.
carschno Nov 18, 2024
1eb7618
Extract YearSpan._to_datetimes()
carschno Nov 18, 2024
f354854
Fix left-over comments/returns.
carschno Nov 18, 2024
e538563
Update term_frequencies notebook.
carschno Nov 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
661 changes: 271 additions & 390 deletions notebooks/term_frequency.ipynb

Large diffs are not rendered by default.

137 changes: 137 additions & 0 deletions scripts/extract_places.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import argparse
import csv
import logging
import sys
from functools import lru_cache
from pathlib import Path

import spacy
import spacy.cli
from spacy.language import Language
from tqdm import tqdm

from tempo_embeddings.io.corpus_reader import CorpusReader

MODEL_NAMES: dict[str, str] = {"en": "en_core_web_sm", "nl": "nl_core_news_lg"}


@lru_cache(maxsize=None)
def load_spacy_model(language: str, *, download: bool = True) -> Language:
"""Load SpaCy model for a given language.

Args:
language (str): Language code.
download (bool): Whether to download the model if not available.
Raises:
ValueError: If no model is available for the given language.
OSError: If the model cannot be loaded and 'download' is False.
"""

try:
model_name = MODEL_NAMES[language]
model: Language = spacy.load(model_name)
except KeyError as e:
raise ValueError(
f"No SpaCy model available for language '{language}'. Available languages are: {list(MODEL_NAMES.keys())}"
) from e
except OSError as e:
if download:
logging.warning(
f"Failed to load Spacy model for language '{language}': '{e}. '{e}'. Downloading and re-trying."
)
spacy.cli.download(model_name)

# retry loading the model, but don't retry downloading:
model = load_spacy_model(language, download=False)
else:
raise RuntimeError(e)
return model


def extract_years_from_csv(csvfile: Path):
years = set()
with csvfile.open(mode="r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
date = row["date"]
if date != "unknown":
year = date.split("-")[0]
years.add(year)
return years


def main(corpora, csvfile: Path, resume: bool):
file_exists = csvfile.exists()

if resume and file_exists:
years_to_skip = extract_years_from_csv(csvfile)
logging.info(f"Skipping years: {years_to_skip}")
else:
years_to_skip = set()

fieldnames = ["date", "source", "place_name"]
with csvfile.open(mode="a", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)

if not file_exists:
writer.writeheader()

corpus_reader = CorpusReader(corpora=corpora)

for corpus_name in corpora:
corpus_config = corpus_reader[corpus_name]
nlp = load_spacy_model(corpus_config.language)

skip_files: set[str] = {
file.name
for file in corpus_config.files()
if any(year in file.name for year in years_to_skip)
}
logging.debug(f"Skipping files: {skip_files}")

for corpus in corpus_config.build_corpora(
filter_terms=[], skip_files=skip_files
):
try:
provenance = corpus.passages[0].metadata.get("provenance")
except IndexError:
logging.warning(f"Empty corpus: {corpus_name}")
continue
rows = [
{
"date": passage.metadata["date"],
"source": corpus_name,
"place_name": ent.text,
}
for passage in tqdm(
corpus.passages, desc=provenance, unit="passage"
)
for ent in nlp(passage.text).ents
if ent.label_ == "GPE"
]
writer.writerows(rows)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Perform NER on corpora and extract place names."
)
parser.add_argument("--corpora", nargs="+", help="List of corpora to process")
parser.add_argument(
"--output",
"-o",
type=Path,
default=Path(sys.stdout.name),
help="Output CSV file",
)
parser.add_argument(
"--resume",
action="store_true",
help="Resume from the last run by reading the existing output file",
)
args = parser.parse_args()

if not args.resume and args.output.exists():
parser.error(f"Output file already exists: {args.output}")

main(args.corpora, args.output, args.resume)
243 changes: 243 additions & 0 deletions scripts/map_places.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
import csv
import re
from collections import defaultdict, deque
from typing import Optional

import folium
import pandas as pd
from folium.plugins import HeatMapWithTime
from tqdm import tqdm

from tempo_embeddings.io.geocoder import Geocoder

# TODO: use (named) tuple for coordinates
# TODO simplify read_data_list to use a single loop and single return variable


def read_data_list(
input_csv: str,
limit: Optional[int],
geocoder: Geocoder,
start_year: Optional[int],
end_year: Optional[int],
) -> tuple[list[list[str, float, float]], dict[str, list[list[float]]]]:
"""
Reads data from a CSV file and filters it based on the provided criteria.

Args:
input_csv (str): Path to the input CSV file.
limit (Optional[int]): Maximum number of rows to process.
geocoder (Geocoder): Geocoder instance for geocoding place names.
start_year (Optional[int]): Start year for filtering data.
end_year (Optional[int]): End year for filtering data.

Returns:
tuple: A tuple containing the filtered data and heatmap data.
"""
data = []
heat_data = defaultdict(list)
with open(input_csv, mode="r", encoding="utf-8") as csvfile:
reader = csv.DictReader(csvfile)
total_lines = sum(1 for _ in reader) - 1
if limit:
total_lines = min(total_lines, limit)
csvfile.seek(0) # Reset file pointer to the beginning
next(reader) # Skip header

for i, row in enumerate(
tqdm(reader, unit="row", desc="Processing places", total=total_lines)
):
if limit and i >= limit:
break
place_name = row["place_name"]
date = row["date"][:10] # Extract the date part (YYYY-MM-DD)
year = int(date[:4])

if (
(start_year is None or year >= start_year)
and (end_year is None or year < end_year)
and len(re.findall(r"[a-zA-Z]", place_name)) >= 3 # valid place name?
):
latitude, longitude = geocoder.geocode_place(place_name)
if latitude and longitude:
data.append([place_name, latitude, longitude, date])
heat_data[date].append([latitude, longitude])

return data, heat_data


def add_markers(
data: list[list[str, float, float]], pins_group: folium.Element
) -> None:
"""
Adds markers to the map for each unique location.

Args:
data (list): List of place data.
pins_group (folium.Element): Folium feature group to add the markers to.
"""
df = pd.DataFrame(data, columns=["place_name", "latitude", "longitude", "date"])
grouped = (
df.groupby(["latitude", "longitude"])
.agg(
{
"place_name": lambda x: list(set(x)),
"date": lambda x: list(sorted(set(x))),
}
)
.reset_index()
)

for _, row in grouped.iterrows():
table_html = """
<div style="width: 300px;">
<table style="width: 100%;">
<tr><th>Place Name</th><th>Dates</th></tr>
"""
for place_name in row["place_name"]:
place_dates = df[
(df["latitude"] == row["latitude"])
& (df["longitude"] == row["longitude"])
& (df["place_name"] == place_name)
]["date"].tolist()
table_html += f"<tr><td>{place_name}</td><td>{', '.join(sorted(set(place_dates)))}</td></tr>"
table_html += "</table></div>"
folium.Marker([row["latitude"], row["longitude"]], popup=table_html).add_to(
pins_group
)


def create_smoothed_heat_data(
heat_data: dict[str, list[list[float]]], window_size: int
) -> tuple[list[list[list[float]]], list[str]]:
"""
Creates smoothed heatmap data using a sliding window.

Args:
heat_data (dict): Heatmap data.
window_size (int): Size of the sliding window.

Returns:
tuple: A tuple containing the smoothed heatmap data and sorted dates.
"""
sorted_dates = sorted(heat_data)
smoothed_heat_data = []
window = deque(maxlen=window_size)

for date in sorted_dates:
window.append(heat_data[date])
combined_data = [coord for day_data in window for coord in day_data]
smoothed_heat_data.append(combined_data)

return smoothed_heat_data, sorted_dates


def create_map(
input_csv: str,
output: str,
title: Optional[str],
limit: Optional[int],
window_size: int,
start_year: Optional[int],
end_year: Optional[int],
include_markers: bool,
) -> None:
"""
Creates a map with location pins and a time-space heatmap.

Args:
input_csv (str): Path to the input CSV file.
output (str): Path to the output HTML file.
title (Optional[str]): Title to be included in the map.
limit (Optional[int]): Maximum number of rows to process.
window_size (int): Size of the sliding window for smoothing the heatmap.
start_year (Optional[int]): Start year for filtering data.
end_year (Optional[int]): End year for filtering data.
"""
geocoder = Geocoder() # Initialize the Geocoder
map_ = folium.Map(location=[52.3676, 4.9041], zoom_start=6) # Centered on Amsterdam

# Add a title to the map if provided
if title:
title_html = f"""
<div style="position: fixed;
top: 10px; left: 50%; transform: translateX(-50%); width: auto; height: 50px;
background-color: white; z-index: 9999; font-size: 24px;">
<center>{title}</center>
</div>
"""
map_.get_root().html.add_child(folium.Element(title_html))

# Create a feature group for the location pins
pins_group = folium.FeatureGroup(name="Location Pins", show=False)

data, heat_data = read_data_list(input_csv, limit, geocoder, start_year, end_year)

if include_markers:
add_markers(data, pins_group)

smoothed_heat_data, sorted_dates = create_smoothed_heat_data(heat_data, window_size)

HeatMapWithTime(
smoothed_heat_data, index=sorted_dates, name="Time-Space Heat Map"
).add_to(map_)
pins_group.add_to(map_)
folium.LayerControl().add_to(map_) # Add layer control to toggle pins
map_.save(output) # Save the map to the file


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(
description="Create a map of places from a CSV file."
)
parser.add_argument("--input", "-i", help="Input CSV file with place names")
parser.add_argument(
"--output",
"-o",
type=argparse.FileType("x"),
required=True,
help="Output HTML file for the map",
)
parser.add_argument(
"--title", help="Title to be included in the map", required=False
)
parser.add_argument(
"--limit",
type=int,
required=False,
help="Limit the number of places to process",
)
parser.add_argument(
"--window-size",
type=int,
default=7,
help="Window size for smoothing the heatmap",
)
parser.add_argument(
"--start-year", "--start", type=int, help="Start year to include in the map"
)
parser.add_argument(
"--end-year", "--end", type=int, help="End year to include in the map"
)
parser.add_argument(
"--include-markers",
action="store_true",
help="Include indivdual location markers",
)
args = parser.parse_args()

if args.start_year and args.end_year and args.start_year >= args.end_year:
parser.error("START_YEAR must be smaller than END_YEAR")

create_map(
args.input,
args.output.name,
args.title,
args.limit,
args.window_size,
args.start_year,
args.end_year,
args.include_markers,
)
Loading
Loading