diff --git a/config.py b/config.py index 5c4dcae..6111f5c 100644 --- a/config.py +++ b/config.py @@ -3,10 +3,10 @@ API_TOKEN = "" -CLASSIFICATION_PRETRAINED_MODEL = "" +CLASSIFICATION_PRETRAINED_MODEL = "adsabs/ASTROBERT" CLASSIFICATION_PRETRAINED_MODEL_REVISION = "SciX-Categorizer" CLASSIFICATION_PRETRAINED_MODEL_TOKENIZER = "adsabs/ASTROBERT" CLASSIFICATION_THRESHOLDS = [0.06, 0.03, 0.04, 0.02, 0.99, 0.02, 0.02, 0.99] -ADDITIONAL_EARTH_SCIENCE_PROCESSING = True +ADDITIONAL_EARTH_SCIENCE_PROCESSING = False ADDITIONAL_EARTH_SCIENCE_PROCESSING_THRESHOLD = 0.015 diff --git a/harvest_solr.py b/harvest_solr.py index b448a4c..5fbada4 100644 --- a/harvest_solr.py +++ b/harvest_solr.py @@ -105,20 +105,11 @@ def transform_r_json(r_json): Extract the needed information from the json response from the solr query. """ - # Bibcoded and titles are always present - bibcodes = [doc['bibcode'] for doc in r_json['response']['docs']] - titles = [doc['title'][0] for doc in r_json['response']['docs']] # without [0] it returns a list - # Abstracts are not always present - abstracts = [] + record_list = [] for doc in r_json['response']['docs']: - if 'abstract' in doc: - abstracts.append(doc['abstract']) - else: - abstracts.append('') - - record_list = [{'bibcode': bibcodes[i], - 'title' : titles[i], - 'abstract' : abstracts[i], - 'text': f'{titles[i]} {abstracts[i]}'} for i in range(len(bibcodes))] + if ('title' in doc) or ('abstract' in doc): + doc['text'] = f"{doc['title']} {doc['abstract']}" + record_list.append(doc) + return record_list diff --git a/quick_classifier.py b/quick_classifier.py index bc82f20..559e09a 100644 --- a/quick_classifier.py +++ b/quick_classifier.py @@ -13,6 +13,8 @@ import os import csv import argparse +import sys +# sys.path.append("/Users/thomasallen/Code/ADS/classifier_script/venv/lib/python3.11/site-packages") from transformers import AutoTokenizer, AutoModelForSequenceClassification @@ -85,6 +87,8 @@ def write_batch_to_tsv(batch, header, filename, mode='w', include_header=True): # Harvest Title and Abstract from Solr bibcode_batch = bibcodes[output_idx:output_idx+output_batch] records = harvest_solr(bibcode_batch, start_index=0, fields='bibcode, title, abstract') + if len(records) == 0: + sys.exit('No records returned from harvesting Solr - exiting') for index, record in enumerate(records): record = score_record(record)