diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f88ec01..80c3a3c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,12 @@ ci: autofix_prs: false autoupdate_schedule: quarterly repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.6.8" + hooks: + - id: ruff + args: ["--fix"] + types_or: [python] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.6.0 hooks: diff --git a/pyproject.toml b/pyproject.toml index f847c0c..2a89ac2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,20 +10,22 @@ dependencies = [ "inflect==7.0.0", "phonenumbers==8.13.45", "pydantic-extra-types==2.9.0", - "pre-commit==3.6.2", "pydantic[email]==2.9.2", "PyYAML==6.0.1", ] readme = "README.md" [project.optional-dependencies] +dev = ["pre-commit==3.6.2", "ruff==0.6.8"] app = [ + "common-libs[dev]", "Quart==0.19.4", "quart-auth==0.9.0", "quart-schema[pydantic]==0.19.1", ] test = [ + "common-libs[dev]", "openapi-test-client[app]", "pytest==8.3.2", "pytest-lazy-fixtures==1.1.1", @@ -49,3 +51,23 @@ profile = "black" [tool.black] line_length = 120 + +[tool.ruff] +line-length = 120 +indent-width = 4 + +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", +] +ignore = ["E731", "E741", "F403"] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] +# We currently use `Optional` in a special way +"**/{clients,}/**/{api,models}/*" = ["UP007"] diff --git a/src/demo_app/__init__.py b/src/demo_app/__init__.py index 8c15ebd..cb13918 100644 --- a/src/demo_app/__init__.py +++ b/src/demo_app/__init__.py @@ -33,7 +33,7 @@ def _register_blueprints(app, version: int): from demo_app.handlers.error_handlers import bp_error_handler from demo_app.handlers.request_handlers import bp_request_handler - bp_api = Blueprint(f"demo_app", __name__, url_prefix=f"/v{version}") + bp_api = Blueprint("demo_app", __name__, url_prefix=f"/v{version}") bp_api.register_blueprint(bp_auth, name=bp_auth.name) bp_api.register_blueprint(bp_user, name=bp_user.name) diff --git a/src/demo_app/api/user/models.py b/src/demo_app/api/user/models.py index 5e2ceb7..cb5cd2e 100644 --- a/src/demo_app/api/user/models.py +++ b/src/demo_app/api/user/models.py @@ -1,5 +1,4 @@ from enum import Enum -from typing import Optional from pydantic import AnyUrl, BaseModel, EmailStr, Field from quart_schema.pydantic import File @@ -18,27 +17,27 @@ class UserTheme(Enum): class UserQuery(BaseModel): - id: Optional[int] = None - email: Optional[EmailStr] = None - role: Optional[UserRole] = None + id: int | None = None + email: EmailStr | None = None + role: UserRole | None = None class SocialLinks(BaseModel): - facebook: Optional[AnyUrl] = None - instagram: Optional[AnyUrl] = None - linkedin: Optional[AnyUrl] = None - github: Optional[AnyUrl] = None + facebook: AnyUrl | None = None + instagram: AnyUrl | None = None + linkedin: AnyUrl | None = None + github: AnyUrl | None = None class Preferences(BaseModel): - theme: Optional[UserTheme] = UserTheme.LIGHT_MODE.value - language: Optional[str] = None - font_size: Optional[int] = Field(None, ge=8, le=40, multiple_of=2) + theme: UserTheme | None = UserTheme.LIGHT_MODE.value + language: str | None = None + font_size: int | None = Field(None, ge=8, le=40, multiple_of=2) class Metadata(BaseModel): - preferences: Optional[Preferences] = None - social_links: Optional[SocialLinks] = None + preferences: Preferences | None = None + social_links: SocialLinks | None = None class UserRequest(BaseModel): @@ -46,7 +45,7 @@ class UserRequest(BaseModel): last_name: str = Field(..., min_length=1, max_length=255) email: EmailStr role: UserRole - metadata: Optional[Metadata] = Field(default_factory=dict) + metadata: Metadata | None = Field(default_factory=dict) class User(UserRequest): @@ -55,4 +54,4 @@ class User(UserRequest): class UserImage(BaseModel): file: File - description: Optional[str] = None + description: str | None = None diff --git a/src/openapi_test_client/clients/demo_app/api/request_hooks/request_wrapper.py b/src/openapi_test_client/clients/demo_app/api/request_hooks/request_wrapper.py index 4639763..ffcbb60 100644 --- a/src/openapi_test_client/clients/demo_app/api/request_hooks/request_wrapper.py +++ b/src/openapi_test_client/clients/demo_app/api/request_hooks/request_wrapper.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Callable from functools import wraps -from typing import TYPE_CHECKING, Callable, ParamSpec +from typing import TYPE_CHECKING, ParamSpec from common_libs.clients.rest_client import RestResponse @@ -15,9 +16,12 @@ def do_something_before_and_after_request(f: Callable[P, RestResponse]) -> Callable[P, RestResponse]: """This is a template of the request wrapper that will decorate an API request" - To enable this hook, add this function to the parent class's `request_wrapper` inside the base API class's pre_request_hook(): + To enable this hook, add this function to the parent class's `request_wrapper` inside the base API class's + pre_request_hook(): >>> from typing import Callable - >>> from openapi_test_client.clients.demo_app.api.request_hooks.request_wrapper import do_something_before_and_after_request + >>> from openapi_test_client.clients.demo_app.api.request_hooks.request_wrapper import ( + >>> do_something_before_and_after_request + >>> ) >>> >>> def request_wrapper(self) -> list[Callable]: >>> request_wrappers = super().request_wrapper() # noqa diff --git a/src/openapi_test_client/libraries/api/api_classes/__init__.py b/src/openapi_test_client/libraries/api/api_classes/__init__.py index 5dc40dc..7a8f341 100644 --- a/src/openapi_test_client/libraries/api/api_classes/__init__.py +++ b/src/openapi_test_client/libraries/api/api_classes/__init__.py @@ -33,7 +33,7 @@ def init_api_classes(base_api_class: type[APIClassType]) -> list[type[APIClassTy Endpoint(tag='Auth', api_class=, method='get', path='/v1/auth/logout', func_name='logout', model=) ] - """ + """ # noqa: E501 from openapi_test_client.libraries.api import Endpoint previous_frame = inspect.currentframe().f_back diff --git a/src/openapi_test_client/libraries/api/api_classes/base.py b/src/openapi_test_client/libraries/api/api_classes/base.py index 6db1522..f5e1dc1 100644 --- a/src/openapi_test_client/libraries/api/api_classes/base.py +++ b/src/openapi_test_client/libraries/api/api_classes/base.py @@ -1,7 +1,8 @@ from __future__ import annotations from abc import ABCMeta, abstractmethod -from typing import TYPE_CHECKING, Callable, Optional +from collections.abc import Callable +from typing import TYPE_CHECKING from common_libs.clients.rest_client import RestResponse from common_libs.logging import get_logger @@ -18,10 +19,10 @@ class APIBase(metaclass=ABCMeta): """Base API class""" - app_name: Optional[str] = None + app_name: str | None = None is_documented: bool = True is_deprecated: bool = False - endpoints: Optional[list[Endpoint]] = None + endpoints: list[Endpoint] | None = None def __init__(self, api_client: APIClientType): if self.app_name != api_client.app_name: @@ -51,8 +52,8 @@ def pre_request_hook(self, endpoint: Endpoint, *path_params, **params): def post_request_hook( self, endpoint: Endpoint, - response: Optional[RestResponse], - request_exception: Optional[RequestException], + response: RestResponse | None, + request_exception: RequestException | None, *path_params, **params, ): diff --git a/src/openapi_test_client/libraries/api/api_client_generator.py b/src/openapi_test_client/libraries/api/api_client_generator.py index 5eb3864..b0c29f3 100644 --- a/src/openapi_test_client/libraries/api/api_client_generator.py +++ b/src/openapi_test_client/libraries/api/api_client_generator.py @@ -232,7 +232,8 @@ def update_endpoint_functions( ) -> bool | tuple[str, Exception]: '''Update endpoint functions (signature and docstring) and API TAGs based on the definition of the latest API spec - When no exception is thrown during the process, a boolean flag to indicate whether update is required or not is returned. + When no exception is thrown during the process, a boolean flag to indicate whether update is required or not is + returned. If an exception is thrown, API class name and the exception will be returned. :param api_class: API class @@ -258,7 +259,9 @@ def update_endpoint_functions( >>> TAGs = ("Some Tag",) >>> >>> @endpoint.get("/v1/something/{uuid}") - >>> def do_something(self, uuid: str, /, *, param1: str = None, param2: int = None, **kwargs) -> RestResponse: + >>> def do_something( + >>> self, uuid: str, /, *, param1: str = None, param2: int = None, **kwargs + >>> ) -> RestResponse: >>> """Do something""" >>> ... >>> @@ -369,7 +372,8 @@ def update_existing_endpoints(target_api_class: type[APIClassType] = api_class): # Collect all param models for this endpoint param_models.extend(param_model_util.get_param_models(endpoint_model)) - # Fill missing imports (typing and custom param model classes). Duplicates will be removed by black at the end + # Fill missing imports (typing and custom param model classes). Duplicates will be removed by black at + # the end if missing_imports_code := param_model_util.generate_imports_code_from_model(api_class, endpoint_model): new_code = missing_imports_code + new_code @@ -424,9 +428,9 @@ def update_existing_endpoints(target_api_class: type[APIClassType] = api_class): if tags_in_class: defined_tags = re.findall(regex_tag, tags_in_class.group(0)) if defined_tags or (not defined_tags and tags_in_class): - # Update TAGs only when none of defined tags match with documented tags. Note that when multiple tags are - # documented, the updated tags may not what you exactly want. If that is the case you'll need to remove - # tags that is not needed for this API class + # Update TAGs only when none of defined tags match with documented tags. Note that when multiple tags + # are documented, the updated tags may not what you exactly want. If that is the case you'll need to + # remove tags that is not needed for this API class if not set(defined_tags).intersection(api_spec_tags): new_code = re.sub(regex_tags, f"TAGs = {tuple(api_spec_tags)}", new_code) else: diff --git a/src/openapi_test_client/libraries/api/api_functions/decorators.py b/src/openapi_test_client/libraries/api/api_functions/decorators.py index 4341814..5824f17 100644 --- a/src/openapi_test_client/libraries/api/api_functions/decorators.py +++ b/src/openapi_test_client/libraries/api/api_functions/decorators.py @@ -1,5 +1,6 @@ +from collections.abc import Callable from functools import wraps -from typing import Callable, ParamSpec +from typing import ParamSpec from common_libs.clients.rest_client import RestResponse from common_libs.logging import get_logger diff --git a/src/openapi_test_client/libraries/api/api_functions/endpoints.py b/src/openapi_test_client/libraries/api/api_functions/endpoints.py index 65a5646..215723a 100644 --- a/src/openapi_test_client/libraries/api/api_functions/endpoints.py +++ b/src/openapi_test_client/libraries/api/api_functions/endpoints.py @@ -1,10 +1,11 @@ from __future__ import annotations +from collections.abc import Callable, Sequence from copy import deepcopy from dataclasses import dataclass from functools import partial, update_wrapper, wraps from threading import RLock -from typing import TYPE_CHECKING, Any, Callable, Optional, ParamSpec, Sequence, TypeVar, cast +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast from common_libs.ansi_colors import ColorCodes, color from common_libs.clients.rest_client import RestResponse @@ -49,8 +50,8 @@ class Endpoint: path: str func_name: str model: type[EndpointModel] - url: Optional[str] = None # Available only for an endpoint object accessed via an API client instance - content_type: Optional[str] = None + url: str | None = None # Available only for an endpoint object accessed via an API client instance + content_type: str | None = None is_public: bool = False is_documented: bool = True is_deprecated: bool = False @@ -125,7 +126,7 @@ class endpoint: >>> client.AUTH.login.endpoint.url 'http://127.0.0.1:5000/v1/auth/login' - """ + """ # noqa: E501 @staticmethod def get(path: str, **requests_lib_options) -> Callable[P, OriginalFunc | EndpointFunc]: @@ -409,7 +410,7 @@ def __init__( self.is_deprecated = False self.__decorators = [] - def __get__(self, instance: Optional[APIClassType], owner: type[APIClassType]) -> EndpointFunc: + def __get__(self, instance: APIClassType | None, owner: type[APIClassType]) -> EndpointFunc: """Return an EndpointFunc object""" key = (self.original_func.__name__, instance, owner) with EndpointHandler._lock: @@ -442,14 +443,14 @@ class EndpointFunc: All parameters passed to the original API class function call will be passed through to the __call__() """ - def __init__(self, endpoint_handler: EndpointHandler, instance: Optional[APIClassType], owner: type[APIClassType]): + def __init__(self, endpoint_handler: EndpointHandler, instance: APIClassType | None, owner: type[APIClassType]): """Initialize endpoint function""" if not issubclass(owner, APIBase): raise NotImplementedError(f"Unsupported API class: {owner}") self.method = endpoint_handler.method self.path = endpoint_handler.path - self.rest_client: Optional[RestClient] + self.rest_client: RestClient | None if instance: self.api_client = instance.api_client self.rest_client = self.api_client.rest_client @@ -643,7 +644,7 @@ def with_retry( f = retry_on(condition, num_retry=num_retry, retry_after=retry_after, safe_methods_only=False)(self) return f(*args, **kwargs) - def get_usage(self) -> Optional[str]: + def get_usage(self) -> str | None: """Get OpenAPI spec definition for the endpoint""" if self.api_client and self.endpoint.is_documented: return self.api_client.api_spec.get_endpoint_usage(self.endpoint) diff --git a/src/openapi_test_client/libraries/api/api_functions/utils/endpoint_function.py b/src/openapi_test_client/libraries/api/api_functions/utils/endpoint_function.py index 6281ff1..7d771ee 100644 --- a/src/openapi_test_client/libraries/api/api_functions/utils/endpoint_function.py +++ b/src/openapi_test_client/libraries/api/api_functions/utils/endpoint_function.py @@ -3,7 +3,7 @@ import json import re from collections import OrderedDict -from typing import TYPE_CHECKING, Annotated, Any, Optional, get_args, get_origin +from typing import TYPE_CHECKING, Annotated, Any, get_args, get_origin from common_libs.clients.rest_client.utils import get_supported_request_parameters from common_libs.logging import get_logger @@ -39,7 +39,8 @@ def check_params(endpoint: Endpoint, params: dict[str, Any]): if unexpected_params: msg = ( f"The request contains one or more parameters " - f"{endpoint.api_class.__name__}.{endpoint.func_name}() does not expect:\n{list_items(unexpected_params)}" + f"{endpoint.api_class.__name__}.{endpoint.func_name}() does not expect:\n" + f"{list_items(unexpected_params)}" ) logger.warning(msg) @@ -224,7 +225,7 @@ def generate_rest_func_params( # We will set the Content-type value using from the OpenAPI specs for this case, unless the header is explicitly # set by a user. Otherwise, requests lib will automatically handle this part if (data := rest_func_params.get("data")) and ( - isinstance(data, (str, bytes)) and not specified_content_type_header and endpoint.content_type + isinstance(data, str | bytes) and not specified_content_type_header and endpoint.content_type ): rest_func_params.setdefault("headers", {}).update({"Content-Type": endpoint.content_type}) @@ -233,7 +234,7 @@ def generate_rest_func_params( def _get_specified_content_type_header( requests_lib_options: dict[str, Any], session_headers: dict[str, str] -) -> Optional[str]: +) -> str | None: """Get Content-Type header value set for the request or for the current session""" request_headers = requests_lib_options.get("headers", {}) content_type_header = ( diff --git a/src/openapi_test_client/libraries/api/api_functions/utils/endpoint_model.py b/src/openapi_test_client/libraries/api/api_functions/utils/endpoint_model.py index 79c5a01..3ce65ab 100644 --- a/src/openapi_test_client/libraries/api/api_functions/utils/endpoint_model.py +++ b/src/openapi_test_client/libraries/api/api_functions/utils/endpoint_model.py @@ -5,7 +5,7 @@ import re from copy import deepcopy from dataclasses import MISSING, Field, field, make_dataclass -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from common_libs.logging import get_logger @@ -131,7 +131,7 @@ def _parse_parameter_objects( method: str, parameter_objects: list[dict[str, Any]], path_param_fields: list[tuple[str, Any]], - body_or_query_param_fields: list[tuple[str, Any, Optional[Field]]], + body_or_query_param_fields: list[tuple[str, Any, Field | None]], ): """Parse parameter objects @@ -205,8 +205,8 @@ def _parse_parameter_objects( def _parse_request_body_object( - request_body_obj: dict[str, Any], body_or_query_param_fields: list[tuple[str, Any, Optional[Field]]] -) -> Optional[str]: + request_body_obj: dict[str, Any], body_or_query_param_fields: list[tuple[str, Any, Field | None]] +) -> str | None: """Parse request body object https://swagger.io/specification/#request-body-object @@ -249,7 +249,7 @@ def parse_schema_obj(obj: dict[str, Any]): if _is_file_param(content_type, param_def): param_type = File if not param_def.is_required: - param_type = Optional[param_type] + param_type = param_type | None body_or_query_param_fields.append((param_name, param_type, field(default=None))) else: existing_param_names = [x[0] for x in body_or_query_param_fields] diff --git a/src/openapi_test_client/libraries/api/api_functions/utils/param_model.py b/src/openapi_test_client/libraries/api/api_functions/utils/param_model.py index fba8eb6..c61d53c 100644 --- a/src/openapi_test_client/libraries/api/api_functions/utils/param_model.py +++ b/src/openapi_test_client/libraries/api/api_functions/utils/param_model.py @@ -48,7 +48,7 @@ def _is_param_model(obj: Any) -> bool: return _is_param_model(inner_type) -def get_param_model(annotated_type: Any) -> Optional[ParamModel | list[ParamModel]]: +def get_param_model(annotated_type: Any) -> ParamModel | list[ParamModel] | None: """Returns a param model from the annotated type, if there is any :param annotated_type: Annotated type @@ -91,7 +91,7 @@ def get_reserved_model_names() -> list[str]: custom_param_annotation_names = [ x.__name__ for x in mod.__dict__.values() - if inspect.isclass(x) and issubclass(x, (ParamAnnotationType, DataclassModel)) + if inspect.isclass(x) and issubclass(x, ParamAnnotationType | DataclassModel) ] typing_class_names = [x.__name__ for x in [Any, Optional, Annotated, Literal, Union]] return custom_param_annotation_names + typing_class_names @@ -106,7 +106,7 @@ def create_model_from_param_def( :param model_name: The model name :param param_def: ParamDef generated from an OpenAPI parameter object """ - if not isinstance(param_def, (ParamDef, ParamDef.ParamGroup, ParamDef.UnknownType)): + if not isinstance(param_def, ParamDef | ParamDef.ParamGroup | ParamDef.UnknownType): raise ValueError(f"Invalid param_def type: {type(param_def)}") if isinstance(param_def, ParamDef) and param_def.is_array and "items" in param_def: @@ -273,7 +273,7 @@ def visit(model_name: str): return sorted(models, key=lambda x: sorted_models_names.index(x.__name__)) -def alias_illegal_model_field_names(param_fields: list[tuple[str, Any] | tuple[str, Any, Optional[Field]]]): +def alias_illegal_model_field_names(param_fields: list[tuple[str, Any] | tuple[str, Any, Field | None]]): """Clean illegal model field name and annotate the field type with Alias class :param param_fields: fields value to be passed to make_dataclass() diff --git a/src/openapi_test_client/libraries/api/api_functions/utils/param_type.py b/src/openapi_test_client/libraries/api/api_functions/utils/param_type.py index 5bdd13c..d8f1b1f 100644 --- a/src/openapi_test_client/libraries/api/api_functions/utils/param_type.py +++ b/src/openapi_test_client/libraries/api/api_functions/utils/param_type.py @@ -1,10 +1,11 @@ import inspect +from collections.abc import Sequence from dataclasses import asdict from functools import reduce from operator import or_ from types import NoneType, UnionType from typing import _AnnotatedAlias # noqa -from typing import Annotated, Any, Literal, Optional, Sequence, Union, get_args, get_origin +from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origin from common_libs.logging import get_logger @@ -41,7 +42,8 @@ def get_type_annotation_as_str(tp: Any) -> str: return f"{Optional.__name__}[{get_type_annotation_as_str(inner_types[0])}]" else: inner_types_union = " | ".join(get_type_annotation_as_str(x) for x in inner_types) - # Note: This is actually Union[tp1, ..., None] in Python, but we annotate this as Optional[tp1 | ...] in code + # Note: This is actually Union[tp1, ..., None] in Python, but we annotate this as + # Optional[tp1 | ...] in code return f"{Optional.__name__}[{inner_types_union}]" else: return " | ".join(get_type_annotation_as_str(x) for x in args) @@ -50,7 +52,7 @@ def get_type_annotation_as_str(tp: Any) -> str: return f"{tp.__origin__.__name__}[{inner_types}]" elif get_origin(tp) is Literal: return repr(tp).replace("typing.", "") - elif isinstance(tp, (Alias, Format)): + elif isinstance(tp, Alias | Format): return f"{type(tp).__name__}({repr(tp.value)})" elif isinstance(tp, Constraint): const = ", ".join( @@ -71,7 +73,7 @@ def get_type_annotation_as_str(tp: Any) -> str: def resolve_type_annotation( param_name: str, param_def: ParamDef | ParamDef.ParamGroup | ParamDef.UnknownType, - _is_required: Optional[bool] = None, + _is_required: bool | None = None, _is_array: bool = False, ) -> Any: """Resolve type annotation for the given parameter definition @@ -85,7 +87,8 @@ def resolve_type_annotation( def resolve(param_type: str, param_format: str = None): """Resolve type annotation - NOTE: Some OpenAPI spec use a wrong param type value (eg. string v.s. str). We handle these scenarios accordingly + NOTE: Some OpenAPI spec use a wrong param type value (eg. string v.s. str). + We handle these scenarios accordingly """ if param_type in STR_PARAM_TYPES: if param_format: @@ -134,7 +137,7 @@ def resolve(param_type: str, param_format: str = None): else: raise NotImplementedError(f"Unsupported type: {param_type}") - if not isinstance(param_def, (ParamDef, ParamDef.ParamGroup, ParamDef.UnknownType)): + if not isinstance(param_def, ParamDef | ParamDef.ParamGroup | ParamDef.UnknownType): # for inner obj param_def = ParamDef.from_param_obj(param_def) @@ -173,7 +176,8 @@ def resolve(param_type: str, param_format: str = None): type_annotation = generate_optional_type(type_annotation) if num_optional_types := repr(type_annotation).count("Optional"): - # Sanity check for Optional type. If it is annotated with `Optional`, we want it to appear as the origin type only. If this check fails, it means the logic is broke somewhere + # Sanity check for Optional type. If it is annotated with `Optional`, we want it to appear as the origin type + # only. If this check fails, it means the logic is broke somewhere if num_optional_types > 1: raise RuntimeError(f"{Optional} should not appear more than once: {type_annotation}") if type_annotation.__name__ != Optional.__name__: @@ -236,7 +240,7 @@ def replace_inner_type(tp: Any, new_type: Any, replace_container_type: bool = Fa args = get_args(tp) if is_union_type(tp): if is_optional_type(tp): - return Optional[replace_inner_type(args[0], new_type)] + return Optional[replace_inner_type(args[0], new_type)] # noqa: UP007 else: return replace_inner_type(args, new_type) elif origin_type is Annotated: @@ -355,7 +359,7 @@ def generate_optional_type(tp: Any) -> Any: if is_optional_type(tp): return tp else: - return Union[tp, None] + return Union[tp, None] # noqa: UP007 def generate_annotated_type(tp: Any, metadata: Any): @@ -365,12 +369,12 @@ def generate_annotated_type(tp: Any, metadata: Any): """ if is_optional_type(tp): inner_type = get_args(tp)[0] - return Optional[Annotated[inner_type, metadata]] + return Optional[Annotated[inner_type, metadata]] # noqa: UP007 else: return Annotated[tp, metadata] -def get_annotated_type(tp: Any) -> Optional[_AnnotatedAlias]: +def get_annotated_type(tp: Any) -> _AnnotatedAlias | None: """Get annotated type definition :param tp: Type annotation diff --git a/src/openapi_test_client/libraries/api/api_functions/utils/pydantic_model.py b/src/openapi_test_client/libraries/api/api_functions/utils/pydantic_model.py index fb3df6c..ab499b7 100644 --- a/src/openapi_test_client/libraries/api/api_functions/utils/pydantic_model.py +++ b/src/openapi_test_client/libraries/api/api_functions/utils/pydantic_model.py @@ -4,7 +4,7 @@ from datetime import date, datetime, time, timedelta from pathlib import Path from types import EllipsisType -from typing import Any, Optional, TypeVar, get_origin +from typing import Any, TypeVar, get_origin from uuid import UUID from pydantic import ( @@ -72,7 +72,7 @@ def in_validation_mode(): def generate_pydantic_model_fields( original_model: type[DataclassModel | EndpointModel | ParamModel], field_type: Any -) -> tuple[str, Optional[EllipsisType | FieldInfo]]: +) -> tuple[str, EllipsisType | FieldInfo | None]: """Generate Pydantic field definition for validation mode :param original_model: The original model @@ -134,22 +134,22 @@ def generate_pydantic_model_fields( if default_value is not None and constraint.nullable: # Required and nullable = Optional - field_type = Optional[field_type] + field_type = field_type | None - # For query parameters,each parameter may be allowed to use multiple times with different values. Our client will - # support this scenario by taking values as a list. To prevent a validation error to occur when giving a list, - # adjust the model type to also allow list. + # For query parameters,each parameter may be allowed to use multiple times with different values. Our client + # will support this scenario by taking values as a list. To prevent a validation error to occur when giving a + # list, adjust the model type to also allow list. if is_query_param or ( issubclass(original_model, EndpointModel) and original_model.endpoint_func.method.upper() == "GET" ): inner_type = param_type_util.get_inner_type(field_type) - if not get_origin(inner_type) is list: + if get_origin(inner_type) is not list: field_type = param_type_util.replace_inner_type(field_type, inner_type | list[inner_type]) return (field_type, field_value) -def filter_annotated_metadata(annotated_type: Any, target_class: type[T]) -> Optional[T]: +def filter_annotated_metadata(annotated_type: Any, target_class: type[T]) -> T | None: """Get a metadata for the target class from annotated type :param annotated_type: Type annotation with Annotated[] diff --git a/src/openapi_test_client/libraries/api/api_spec.py b/src/openapi_test_client/libraries/api/api_spec.py index 3fdd74a..137e449 100644 --- a/src/openapi_test_client/libraries/api/api_spec.py +++ b/src/openapi_test_client/libraries/api/api_spec.py @@ -4,7 +4,7 @@ import json import re from functools import lru_cache, reduce -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import requests import yaml @@ -29,7 +29,7 @@ def __init__(self, api_client: APIClientType, doc_path: str): self._spec = None @lru_cache - def get_api_spec(self, url: str = None) -> Optional[dict[str, Any]]: + def get_api_spec(self, url: str = None) -> dict[str, Any] | None: """Return OpenAPI spec""" if self._spec is None: if url: @@ -64,7 +64,7 @@ def get_api_spec(self, url: str = None) -> Optional[dict[str, Any]]: else: logger.warning("API spec is not available") - def get_endpoint_usage(self, endpoint: Endpoint) -> Optional[str]: + def get_endpoint_usage(self, endpoint: Endpoint) -> str | None: """Return usage of the endpoint :param endpoint: Endpoint object diff --git a/src/openapi_test_client/libraries/api/multipart_form_data.py b/src/openapi_test_client/libraries/api/multipart_form_data.py index 24d22c3..57fb78a 100644 --- a/src/openapi_test_client/libraries/api/multipart_form_data.py +++ b/src/openapi_test_client/libraries/api/multipart_form_data.py @@ -12,7 +12,7 @@ class MultipartFormData(MutableMapping): {'logo': ('logo.png', b'content', 'image/png'), 'favicon': ('fabicon.png', b'content', 'image/png')} NOTE: The File obj can be a dictionary instead - """ + """ # noqa: E501 def __init__(self, **files: File | dict[str, str | bytes | Any]): self._files = dict( diff --git a/src/openapi_test_client/libraries/api/types.py b/src/openapi_test_client/libraries/api/types.py index a95c5ba..746f898 100644 --- a/src/openapi_test_client/libraries/api/types.py +++ b/src/openapi_test_client/libraries/api/types.py @@ -1,10 +1,11 @@ from __future__ import annotations import json +from collections.abc import Callable, Mapping, Sequence from dataclasses import _DataclassParams # noqa from dataclasses import MISSING, Field, asdict, astuple, dataclass, field, is_dataclass, make_dataclass from functools import lru_cache -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Mapping, Optional, Sequence, cast +from typing import TYPE_CHECKING, Any, ClassVar, cast from common_libs.decorators import freeze_args from common_libs.hash import HashableDict @@ -52,7 +53,7 @@ def type(self) -> str: return self["type"] @property - def format(self) -> Optional[str]: + def format(self) -> str | None: return self.get("format") @property @@ -104,7 +105,7 @@ def from_param_obj( """Convert the parameter object to a ParamDef""" def convert(obj: Any): - if isinstance(obj, (ParamDef, ParamDef.ParamGroup)): + if isinstance(obj, ParamDef | ParamDef.ParamGroup): return obj else: if "oneOf" in obj: @@ -192,7 +193,7 @@ def to_pydantic(cls) -> type[PydanticModel]: class EndpointModel(DataclassModel): - content_type: Optional[str] + content_type: str | None endpoint_func: EndpointFunc @@ -364,7 +365,7 @@ def setdefault(self, key: str, default: Any = None) -> Any: @classmethod def recreate( - cls, current_class: type[ParamModel], new_fields: list[tuple[str, Any, Optional[field]]] + cls, current_class: type[ParamModel], new_fields: list[tuple[str, Any, field | None]] ) -> type[ParamModel]: """Recreate the model with the new fields diff --git a/src/openapi_test_client/libraries/common/json_encoder.py b/src/openapi_test_client/libraries/common/json_encoder.py index f46eeea..cd788c3 100644 --- a/src/openapi_test_client/libraries/common/json_encoder.py +++ b/src/openapi_test_client/libraries/common/json_encoder.py @@ -9,7 +9,7 @@ class CustomJsonEncoder(json.JSONEncoder): def default(self, obj): - if isinstance(obj, (UUID, Decimal)): + if isinstance(obj, UUID | Decimal): return str(obj) elif isinstance(obj, datetime): return obj.isoformat() diff --git a/tests/integration/test_script.py b/tests/integration/test_script.py index ae43b77..20e9fea 100644 --- a/tests/integration/test_script.py +++ b/tests/integration/test_script.py @@ -112,7 +112,7 @@ def test_update_client(temp_app_client: OpenAPIClient, dry_run: bool, option: st ( f"{TAB * 2}{create_user_func_docstring}\n" f"{TAB * 2}# fake custom func logic\n" - f"{TAB * 2}params = dict(first_name=first_name, last_name=last_name, email=email, role=role, metadata=metadata)\n" + f"{TAB * 2}params = dict(first_name=first_name, last_name=last_name, email=email, role=role, metadata=metadata)\n" # noqa: E501 f"{TAB * 2}return self.{UsersAPI.create_user.__name__}(**params)" ), )