diff --git a/shopinvader_api_cart/routers/cart.py b/shopinvader_api_cart/routers/cart.py index 04bf8b242c..5598a5bbc2 100644 --- a/shopinvader_api_cart/routers/cart.py +++ b/shopinvader_api_cart/routers/cart.py @@ -2,7 +2,7 @@ # Copyright 2024 Camptocamp (http://www.camptocamp.com). # @author Simone Orsi # License AGPL-3.0 or later (http://www.gnu.org/licenses/agpl). -from collections import OrderedDict +from collections import defaultdict, namedtuple from typing import Annotated from uuid import UUID @@ -95,6 +95,30 @@ class ShopinvaderApiCartRouterHelper(models.AbstractModel): _name = "shopinvader_api_cart.cart_router.helper" _description = "ShopInvader API Cart Router Helper" + def _get_transaction_key(self, transaction: CartTransaction): + """ + Return a namedtuple of values that identify a transaction and match it + against a cart line. + + To override this method create a new namedtuple combining the super fields + and the new ones and return it with the values that identify the transaction. + + Example: + ```python + + def _get_transaction_key(self, transaction: CartTransaction): + key = super()._get_transaction_key(transaction) + return namedtuple( + key.__class__.__name__, key._fields + ('my_field',) + )( + *key, + my_field=transaction.my_field, + ) + ``` + + """ + return namedtuple("TransactionKey", ["product_id"])(transaction.product_id) + @api.model def _check_transactions(self, transactions: list[CartTransaction]): """Check if the transactions info are valid. @@ -109,21 +133,15 @@ def _check_transactions(self, transactions: list[CartTransaction]): ) @api.model - def _group_transactions_by_product_id(self, transactions: list[CartTransaction]): + def _group_transactions(self, transactions: list[CartTransaction]): """ - Gather together transactions that are linked to the same product. + Gather together transactions that are linked to the same transaction key. """ - # take an ordered dict to ensure to create lines into the same - # order as the transactions list - transactions_by_product_id = OrderedDict() - for trans in transactions: - product_id = trans.product_id - transactions = transactions_by_product_id.get(product_id) - if not transactions: - transactions = [] - transactions_by_product_id[product_id] = transactions - transactions.append(trans) - return transactions_by_product_id + grouped_transactions = defaultdict(list) + for transaction in transactions: + key = self._get_transaction_key(transaction) + grouped_transactions[key].append(transaction) + return grouped_transactions @api.model def _apply_transactions_on_existing_cart_line( @@ -231,17 +249,15 @@ def _apply_transactions(self, cart, transactions: list[CartTransaction]): return cart.ensure_one() self._check_transactions(transactions=transactions) - transactions_by_product_id = self._group_transactions_by_product_id( - transactions=transactions - ) - update_cmds = [] # prefetch all products - self.env["product.product"].browse(transactions_by_product_id.keys()) + self.env["product.product"].browse({tx.product_id for tx in transactions}) + grouped_transactions = self._group_transactions(transactions=transactions) + update_cmds = [] # here we avoid that each on change on a line trigger all the # recompute methods on the SO. These methods will be triggered # by the orm into the 'write' process - for product_id, trxs in transactions_by_product_id.items(): - line = cart._get_cart_line(product_id) + for key, trxs in grouped_transactions.items(): + line = cart._get_cart_line(**key._asdict()) if line: cmd = self._apply_transactions_on_existing_cart_line(line, trxs) else: diff --git a/shopinvader_sale_cart/models/sale_order.py b/shopinvader_sale_cart/models/sale_order.py index 69208b4067..49b298a830 100644 --- a/shopinvader_sale_cart/models/sale_order.py +++ b/shopinvader_sale_cart/models/sale_order.py @@ -6,7 +6,6 @@ class SaleOrder(models.Model): - _inherit = "sale.order" uuid = fields.Char(string="EShop Unique identifier", readonly=True) @@ -75,31 +74,26 @@ def _create_empty_cart(self, partner_id): vals = self._prepare_cart(partner_id) return self.create(vals) - def _get_cart_line(self, product_id): + def _get_cart_line(self, **kwargs): """ - Return the sale order line of the cart associated to the given product. + Return the sale order line of the cart associated to the given fields. """ self.ensure_one() - return self.order_line.filtered( - lambda l, product_id=product_id: l.product_id.id == product_id - )[:1] + return self.order_line.filtered(lambda sol: sol._match_cart_line(**kwargs))[:1] def _update_cart_lines_from_cart(self, cart): self.ensure_one() update_cmds = [] for cart_line in cart.order_line: - line = self._get_cart_line(cart_line.product_id.id) + line = self._get_cart_line(**cart_line.read(load=False)[0]) if line: new_qty = line.product_uom_qty + cart_line.product_uom_qty vals = {"product_uom_qty": new_qty} vals.update(line._play_onchanges_cart_line(vals)) cmd = (1, line.id, vals) else: - vals = { - "order_id": self.id, - "product_id": cart_line.product_id.id, - "product_uom_qty": cart_line.product_uom_qty, - } + vals = cart_line._prepare_cart_line_transfer_values() + vals["order_id"] = self.id vals.update(self.env["sale.order.line"]._play_onchanges_cart_line(vals)) cmd = (0, None, vals) update_cmds.append(cmd) diff --git a/shopinvader_sale_cart/models/sale_order_line.py b/shopinvader_sale_cart/models/sale_order_line.py index 638fc34411..ba3e7daec8 100644 --- a/shopinvader_sale_cart/models/sale_order_line.py +++ b/shopinvader_sale_cart/models/sale_order_line.py @@ -5,9 +5,24 @@ class SaleOrderLine(models.Model): - _inherit = "sale.order.line" @api.model def _play_onchanges_cart_line(self, vals): return self.sudo().play_onchanges(vals, vals.keys()) + + def _prepare_cart_line_transfer_values(self): + """ + Prepare the values to create a new cart line from a given cart line, + in case of a cart transfer for example. + """ + return { + "product_id": self.product_id.id, + "product_uom_qty": self.product_uom_qty, + } + + def _match_cart_line(self, product_id, **kwargs): + """ + Return True if the given sale order line matches the given fields. + """ + return self.product_id.id == product_id