-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
executable file
·215 lines (172 loc) · 6.51 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
#!/usr/bin/env python3
import aiosmtplib
import asyncio
import os
import typing as ty
import uvicorn
from dotenv import load_dotenv
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from fastapi import (
BackgroundTasks,
FastAPI,
File,
HTTPException,
Form,
Request,
UploadFile,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse
from logger import logger
from rfam_batch import job_dispatcher as jd
from rfam_batch import api
app = FastAPI(docs_url="/docs")
origins = [
"http://localhost:3000",
"http://127.0.0.1:3000",
"https://rfam.org",
"https://preview.rfam.org",
"https://rfam.xfam.org",
"https://batch.rfam.org",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load environment variables from .env file
load_dotenv()
@app.on_event("shutdown")
async def on_shutdown() -> None:
await jd.JobDispatcher.shutdown()
@app.on_event("startup")
async def on_startup() -> None:
await jd.JobDispatcher.startup()
@app.get("/result/{job_id}")
async def get_result(job_id: str) -> api.CmScanResult | api.MultipleSequences:
try:
out = await jd.JobDispatcher().cmscan_result(job_id)
sequence = await jd.JobDispatcher().cmscan_sequence(job_id)
tblout = await jd.JobDispatcher().cmscan_tblout(job_id)
cm_scan_result = api.parse_cm_scan_result(out, sequence, tblout, job_id)
except HTTPException as e:
logger.error(f"Error fetching results for {job_id}. Error: {e}")
raise e
return cm_scan_result
@app.get("/result/{job_id}/tblout", response_class=PlainTextResponse)
async def get_tblout(job_id: str) -> PlainTextResponse:
try:
tblout = await jd.JobDispatcher().cmscan_tblout(job_id)
except HTTPException as e:
logger.error(f"Error fetching TBLOUT results for {job_id}. Error: {e}")
raise e
# Create a PlainTextResponse with CORS headers
response = PlainTextResponse(content=tblout)
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type"
return response
@app.get("/result/{job_id}/out", response_class=PlainTextResponse)
async def get_out(job_id: str) -> PlainTextResponse:
try:
out = await jd.JobDispatcher().cmscan_result(job_id)
except HTTPException as e:
logger.error(f"Error fetching OUT results for {job_id}. Error: {e}")
raise e
# Create a PlainTextResponse with CORS headers
response = PlainTextResponse(content=out)
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type"
return response
@app.get("/status/{job_id}", response_class=PlainTextResponse)
async def fetch_status(job_id: str) -> PlainTextResponse:
try:
status = await jd.JobDispatcher().cmscan_status(job_id)
except HTTPException as e:
logger.error(f"Error fetching job status for {job_id}. Error: {e}")
raise e
# Create a PlainTextResponse with CORS headers
response = PlainTextResponse(content=status)
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type"
return response
async def send_email(email_address: str, job_id: str, status: str, tblout: str):
sender_email = os.getenv("EMAIL")
server = os.getenv("SERVER")
port = os.getenv("PORT")
msg = MIMEMultipart()
msg["From"] = sender_email
msg["To"] = email_address
if status == "FINISHED":
msg["Subject"] = f"Results for batch search job {job_id}"
msg.attach(MIMEText(tblout, "plain"))
else:
msg["Subject"] = f"Error in batch search job {job_id}"
body = (
"There was a problem while running the search. Please try "
"again or send us the job id if the problem persists."
)
msg.attach(MIMEText(body, "plain"))
async with aiosmtplib.SMTP(hostname=server, port=port) as smtp:
await smtp.send_message(msg)
async def check_status(email_address: str, job_id: str):
while True:
# This function will run as long as the status is 'RUNNING' or 'QUEUED'
status = await jd.JobDispatcher().cmscan_status(job_id)
if status == "FINISHED":
tblout = await jd.JobDispatcher().cmscan_tblout(job_id)
await send_email(email_address, job_id, status, tblout)
break
elif status == "NOT_FOUND":
# I'm assuming we will never see this status after a POST
break
elif status == "FAILURE" or status == "ERROR":
await send_email(email_address, job_id, status, "")
break
await asyncio.sleep(10)
@app.post("/submit-job")
async def submit_job(
*,
email_address: ty.Annotated[ty.Optional[str], Form()] = None,
sequence_file: UploadFile = File(None),
id: ty.Optional[str] = Form(None),
request: Request,
background_tasks: BackgroundTasks,
) -> api.SubmissionResponse:
url = request.url
if sequence_file is None or sequence_file.filename == "":
raise HTTPException(
status_code=400, detail="Please upload a file in FASTA format"
)
# Validate the FASTA file
try:
content = await sequence_file.read()
parsed = api.SubmittedRequest.parse(content.decode())
except ValueError as e:
logger.error(f"Error parsing sequence. Error: {e}")
raise HTTPException(status_code=400, detail=str(e))
query = jd.Query()
query.id = id
query.sequences = "\n".join(parsed.sequences)
query.email_address = email_address if email_address else "[email protected]"
# Submit to Job Dispatcher
job_id = await jd.JobDispatcher().submit_cmscan_job(query)
if email_address:
# Background task to check status
background_tasks.add_task(check_status, email_address, job_id)
logger.info(f"Job submitted: {job_id}")
else:
logger.info(f"Job submitted programmatically: {job_id}")
return api.SubmissionResponse.build(
result_url=f"https://{url.netloc}/result/{job_id}",
job_id=job_id,
)
def main():
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
main()