Skip to content

Commit

Permalink
get session from object if not exist
Browse files Browse the repository at this point in the history
  • Loading branch information
uralbash committed Sep 14, 2015
1 parent 1884205 commit 12ed878
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions sqlalchemy_mptt/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,13 @@ def _node_to_dict(cls, node, json, json_fields):

@classmethod
def _base_query(cls, session=None):
# get orm session
if not session:
session = object_session(cls)

# handle custom query
return session.query(cls)

def _base_query_obj(self, session=None):
if not session:
session = object_session(self)
return self._base_query(session)

@classmethod
def _base_order(cls, query, order=asc):
return query.order_by(order(cls.tree_id))\
Expand Down Expand Up @@ -243,8 +243,10 @@ def get_node_id(node):
nodes_of_level[get_node_id(node)] = tree[-1]
return tree

def _drilldown_query(self, nodes):
def _drilldown_query(self, nodes=None):
table = self.__class__
if not nodes:
nodes = self._base_query_obj()
return nodes.filter(table.tree_id == self.tree_id)\
.filter(table.left >= self.left)\
.filter(table.right <= self.right)
Expand Down Expand Up @@ -275,6 +277,8 @@ def drilldown_tree(self, session=None, json=False, json_fields=None):
* :mod:`sqlalchemy_mptt.tests.cases.get_tree.test_drilldown_tree`
"""
if not session:
session = object_session(self)
return self.get_tree(session, json=json, json_fields=json_fields,
query=self._drilldown_query)

Expand Down Expand Up @@ -304,7 +308,7 @@ def path_to_root(self, session=None):
-------------
"""
table = self.__class__
query = table._base_query(session)
query = self._base_query_obj(session=session)
query = query.filter(table.tree_id == self.tree_id)\
.filter(table.left <= self.left)\
.filter(table.right >= self.right)
Expand Down

0 comments on commit 12ed878

Please sign in to comment.