Skip to content

Commit

Permalink
Merge pull request #2 from Thomas-S-Allen/master
Browse files Browse the repository at this point in the history
Fixed missing title bug
  • Loading branch information
Thomas-S-Allen authored Oct 30, 2024
2 parents 13b5fe7 + 314ef05 commit 296ca77
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 16 deletions.
4 changes: 2 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
API_TOKEN = ""


CLASSIFICATION_PRETRAINED_MODEL = ""
CLASSIFICATION_PRETRAINED_MODEL = "adsabs/ASTROBERT"
CLASSIFICATION_PRETRAINED_MODEL_REVISION = "SciX-Categorizer"
CLASSIFICATION_PRETRAINED_MODEL_TOKENIZER = "adsabs/ASTROBERT"

CLASSIFICATION_THRESHOLDS = [0.06, 0.03, 0.04, 0.02, 0.99, 0.02, 0.02, 0.99]
ADDITIONAL_EARTH_SCIENCE_PROCESSING = True
ADDITIONAL_EARTH_SCIENCE_PROCESSING = False
ADDITIONAL_EARTH_SCIENCE_PROCESSING_THRESHOLD = 0.015
19 changes: 5 additions & 14 deletions harvest_solr.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,20 +105,11 @@ def transform_r_json(r_json):
Extract the needed information from the json response from the solr query.
"""

# Bibcoded and titles are always present
bibcodes = [doc['bibcode'] for doc in r_json['response']['docs']]
titles = [doc['title'][0] for doc in r_json['response']['docs']] # without [0] it returns a list
# Abstracts are not always present
abstracts = []
record_list = []
for doc in r_json['response']['docs']:
if 'abstract' in doc:
abstracts.append(doc['abstract'])
else:
abstracts.append('')

record_list = [{'bibcode': bibcodes[i],
'title' : titles[i],
'abstract' : abstracts[i],
'text': f'{titles[i]} {abstracts[i]}'} for i in range(len(bibcodes))]
if ('title' in doc) or ('abstract' in doc):
doc['text'] = f"{doc['title']} {doc['abstract']}"
record_list.append(doc)


return record_list
4 changes: 4 additions & 0 deletions quick_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import os
import csv
import argparse
import sys
# sys.path.append("/Users/thomasallen/Code/ADS/classifier_script/venv/lib/python3.11/site-packages")

from transformers import AutoTokenizer, AutoModelForSequenceClassification

Expand Down Expand Up @@ -85,6 +87,8 @@ def write_batch_to_tsv(batch, header, filename, mode='w', include_header=True):
# Harvest Title and Abstract from Solr
bibcode_batch = bibcodes[output_idx:output_idx+output_batch]
records = harvest_solr(bibcode_batch, start_index=0, fields='bibcode, title, abstract')
if len(records) == 0:
sys.exit('No records returned from harvesting Solr - exiting')

for index, record in enumerate(records):
record = score_record(record)
Expand Down

0 comments on commit 296ca77

Please sign in to comment.