Skip to content

Commit

Permalink
system of auto filters for sqlalchemy
Browse files Browse the repository at this point in the history
  • Loading branch information
mrtedn21 committed Oct 4, 2023
1 parent d77ab87 commit 0131ac7
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 11 deletions.
30 changes: 26 additions & 4 deletions database.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,34 @@ class Base(AsyncAttrs, DeclarativeBase):
}

reverse_types_map = {
Str: str,
Int: int,
Date: date,
DateTime: datetime,
value: key for key, value in types_map.items()
}


registered_models = {}


def register_model(model_class):
model_name = model_class.__name__
# remove "Orm" postfix from model name
model_name = model_name[:-3]
model_name = model_name.lower()
registered_models[model_name] = model_class


def query_params_to_alchemy_filters(query_param, value):
"""Example of query_param:
user__first_name__like=martin"""
model_name, field_name, method_name = query_param.split('__')
model_class = registered_models[model_name]
field_obj = getattr(model_class, field_name)
method_obj = getattr(field_obj, method_name)
if method_name == 'like':
return method_obj(f'%{value}%')
else:
return method_obj(value)


class SqlAlchemyToMarshmallow(type(Base)):
"""Metaclass that get sql alchemy model fields, creates marshmallow
schemas based on them and moreover, metaclass extends schemas of
Expand All @@ -63,6 +84,7 @@ class NewModel(SomeSqlAlchemyModel, metaclass=SqlAlchemyToPydantic):

def __new__(cls, name, bases, fields):
origin_model = bases[0]
register_model(origin_model)

# alchemy_fields variable needs to exclude these
# properties from origin_model_field_names
Expand Down
24 changes: 19 additions & 5 deletions http_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,30 @@ def get_method_name(self) -> str:
first_word = first_line.split(' ')[0]
return first_word

def get_path(self):
# Path is second word of first line
def _get_path_and_query_params(self):
first_line = self.lines_of_header[0]
second_word = first_line.split(' ')[1]
return second_word
return second_word.split('?')

def get_path(self):
url_parts = self._get_path_and_query_params()
return url_parts[0]

def get_query_params(self):
url_parts = self._get_path_and_query_params()
if len(url_parts) == 1:
return

query_params = {}
for query_param in url_parts[1].split('&'):
key, value = query_param.split('=')
query_params[key] = value
return query_params

def get_body(self):
position_of_body_starts = (
self.message.find(self.line_break_char * 2) +
len(self.line_break_char * 2)
self.message.find(self.line_break_char * 2) +
len(self.line_break_char * 2)
)
return self.message[position_of_body_starts:]

Expand Down
13 changes: 11 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
from inspect import signature
from dacite import from_dict
import dataclasses
from operator import itemgetter
import json
import socket
from asyncio import AbstractEventLoop
from database import query_params_to_alchemy_filters

from sqlalchemy import select
from sqlalchemy.orm import aliased
Expand Down Expand Up @@ -112,14 +114,15 @@ class Message(MessageSchema, metaclass=MarshmallowToDataclass):
'/users/', ('get', ),
response=user_list_get_schema,
)
async def get_users() -> str:
async def get_users(query_params) -> str:
async with db.create_session() as session:
sql_query = (
select(
UserOrm, CityOrm, CountryOrm, LanguageOrm, GenderOrm
).select_from(UserOrm)
.outerjoin(CityOrm).outerjoin(CountryOrm)
.outerjoin(LanguageOrm).outerjoin(GenderOrm)
.filter(query_params)
)
result = await session.execute(sql_query)
return user_list_get_schema.dumps(map(itemgetter(0), result.fetchall()))
Expand Down Expand Up @@ -230,6 +233,7 @@ async def handle_request(
parser = HttpHeadersParser(message)
path = parser.get_path()
method = parser.get_method_name()
query_params = parser.get_query_params()

if method == 'OPTIONS':
headers = create_response_headers(200, 'application/json')
Expand Down Expand Up @@ -265,7 +269,12 @@ async def handle_request(
response = json.dumps(response)
break
else:
response: str = await controller()
if 'query_params' in list(dict(signature(controller).parameters).keys()):
key, value = list(query_params.items())[0]
response: str = await controller(query_params=query_params_to_alchemy_filters(key, value))
else:
response: str = await controller()

if isinstance(response, list):
response = json.dumps(response)

Expand Down

0 comments on commit 0131ac7

Please sign in to comment.