Skip to content

Commit

Permalink
swtich to fastapi uvicorn
Browse files Browse the repository at this point in the history
Signed-off-by: Steve Taylor <[email protected]>
  • Loading branch information
sbtaylor15 committed Apr 22, 2024
1 parent 37e48e3 commit bf48a7f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 11 deletions.
40 changes: 30 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
import argparse
import os
import re
import json

import joblib
import stanza
from flask import Flask, jsonify, request
import uvicorn
from fastapi import FastAPI, Response
from pydantic import BaseModel # pylint: disable=E0611
from mitreattack.stix20 import MitreAttackData
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
Expand Down Expand Up @@ -49,8 +52,26 @@
vectorizer = TfidfVectorizer()
mitre_data = [] # type: ignore

app = Flask(__name__)

# Init FastAPI
app = FastAPI(
title=__name__,
description="RestAPI endpoint for retrieving Mitre Techiques",
version="10.0.0",
license_info={
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0.html",
},
servers=[{"url": "http://localhost:8080", "description": "Local Server"}],
contact={
"name": "Ortelius Open Source Project",
"url": "https://github.com/ortelius/ortelius/issues",
"email": "[email protected]",
},
debug=True,
)

class CveText(BaseModel):
cvetext: str

def preprocess(text):
text = text.replace("<code>", " ").replace("</code>", " ")
Expand Down Expand Up @@ -129,13 +150,11 @@ def load_mitre(nlp, mitre_data_file):


# Define the Flask route for the /mitre endpoint
@app.route("/msapi/mitre", methods=["POST"])
def mitremap():
# Get JSON payload
data = request.get_json()
@app.post("/msapi/mitre")
def mitremap(data: CveText):

# Extract cvetext from the payload
cvetext = data.get("cvetext", "")
cvetext = data.cvetext
cvedoc = nlp(preprocess(cvetext))
cvedoc_processed = process_document(cvedoc)
cvedoc_words_weight = calculate_capitalized_words_weight(cvedoc)
Expand All @@ -153,7 +172,8 @@ def mitremap():
sorted_dict = dict(sorted(scoring.items(), key=lambda item: item[1], reverse=True)[:2])

sorted_dict = dict(sorted(sorted_dict.items(), key=lambda item: item[1], reverse=True))
return jsonify(sorted_dict)
json_str = json.dumps(sorted_dict, indent=4, default=str)
return Response(content=json_str, media_type='application/json')


if __name__ == "__main__":
Expand All @@ -167,4 +187,4 @@ def mitremap():

# Check if --loaddata flag is provided
if not args.loaddata:
app.run(host="0.0.0.0", port=8080)
uvicorn.run(app, port=8080)
3 changes: 2 additions & 1 deletion requirements.in
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
Flask==3.0.3
fastapi==0.110.2
joblib==1.4.0
mitreattack-python==3.0.3
numpy==1.26.4 # This is a common dependency for scikit-learn
scikit-learn==1.4.2
stanza==1.8.2
uvicorn==0.29.0

0 comments on commit bf48a7f

Please sign in to comment.