Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add method to add more dependencies to specific endpoints #406

Merged
merged 5 commits into from
Nov 24, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 101 additions & 1 deletion src/titiler/core/tests/test_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import attr
import morecantile
import pytest
from requests.auth import HTTPBasicAuth
from rio_tiler.io import BaseReader, COGReader, MultiBandReader, STACReader

from titiler.core.dependencies import TMSParams, WebMercatorTMSParams
Expand All @@ -23,7 +25,7 @@

from .conftest import DATA_DIR, mock_rasterio_open, parse_img

from fastapi import FastAPI, Query
from fastapi import Depends, FastAPI, HTTPException, Query, security, status

from starlette.testclient import TestClient

Expand Down Expand Up @@ -1131,3 +1133,101 @@ def test_TMSFactory():
body = response.json()
assert body["type"] == "TileMatrixSetType"
assert body["identifier"] == "WebMercatorQuad"


def test_TilerFactory_WithDependencies():
"""Test TilerFactory class."""

http_basic = security.HTTPBasic()

def must_be_bob(credentials: security.HTTPBasicCredentials = Depends(http_basic)):
print(credentials)
if credentials.username == "bob":
return True
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="You're not Bob",
headers={"WWW-Authenticate": "Basic"},
)

cog = TilerFactory(
route_dependencies=[
(
[
{"path": "/bounds", "method": "GET"},
{"path": "/tiles/{z}/{x}/{y}", "method": "GET"},
],
[Depends(must_be_bob)],
),
],
)
assert len(cog.router.routes) == 25
assert cog.tms_dependency == TMSParams

app = FastAPI()
app.include_router(cog.router, prefix="/something")
client = TestClient(app)

response = client.get(f"/something/tilejson.json?url={DATA_DIR}/cog.tif")
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
assert response.json()["tilejson"]

auth_bob = HTTPBasicAuth(username="bob", password="ILoveSponge")
auth_notbob = HTTPBasicAuth(username="notbob", password="IHateSponge")

response = client.get(
f"/something/bounds?url={DATA_DIR}/cog.tif&rescale=0,1000", auth=auth_bob
)
assert response.status_code == 200

with pytest.raises(HTTPException):
client.get(
f"/something/bounds?url={DATA_DIR}/cog.tif&rescale=0,1000", auth=auth_notbob
)

# response = client.get(f"/tiles/8/87/48?url={DATA_DIR}/cog.tif&rescale=0,1000")
# assert response.status_code == 200
# assert response.headers["content-type"] == "image/jpeg"
# timing = response.headers["server-timing"]
# assert "dataread;dur" in timing
# assert "postprocess;dur" in timing
# assert "format;dur" in timing

# response = client.get(
# f"/tiles/8/87/48?url={DATA_DIR}/cog.tif&rescale=-3.4028235e+38,3.4028235e+38"
# )
# assert response.status_code == 200
# assert response.headers["content-type"] == "image/jpeg"
# timing = response.headers["server-timing"]
# assert "dataread;dur" in timing
# assert "postprocess;dur" in timing
# assert "format;dur" in timing

# response = client.get(
# f"/tiles/8/87/48.tif?url={DATA_DIR}/cog.tif&bidx=1&bidx=1&bidx=1&return_mask=false"
# )
# assert response.status_code == 200
# assert response.headers["content-type"] == "image/tiff; application=geotiff"
# meta = parse_img(response.content)
# assert meta["dtype"] == "uint16"
# assert meta["count"] == 3
# assert meta["width"] == 256
# assert meta["height"] == 256

# response = client.get(
# f"/tiles/8/87/48.tif?url={DATA_DIR}/cog.tif&expression=b1,b1,b1&return_mask=false"
# )
# assert response.status_code == 200
# assert response.headers["content-type"] == "image/tiff; application=geotiff"
# meta = parse_img(response.content)
# assert meta["dtype"] == "int32"
# assert meta["count"] == 3
# assert meta["width"] == 256
# assert meta["height"] == 256

# response = client.get(
# f"/tiles/8/84/47?url={DATA_DIR}/cog.tif&bidx=1&rescale=0,1000&colormap_name=viridis"
# )
# assert response.status_code == 200
# assert response.headers["content-type"] == "image/png"
41 changes: 39 additions & 2 deletions src/titiler/core/titiler/core/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import abc
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Type, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from urllib.parse import urlencode

import rasterio
Expand Down Expand Up @@ -46,14 +46,17 @@
Statistics,
StatisticsGeoJSON,
)
from titiler.core.models.routes import EndpointScope
from titiler.core.resources.enums import ImageType, MediaType, OptionalHeader
from titiler.core.resources.responses import GeoJSONResponse, JSONResponse, XMLResponse
from titiler.core.utils import Timer

from fastapi import APIRouter, Body, Depends, Path, Query
from fastapi import APIRouter, Body, Depends, Path, Query, params
from fastapi.dependencies.utils import get_parameterless_sub_dependant

from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Match
from starlette.templating import Jinja2Templates

try:
Expand Down Expand Up @@ -149,10 +152,18 @@ class BaseTilerFactory(metaclass=abc.ABCMeta):
# add additional headers in response
optional_headers: List[OptionalHeader] = field(default_factory=list)

# add dependencies to specific routes
route_dependencies: List[Tuple[List[EndpointScope], List[params.Depends]]] = field(
default_factory=list
)

def __post_init__(self):
"""Post Init: register route and configure specific options."""
self.register_routes()

for scopes, dependencies in self.route_dependencies:
self.add_route_dependencies(scopes=scopes, dependencies=dependencies)

@abc.abstractmethod
def register_routes(self):
"""Register Tiler Routes."""
Expand All @@ -166,6 +177,32 @@ def url_for(self, request: Request, name: str, **path_params: Any) -> str:
base_url += self.router_prefix.lstrip("/")
return url_path.make_absolute_url(base_url=base_url)

def add_route_dependencies(
self,
scopes: List[EndpointScope],
dependencies=List[params.Depends],
):
"""Add dependencies to routes.

Allows a developer to add dependencies to a route after the route has been defined.

"""
for route in self.router.routes:
vincentsarago marked this conversation as resolved.
Show resolved Hide resolved
for scope in scopes:
match, _ = route.matches({"type": "http", **scope})
if match != Match.FULL:
continue

# Mimicking how APIRoute handles dependencies:
# https://github.com/tiangolo/fastapi/blob/1760da0efa55585c19835d81afa8ca386036c325/fastapi/routing.py#L408-L412
for depends in dependencies[::-1]:
route.dependant.dependencies.insert( # type: ignore
alukach marked this conversation as resolved.
Show resolved Hide resolved
0,
get_parameterless_sub_dependant(
depends=depends, path=route.path_format # type: ignore
),
)
vincentsarago marked this conversation as resolved.
Show resolved Hide resolved

alukach marked this conversation as resolved.
Show resolved Hide resolved

@dataclass
class TilerFactory(BaseTilerFactory):
Expand Down
13 changes: 13 additions & 0 deletions src/titiler/core/titiler/core/models/routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""models for routes."""

from typing import Optional, TypedDict


class EndpointScope(TypedDict, total=False):
"""Define endpoint."""

# More strict version of Starlette's Scope
# https://github.com/encode/starlette/blob/6af5c515e0a896cbf3f86ee043b88f6c24200bcf/starlette/types.py#L3
path: str
method: str
type: Optional[str] # http or websocket