From 0131ac72823bf754cd71373b4d58e0104441aa8c Mon Sep 17 00:00:00 2001 From: mrtedn21 Date: Wed, 4 Oct 2023 23:47:19 +0700 Subject: [PATCH] system of auto filters for sqlalchemy --- database.py | 30 ++++++++++++++++++++++++++---- http_headers.py | 24 +++++++++++++++++++----- main.py | 13 +++++++++++-- 3 files changed, 56 insertions(+), 11 deletions(-) diff --git a/database.py b/database.py index 6a77fad..37f484e 100644 --- a/database.py +++ b/database.py @@ -36,13 +36,34 @@ class Base(AsyncAttrs, DeclarativeBase): } reverse_types_map = { - Str: str, - Int: int, - Date: date, - DateTime: datetime, + value: key for key, value in types_map.items() } +registered_models = {} + + +def register_model(model_class): + model_name = model_class.__name__ + # remove "Orm" postfix from model name + model_name = model_name[:-3] + model_name = model_name.lower() + registered_models[model_name] = model_class + + +def query_params_to_alchemy_filters(query_param, value): + """Example of query_param: + user__first_name__like=martin""" + model_name, field_name, method_name = query_param.split('__') + model_class = registered_models[model_name] + field_obj = getattr(model_class, field_name) + method_obj = getattr(field_obj, method_name) + if method_name == 'like': + return method_obj(f'%{value}%') + else: + return method_obj(value) + + class SqlAlchemyToMarshmallow(type(Base)): """Metaclass that get sql alchemy model fields, creates marshmallow schemas based on them and moreover, metaclass extends schemas of @@ -63,6 +84,7 @@ class NewModel(SomeSqlAlchemyModel, metaclass=SqlAlchemyToPydantic): def __new__(cls, name, bases, fields): origin_model = bases[0] + register_model(origin_model) # alchemy_fields variable needs to exclude these # properties from origin_model_field_names diff --git a/http_headers.py b/http_headers.py index cd875d7..9651856 100644 --- a/http_headers.py +++ b/http_headers.py @@ -24,16 +24,30 @@ def get_method_name(self) -> str: first_word = first_line.split(' ')[0] return first_word - def get_path(self): - # Path is second word of first line + def _get_path_and_query_params(self): first_line = self.lines_of_header[0] second_word = first_line.split(' ')[1] - return second_word + return second_word.split('?') + + def get_path(self): + url_parts = self._get_path_and_query_params() + return url_parts[0] + + def get_query_params(self): + url_parts = self._get_path_and_query_params() + if len(url_parts) == 1: + return + + query_params = {} + for query_param in url_parts[1].split('&'): + key, value = query_param.split('=') + query_params[key] = value + return query_params def get_body(self): position_of_body_starts = ( - self.message.find(self.line_break_char * 2) + - len(self.line_break_char * 2) + self.message.find(self.line_break_char * 2) + + len(self.line_break_char * 2) ) return self.message[position_of_body_starts:] diff --git a/main.py b/main.py index f568520..be9efd7 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,12 @@ import asyncio +from inspect import signature from dacite import from_dict import dataclasses from operator import itemgetter import json import socket from asyncio import AbstractEventLoop +from database import query_params_to_alchemy_filters from sqlalchemy import select from sqlalchemy.orm import aliased @@ -112,7 +114,7 @@ class Message(MessageSchema, metaclass=MarshmallowToDataclass): '/users/', ('get', ), response=user_list_get_schema, ) -async def get_users() -> str: +async def get_users(query_params) -> str: async with db.create_session() as session: sql_query = ( select( @@ -120,6 +122,7 @@ async def get_users() -> str: ).select_from(UserOrm) .outerjoin(CityOrm).outerjoin(CountryOrm) .outerjoin(LanguageOrm).outerjoin(GenderOrm) + .filter(query_params) ) result = await session.execute(sql_query) return user_list_get_schema.dumps(map(itemgetter(0), result.fetchall())) @@ -230,6 +233,7 @@ async def handle_request( parser = HttpHeadersParser(message) path = parser.get_path() method = parser.get_method_name() + query_params = parser.get_query_params() if method == 'OPTIONS': headers = create_response_headers(200, 'application/json') @@ -265,7 +269,12 @@ async def handle_request( response = json.dumps(response) break else: - response: str = await controller() + if 'query_params' in list(dict(signature(controller).parameters).keys()): + key, value = list(query_params.items())[0] + response: str = await controller(query_params=query_params_to_alchemy_filters(key, value)) + else: + response: str = await controller() + if isinstance(response, list): response = json.dumps(response)