diff --git a/docs/src/dependencies/tutorial007.py b/docs/src/dependencies/tutorial007.py index 5d22f6823785c..2e4ab4777b393 100644 --- a/docs/src/dependencies/tutorial007.py +++ b/docs/src/dependencies/tutorial007.py @@ -1,21 +1,6 @@ -from fastapi import Depends, FastAPI - -app = FastAPI() - - -class FixedContentQueryChecker: - def __init__(self, fixed_content: str): - self.fixed_content = fixed_content - - def __call__(self, q: str = ""): - if q: - return self.fixed_content in q - return False - - -checker = FixedContentQueryChecker("bar") - - -@app.get("/query-checker/") -async def read_query_check(fixed_content_included: bool = Depends(checker)): - return {"fixed_content_in_query": fixed_content_included} +async def get_db(): + db = DBSession() + try: + yield db + finally: + db.close() diff --git a/docs/src/dependencies/tutorial008.py b/docs/src/dependencies/tutorial008.py new file mode 100644 index 0000000000000..8472f642de850 --- /dev/null +++ b/docs/src/dependencies/tutorial008.py @@ -0,0 +1,25 @@ +from fastapi import Depends + + +async def dependency_a(): + dep_a = generate_dep_a() + try: + yield dep_a + finally: + dep_a.close() + + +async def dependency_b(dep_a=Depends(dependency_a)): + dep_b = generate_dep_b() + try: + yield dep_b + finally: + dep_b.close(dep_a) + + +async def dependency_c(dep_b=Depends(dependency_b)): + dep_c = generate_dep_c() + try: + yield dep_c + finally: + dep_c.close(dep_b) diff --git a/docs/src/dependencies/tutorial009.py b/docs/src/dependencies/tutorial009.py new file mode 100644 index 0000000000000..8472f642de850 --- /dev/null +++ b/docs/src/dependencies/tutorial009.py @@ -0,0 +1,25 @@ +from fastapi import Depends + + +async def dependency_a(): + dep_a = generate_dep_a() + try: + yield dep_a + finally: + dep_a.close() + + +async def dependency_b(dep_a=Depends(dependency_a)): + dep_b = generate_dep_b() + try: + yield dep_b + finally: + dep_b.close(dep_a) + + +async def dependency_c(dep_b=Depends(dependency_b)): + dep_c = generate_dep_c() + try: + yield dep_c + finally: + dep_c.close(dep_b) diff --git a/docs/src/dependencies/tutorial010.py b/docs/src/dependencies/tutorial010.py new file mode 100644 index 0000000000000..c27f1b1702721 --- /dev/null +++ b/docs/src/dependencies/tutorial010.py @@ -0,0 +1,14 @@ +class MySuperContextManager: + def __init__(self): + self.db = DBSession() + + def __enter__(self): + return self.db + + def __exit__(self, exc_type, exc_value, traceback): + self.db.close() + + +async def get_db(): + with MySuperContextManager() as db: + yield db diff --git a/docs/src/dependencies/tutorial011.py b/docs/src/dependencies/tutorial011.py new file mode 100644 index 0000000000000..5d22f6823785c --- /dev/null +++ b/docs/src/dependencies/tutorial011.py @@ -0,0 +1,21 @@ +from fastapi import Depends, FastAPI + +app = FastAPI() + + +class FixedContentQueryChecker: + def __init__(self, fixed_content: str): + self.fixed_content = fixed_content + + def __call__(self, q: str = ""): + if q: + return self.fixed_content in q + return False + + +checker = FixedContentQueryChecker("bar") + + +@app.get("/query-checker/") +async def read_query_check(fixed_content_included: bool = Depends(checker)): + return {"fixed_content_in_query": fixed_content_included} diff --git a/docs/src/sql_databases/sql_app/alt_main.py b/docs/src/sql_databases/sql_app/alt_main.py new file mode 100644 index 0000000000000..01b71333b408b --- /dev/null +++ b/docs/src/sql_databases/sql_app/alt_main.py @@ -0,0 +1,64 @@ +from typing import List + +from fastapi import Depends, FastAPI, HTTPException +from sqlalchemy.orm import Session +from starlette.requests import Request +from starlette.responses import Response + +from . import crud, models, schemas +from .database import SessionLocal, engine + +models.Base.metadata.create_all(bind=engine) + +app = FastAPI() + + +@app.middleware("http") +async def db_session_middleware(request: Request, call_next): + response = Response("Internal server error", status_code=500) + try: + request.state.db = SessionLocal() + response = await call_next(request) + finally: + request.state.db.close() + return response + + +# Dependency +def get_db(request: Request): + return request.state.db + + +@app.post("/users/", response_model=schemas.User) +def create_user(user: schemas.UserCreate, db: Session = Depends(get_db)): + db_user = crud.get_user_by_email(db, email=user.email) + if db_user: + raise HTTPException(status_code=400, detail="Email already registered") + return crud.create_user(db=db, user=user) + + +@app.get("/users/", response_model=List[schemas.User]) +def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): + users = crud.get_users(db, skip=skip, limit=limit) + return users + + +@app.get("/users/{user_id}", response_model=schemas.User) +def read_user(user_id: int, db: Session = Depends(get_db)): + db_user = crud.get_user(db, user_id=user_id) + if db_user is None: + raise HTTPException(status_code=404, detail="User not found") + return db_user + + +@app.post("/users/{user_id}/items/", response_model=schemas.Item) +def create_item_for_user( + user_id: int, item: schemas.ItemCreate, db: Session = Depends(get_db) +): + return crud.create_user_item(db=db, item=item, user_id=user_id) + + +@app.get("/items/", response_model=List[schemas.Item]) +def read_items(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): + items = crud.get_items(db, skip=skip, limit=limit) + return items diff --git a/docs/src/sql_databases/sql_app/main.py b/docs/src/sql_databases/sql_app/main.py index 01b71333b408b..33f63d332c4d8 100644 --- a/docs/src/sql_databases/sql_app/main.py +++ b/docs/src/sql_databases/sql_app/main.py @@ -2,8 +2,6 @@ from fastapi import Depends, FastAPI, HTTPException from sqlalchemy.orm import Session -from starlette.requests import Request -from starlette.responses import Response from . import crud, models, schemas from .database import SessionLocal, engine @@ -13,20 +11,13 @@ app = FastAPI() -@app.middleware("http") -async def db_session_middleware(request: Request, call_next): - response = Response("Internal server error", status_code=500) +# Dependency +def get_db(): try: - request.state.db = SessionLocal() - response = await call_next(request) + db = SessionLocal() + yield db finally: - request.state.db.close() - return response - - -# Dependency -def get_db(request: Request): - return request.state.db + db.close() @app.post("/users/", response_model=schemas.User) diff --git a/docs/tutorial/dependencies/advanced-dependencies.md b/docs/tutorial/dependencies/advanced-dependencies.md index 903090f238a37..5a664db33c207 100644 --- a/docs/tutorial/dependencies/advanced-dependencies.md +++ b/docs/tutorial/dependencies/advanced-dependencies.md @@ -1,4 +1,4 @@ -!!! danger +!!! warning This is, more or less, an "advanced" chapter. If you are just starting with **FastAPI** you might want to skip this chapter and come back to it later. @@ -22,7 +22,7 @@ Not the class itself (which is already a callable), but an instance of that clas To do that, we declare a method `__call__`: ```Python hl_lines="10" -{!./src/dependencies/tutorial007.py!} +{!./src/dependencies/tutorial011.py!} ``` In this case, this `__call__` is what **FastAPI** will use to check for additional parameters and sub-dependencies, and this is what will be called to pass a value to the parameter in your *path operation function* later. @@ -32,7 +32,7 @@ In this case, this `__call__` is what **FastAPI** will use to check for addition And now, we can use `__init__` to declare the parameters of the instance that we can use to "parameterize" the dependency: ```Python hl_lines="7" -{!./src/dependencies/tutorial007.py!} +{!./src/dependencies/tutorial011.py!} ``` In this case, **FastAPI** won't ever touch or care about `__init__`, we will use it directly in our code. @@ -42,7 +42,7 @@ In this case, **FastAPI** won't ever touch or care about `__init__`, we will use We could create an instance of this class with: ```Python hl_lines="16" -{!./src/dependencies/tutorial007.py!} +{!./src/dependencies/tutorial011.py!} ``` And that way we are able to "parameterize" our dependency, that now has `"bar"` inside of it, as the attribute `checker.fixed_content`. @@ -60,7 +60,7 @@ checker(q="somequery") ...and pass whatever that returns as the value of the dependency in our path operation function as the parameter `fixed_content_included`: ```Python hl_lines="20" -{!./src/dependencies/tutorial007.py!} +{!./src/dependencies/tutorial011.py!} ``` !!! tip diff --git a/docs/tutorial/dependencies/dependencies-with-yield.md b/docs/tutorial/dependencies/dependencies-with-yield.md new file mode 100644 index 0000000000000..687ecd39de180 --- /dev/null +++ b/docs/tutorial/dependencies/dependencies-with-yield.md @@ -0,0 +1,153 @@ +# Dependencies with `yield` + +FastAPI supports dependencies that do some extra steps after finishing. + +To do this, use `yield` instead of `return`, and write the extra steps after. + +!!! tip + Make sure to use `yield` one single time. + +!!! info + For this to work, you need to use **Python 3.7** or above, or in **Python 3.6**, install the "backports": + + ```bash + pip install async-exit-stack async-generator + ``` + + This installs async-exit-stack and async-generator. + +!!! note "Technical Details" + Any function that is valid to use with: + + * `@contextlib.contextmanager` or + * `@contextlib.asynccontextmanager` + + would be valid to use as a **FastAPI** dependency. + + In fact, FastAPI uses those two decorators internally. + +## A database dependency with `yield` + +For example, you could use this to create a database session and close it after finishing. + +Only the code prior to and including the `yield` statement is executed before sending a response: + +```Python hl_lines="2 3 4" +{!./src/dependencies/tutorial007.py!} +``` + +The yielded value is what is injected into *path operations* and other dependencies: + +```Python hl_lines="4" +{!./src/dependencies/tutorial007.py!} +``` + +The code following the `yield` statement is executed after the response has been delivered: + +```Python hl_lines="5 6" +{!./src/dependencies/tutorial007.py!} +``` + +!!! tip + You can use `async` or normal functions. + + **FastAPI** will do the right thing with each, the same as with normal dependencies. + +## A dependency with `yield` and `try` + +If you use a `try` block in a dependency with `yield`, you'll receive any exception that was thrown when using the dependency. + +For example, if some code at some point in the middle, in another dependency or in a *path operation*, made a database transaction "rollback" or create any other error, you will receive the exception in your dependency. + +So, you can look for that specific exception inside the dependency with `except SomeException`. + +In the same way, you can use `finally` to make sure the exit steps are executed, no matter if there was an exception or not. + +```Python hl_lines="3 5" +{!./src/dependencies/tutorial007.py!} +``` + +## Sub-dependencies with `yield` + +You can have sub-dependencies and "trees" of sub-dependencies of any size and shape, and any or all of them can use `yield`. + +**FastAPI** will make sure that the "exit code" in each dependency with `yield` is run in the correct order. + +For example, `dependency_c` can have a dependency on `dependency_b`, and `dependency_b` on `dependency_a`: + +```Python hl_lines="4 12 20" +{!./src/dependencies/tutorial008.py!} +``` + +And all of them can use `yield`. + +In this case `dependency_c`, to execute its exit code, needs the value from `dependency_b` (here named `dep_b`) to still be available. + +And, in turn, `dependency_b` needs the value from `dependency_a` (here named `dep_a`) to be available for its exit code. + +```Python hl_lines="16 17 24 25" +{!./src/dependencies/tutorial008.py!} +``` + +The same way, you could have dependencies with `yield` and `return` mixed. + +And you could have a single dependency that requires several other dependencies with `yield`, etc. + +You can have any combinations of dependencies that you want. + +**FastAPI** will make sure everything is run in the correct order. + +!!! note "Technical Details" + This works thanks to Python's Context Managers. + + **FastAPI** uses them internally to achieve this. + +## Context Managers + +### What are "Context Managers" + +"Context Managers" are any of those Python objects that you can use in a `with` statement. + +For example, you can use `with` to read a file: + +```Python +with open("./somefile.txt") as f: + contents = f.read() + print(contents) +``` + +Underneath, the `open("./somefile.txt")` returns an object that is a called a "Context Manager". + +When the `with` block finishes, it makes sure to close the file, even if there were exceptions. + +When you create a dependency with `yield`, **FastAPI** will internally convert it to a context manager, and combine it with some other related tools. + +### Using context managers in dependencies with `yield` + +!!! warning + This is, more or less, an "advanced" idea. + + If you are just starting with **FastAPI** you might want to skip it for now. + +In Python, you can create context managers by creating a class with two methods: `__enter__()` and `__exit__()`. + +You can also use them with **FastAPI** dependencies with `yield` by using +`with` or `async with` statements inside of the dependency function: + +```Python hl_lines="1 2 3 4 5 6 7 8 9 13" +{!./src/dependencies/tutorial010.py!} +``` + +!!! tip + Another way to create a context manager is with: + + * `@contextlib.contextmanager` or + * `@contextlib.asynccontextmanager` + + using them to decorate a function with a single `yield`. + + That's what **FastAPI** uses internally for dependencies with `yield`. + + But you don't have to use the decorators for FastAPI dependencies (and you shouldn't). + + FastAPI will do it for you internally. diff --git a/docs/tutorial/sql-databases.md b/docs/tutorial/sql-databases.md index 7d10888c4184c..edaf67385d1e0 100644 --- a/docs/tutorial/sql-databases.md +++ b/docs/tutorial/sql-databases.md @@ -427,21 +427,30 @@ And you would also use Alembic for "migrations" (that's its main job). A "migration" is the set of steps needed whenever you change the structure of your SQLAlchemy models, add a new attribute, etc. to replicate those changes in the database, add a new column, a new table, etc. -### Create a middleware to handle sessions +### Create a dependency + +!!! info + For this to work, you need to use **Python 3.7** or above, or in **Python 3.6**, install the "backports": + + ```bash + pip install async-exit-stack async-generator + ``` + + This installs async-exit-stack and async-generator. + + You can also use the alternative method with a "middleware" explained at the end. -Now use the `SessionLocal` class we created in the `sql_app/databases.py` file. +Now use the `SessionLocal` class we created in the `sql_app/databases.py` file to create a dependency. We need to have an independent database session/connection (`SessionLocal`) per request, use the same session through all the request and then close it after the request is finished. And then a new session will be created for the next request. -For that, we will create a new middleware. +For that, we will create a new dependency with `yield`, as explained before in the section about Dependencies with `yield`. -A "middleware" is a function that is always executed for each request, and have code before and after the request. +Our dependency will create a new SQLAlchemy `SessionLocal` that will be used in a single request, and then close it once the request is finished. -This middleware (just a function) will create a new SQLAlchemy `SessionLocal` for each request, add it to the request and then close it once the request is finished. - -```Python hl_lines="16 17 18 19 20 21 22 23 24" +```Python hl_lines="15 16 17 18 19 20" {!./src/sql_databases/sql_app/main.py!} ``` @@ -452,21 +461,11 @@ This middleware (just a function) will create a new SQLAlchemy `SessionLocal` fo This way we make sure the database session is always closed after the request. Even if there was an exception while processing the request. -#### About `request.state` - -`request.state` is a property of each Starlette `Request` object, it is there to store arbitrary objects attached to the request itself, like the database session in this case. - -For us in this case, it helps us ensuring a single database session is used through all the request, and then closed afterwards (in the middleware). - -### Create a dependency +And then, when using the dependency in a *path operation function*, we declare it with the type `Session` we imported directly from SQLAlchemy. -To simplify the code, reduce repetition and get better editor support, we will create a dependency that returns this same database session from the request. +This will then give us better editor support inside the *path operation function*, because the editor will know that the `db` parameter is of type `Session`: -And when using the dependency in a path operation function, we declare it with the type `Session` we imported directly from SQLAlchemy. - -This will then give us better editor support inside the path operation function, because the editor will know that the `db` parameter is of type `Session`. - -```Python hl_lines="28 29" +```Python hl_lines="24 32 38 47 53" {!./src/sql_databases/sql_app/main.py!} ``` @@ -479,22 +478,16 @@ This will then give us better editor support inside the path operation function, Now, finally, here's the standard **FastAPI** *path operations* code. -```Python hl_lines="32 33 34 35 36 37 40 41 42 43 46 47 48 49 50 51 54 55 56 57 58 61 62 63 64 65" +```Python hl_lines="23 24 25 26 27 28 31 32 33 34 37 38 39 40 41 42 45 46 47 48 49 52 53 54 55" {!./src/sql_databases/sql_app/main.py!} ``` -We are creating the database session before each request, attaching it to the request, and then closing it afterwards. - -All of this is done in the middleware explained above. +We are creating the database session before each request in the dependency with `yield`, and then closing it afterwards. -Then, in the dependency `get_db()` we are extracting the database session from the request. - -And then we can create the dependency in the path operation function, to get that session directly. +And then we can create the required dependency in the path operation function, to get that session directly. With that, we can just call `crud.get_user` directly from inside of the path operation function and use that session. -Having this 3-step process (middleware, dependency, path operation) you get better support/checks/completion in all the path operation functions while reducing code repetition. - !!! tip Notice that the values you return are SQLAlchemy models, or lists of SQLAlchemy models. @@ -507,7 +500,7 @@ Having this 3-step process (middleware, dependency, path operation) you get bett ### About `def` vs `async def` -Here we are using SQLAlchemy code inside of the path operation function, and, in turn, it will go and communicate with an external database. +Here we are using SQLAlchemy code inside of the path operation function and in the dependency, and, in turn, it will go and communicate with an external database. That could potentially require some "waiting". @@ -523,7 +516,7 @@ user = await db.query(User).first() user = db.query(User).first() ``` -Then we should declare the path operation without `async def`, just with a normal `def`, as: +Then we should declare the *path operation functions* and the dependency without `async def`, just with a normal `def`, as: ```Python hl_lines="2" @app.get("/users/{user_id}", response_model=schemas.User) @@ -548,8 +541,8 @@ For example, in a background task worker with You can also use an online SQLite browser like SQLite Viewer or ExtendsClass. + +## Alternative DB session with middleware + +If you can't use dependencies with `yield` -- for example, if you are not using **Python 3.7** and can't install the "backports" mentioned above for **Python 3.6** -- you can set up the session in a "middleware" in a similar way. + +A "middleware" is basically a function that is always executed for each request, with some code executed before, and some code executed after the endpoint function. + +### Create a middleware + +The middleware we'll add (just a function) will create a new SQLAlchemy `SessionLocal` for each request, add it to the request and then close it once the request is finished. + +```Python hl_lines="16 17 18 19 20 21 22 23 24" +{!./src/sql_databases/sql_app/alt_main.py!} +``` + +!!! info + We put the creation of the `SessionLocal()` and handling of the requests in a `try` block. + + And then we close it in the `finally` block. + + This way we make sure the database session is always closed after the request. Even if there was an exception while processing the request. + +### About `request.state` + +`request.state` is a property of each Starlette `Request` object. It is there to store arbitrary objects attached to the request itself, like the database session in this case. + +For us in this case, it helps us ensure a single database session is used through all the request, and then closed afterwards (in the middleware). + +### Dependencies with `yield` or middleware + +Adding a **middleware** here is similar to what a dependency with `yield` does, with some differences: + +* It requires more code and is a bit more complex. +* The middleware has to be an `async` function. + * If there is code in it that has to "wait" for the network, it could "block" your application there and degrade performance a bit. + * Although it's probably not very problematic here with the way `SQLAlchemy` works. + * But if you added more code to the middleware that had a lot of I/O waiting, it could then be problematic. +* A middleware is run for *every* request. + * So, a connection will be created for every request. + * Even when the *path operation* that handles that request didn't need the DB. + +!!! tip + It's probably better to use dependencies with `yield` when they are enough for the use case. + +!!! info + Dependencies with `yield` were added recently to **FastAPI**. + + A previous version of this tutorial only had the examples with a middleware and there are probably several applications using the middleware for database session management. diff --git a/fastapi/applications.py b/fastapi/applications.py index 811b90d4ad83e..000026b8384fc 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -1,6 +1,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union from fastapi import routing +from fastapi.concurrency import AsyncExitStack from fastapi.encoders import DictIntStrAny, SetIntStr from fastapi.exception_handlers import ( http_exception_handler, @@ -21,6 +22,7 @@ from starlette.requests import Request from starlette.responses import HTMLResponse, JSONResponse, Response from starlette.routing import BaseRoute +from starlette.types import Receive, Scope, Send class FastAPI(Starlette): @@ -130,6 +132,14 @@ async def redoc_html(req: Request) -> HTMLResponse: RequestValidationError, request_validation_exception_handler ) + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if AsyncExitStack: + async with AsyncExitStack() as stack: + scope["fastapi_astack"] = stack + await super().__call__(scope, receive, send) + else: + await super().__call__(scope, receive, send) # pragma: no cover + def add_api_route( self, path: str, diff --git a/fastapi/concurrency.py b/fastapi/concurrency.py new file mode 100644 index 0000000000000..7006c1ad0c2ff --- /dev/null +++ b/fastapi/concurrency.py @@ -0,0 +1,45 @@ +from typing import Any, Callable + +from starlette.concurrency import iterate_in_threadpool, run_in_threadpool # noqa + +asynccontextmanager_error_message = """ +FastAPI's contextmanager_in_threadpool require Python 3.7 or above, +or the backport for Python 3.6, installed with: + pip install async-generator +""" + + +def _fake_asynccontextmanager(func: Callable) -> Callable: + def raiser(*args: Any, **kwargs: Any) -> Any: + raise RuntimeError(asynccontextmanager_error_message) + + return raiser + + +try: + from contextlib import asynccontextmanager # type: ignore +except ImportError: + try: + from async_generator import asynccontextmanager # type: ignore + except ImportError: # pragma: no cover + asynccontextmanager = _fake_asynccontextmanager + +try: + from contextlib import AsyncExitStack # type: ignore +except ImportError: + try: + from async_exit_stack import AsyncExitStack # type: ignore + except ImportError: # pragma: no cover + AsyncExitStack = None # type: ignore + + +@asynccontextmanager +async def contextmanager_in_threadpool(cm: Any) -> Any: + try: + yield await run_in_threadpool(cm.__enter__) + except Exception as e: + ok = await run_in_threadpool(cm.__exit__, type(e), e, None) + if not ok: + raise e + else: + await run_in_threadpool(cm.__exit__, None, None, None) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index d5d1145653d1a..437b3184af68b 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -1,5 +1,6 @@ import asyncio import inspect +from contextlib import contextmanager from copy import deepcopy from typing import ( Any, @@ -16,6 +17,12 @@ ) from fastapi import params +from fastapi.concurrency import ( + AsyncExitStack, + _fake_asynccontextmanager, + asynccontextmanager, + contextmanager_in_threadpool, +) from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.security.base import SecurityBase from fastapi.security.oauth2 import OAuth2, SecurityScopes @@ -195,6 +202,18 @@ def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> return annotation +async_contextmanager_dependencies_error = """ +FastAPI dependencies with yield require Python 3.7 or above, +or the backports for Python 3.6, installed with: + pip install async-exit-stack async-generator +""" + + +def check_dependency_contextmanagers() -> None: + if AsyncExitStack is None or asynccontextmanager == _fake_asynccontextmanager: + raise RuntimeError(async_contextmanager_dependencies_error) # pragma: no cover + + def get_dependant( *, path: str, @@ -206,6 +225,8 @@ def get_dependant( path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) signature_params = endpoint_signature.parameters + if inspect.isgeneratorfunction(call) or inspect.isasyncgenfunction(call): + check_dependency_contextmanagers() dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache) for param_name, param in signature_params.items(): if isinstance(param.default, params.Depends): @@ -338,6 +359,16 @@ def is_coroutine_callable(call: Callable) -> bool: return asyncio.iscoroutinefunction(call) +async def solve_generator( + *, call: Callable, stack: AsyncExitStack, sub_values: Dict[str, Any] +) -> Any: + if inspect.isgeneratorfunction(call): + cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) + elif inspect.isasyncgenfunction(call): + cm = asynccontextmanager(call)(**sub_values) + return await stack.enter_async_context(cm) + + async def solve_dependencies( *, request: Union[Request, WebSocket], @@ -410,6 +441,15 @@ async def solve_dependencies( continue if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: solved = dependency_cache[sub_dependant.cache_key] + elif inspect.isgeneratorfunction(call) or inspect.isasyncgenfunction(call): + stack = request.scope.get("fastapi_astack") + if stack is None: + raise RuntimeError( + async_contextmanager_dependencies_error + ) # pragma: no cover + solved = await solve_generator( + call=call, stack=stack, sub_values=sub_values + ) elif is_coroutine_callable(call): solved = await call(**sub_values) else: diff --git a/mkdocs.yml b/mkdocs.yml index b5e85e5369064..9140e6a3f9444 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -58,6 +58,7 @@ nav: - Classes as Dependencies: 'tutorial/dependencies/classes-as-dependencies.md' - Sub-dependencies: 'tutorial/dependencies/sub-dependencies.md' - Dependencies in path operation decorators: 'tutorial/dependencies/dependencies-in-path-operation-decorators.md' + - Dependencies with yield: 'tutorial/dependencies/dependencies-with-yield.md' - Advanced Dependencies: 'tutorial/dependencies/advanced-dependencies.md' - Security: - Security Intro: 'tutorial/security/intro.md' diff --git a/pyproject.toml b/pyproject.toml index e47cb965a4f71..21cade15cc51c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,9 @@ test = [ "email_validator", "sqlalchemy", "databases[sqlite]", - "orjson" + "orjson", + "async_exit_stack", + "async_generator" ] doc = [ "mkdocs", @@ -61,4 +63,6 @@ all = [ "ujson", "email_validator", "uvicorn", + "async_exit_stack", + "async_generator" ] diff --git a/tests/test_dependency_contextmanager.py b/tests/test_dependency_contextmanager.py new file mode 100644 index 0000000000000..5771b0fd8671b --- /dev/null +++ b/tests/test_dependency_contextmanager.py @@ -0,0 +1,349 @@ +from typing import Dict + +import pytest +from fastapi import BackgroundTasks, Depends, FastAPI +from starlette.testclient import TestClient + +app = FastAPI() +state = { + "/async": "asyncgen not started", + "/sync": "generator not started", + "/async_raise": "asyncgen raise not started", + "/sync_raise": "generator raise not started", + "context_a": "not started a", + "context_b": "not started b", + "bg": "not set", + "sync_bg": "not set", +} + +errors = [] + + +async def get_state(): + return state + + +class AsyncDependencyError(Exception): + pass + + +class SyncDependencyError(Exception): + pass + + +class OtherDependencyError(Exception): + pass + + +async def asyncgen_state(state: Dict[str, str] = Depends(get_state)): + state["/async"] = "asyncgen started" + yield state["/async"] + state["/async"] = "asyncgen completed" + + +def generator_state(state: Dict[str, str] = Depends(get_state)): + state["/sync"] = "generator started" + yield state["/sync"] + state["/sync"] = "generator completed" + + +async def asyncgen_state_try(state: Dict[str, str] = Depends(get_state)): + state["/async_raise"] = "asyncgen raise started" + try: + yield state["/async_raise"] + except AsyncDependencyError: + errors.append("/async_raise") + finally: + state["/async_raise"] = "asyncgen raise finalized" + + +def generator_state_try(state: Dict[str, str] = Depends(get_state)): + state["/sync_raise"] = "generator raise started" + try: + yield state["/sync_raise"] + except SyncDependencyError: + errors.append("/sync_raise") + finally: + state["/sync_raise"] = "generator raise finalized" + + +async def context_a(state: dict = Depends(get_state)): + state["context_a"] = "started a" + try: + yield state + finally: + state["context_a"] = "finished a" + + +async def context_b(state: dict = Depends(context_a)): + state["context_b"] = "started b" + try: + yield state + finally: + state["context_b"] = f"finished b with a: {state['context_a']}" + + +@app.get("/async") +async def get_async(state: str = Depends(asyncgen_state)): + return state + + +@app.get("/sync") +async def get_sync(state: str = Depends(generator_state)): + return state + + +@app.get("/async_raise") +async def get_async_raise(state: str = Depends(asyncgen_state_try)): + assert state == "asyncgen raise started" + raise AsyncDependencyError() + + +@app.get("/sync_raise") +async def get_sync_raise(state: str = Depends(generator_state_try)): + assert state == "generator raise started" + raise SyncDependencyError() + + +@app.get("/async_raise_other") +async def get_async_raise_other(state: str = Depends(asyncgen_state_try)): + assert state == "asyncgen raise started" + raise OtherDependencyError() + + +@app.get("/sync_raise_other") +async def get_sync_raise_other(state: str = Depends(generator_state_try)): + assert state == "generator raise started" + raise OtherDependencyError() + + +@app.get("/context_b") +async def get_context_b(state: dict = Depends(context_b)): + return state + + +@app.get("/context_b_raise") +async def get_context_b_raise(state: dict = Depends(context_b)): + assert state["context_b"] == "started b" + assert state["context_a"] == "started a" + raise OtherDependencyError() + + +@app.get("/context_b_bg") +async def get_context_b_bg(tasks: BackgroundTasks, state: dict = Depends(context_b)): + async def bg(state: dict): + state["bg"] = f"bg set - b: {state['context_b']} - a: {state['context_a']}" + + tasks.add_task(bg, state) + return state + + +# Sync versions + + +@app.get("/sync_async") +def get_sync_async(state: str = Depends(asyncgen_state)): + return state + + +@app.get("/sync_sync") +def get_sync_sync(state: str = Depends(generator_state)): + return state + + +@app.get("/sync_async_raise") +def get_sync_async_raise(state: str = Depends(asyncgen_state_try)): + assert state == "asyncgen raise started" + raise AsyncDependencyError() + + +@app.get("/sync_sync_raise") +def get_sync_sync_raise(state: str = Depends(generator_state_try)): + assert state == "generator raise started" + raise SyncDependencyError() + + +@app.get("/sync_async_raise_other") +def get_sync_async_raise_other(state: str = Depends(asyncgen_state_try)): + assert state == "asyncgen raise started" + raise OtherDependencyError() + + +@app.get("/sync_sync_raise_other") +def get_sync_sync_raise_other(state: str = Depends(generator_state_try)): + assert state == "generator raise started" + raise OtherDependencyError() + + +@app.get("/sync_context_b") +def get_sync_context_b(state: dict = Depends(context_b)): + return state + + +@app.get("/sync_context_b_raise") +def get_sync_context_b_raise(state: dict = Depends(context_b)): + assert state["context_b"] == "started b" + assert state["context_a"] == "started a" + raise OtherDependencyError() + + +@app.get("/sync_context_b_bg") +async def get_sync_context_b_bg( + tasks: BackgroundTasks, state: dict = Depends(context_b) +): + async def bg(state: dict): + state[ + "sync_bg" + ] = f"sync_bg set - b: {state['context_b']} - a: {state['context_a']}" + + tasks.add_task(bg, state) + return state + + +client = TestClient(app) + + +def test_async_state(): + assert state["/async"] == f"asyncgen not started" + response = client.get("/async") + assert response.status_code == 200 + assert response.json() == f"asyncgen started" + assert state["/async"] == f"asyncgen completed" + + +def test_sync_state(): + assert state["/sync"] == f"generator not started" + response = client.get("/sync") + assert response.status_code == 200 + assert response.json() == f"generator started" + assert state["/sync"] == f"generator completed" + + +def test_async_raise_other(): + assert state["/async_raise"] == "asyncgen raise not started" + with pytest.raises(OtherDependencyError): + client.get("/async_raise_other") + assert state["/async_raise"] == "asyncgen raise finalized" + assert "/async_raise" not in errors + + +def test_sync_raise_other(): + assert state["/sync_raise"] == "generator raise not started" + with pytest.raises(OtherDependencyError): + client.get("/sync_raise_other") + assert state["/sync_raise"] == "generator raise finalized" + assert "/sync_raise" not in errors + + +def test_async_raise(): + response = client.get("/async_raise") + assert response.status_code == 500 + assert state["/async_raise"] == "asyncgen raise finalized" + assert "/async_raise" in errors + errors.clear() + + +def test_context_b(): + response = client.get("/context_b") + data = response.json() + assert data["context_b"] == "started b" + assert data["context_a"] == "started a" + assert state["context_b"] == "finished b with a: started a" + assert state["context_a"] == "finished a" + + +def test_context_b_raise(): + with pytest.raises(OtherDependencyError): + client.get("/context_b_raise") + assert state["context_b"] == "finished b with a: started a" + assert state["context_a"] == "finished a" + + +def test_background_tasks(): + response = client.get("/context_b_bg") + data = response.json() + assert data["context_b"] == "started b" + assert data["context_a"] == "started a" + assert data["bg"] == "not set" + assert state["context_b"] == "finished b with a: started a" + assert state["context_a"] == "finished a" + assert state["bg"] == "bg set - b: started b - a: started a" + + +def test_sync_raise(): + response = client.get("/sync_raise") + assert response.status_code == 500 + assert state["/sync_raise"] == "generator raise finalized" + assert "/sync_raise" in errors + errors.clear() + + +def test_sync_async_state(): + response = client.get("/sync_async") + assert response.status_code == 200 + assert response.json() == f"asyncgen started" + assert state["/async"] == f"asyncgen completed" + + +def test_sync_sync_state(): + response = client.get("/sync_sync") + assert response.status_code == 200 + assert response.json() == f"generator started" + assert state["/sync"] == f"generator completed" + + +def test_sync_async_raise_other(): + with pytest.raises(OtherDependencyError): + client.get("/sync_async_raise_other") + assert state["/async_raise"] == "asyncgen raise finalized" + assert "/async_raise" not in errors + + +def test_sync_sync_raise_other(): + with pytest.raises(OtherDependencyError): + client.get("/sync_sync_raise_other") + assert state["/sync_raise"] == "generator raise finalized" + assert "/sync_raise" not in errors + + +def test_sync_async_raise(): + response = client.get("/sync_async_raise") + assert response.status_code == 500 + assert state["/async_raise"] == "asyncgen raise finalized" + assert "/async_raise" in errors + errors.clear() + + +def test_sync_sync_raise(): + response = client.get("/sync_sync_raise") + assert response.status_code == 500 + assert state["/sync_raise"] == "generator raise finalized" + assert "/sync_raise" in errors + errors.clear() + + +def test_sync_context_b(): + response = client.get("/sync_context_b") + data = response.json() + assert data["context_b"] == "started b" + assert data["context_a"] == "started a" + assert state["context_b"] == "finished b with a: started a" + assert state["context_a"] == "finished a" + + +def test_sync_context_b_raise(): + with pytest.raises(OtherDependencyError): + client.get("/sync_context_b_raise") + assert state["context_b"] == "finished b with a: started a" + assert state["context_a"] == "finished a" + + +def test_sync_background_tasks(): + response = client.get("/sync_context_b_bg") + data = response.json() + assert data["context_b"] == "started b" + assert data["context_a"] == "started a" + assert data["sync_bg"] == "not set" + assert state["context_b"] == "finished b with a: started a" + assert state["context_a"] == "finished a" + assert state["sync_bg"] == "sync_bg set - b: started b - a: started a" diff --git a/tests/test_fakeasync.py b/tests/test_fakeasync.py new file mode 100644 index 0000000000000..4e146b0ff2328 --- /dev/null +++ b/tests/test_fakeasync.py @@ -0,0 +1,12 @@ +import pytest +from fastapi.concurrency import _fake_asynccontextmanager + + +@_fake_asynccontextmanager +def never_run(): + pass # pragma: no cover + + +def test_fake_async(): + with pytest.raises(RuntimeError): + never_run() diff --git a/tests/test_tutorial/test_sql_databases/test_sql_databases.py b/tests/test_tutorial/test_sql_databases/test_sql_databases.py index 8f46a94376fc4..a85a7f91427d3 100644 --- a/tests/test_tutorial/test_sql_databases/test_sql_databases.py +++ b/tests/test_tutorial/test_sql_databases/test_sql_databases.py @@ -1,8 +1,7 @@ -from starlette.testclient import TestClient - -from sql_databases.sql_app.main import app +from pathlib import Path -client = TestClient(app) +import pytest +from starlette.testclient import TestClient openapi_schema = { "openapi": "3.0.2", @@ -282,13 +281,24 @@ } -def test_openapi_schema(): +@pytest.fixture(scope="module") +def client(): + # Import while creating the client to create the DB after starting the test session + from sql_databases.sql_app.main import app + + test_db = Path("./test.db") + with TestClient(app) as c: + yield c + test_db.unlink() + + +def test_openapi_schema(client): response = client.get("/openapi.json") assert response.status_code == 200 assert response.json() == openapi_schema -def test_create_user(): +def test_create_user(client): test_user = {"email": "johndoe@example.com", "password": "secret"} response = client.post("/users/", json=test_user) assert response.status_code == 200 @@ -299,7 +309,7 @@ def test_create_user(): assert response.status_code == 400 -def test_get_user(): +def test_get_user(client): response = client.get("/users/1") assert response.status_code == 200 data = response.json() @@ -307,12 +317,12 @@ def test_get_user(): assert "id" in data -def test_inexistent_user(): +def test_inexistent_user(client): response = client.get("/users/999") assert response.status_code == 404 -def test_get_users(): +def test_get_users(client): response = client.get("/users/") assert response.status_code == 200 data = response.json() @@ -320,7 +330,7 @@ def test_get_users(): assert "id" in data[0] -def test_create_item(): +def test_create_item(client): item = {"title": "Foo", "description": "Something that fights"} response = client.post("/users/1/items/", json=item) assert response.status_code == 200 @@ -343,7 +353,7 @@ def test_create_item(): assert item_to_check["description"] == item["description"] -def test_read_items(): +def test_read_items(client): response = client.get("/items/") assert response.status_code == 200 data = response.json() diff --git a/tests/test_tutorial/test_sql_databases/test_sql_databases_middleware.py b/tests/test_tutorial/test_sql_databases/test_sql_databases_middleware.py new file mode 100644 index 0000000000000..d5644d87651d1 --- /dev/null +++ b/tests/test_tutorial/test_sql_databases/test_sql_databases_middleware.py @@ -0,0 +1,363 @@ +from pathlib import Path + +import pytest +from starlette.testclient import TestClient + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/users/": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "title": "Response_Read_Users_Users__Get", + "type": "array", + "items": {"$ref": "#/components/schemas/User"}, + } + } + }, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "summary": "Read Users", + "operationId": "read_users_users__get", + "parameters": [ + { + "required": False, + "schema": {"title": "Skip", "type": "integer", "default": 0}, + "name": "skip", + "in": "query", + }, + { + "required": False, + "schema": {"title": "Limit", "type": "integer", "default": 100}, + "name": "limit", + "in": "query", + }, + ], + }, + "post": { + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/User"} + } + }, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "summary": "Create User", + "operationId": "create_user_users__post", + "requestBody": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/UserCreate"} + } + }, + "required": True, + }, + }, + }, + "/users/{user_id}": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/User"} + } + }, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "summary": "Read User", + "operationId": "read_user_users__user_id__get", + "parameters": [ + { + "required": True, + "schema": {"title": "User_Id", "type": "integer"}, + "name": "user_id", + "in": "path", + } + ], + } + }, + "/users/{user_id}/items/": { + "post": { + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Item"} + } + }, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "summary": "Create Item For User", + "operationId": "create_item_for_user_users__user_id__items__post", + "parameters": [ + { + "required": True, + "schema": {"title": "User_Id", "type": "integer"}, + "name": "user_id", + "in": "path", + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/ItemCreate"} + } + }, + "required": True, + }, + } + }, + "/items/": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "title": "Response_Read_Items_Items__Get", + "type": "array", + "items": {"$ref": "#/components/schemas/Item"}, + } + } + }, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "summary": "Read Items", + "operationId": "read_items_items__get", + "parameters": [ + { + "required": False, + "schema": {"title": "Skip", "type": "integer", "default": 0}, + "name": "skip", + "in": "query", + }, + { + "required": False, + "schema": {"title": "Limit", "type": "integer", "default": 100}, + "name": "limit", + "in": "query", + }, + ], + } + }, + }, + "components": { + "schemas": { + "ItemCreate": { + "title": "ItemCreate", + "required": ["title"], + "type": "object", + "properties": { + "title": {"title": "Title", "type": "string"}, + "description": {"title": "Description", "type": "string"}, + }, + }, + "Item": { + "title": "Item", + "required": ["title", "id", "owner_id"], + "type": "object", + "properties": { + "title": {"title": "Title", "type": "string"}, + "description": {"title": "Description", "type": "string"}, + "id": {"title": "Id", "type": "integer"}, + "owner_id": {"title": "Owner_Id", "type": "integer"}, + }, + }, + "User": { + "title": "User", + "required": ["email", "id", "is_active"], + "type": "object", + "properties": { + "email": {"title": "Email", "type": "string"}, + "id": {"title": "Id", "type": "integer"}, + "is_active": {"title": "Is_Active", "type": "boolean"}, + "items": { + "title": "Items", + "type": "array", + "items": {"$ref": "#/components/schemas/Item"}, + "default": [], + }, + }, + }, + "UserCreate": { + "title": "UserCreate", + "required": ["email", "password"], + "type": "object", + "properties": { + "email": {"title": "Email", "type": "string"}, + "password": {"title": "Password", "type": "string"}, + }, + }, + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"type": "string"}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + }, + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, + } + }, +} + + +@pytest.fixture(scope="module") +def client(): + # Import while creating the client to create the DB after starting the test session + from sql_databases.sql_app.alt_main import app + + test_db = Path("./test.db") + with TestClient(app) as c: + yield c + test_db.unlink() + + +def test_openapi_schema(client): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +def test_create_user(client): + test_user = {"email": "johndoe@example.com", "password": "secret"} + response = client.post("/users/", json=test_user) + assert response.status_code == 200 + data = response.json() + assert test_user["email"] == data["email"] + assert "id" in data + response = client.post("/users/", json=test_user) + assert response.status_code == 400 + + +def test_get_user(client): + response = client.get("/users/1") + assert response.status_code == 200 + data = response.json() + assert "email" in data + assert "id" in data + + +def test_inexistent_user(client): + response = client.get("/users/999") + assert response.status_code == 404 + + +def test_get_users(client): + response = client.get("/users/") + assert response.status_code == 200 + data = response.json() + assert "email" in data[0] + assert "id" in data[0] + + +def test_create_item(client): + item = {"title": "Foo", "description": "Something that fights"} + response = client.post("/users/1/items/", json=item) + assert response.status_code == 200 + item_data = response.json() + assert item["title"] == item_data["title"] + assert item["description"] == item_data["description"] + assert "id" in item_data + assert "owner_id" in item_data + response = client.get("/users/1") + assert response.status_code == 200 + user_data = response.json() + item_to_check = [it for it in user_data["items"] if it["id"] == item_data["id"]][0] + assert item_to_check["title"] == item["title"] + assert item_to_check["description"] == item["description"] + response = client.get("/users/1") + assert response.status_code == 200 + user_data = response.json() + item_to_check = [it for it in user_data["items"] if it["id"] == item_data["id"]][0] + assert item_to_check["title"] == item["title"] + assert item_to_check["description"] == item["description"] + + +def test_read_items(client): + response = client.get("/items/") + assert response.status_code == 200 + data = response.json() + assert data + first_item = data[0] + assert "title" in first_item + assert "description" in first_item