forked from opensourceware/Neural-ParsCit
-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #7 from kylase/feature/api
REST API for ParsCit
- Loading branch information
Showing
21 changed files
with
369 additions
and
190 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import os | ||
import logging | ||
from flask import Flask, Blueprint, jsonify, g | ||
from flask_restful_swagger_2 import Api, get_swagger_blueprint | ||
from flask_swagger_ui import get_swaggerui_blueprint | ||
from app.resources.parscit import Parse, ParseBatch | ||
from utils import get_model | ||
|
||
|
||
def create_app(config): | ||
""" | ||
Wrapper function for Flask app | ||
params: | ||
config: Config | ||
""" | ||
app = Flask(__name__) | ||
app.config.from_object(config) | ||
|
||
model_path = os.path.abspath(os.getenv('MODEL_PATH', | ||
default='models/neuralParscit/')) | ||
word_emb_path = os.path.abspath(os.getenv('WORD_EMB_PATH', | ||
default='vectors_with_unk.kv')) | ||
|
||
with app.app_context(): | ||
logging.info("Loading model from {} and using word embeddings from {}".format(model_path, word_emb_path)) | ||
model, inference = get_model(model_path, word_emb_path) | ||
setattr(app, 'model', model) | ||
setattr(app, 'inference', inference) | ||
setattr(app, 'word_to_id', {v:i for i, v in model.id_to_word.items()}) | ||
setattr(app, 'char_to_id', {v:i for i, v in model.id_to_char.items()}) | ||
|
||
API_DOC_PATH = '/docs' | ||
SWAGGER_PATH = '/swagger' | ||
|
||
api_bp = Blueprint('api', __name__) | ||
api = Api(api_bp, add_api_spec_resource=False) | ||
api.add_resource(Parse, '/parscit/parse') | ||
api.add_resource(ParseBatch, '/parscit/parse/batch') | ||
|
||
docs = [api.get_swagger_doc()] | ||
|
||
swagger_ui_blueprint = get_swaggerui_blueprint( | ||
API_DOC_PATH, | ||
SWAGGER_PATH + '.json', | ||
config={ | ||
'app_name': 'ParsCit API' | ||
} | ||
) | ||
|
||
app.register_blueprint(api.blueprint) | ||
app.register_blueprint(get_swagger_blueprint(docs, SWAGGER_PATH, | ||
title='ParsCit API', | ||
api_version='1.0', | ||
base_path='/')) | ||
app.register_blueprint(swagger_ui_blueprint, url_prefix=API_DOC_PATH) | ||
|
||
@app.errorhandler(404) | ||
def not_found(error): | ||
""" | ||
Handles URLs that are not specified | ||
""" | ||
return jsonify({ | ||
'message': "API doesn't exist" | ||
}), 404 | ||
|
||
return app |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from __future__ import print_function | ||
import numpy as np | ||
from flask import abort, current_app, g | ||
from flask_restful import reqparse | ||
from flask_restful_swagger_2 import swagger, Resource | ||
from app.resources.schemas import Entity, ParseResponse, ParseBatchResponse | ||
from app.utils import get_model | ||
from utils import create_input | ||
from loader import prepare_dataset | ||
|
||
class Parse(Resource): | ||
""" | ||
""" | ||
parser = reqparse.RequestParser() | ||
parser.add_argument('string', type=unicode, trim=True, required=True, location='json') | ||
@swagger.doc({ | ||
'description': 'Parse a single string and return the associated entity for each token in the string.', | ||
'reqparser': { | ||
'name': 'Single Submission Request', | ||
'parser': parser | ||
}, | ||
'responses': { | ||
'200': { | ||
'description': 'Successfully parsed provided string.', | ||
'schema': ParseResponse | ||
} | ||
} | ||
}) | ||
|
||
def post(self): | ||
""" | ||
Parse a single string and return the associated entity for each token in the string. | ||
""" | ||
args = self.parser.parse_args() | ||
ref_string = args.get('string') | ||
if ref_string is None or ref_string == "": | ||
# Hackish way as reqparse can't catch empty string | ||
abort(400, description='string is empty or not provided.') | ||
|
||
tokens = ref_string.split(" ") | ||
|
||
data = prepare_dataset([[[token] for token in tokens]], | ||
current_app.word_to_id, | ||
current_app.char_to_id, | ||
current_app.model.parameters['lower'], | ||
True) | ||
|
||
model_inputs = create_input(data[0], current_app.model.parameters, False) | ||
y_pred = np.array(current_app.inference[1](*model_inputs))[1:-1] | ||
tags = [current_app.model.id_to_tag[y_pred[i]] for i in range(len(y_pred))] | ||
|
||
response = ParseResponse(reference_string=ref_string, | ||
data=[Entity(term=term, entity=entity) | ||
for term, entity in zip(tokens, tags)]) | ||
return response | ||
|
||
class ParseBatch(Resource): | ||
parser = reqparse.RequestParser() | ||
parser.add_argument('strings', type=unicode, action='append', required=True, location='json') | ||
@swagger.doc({ | ||
'description': 'Parse multiple string and return the associated entity for each token in each string.', | ||
'reqparser': { | ||
'name': 'Mutliple Submission Request', | ||
'parser': parser | ||
}, | ||
'responses': { | ||
'200': { | ||
'description': 'Successfully parsed provided strings.', | ||
'schema': ParseBatchResponse | ||
} | ||
} | ||
}) | ||
def post(self): | ||
""" | ||
Parse multiple string and return the associated entity for each token in each string. | ||
""" | ||
args = self.parser.parse_args() | ||
ref_strings = args.get('strings') | ||
|
||
tokens = [[[token] for token in ref_string.split(" ")] for ref_string in ref_strings] | ||
data = prepare_dataset(tokens, | ||
current_app.word_to_id, | ||
current_app.char_to_id, | ||
current_app.model.parameters['lower'], | ||
True) | ||
|
||
tagged = [] | ||
|
||
for index, datum in enumerate(data): | ||
model_inputs = create_input(datum, current_app.model.parameters, False) | ||
y_pred = np.array(current_app.inference[1](*model_inputs))[1:-1] | ||
tags = [current_app.model.id_to_tag[y_pred[i]] for i in range(len(y_pred))] | ||
|
||
tagged.append([Entity(term=term, entity=entity) | ||
for term, entity in zip(ref_strings[index].split(" "), tags)]) | ||
|
||
response = ParseBatchResponse(reference_strings=ref_strings, | ||
data=tagged) | ||
return response |
Oops, something went wrong.