Skip to content

Commit

Permalink
Change API class TAGs value to be a tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
yugokato committed Sep 30, 2024
1 parent 3fc4959 commit 434df7e
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 34 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ from openapi_test_client.libraries.api.api_functions import endpoint


class AuthAPI(DemoAppBaseAPI):
TAGs = ["Auth"]
TAGs = ("Auth",)

@endpoint.is_public
@endpoint.post("/v1/auth/login")
Expand Down Expand Up @@ -441,10 +441,10 @@ Some attributes available from the API class:
```pycon
>>> # Get tag data
>>> client.AUTH.TAGs
['Auth']
('Auth',)
>>> # Get available endpoints under this API class
>>> pprint(client.AUTH.endpoints)
[Endpoint(tags=['Auth'],
[Endpoint(tags=('Auth',),
api_class=<class 'openapi_test_client.clients.demo_app.api.auth.AuthAPI'>,
method='post',
path='/v1/auth/login',
Expand All @@ -455,7 +455,7 @@ Some attributes available from the API class:
is_public=True,
is_documented=True,
is_deprecated=False),
Endpoint(tags=['Auth'],
Endpoint(tags=('Auth',),
api_class=<class 'openapi_test_client.clients.demo_app.api.auth.AuthAPI'>,
method='get',
path='/v1/auth/logout',
Expand Down Expand Up @@ -513,7 +513,7 @@ Various endpoint data is available from the endpoint function via `endpoint` pro
>>> print(client.AUTH.login.endpoint)
POST /v1/auth/login
>>> pprint(client.AUTH.login.endpoint)
Endpoint(tags=['Auth'],
Endpoint(tags=('Auth',),
api_class=<class 'openapi_test_client.clients.demo_app.api.auth.AuthAPI'>,
method='post',
path='/v1/auth/login',
Expand All @@ -540,7 +540,7 @@ True
>>> print(AuthAPI.login.endpoint)
POST /v1/auth/login
>>> pprint(AuthAPI.login.endpoint)
Endpoint(tags=['Auth'],
Endpoint(tags=('Auth',),
api_class=<class 'openapi_test_client.clients.demo_app.api.auth.AuthAPI'>,
method='post',
path='/v1/auth/login',
Expand Down Expand Up @@ -619,7 +619,7 @@ from ..models.users import Metadata


class UsersAPI(DemoAppBaseAPI):
TAGs = ["Users"]
TAGs = ("Users",)

@endpoint.post("/v1/users")
def create_user(
Expand Down
Binary file modified images/generate.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion src/openapi_test_client/clients/demo_app/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class AuthAPI(DemoAppBaseAPI):
TAGs = ["Auth"]
TAGs = ("Auth",)

@endpoint.is_public
@endpoint.post("/v1/auth/login")
Expand Down
2 changes: 1 addition & 1 deletion src/openapi_test_client/clients/demo_app/api/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


class UsersAPI(DemoAppBaseAPI):
TAGs = ["Users"]
TAGs = ("Users",)

@endpoint.post("/v1/users")
def create_user(
Expand Down
2 changes: 1 addition & 1 deletion src/openapi_test_client/libraries/api/api_classes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, api_client: APIClientType):
@property
@classmethod
@abstractmethod
def TAGs(cls) -> list[str]:
def TAGs(cls) -> tuple[str, ...]:
"""API Tags defined in the swagger doc. Every API class MUST have this attribute"""
raise NotImplementedError

Expand Down
12 changes: 6 additions & 6 deletions src/openapi_test_client/libraries/api/api_client_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def generate_api_class(
"\n".join([f"from {_get_package(m)} import {m.__name__}" for m in [base_class, endpoint, RestResponse]])
+ "\n\n"
)
code += f'class {class_name}({base_class.__name__}):\n{TAB}TAGs = ["{tag}"]\n\n'
code += f"class {class_name}({base_class.__name__}):\n{TAB}TAGs = {tuple([tag])}\n\n"
code = format_code(code, remove_unused_imports=False)
if is_temp_client:
api_class_file_path.parent.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -255,7 +255,7 @@ def update_endpoint_functions(
>>>
>>>
>>> class SomeDemoAPI(DemoAppBaseAPI):
>>> TAGs = ["Some Tag"]
>>> TAGs = ("Some Tag",)
>>>
>>> @endpoint.get("/v1/something/{uuid}")
>>> def do_something(self, uuid: str, /, *, param1: str = None, param2: int = None, **kwargs) -> RestResponse:
Expand All @@ -269,7 +269,7 @@ def update_endpoint_functions(
# Regex for API class definition
regex_api_class = re.compile(rf"class {api_class.__name__}\(\S+{BASE_API_CLASS_NAME_SUFFIX}\):")
# Regex for TAGs and for individual tag inside TAGs
regex_tags = re.compile(r"TAGs = \[[^]]*\]", flags=re.MULTILINE)
regex_tags = re.compile(r"TAGs = \([^)]*\)", flags=re.MULTILINE)
regex_tag = re.compile(r'"(?P<tag>[^"]*)"', flags=re.MULTILINE)
# Regex for each endpoint function block
tab = f"(?:{TAB}|\t)"
Expand Down Expand Up @@ -423,17 +423,17 @@ def update_existing_endpoints(target_api_class: type[APIClassType] = api_class):
tags_in_class = re.search(regex_tags, original_api_cls_code)
if tags_in_class:
defined_tags = re.findall(regex_tag, tags_in_class.group(0))
if defined_tags:
if defined_tags or (not defined_tags and tags_in_class):
# Update TAGs only when none of defined tags match with documented tags. Note that when multiple tags are
# documented, the updated tags may not what you exactly want. If that is the case you'll need to remove
# tags that is not needed for this API class
if not set(defined_tags).intersection(api_spec_tags):
new_code = re.sub(regex_tags, f"TAGs = {list(api_spec_tags)}", new_code)
new_code = re.sub(regex_tags, f"TAGs = {tuple(api_spec_tags)}", new_code)
else:
api_class_matched = re.search(regex_api_class, original_api_cls_code)
defined_api_class = api_class_matched.group(0)
new_code = re.sub(
regex_api_class, f"{defined_api_class}\n{TAB}TAGs = {list(api_spec_tags)}\n", new_code
regex_api_class, f"{defined_api_class}\n{TAB}TAGs = {tuple(api_spec_tags)}\n", new_code
)

# Update code (if code changes)
Expand Down
35 changes: 19 additions & 16 deletions src/openapi_test_client/libraries/api/api_functions/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from copy import deepcopy
from dataclasses import dataclass
from functools import partial, update_wrapper, wraps
from threading import RLock
from typing import TYPE_CHECKING, Any, Callable, Optional, ParamSpec, Sequence, TypeVar, cast

from common_libs.ansi_colors import ColorCodes, color
Expand Down Expand Up @@ -42,7 +43,7 @@ class Endpoint:
This is accessible via an EndpointFunc object (see docstrings of the `endpoint` class below).
"""

tags: list[str]
tags: tuple[str, ...]
api_class: type[APIClassType]
method: str
path: str
Expand Down Expand Up @@ -112,9 +113,9 @@ class endpoint:
>>> isinstance(client.AUTH.login, EndpointFunc) and isinstance(AuthAPI.login, EndpointFunc)
True
>>> client.AUTH.login.endpoint
Endpoint(tags=['Auth'], api_class=<class 'openapi_test_client.clients.demo_app.api.auth.AuthAPI'>, method='post', path='/v1/auth/login', func_name='login', model=<class 'types.AuthAPILoginEndpointModel'>, url='http://127.0.0.1:5000/v1/auth/login', content_type=None, is_public=False, is_documented=True, is_deprecated=False)
Endpoint(tags=('Auth',), api_class=<class 'openapi_test_client.clients.demo_app.api.auth.AuthAPI'>, method='post', path='/v1/auth/login', func_name='login', model=<class 'types.AuthAPILoginEndpointModel'>, url='http://127.0.0.1:5000/v1/auth/login', content_type=None, is_public=False, is_documented=True, is_deprecated=False)
>>> AuthAPI.login.endpoint
Endpoint(tags=['Auth'], api_class=<class 'openapi_test_client.clients.demo_app.api.auth.AuthAPI'>, method='post', path='/v1/auth/login', func_name='login', model=<class 'types.AuthAPILoginEndpointModel'>, url=None, content_type=None, is_public=False, is_documented=True, is_deprecated=False)
Endpoint(tags=('Auth',), api_class=<class 'openapi_test_client.clients.demo_app.api.auth.AuthAPI'>, method='post', path='/v1/auth/login', func_name='login', model=<class 'types.AuthAPILoginEndpointModel'>, url=None, content_type=None, is_public=False, is_documented=True, is_deprecated=False)
>>> str(client.AUTH.login.endpoint)
'POST /v1/auth/login'
>>> str(AuthAPI.login.endpoint)
Expand Down Expand Up @@ -387,6 +388,7 @@ class EndpointHandler:

# cache endpoint function objects
_endpoint_functions = {}
_lock = RLock()

def __init__(
self,
Expand All @@ -410,17 +412,18 @@ def __init__(
def __get__(self, instance: Optional[APIClassType], owner: type[APIClassType]) -> EndpointFunc:
"""Return an EndpointFunc object"""
key = (self.original_func.__name__, instance, owner)
if not (endpoint_func := EndpointHandler._endpoint_functions.get(key)):
endpoint_func_name = (
f"{owner.__name__}{generate_class_name(self.original_func.__name__, suffix=EndpointFunc.__name__)}"
)
endpoint_func_class = type(
endpoint_func_name,
(EndpointFunc,),
{},
)
endpoint_func = endpoint_func_class(self, instance, owner)
EndpointHandler._endpoint_functions[key] = endpoint_func
with EndpointHandler._lock:
if not (endpoint_func := EndpointHandler._endpoint_functions.get(key)):
endpoint_func_name = (
f"{owner.__name__}{generate_class_name(self.original_func.__name__, suffix=EndpointFunc.__name__)}"
)
endpoint_func_class = type(
endpoint_func_name,
(EndpointFunc,),
{},
)
endpoint_func = endpoint_func_class(self, instance, owner)
EndpointHandler._endpoint_functions[key] = endpoint_func
return cast(EndpointFunc, update_wrapper(endpoint_func, self.original_func))

@property
Expand Down Expand Up @@ -467,8 +470,8 @@ def __init__(self, endpoint_handler: EndpointHandler, instance: Optional[APIClas
# API class. To make the sorting of endpoint objects during an initialization of API
# classes work using (endpoint.tag, endpoint.method, endpoint.path) key, assign an empty
# list if TAGs is not defined
if not isinstance(tags := (instance or owner).TAGs, list):
tags = []
if isinstance(tags := (instance or owner).TAGs, property):
tags = ()
self.endpoint = Endpoint(
tags,
owner,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_code_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_generate_api_class_code(
assert NewAPIClass.__name__ == api_class_name
assert NewAPIClass.__module__.endswith(".test_something")
assert NewAPIClass.app_name == temp_api_client.app_name
assert NewAPIClass.TAGs == ["Test"]
assert NewAPIClass.TAGs == ("Test",)
assert (Path(inspect.getfile(NewAPIClass)).parent / "__init__.py").exists()
# API class generation will trigger the initialization of API classes, which will update the `endpoints` attr
assert NewAPIClass.endpoints is not None
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_endpiont_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_endpoint_handler(mocker: MockerFixture, api_client: DemoAppAPIClient, w
"""Verify the basic capability around EndpointHandler"""

class TestAPI(APIBase):
TAGs = ["Test"]
TAGs = ("Test",)
app_name = api_client.app_name

def do_something(self):
Expand Down

0 comments on commit 434df7e

Please sign in to comment.