Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

maps by keyword #10

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,821 changes: 1,821 additions & 0 deletions experimental/document-viewer.ipynb

Large diffs are not rendered by default.

522 changes: 521 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ jupyterlab = "^4.2.5"
datasets = "^3.0.1"
duckdb = "^1.1.1"
seaborn = "^0.13.2"
geopandas = "^1.0.1"
geodatasets = "^2024.8.0"
streamlit = "^1.39.0"

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.8.0"
Expand Down
26 changes: 26 additions & 0 deletions streamlit_apps/data_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Optional
from huggingface_hub import snapshot_download

REPO_NAME = "ClimatePolicyRadar/all-document-text-data-weekly"
REPO_URL = f"https://huggingface.co/datasets/{REPO_NAME}"
CACHE_DIR = "../cache"

# INTERNAL NOTE: use this commit hash until the weekly pipeline run that produces the data
# is stable.
REVISION = "bd0abf24ae34d3150bdd8ac66f36a28e47f3ee93" # Use this to set a commit hash. Recommended!


def download_data(cache_dir: str, revision: Optional[str] = None) -> None:
"""
Download the data to the cache directory.

:param cache_dir: the directory to save the data to
:param revision: optional commit hash, defaults to None
"""
snapshot_download(
repo_id=REPO_NAME,
repo_type="dataset",
local_dir=cache_dir,
revision=revision,
allow_patterns=["*.parquet"],
)

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
5.1.1
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
UTF-8
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137,298.257223563]],PRIMEM["Greenwich",0],UNIT["Degree",0.017453292519943295]]
Binary file not shown.
Binary file not shown.
231 changes: 231 additions & 0 deletions streamlit_apps/searchable_world_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
from pathlib import Path
import duckdb
import geopandas as gpd
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
import streamlit as st


from data_helpers import download_data

CACHE_DIR = Path(__file__).parent / "../cache"


@st.cache_resource
def load_data():
# INTERNAL NOTE: use this commit hash until the weekly pipeline run that produces the
# data is stable.
download_data(
cache_dir=str(CACHE_DIR),
revision="bd0abf24ae34d3150bdd8ac66f36a28e47f3ee93",
)

db = duckdb.connect()

# Authenticate (needed if loading a private dataset)
# You'll need to log in using `huggingface-cli login` in your terminal first
db.execute("CREATE SECRET hf_token (TYPE HUGGINGFACE, PROVIDER credential_chain);")

# Create a view called 'open_data', and count the number of rows and distinct documents
# in the view
db.execute(
f"CREATE VIEW open_data AS SELECT * FROM read_parquet('{CACHE_DIR}/*.parquet')"
)

return db


def get_geography_count_for_texts(texts: list[str]) -> pd.DataFrame:
"""
Get the number of paragraphs containing any of the given texts, grouped by geography.

Returns dataframe with columns 'geography ISO', and 'count'.
"""
texts = [f"\\b{text.lower()}\\b" for text in texts]
regex = f".*({'|'.join(texts)}).*"
results_df = db.sql(
f"""
SELECT "document_metadata.geographies", COUNT(*)
FROM open_data
WHERE "text_block.language" = 'en'
AND (lower("text_block.text") SIMILAR TO '{regex}')
AND "document_metadata.geographies" IS NOT NULL
AND "document_metadata.geographies" <> ['XAA']
GROUP BY "document_metadata.geographies"
ORDER BY COUNT(*) DESC
"""
).to_df()

results_df["document_metadata.geographies"] = results_df[
"document_metadata.geographies"
].apply(lambda x: x[0])

results_df = results_df.rename(
columns={
"document_metadata.geographies": "geography ISO",
"count_star()": "count",
}
)

return results_df


@st.cache_data
def get_num_paragraphs_in_db() -> int:
return db.sql("SELECT COUNT(*) FROM open_data").to_df().iloc[0, 0]


@st.cache_data
def load_world_geometries():
"""
Get world geometries in Eckert IV projection.

Drop Antarctica and Seven seas (open ocean) geometries to make the map look nicer.
"""
world = gpd.read_file(
Path(__file__).parent
/ "./ne_50m_admin_0_countries/ne_50m_admin_0_countries.shp"
)
world = world.to_crs(
"+proj=eck4 +lon_0=0 +x_0=0 +y_0=0 +datum=WGS84 +units=m +no_defs"
)

world = world[~world["ADMIN"].isin(["Antarctica", "Seven seas (open ocean)"])]

return world


def plot_country_map(
keywords: list[str],
normalize_counts: bool = False,
axis=None,
):
"""
Plot a map of the world with countries colored by the number of paragraphs containing any of the given keywords.

Returns the raw results.
"""
results_df = get_geography_count_for_texts(keywords)

# normalise by paragraph_count_by_geography
if normalize_counts:
results_df["count"] = (
results_df["count"] / paragraph_count_by_geography["count"]
)
legend_label = "Relative frequency in dataset"
else:
legend_label = "Number of paragraphs"

min_count, max_count = results_df["count"].min(), results_df["count"].max()
num_geographies = results_df["geography ISO"].nunique()

world_with_counts = world.merge(
results_df, left_on="ADM0_A3", right_on="geography ISO", how="left"
)

if axis:
fig = axis.get_figure()
else:
fig, axis = plt.subplots(figsize=(18, 9), dpi=300)

world_with_counts.plot(
column="count",
legend=False,
figsize=(20, 10),
ax=axis,
vmin=min_count,
vmax=max_count,
cmap="viridis_r",
edgecolor="face",
linewidth=0.3, # helps small states stand out
missing_kwds={"color": "darkgrey", "edgecolor": "white", "hatch": "///"},
)

divider = make_axes_locatable(axis)
cax = divider.append_axes("bottom", size="5%", pad=0.05)
fig.colorbar(
mpl.cm.ScalarMappable(
norm=mpl.colors.Normalize(vmin=min_count, vmax=max_count), cmap="viridis_r"
),
cax=cax,
orientation="horizontal",
label=legend_label,
)

sns.despine(ax=axis, top=True, bottom=True, left=True, right=True)
axis.set_xticks([])
axis.set_yticks([])

fig.tight_layout()

axis.set_title(
f"Number of paragraphs containing words '{', '.join(keywords)}'. {num_geographies} total geographies."
)

return results_df


def plot_normalised_unnormalised_subplots(
kwds,
) -> tuple[plt.Figure, pd.DataFrame, pd.DataFrame]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 9), dpi=300)

df_unnorm = plot_country_map(
kwds,
normalize_counts=False,
axis=ax1,
)

df_norm = plot_country_map(
kwds,
normalize_counts=True,
axis=ax2,
)

return fig, df_unnorm, df_norm


if __name__ == "__main__":
st.set_page_config(layout="wide")

db = load_data()
world = load_world_geometries()
num_paragraphs_in_db = get_num_paragraphs_in_db()
paragraph_count_by_geography = get_geography_count_for_texts([".*"])

st.title("Searchable World Map")
st.markdown(
"Search for keywords in the dataset and see where they appear on a world map."
)
with st.expander("You can use regex! Open for examples"):
st.markdown("""
- `natural(-|\s)resource`: match "natural-resource" and "natural resource"
- `fish(es)?`: match "fish" and "fishes"
- `elephants?`: match "elephant" and "elephants"
""")

kwds = st.text_input(
"Enter keywords separated by commas (spaces next to commas will be ignored)"
)

if kwds:
kwds = [word.strip() for word in kwds.split(",")]

st.markdown("## all keywords")
fig, data1, data2 = plot_normalised_unnormalised_subplots(kwds)
n_paragraphs = data1["count"].sum()
percentage = round(n_paragraphs / num_paragraphs_in_db * 100, 2)
st.markdown(f"Num paragraphs: {n_paragraphs:,} ({percentage}%)")
st.write(fig)

if len(kwds) > 1:
for keyword in kwds:
st.markdown(f"## {keyword}")
fig, data1, data2 = plot_normalised_unnormalised_subplots([keyword])
n_paragraphs = data1["count"].sum()
percentage = round(n_paragraphs / num_paragraphs_in_db * 100, 2)
st.markdown(f"Num paragraphs: {n_paragraphs:,} ({percentage}%)")
st.write(fig)
Loading