-
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ Add dependencies with yield (used as context managers) (#595)
* ➕ Add development/testing dependencies for Python 3.6 * ✨ Add concurrency submodule with contextmanager_in_threadpool * ✨ Add AsyncExitStack to ASGI scope in FastAPI app call * ✨ Use async stack for contextmanager-able dependencies including running in threadpool sync dependencies * ✅ Add tests for contextmanager dependencies including internal raise checks when exceptions should be handled and when not * ✅ Add test for fake asynccontextmanager raiser * 🐛 Fix mypy errors and coverage * 🔇 Remove development logs and prints * ✅ Add tests for sub-contextmanagers, background tasks, and sync functions * 🐛 Fix mypy errors for Python 3.7 * 💬 Fix error texts for clarity * 📝 Add docs for dependencies with yield * ✨ Update SQL with SQLAlchemy tutorial to use dependencies with yield and add an alternative with a middleware (from the old tutorial) * ✅ Update SQL tests to remove DB file during the same tests * ✅ Add tests for example with middleware as a copy from the tests with dependencies with yield, removing the DB in the tests * ✏️ Fix typos with suggestions from code review Co-Authored-By: dmontagu <[email protected]>
- Loading branch information
Showing
19 changed files
with
1,238 additions
and
88 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,6 @@ | ||
from fastapi import Depends, FastAPI | ||
|
||
app = FastAPI() | ||
|
||
|
||
class FixedContentQueryChecker: | ||
def __init__(self, fixed_content: str): | ||
self.fixed_content = fixed_content | ||
|
||
def __call__(self, q: str = ""): | ||
if q: | ||
return self.fixed_content in q | ||
return False | ||
|
||
|
||
checker = FixedContentQueryChecker("bar") | ||
|
||
|
||
@app.get("/query-checker/") | ||
async def read_query_check(fixed_content_included: bool = Depends(checker)): | ||
return {"fixed_content_in_query": fixed_content_included} | ||
async def get_db(): | ||
db = DBSession() | ||
try: | ||
yield db | ||
finally: | ||
db.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from fastapi import Depends | ||
|
||
|
||
async def dependency_a(): | ||
dep_a = generate_dep_a() | ||
try: | ||
yield dep_a | ||
finally: | ||
dep_a.close() | ||
|
||
|
||
async def dependency_b(dep_a=Depends(dependency_a)): | ||
dep_b = generate_dep_b() | ||
try: | ||
yield dep_b | ||
finally: | ||
dep_b.close(dep_a) | ||
|
||
|
||
async def dependency_c(dep_b=Depends(dependency_b)): | ||
dep_c = generate_dep_c() | ||
try: | ||
yield dep_c | ||
finally: | ||
dep_c.close(dep_b) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from fastapi import Depends | ||
|
||
|
||
async def dependency_a(): | ||
dep_a = generate_dep_a() | ||
try: | ||
yield dep_a | ||
finally: | ||
dep_a.close() | ||
|
||
|
||
async def dependency_b(dep_a=Depends(dependency_a)): | ||
dep_b = generate_dep_b() | ||
try: | ||
yield dep_b | ||
finally: | ||
dep_b.close(dep_a) | ||
|
||
|
||
async def dependency_c(dep_b=Depends(dependency_b)): | ||
dep_c = generate_dep_c() | ||
try: | ||
yield dep_c | ||
finally: | ||
dep_c.close(dep_b) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
class MySuperContextManager: | ||
def __init__(self): | ||
self.db = DBSession() | ||
|
||
def __enter__(self): | ||
return self.db | ||
|
||
def __exit__(self, exc_type, exc_value, traceback): | ||
self.db.close() | ||
|
||
|
||
async def get_db(): | ||
with MySuperContextManager() as db: | ||
yield db |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from fastapi import Depends, FastAPI | ||
|
||
app = FastAPI() | ||
|
||
|
||
class FixedContentQueryChecker: | ||
def __init__(self, fixed_content: str): | ||
self.fixed_content = fixed_content | ||
|
||
def __call__(self, q: str = ""): | ||
if q: | ||
return self.fixed_content in q | ||
return False | ||
|
||
|
||
checker = FixedContentQueryChecker("bar") | ||
|
||
|
||
@app.get("/query-checker/") | ||
async def read_query_check(fixed_content_included: bool = Depends(checker)): | ||
return {"fixed_content_in_query": fixed_content_included} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from typing import List | ||
|
||
from fastapi import Depends, FastAPI, HTTPException | ||
from sqlalchemy.orm import Session | ||
from starlette.requests import Request | ||
from starlette.responses import Response | ||
|
||
from . import crud, models, schemas | ||
from .database import SessionLocal, engine | ||
|
||
models.Base.metadata.create_all(bind=engine) | ||
|
||
app = FastAPI() | ||
|
||
|
||
@app.middleware("http") | ||
async def db_session_middleware(request: Request, call_next): | ||
response = Response("Internal server error", status_code=500) | ||
try: | ||
request.state.db = SessionLocal() | ||
response = await call_next(request) | ||
finally: | ||
request.state.db.close() | ||
return response | ||
|
||
|
||
# Dependency | ||
def get_db(request: Request): | ||
return request.state.db | ||
|
||
|
||
@app.post("/users/", response_model=schemas.User) | ||
def create_user(user: schemas.UserCreate, db: Session = Depends(get_db)): | ||
db_user = crud.get_user_by_email(db, email=user.email) | ||
if db_user: | ||
raise HTTPException(status_code=400, detail="Email already registered") | ||
return crud.create_user(db=db, user=user) | ||
|
||
|
||
@app.get("/users/", response_model=List[schemas.User]) | ||
def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): | ||
users = crud.get_users(db, skip=skip, limit=limit) | ||
return users | ||
|
||
|
||
@app.get("/users/{user_id}", response_model=schemas.User) | ||
def read_user(user_id: int, db: Session = Depends(get_db)): | ||
db_user = crud.get_user(db, user_id=user_id) | ||
if db_user is None: | ||
raise HTTPException(status_code=404, detail="User not found") | ||
return db_user | ||
|
||
|
||
@app.post("/users/{user_id}/items/", response_model=schemas.Item) | ||
def create_item_for_user( | ||
user_id: int, item: schemas.ItemCreate, db: Session = Depends(get_db) | ||
): | ||
return crud.create_user_item(db=db, item=item, user_id=user_id) | ||
|
||
|
||
@app.get("/items/", response_model=List[schemas.Item]) | ||
def read_items(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): | ||
items = crud.get_items(db, skip=skip, limit=limit) | ||
return items |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.