diff --git a/shopinvader/models/shopinvader_variant.py b/shopinvader/models/shopinvader_variant.py index 208f1caf9e..fc08d11eb5 100644 --- a/shopinvader/models/shopinvader_variant.py +++ b/shopinvader/models/shopinvader_variant.py @@ -289,11 +289,38 @@ def _get_price( ) return res + @api.model + def _get_main_product_read_fields(cls): + product_model = cls.env["product.product"] + order_by = [x.strip() for x in product_model._order.split(",")] + return ["shopinvader_product_id", "backend_id", "lang_id"] + order_by + + @api.model + def _get_main_product_sorted_variants(cls, variants): + # NOTE: if the order is changed by adding `asc/desc` this can be broken + # but it's very unlikely that the default order for product.product + # will be changed. + order_by = [x.strip() for x in cls.env["product.product"]._order.split(",")] + + def get_value(record, key): + field_type = cls._fields[key].type + value = record[key] + if value is False and field_type in ("char", "text"): + return "" + else: + return value + + return sorted(variants, key=lambda var: [get_value(var, x) for x in order_by]) + + @api.model + def _pick_main_variant(cls, variants): + ordered = cls._get_main_product_sorted_variants(variants) + return ordered[0].get("id") if ordered else None + def _compute_main_product(self): # Respect same order. - order_by = [x.strip() for x in self.env["product.product"]._order.split(",")] backends = self.mapped("backend_id") - fields_to_read = ["shopinvader_product_id", "backend_id", "lang_id"] + order_by + fields_to_read = self._get_main_product_read_fields() product_ids = self.mapped("shopinvader_product_id").ids # Use sudo to bypass permissions (we don't care) _variants = self.sudo().search( @@ -310,23 +337,8 @@ def _compute_main_product(self): lambda x: (x["shopinvader_product_id"], x["backend_id"], x["lang_id"]), ) - def pick_1st_variant(variants): - # NOTE: if the order is changed by adding `asc/desc` this can be broken - # but it's very unlikely that the default order for product.product - # will be changed. - def get_value(record, key): - if record[key] is False and self._fields[key].type in ("char", "text"): - return "" - else: - return record[key] - - ordered = sorted( - variants, key=lambda var: [get_value(var, x) for x in order_by] - ) - return ordered[0].get("id") if ordered else None - main_by_product = { - product: pick_1st_variant(tuple(variants)) + product: self._pick_main_variant(variants) for product, variants in var_by_product } for record in self: