Skip to content

Commit

Permalink
[RFR] base_rest: use routing attribute constant throughout
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanRijnhart committed Dec 7, 2022
1 parent 5a201b1 commit 6cbef43
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 44 deletions.
7 changes: 4 additions & 3 deletions base_rest/apispec/base_rest_service_apispec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from apispec import APISpec

from ..core import _rest_services_databases
from ..tools import ROUTING_DECORATOR_ATTR
from .rest_method_param_plugin import RestMethodParamPlugin
from .rest_method_security_plugin import RestMethodSecurityPlugin
from .restapi_method_route_plugin import RestApiMethodRoutePlugin
Expand Down Expand Up @@ -62,18 +63,18 @@ def _get_plugins(self):

def _add_method_path(self, method):
description = textwrap.dedent(method.__doc__ or "")
routing = method.original_routing
routing = getattr(method, ROUTING_DECORATOR_ATTR)
for paths, method in routing["routes"]:
for path in paths:
self.path(
path,
operations={method.lower(): {"summary": description}},
original_routing=routing,
**{ROUTING_DECORATOR_ATTR: routing},
)

def generate_paths(self):
for _name, method in inspect.getmembers(self._service, inspect.ismethod):
routing = getattr(method, "original_routing", None)
routing = getattr(method, ROUTING_DECORATOR_ATTR, None)
if not routing:
continue
self._add_method_path(method)
3 changes: 2 additions & 1 deletion base_rest/apispec/rest_method_param_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from apispec import BasePlugin

from ..restapi import RestMethodParam
from ..tools import ROUTING_DECORATOR_ATTR


class RestMethodParamPlugin(BasePlugin):
Expand All @@ -25,7 +26,7 @@ def init_spec(self, spec):
self.openapi_version = spec.openapi_version

def operation_helper(self, path=None, operations=None, **kwargs):
routing = kwargs.get("original_routing")
routing = kwargs.get(ROUTING_DECORATOR_ATTR)
if not routing:
super(RestMethodParamPlugin, self).operation_helper(
path, operations, **kwargs
Expand Down
4 changes: 3 additions & 1 deletion base_rest/apispec/rest_method_security_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from apispec import BasePlugin

from ..tools import ROUTING_DECORATOR_ATTR


class RestMethodSecurityPlugin(BasePlugin):
"""
Expand All @@ -23,7 +25,7 @@ def init_spec(self, spec):
spec.components.security_scheme("user", user_scheme)

def operation_helper(self, path=None, operations=None, **kwargs):
routing = kwargs.get("original_routing")
routing = kwargs.get(ROUTING_DECORATOR_ATTR)
if not routing:
super(RestMethodSecurityPlugin, self).operation_helper(
path, operations, **kwargs
Expand Down
5 changes: 3 additions & 2 deletions base_rest/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from odoo.addons.component.core import AbstractComponent

from ..apispec.base_rest_service_apispec import BaseRestServiceAPISpec
from ..tools import ROUTING_DECORATOR_ATTR

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -93,7 +94,7 @@ def _prepare_input_params(self, method, params):
method_name = method.__name__
if hasattr(method, "skip_secure_params"):
return params
routing = getattr(method, "original_routing", None)
routing = getattr(method, ROUTING_DECORATOR_ATTR, None)
if not routing:
_logger.warning(
"Method %s is not a public method of service %s",
Expand Down Expand Up @@ -122,7 +123,7 @@ def _prepare_response(self, method, result):
method_name = method.__name__
if hasattr(method, "skip_secure_response"):
return result
routing = getattr(method, "original_routing", None)
routing = getattr(method, ROUTING_DECORATOR_ATTR, None)
output_param = routing["output_param"]
if not output_param:
_logger.warning(
Expand Down
10 changes: 3 additions & 7 deletions base_rest/models/rest_service_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,7 @@
_rest_services_databases,
_rest_services_routes,
)
from ..tools import _inspect_methods

# Decorator attribute added on a route function (cfr Odoo's route)
ROUTING_DECORATOR_ATTR = "original_routing"

from ..tools import ROUTING_DECORATOR_ATTR, _inspect_methods

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -397,9 +393,9 @@ def _generate_methods(self):
path_sep = "/"
root_path = "{}{}{}".format(root_path, path_sep, self._service._usage)
for name, method in _inspect_methods(self._service.__class__):
if not hasattr(method, "original_routing"):
routing = getattr(method, ROUTING_DECORATOR_ATTR, None)
if routing is None:
continue
routing = method.original_routing
for routes, http_method in routing["routes"]:
method_name = "{}_{}".format(http_method.lower(), name)
default_route = routes[0]
Expand Down
4 changes: 2 additions & 2 deletions base_rest/restapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from odoo import _, http
from odoo.exceptions import UserError, ValidationError

from .tools import cerberus_to_json
from .tools import ROUTING_DECORATOR_ATTR, cerberus_to_json


def method(routes, input_param=None, output_param=None, **kw):
Expand Down Expand Up @@ -104,7 +104,7 @@ def response_wrap(*args, **kw):
response = f(*args, **kw)
return response

response_wrap.original_routing = routing
setattr(response_wrap, ROUTING_DECORATOR_ATTR, routing)
response_wrap.original_func = f
return response_wrap

Expand Down
4 changes: 2 additions & 2 deletions base_rest/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
_rest_controllers_per_module,
_rest_services_databases,
)
from ..tools import _inspect_methods
from ..tools import ROUTING_DECORATOR_ATTR, _inspect_methods


class RegistryMixin(object):
Expand Down Expand Up @@ -187,7 +187,7 @@ def _get_controller_for(service, addon="base_rest"):
def _get_controller_route_methods(controller):
methods = {}
for name, method in _inspect_methods(controller):
if hasattr(method, "original_routing"):
if hasattr(method, ROUTING_DECORATOR_ATTR):
methods[name] = method
return methods

Expand Down
68 changes: 42 additions & 26 deletions base_rest/tests/test_controller_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from odoo.addons.component.core import Component

from .. import restapi
from ..tools import ROUTING_DECORATOR_ATTR
from .common import TransactionRestServiceRegistryCase


Expand Down Expand Up @@ -109,7 +110,7 @@ def _validator_my_instance_method(self):
# the generated method_name is always the {http_method}_{method_name}
method = routes["get_get"]
self.assertDictEqual(
method.original_routing,
getattr(method, ROUTING_DECORATOR_ATTR),
{
"methods": ["GET"],
"auth": "public",
Expand All @@ -126,7 +127,7 @@ def _validator_my_instance_method(self):

method = routes["get_search"]
self.assertDictEqual(
method.original_routing,
getattr(method, ROUTING_DECORATOR_ATTR),
{
"methods": ["GET"],
"auth": "public",
Expand All @@ -140,7 +141,7 @@ def _validator_my_instance_method(self):

method = routes["post_update"]
self.assertDictEqual(
method.original_routing,
getattr(method, ROUTING_DECORATOR_ATTR),
{
"methods": ["POST"],
"auth": "public",
Expand All @@ -157,7 +158,7 @@ def _validator_my_instance_method(self):

method = routes["put_update"]
self.assertDictEqual(
method.original_routing,
getattr(method, ROUTING_DECORATOR_ATTR),
{
"methods": ["PUT"],
"auth": "public",
Expand All @@ -171,7 +172,7 @@ def _validator_my_instance_method(self):

method = routes["post_create"]
self.assertDictEqual(
method.original_routing,
getattr(method, ROUTING_DECORATOR_ATTR),
{
"methods": ["POST"],
"auth": "public",
Expand All @@ -185,7 +186,7 @@ def _validator_my_instance_method(self):

method = routes["post_delete"]
self.assertDictEqual(
method.original_routing,
getattr(method, ROUTING_DECORATOR_ATTR),
{
"methods": ["POST"],
"auth": "public",
Expand All @@ -199,7 +200,7 @@ def _validator_my_instance_method(self):

method = routes["delete_delete"]
self.assertDictEqual(
method.original_routing,
getattr(method, ROUTING_DECORATOR_ATTR),
{
"methods": ["DELETE"],
"auth": "public",
Expand All @@ -213,7 +214,7 @@ def _validator_my_instance_method(self):

method = routes["post_my_method"]
self.assertDictEqual(
method.original_routing,
getattr(method, ROUTING_DECORATOR_ATTR),
{
"methods": ["POST"],
"auth": "public",
Expand All @@ -227,7 +228,7 @@ def _validator_my_instance_method(self):

method = routes["post_my_instance_method"]
self.assertDictEqual(
method.original_routing,
getattr(method, ROUTING_DECORATOR_ATTR),
{
"methods": ["POST"],
"auth": "public",
Expand Down Expand Up @@ -295,7 +296,7 @@ def _get_partner_schema(self):

method = routes["get_get"]
self.assertDictEqual(
method.original_routing,
getattr(method, ROUTING_DECORATOR_ATTR),
{
"methods": ["GET"],
"auth": "public",
Expand All @@ -312,7 +313,7 @@ def _get_partner_schema(self):

method = routes["get_get_name"]
self.assertDictEqual(
method.original_routing,
getattr(method, ROUTING_DECORATOR_ATTR),
{
"methods": ["GET"],
"auth": "public",
Expand All @@ -326,7 +327,7 @@ def _get_partner_schema(self):

method = routes["post_update_name"]
self.assertDictEqual(
method.original_routing,
getattr(method, ROUTING_DECORATOR_ATTR),
{
"methods": ["POST"],
"auth": "user",
Expand Down Expand Up @@ -392,7 +393,7 @@ def update_name(self, _id, **params):

method = routes["get_get"]
self.assertDictEqual(
method.original_routing,
getattr(method, ROUTING_DECORATOR_ATTR),
{
"methods": ["GET"],
"auth": "public",
Expand All @@ -409,7 +410,7 @@ def update_name(self, _id, **params):

method = routes["get_get_name"]
self.assertDictEqual(
method.original_routing,
getattr(method, ROUTING_DECORATOR_ATTR),
{
"methods": ["GET"],
"auth": "public",
Expand All @@ -423,7 +424,7 @@ def update_name(self, _id, **params):

method = routes["post_update_name"]
self.assertDictEqual(
method.original_routing,
getattr(method, ROUTING_DECORATOR_ATTR),
{
"methods": ["POST"],
"auth": "user",
Expand Down Expand Up @@ -510,26 +511,33 @@ def _validator_get(self):
("save_session", default_save_session),
]:
self.assertEqual(
routes["get_new_api_method_without"].original_routing[attr],
getattr(routes["get_new_api_method_without"], ROUTING_DECORATOR_ATTR)[
attr
],
default,
"wrong %s" % attr,
)
self.assertEqual(
routes["get_new_api_method_with"].original_routing["auth"], "public"
getattr(routes["get_new_api_method_with"], ROUTING_DECORATOR_ATTR)["auth"],
"public",
)
self.assertEqual(
routes["get_new_api_method_with"].original_routing["cors"], "http://my_site"
getattr(routes["get_new_api_method_with"], ROUTING_DECORATOR_ATTR)["cors"],
"http://my_site",
)
self.assertEqual(
routes["get_new_api_method_with"].original_routing["csrf"], not default_csrf
getattr(routes["get_new_api_method_with"], ROUTING_DECORATOR_ATTR)["csrf"],
not default_csrf,
)
self.assertEqual(
routes["get_new_api_method_with"].original_routing["save_session"],
getattr(routes["get_new_api_method_with"], ROUTING_DECORATOR_ATTR)[
"save_session"
],
not default_save_session,
)

self.assertEqual(
routes["get_get"].original_routing["auth"],
getattr(routes["get_get"], ROUTING_DECORATOR_ATTR)["auth"],
default_auth,
"wrong auth for get_get",
)
Expand All @@ -541,12 +549,14 @@ def _validator_get(self):
("save_session", default_save_session),
]:
self.assertEqual(
routes["my_controller_route_without"].original_routing[attr],
getattr(routes["my_controller_route_without"], ROUTING_DECORATOR_ATTR)[
attr
],
default,
"wrong %s" % attr,
)

routing = routes["my_controller_route_with"].original_routing
routing = getattr(routes["my_controller_route_with"], ROUTING_DECORATOR_ATTR)
for attr, value in [
("auth", "public"),
("cors", "http://with_cors"),
Expand All @@ -560,7 +570,9 @@ def _validator_get(self):
"wrong %s" % attr,
)
self.assertEqual(
routes["my_controller_route_without_auth_2"].original_routing["auth"],
getattr(
routes["my_controller_route_without_auth_2"], ROUTING_DECORATOR_ATTR
)["auth"],
None,
"wrong auth for my_controller_route_without_auth_2",
)
Expand Down Expand Up @@ -605,7 +617,9 @@ def _validator_get(self):
routes = self._get_controller_route_methods(controller)

self.assertEqual(
routes["get_new_api_method_with_public_or"].original_routing["auth"],
getattr(
routes["get_new_api_method_with_public_or"], ROUTING_DECORATOR_ATTR
)["auth"],
"public_or_my_default_auth",
)

Expand Down Expand Up @@ -643,7 +657,9 @@ def _validator_get(self):
routes = self._get_controller_route_methods(controller)

self.assertEqual(
routes["get_new_api_method_with_public_or"].original_routing["auth"],
getattr(
routes["get_new_api_method_with_public_or"], ROUTING_DECORATOR_ATTR
)["auth"],
"my_default_auth",
)

Expand Down
2 changes: 2 additions & 0 deletions base_rest/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

_logger = logging.getLogger(__name__)

# Decorator attribute added on a route function (cfr Odoo's route)
ROUTING_DECORATOR_ATTR = "original_routing"
SUPPORTED_META = ["title", "description", "example", "examples"]


Expand Down

0 comments on commit 6cbef43

Please sign in to comment.