Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add per indicator thrsholding and new dump #1073

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/ej/cmr_to_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def categorize_processing_level(level):
# remove existing data
EnvironmentalJusticeRow.objects.filter(destination_server=EnvironmentalJusticeRow.DestinationServerChoices.DEV).delete()

ej_dump = json.load(open("backups/ej_dump_20240815_112916.json"))
ej_dump = json.load(open("backups/ej_dump_20241017_133151.json.json"))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double extension?

for dataset in ej_dump:
ej_row = EnvironmentalJusticeRow(
destination_server=EnvironmentalJusticeRow.DestinationServerChoices.DEV,
Expand Down
37 changes: 26 additions & 11 deletions scripts/ej/create_ej_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
inferences are supplied by the classification model. the contact point is Bishwas
cmr is supplied by running
github.com/NASA-IMPACT/llm-app-EJ-classifier/blob/develop/scripts/data_processing/download_cmr.py
move to the serve like this: scp ej_dump_20240814_143036.json sde:/home/ec2-user/sde_indexing_helper/backups/
move to the server like this: scp ej_dump_20241017_133151.json sde:/home/ec2-user/sde_indexing_helper/backups/
"""

import json
Expand All @@ -19,20 +19,22 @@ def save_to_json(data: dict | list, file_path: str) -> None:
json.dump(data, file, indent=2)


def process_classifications(predictions: list[dict[str, float]], threshold: float = 0.5) -> list[str]:
def process_classifications(predictions: list[dict[str, float]], thresholds: dict[str, float]) -> list[str]:
"""
Process the predictions and classify as follows:
1. If 'Not EJ' is the highest scoring prediction, return 'Not EJ' as the only classification
2. Filter classifications based on the threshold, excluding 'Not EJ'
3. Default to 'Not EJ' if no classifications meet the threshold
Process the predictions and classify based on the individual thresholds per indicator:
1. If 'Not EJ' is the highest scoring prediction, return 'Not EJ' as the only classification.
2. Filter classifications based on their individual thresholds, excluding 'Not EJ'.
3. Default to 'Not EJ' if no classifications meet the threshold.
"""
highest_prediction = max(predictions, key=lambda x: x["score"])

if highest_prediction["label"] == "Not EJ":
return ["Not EJ"]

classifications = [
pred["label"] for pred in predictions if pred["score"] >= threshold and pred["label"] != "Not EJ"
pred["label"]
for pred in predictions
if pred["score"] >= thresholds[pred["label"]] and pred["label"] != "Not EJ"
]

return classifications if classifications else ["Not EJ"]
Expand Down Expand Up @@ -63,14 +65,14 @@ def remove_unauthorized_classifications(classifications: list[str]) -> list[str]
def update_cmr_with_classifications(
inferences: list[dict[str, dict]],
cmr_dict: dict[str, dict[str, dict]],
threshold: float = 0.5,
thresholds: dict[str, float],
) -> list[dict[str, dict]]:
"""Update CMR data with valid classifications based on inferences."""

predicted_cmr = []

for inference in inferences:
classifications = process_classifications(predictions=inference["predictions"], threshold=threshold)
classifications = process_classifications(predictions=inference["predictions"], thresholds=thresholds)
classifications = remove_unauthorized_classifications(classifications)

if classifications:
Expand All @@ -84,17 +86,30 @@ def update_cmr_with_classifications(


def main():
inferences = load_json_file("cmr-inference.json")
thresholds = {
"Not EJ": 0.80,
"Climate Change": 0.95,
"Disasters": 0.80,
"Extreme Heat": 0.50,
"Food Availability": 0.80,
"Health & Air Quality": 0.90,
"Human Dimensions": 0.80,
"Urban Flooding": 0.50,
"Water Availability": 0.80,
}
Comment on lines +89 to +99

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to config or somewhere more visible / elegant / configurable


inferences = load_json_file("alpha-1.3-wise-vortex-42-predictions.json")
cmr = load_json_file("cmr_collections_umm_20240807_142146.json")

cmr_dict = create_cmr_dict(cmr)

predicted_cmr = update_cmr_with_classifications(inferences=inferences, cmr_dict=cmr_dict, threshold=0.8)
predicted_cmr = update_cmr_with_classifications(inferences=inferences, cmr_dict=cmr_dict, thresholds=thresholds)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
file_name = f"ej_dump_{timestamp}.json"

save_to_json(predicted_cmr, file_name)
print(f"Saved to {file_name}")


if __name__ == "__main__":
Expand Down