diff --git a/src/streamlit_apps/searchable_world_map.py b/src/streamlit_apps/searchable_world_map.py index c0c06a8..c418a3a 100644 --- a/src/streamlit_apps/searchable_world_map.py +++ b/src/streamlit_apps/searchable_world_map.py @@ -1,3 +1,4 @@ +from typing import Optional from pathlib import Path import duckdb import geopandas as gpd @@ -35,7 +36,9 @@ def load_data(): return db -def get_geography_count_for_texts(texts: list[str]) -> pd.DataFrame: +def get_geography_count_for_texts( + texts: list[str], corpus_type_names: Optional[list[str]] +) -> pd.DataFrame: """ Get the number of paragraphs containing any of the given texts, grouped by geography. @@ -43,8 +46,16 @@ def get_geography_count_for_texts(texts: list[str]) -> pd.DataFrame: """ texts = [f"\\b{text.lower()}\\b" for text in texts] regex = f".*({'|'.join(texts)}).*" - results_df = db.sql( - f""" + + if corpus_type_names is None: + corpus_type_names_clause = "" + else: + corpus_type_names_string = ( + "(" + ",".join([f"'{name}'" for name in corpus_type_names]) + ")" + ) + corpus_type_names_clause = f"""AND "document_metadata.corpus_type_name" IN {corpus_type_names_string} """ + + sql_query = f""" SELECT "document_metadata.geographies", COUNT(*) FROM open_data WHERE "text_block.language" = 'en' @@ -52,10 +63,14 @@ def get_geography_count_for_texts(texts: list[str]) -> pd.DataFrame: AND "document_metadata.geographies" IS NOT NULL AND "document_metadata.geographies" <> ['XAA'] AND ("text_block.type" = 'title' OR "text_block.type" = 'Text' OR "text_block.type" = 'sectionHeading') + {corpus_type_names_clause} GROUP BY "document_metadata.geographies" ORDER BY COUNT(*) DESC """ - ).to_df() + + print(sql_query) + + results_df = db.sql(sql_query).to_df() results_df["document_metadata.geographies"] = results_df[ "document_metadata.geographies" @@ -76,6 +91,19 @@ def get_num_paragraphs_in_db() -> int: return db.sql("SELECT COUNT(*) FROM open_data").to_df().iloc[0, 0] +@st.cache_data +def get_unique_corpus_type_names() -> list[str]: + names = ( + db.sql( + """SELECT DISTINCT "document_metadata.corpus_type_name" FROM open_data""" + ) + .to_df()["document_metadata.corpus_type_name"] + .tolist() + ) + + return [n for n in names if n is not None] + + @st.cache_data def load_world_geometries(): """ @@ -97,6 +125,7 @@ def load_world_geometries(): def plot_country_map( keywords: list[str], + corpus_type_names: Optional[list[str]] = None, normalize_counts: bool = False, axis=None, ): @@ -105,7 +134,7 @@ def plot_country_map( Returns the raw results. """ - results_df = get_geography_count_for_texts(keywords) + results_df = get_geography_count_for_texts(keywords, corpus_type_names) # normalise by paragraph_count_by_geography if normalize_counts: @@ -172,18 +201,20 @@ def plot_country_map( def plot_normalised_unnormalised_subplots( - kwds, + kwds, corpus_type_names: Optional[list[str]] = None ) -> tuple[plt.Figure, pd.DataFrame, pd.DataFrame]: # type: ignore fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 9), dpi=300) df_unnorm = plot_country_map( kwds, + corpus_type_names=corpus_type_names, normalize_counts=False, axis=ax1, ) df_norm = plot_country_map( kwds, + corpus_type_names=corpus_type_names, normalize_counts=True, axis=ax2, ) @@ -197,7 +228,6 @@ def plot_normalised_unnormalised_subplots( db = load_data() world = load_world_geometries() num_paragraphs_in_db = get_num_paragraphs_in_db() - paragraph_count_by_geography = get_geography_count_for_texts([".*"]) st.title("Searchable World Map") st.markdown( @@ -210,6 +240,13 @@ def plot_normalised_unnormalised_subplots( - `elephants?`: match "elephant" and "elephants" """) + corpus_type_names = get_unique_corpus_type_names() + st.markdown("## Select corpus type") + corpus_types = st.multiselect( + "Select corpus types", corpus_type_names, default=corpus_type_names + ) + paragraph_count_by_geography = get_geography_count_for_texts([".*"], corpus_types) + kwds = st.text_input( "Enter keywords separated by commas (spaces next to commas will be ignored)" ) @@ -218,7 +255,7 @@ def plot_normalised_unnormalised_subplots( kwds = [word.strip() for word in kwds.split(",")] st.markdown("## all keywords") - fig, data1, data2 = plot_normalised_unnormalised_subplots(kwds) + fig, data1, data2 = plot_normalised_unnormalised_subplots(kwds, corpus_types) n_paragraphs = data1["count"].sum() percentage = round(n_paragraphs / num_paragraphs_in_db * 100, 2) st.markdown(f"Num paragraphs: {n_paragraphs:,} ({percentage}%)") @@ -227,7 +264,9 @@ def plot_normalised_unnormalised_subplots( if len(kwds) > 1: for keyword in kwds: st.markdown(f"## {keyword}") - fig, data1, data2 = plot_normalised_unnormalised_subplots([keyword]) + fig, data1, data2 = plot_normalised_unnormalised_subplots( + [keyword], corpus_types + ) n_paragraphs = data1["count"].sum() percentage = round(n_paragraphs / num_paragraphs_in_db * 100, 2) st.markdown(f"Num paragraphs: {n_paragraphs:,} ({percentage}%)")