Skip to content

Commit

Permalink
option to export alternative classifier outputs with more information
Browse files Browse the repository at this point in the history
  • Loading branch information
kdutia committed Nov 22, 2023
1 parent 1455454 commit 73bd93f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ train_instruments_classifier:

# NOTE: these should be run against the *best* model artifact, not the latest
run_sector_classifier:
poetry run python classifiers/run_on_full_dataset.py --spans-csv-filename ${SPANS_CSV_FILENAME} --wandb-artifact-name climatepolicyradar/sector-text-classifier/sector-text-classifier:latest --output-dir ./concepts/sectors
poetry run python classifiers/run_on_full_dataset.py --spans-csv-filename ${SPANS_CSV_FILENAME} --wandb-artifact-name climatepolicyradar/sector-text-classifier/sector-text-classifier:latest --output-dir ./concepts/sectors --extra-output

run_instruments_classifier:
poetry run python classifiers/run_on_full_dataset.py --spans-csv-filename ${SPANS_CSV_FILENAME} --wandb-artifact-name climatepolicyradar/policy-instrument-text-classifier/policy-instrument-text-classifier:latest --output-dir ./concepts/policy-instruments
poetry run python classifiers/run_on_full_dataset.py --spans-csv-filename ${SPANS_CSV_FILENAME} --wandb-artifact-name climatepolicyradar/policy-instrument-text-classifier/policy-instrument-text-classifier:latest --output-dir ./concepts/policy-instruments --extra-output

# split spans csvs into smaller chunks that can be pushed to git
split_spans_csvs:
Expand Down
25 changes: 24 additions & 1 deletion classifiers/run_on_full_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,17 @@
default="spans.csv",
help="The filename to use for the spans CSV output file, including the .csv extension",
)
def cli(wandb_artifact_name: str, output_dir: Path, spans_csv_filename: str) -> None:
@click.option(
"--extra-output",
is_flag=True,
help="Whether to output an extra predictions.csv file. Output filename will use the same sufix as the spans CSV file.",
)
def cli(
wandb_artifact_name: str,
output_dir: Path,
spans_csv_filename: str,
extra_output: bool,
) -> None:
"""
Run a classifier from weights and biases on the full dataset.
Expand Down Expand Up @@ -93,6 +103,19 @@ def cli(wandb_artifact_name: str, output_dir: Path, spans_csv_filename: str) ->
spans_output_path.write_text(spans_df.to_csv(index=False))
LOGGER.info(f"Spans written to {spans_output_path}")

if extra_output:
predictions_df.columns = [
f"pred_{col}" if col in class_names else col
for col in predictions_df.columns
]

stem_suffix = spans_output_path.stem[len("spans") :]
predictions_output_path = output_dir / f"predictions{stem_suffix}.csv"
predictions_df.to_csv(predictions_output_path, index=False)
LOGGER.info(
f"Predictions in alternative format written to {predictions_output_path}"
)


if __name__ == "__main__":
cli()

0 comments on commit 73bd93f

Please sign in to comment.