Skip to content

Commit

Permalink
Merge pull request #2988 from FichteForks/pr/item-album-fallback
Browse files Browse the repository at this point in the history
Add fallback for item access to album's attributes
  • Loading branch information
sampsyo authored Mar 7, 2021
2 parents 6644316 + 09a6ec4 commit 3e82613
Show file tree
Hide file tree
Showing 10 changed files with 232 additions and 29 deletions.
69 changes: 51 additions & 18 deletions beets/dbcore/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,11 @@ class FormattedMapping(Mapping):
are replaced.
"""

def __init__(self, model, for_path=False):
def __init__(self, model, for_path=False, compute_keys=True):
self.for_path = for_path
self.model = model
self.model_keys = model.keys(True)
if compute_keys:
self.model_keys = model.keys(True)

def __getitem__(self, key):
if key in self.model_keys:
Expand Down Expand Up @@ -257,6 +258,11 @@ class Model(object):
value is the same as the old value (e.g., `o.f = o.f`).
"""

_revision = -1
"""A revision number from when the model was loaded from or written
to the database.
"""

@classmethod
def _getters(cls):
"""Return a mapping from field names to getter functions.
Expand Down Expand Up @@ -309,9 +315,11 @@ def __repr__(self):

def clear_dirty(self):
"""Mark all fields as *clean* (i.e., not needing to be stored to
the database).
the database). Also update the revision.
"""
self._dirty = set()
if self._db:
self._revision = self._db.revision

def _check_db(self, need_id=True):
"""Ensure that this object is associated with a database row: it
Expand Down Expand Up @@ -351,9 +359,9 @@ def _type(cls, key):
"""
return cls._fields.get(key) or cls._types.get(key) or types.DEFAULT

def __getitem__(self, key):
"""Get the value for a field. Raise a KeyError if the field is
not available.
def _get(self, key, default=None, raise_=False):
"""Get the value for a field, or `default`. Alternatively,
raise a KeyError if the field is not available.
"""
getters = self._getters()
if key in getters: # Computed.
Expand All @@ -365,8 +373,18 @@ def __getitem__(self, key):
return self._type(key).null
elif key in self._values_flex: # Flexible.
return self._values_flex[key]
else:
elif raise_:
raise KeyError(key)
else:
return default

get = _get

def __getitem__(self, key):
"""Get the value for a field. Raise a KeyError if the field is
not available.
"""
return self._get(key, raise_=True)

def _setitem(self, key, value):
"""Assign the value for a field, return whether new and old value
Expand Down Expand Up @@ -441,19 +459,10 @@ def items(self):
for key in self:
yield key, self[key]

def get(self, key, default=None):
"""Get the value for a given key or `default` if it does not
exist.
"""
if key in self:
return self[key]
else:
return default

def __contains__(self, key):
"""Determine whether `key` is an attribute on this object.
"""
return key in self.keys(True)
return key in self.keys(computed=True)

def __iter__(self):
"""Iterate over the available field names (excluding computed
Expand Down Expand Up @@ -538,8 +547,14 @@ def store(self, fields=None):

def load(self):
"""Refresh the object's metadata from the library database.
If check_revision is true, the database is only queried loaded when a
transaction has been committed since the item was last loaded.
"""
self._check_db()
if not self._dirty and self._db.revision == self._revision:
# Exit early
return
stored_obj = self._db._get(type(self), self.id)
assert stored_obj is not None, u"object {0} not in DB".format(self.id)
self._values_fixed = LazyConvertDict(self)
Expand Down Expand Up @@ -794,6 +809,12 @@ class Transaction(object):
"""A context manager for safe, concurrent access to the database.
All SQL commands should be executed through a transaction.
"""

_mutated = False
"""A flag storing whether a mutation has been executed in the
current transaction.
"""

def __init__(self, db):
self.db = db

Expand All @@ -815,12 +836,15 @@ def __exit__(self, exc_type, exc_value, traceback):
entered but not yet exited transaction. If it is the last active
transaction, the database updates are committed.
"""
# Beware of races; currently secured by db._db_lock
self.db.revision += self._mutated
with self.db._tx_stack() as stack:
assert stack.pop() is self
empty = not stack
if empty:
# Ending a "root" transaction. End the SQLite transaction.
self.db._connection().commit()
self._mutated = False
self.db._db_lock.release()

def query(self, statement, subvals=()):
Expand All @@ -836,7 +860,6 @@ def mutate(self, statement, subvals=()):
"""
try:
cursor = self.db._connection().execute(statement, subvals)
return cursor.lastrowid
except sqlite3.OperationalError as e:
# In two specific cases, SQLite reports an error while accessing
# the underlying database file. We surface these exceptions as
Expand All @@ -846,9 +869,14 @@ def mutate(self, statement, subvals=()):
raise DBAccessError(e.args[0])
else:
raise
else:
self._mutated = True
return cursor.lastrowid

def script(self, statements):
"""Execute a string containing multiple SQL statements."""
# We don't know whether this mutates, but quite likely it does.
self._mutated = True
self.db._connection().executescript(statements)


Expand All @@ -864,6 +892,11 @@ class Database(object):
supports_extensions = hasattr(sqlite3.Connection, 'enable_load_extension')
"""Whether or not the current version of SQLite supports extensions"""

revision = 0
"""The current revision of the database. To be increased whenever
data is written in a transaction.
"""

def __init__(self, path, timeout=5.0):
self.path = path
self.timeout = timeout
Expand Down
2 changes: 1 addition & 1 deletion beets/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ def record_replaced(self, lib):
if (not dup_item.album_id or
dup_item.album_id in replaced_album_ids):
continue
replaced_album = dup_item.get_album()
replaced_album = dup_item._cached_album
if replaced_album:
replaced_album_ids.add(dup_item.album_id)
self.replaced_albums[replaced_album.path] = replaced_album
Expand Down
68 changes: 64 additions & 4 deletions beets/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,11 @@ class FormattedItemMapping(dbcore.db.FormattedMapping):
"""

def __init__(self, item, for_path=False):
super(FormattedItemMapping, self).__init__(item, for_path)
# We treat album and item keys specially here,
# so exclude transitive album keys from the model's keys.
super(FormattedItemMapping, self).__init__(item, for_path,
compute_keys=False)
self.model_keys = item.keys(computed=True, with_album=False)
self.item = item

@lazy_property
Expand All @@ -386,15 +390,15 @@ def all_keys(self):
def album_keys(self):
album_keys = []
if self.album:
for key in self.album.keys(True):
for key in self.album.keys(computed=True):
if key in Album.item_keys \
or key not in self.item._fields.keys():
album_keys.append(key)
return album_keys

@lazy_property
@property
def album(self):
return self.item.get_album()
return self.item._cached_album

def _get(self, key):
"""Get the value for a key, either from the album or the item.
Expand Down Expand Up @@ -545,6 +549,29 @@ class Item(LibModel):

_format_config_key = 'format_item'

__album = None
"""Cached album object. Read-only."""

@property
def _cached_album(self):
"""The Album object that this item belongs to, if any, or
None if the item is a singleton or is not associated with a
library.
The instance is cached and refreshed on access.
DO NOT MODIFY!
If you want a copy to modify, use :meth:`get_album`.
"""
if not self.__album and self._db:
self.__album = self._db.get_album(self)
elif self.__album:
self.__album.load()
return self.__album

@_cached_album.setter
def _cached_album(self, album):
self.__album = album

@classmethod
def _getters(cls):
getters = plugins.item_field_getters()
Expand All @@ -571,12 +598,45 @@ def __setitem__(self, key, value):
value = bytestring_path(value)
elif isinstance(value, BLOB_TYPE):
value = bytes(value)
elif key == 'album_id':
self._cached_album = None

changed = super(Item, self)._setitem(key, value)

if changed and key in MediaFile.fields():
self.mtime = 0 # Reset mtime on dirty.

def __getitem__(self, key):
"""Get the value for a field, falling back to the album if
necessary. Raise a KeyError if the field is not available.
"""
try:
return super(Item, self).__getitem__(key)
except KeyError:
if self._cached_album:
return self._cached_album[key]
raise

def keys(self, computed=False, with_album=True):
"""Get a list of available field names. `with_album`
controls whether the album's fields are included.
"""
keys = super(Item, self).keys(computed=computed)
if with_album and self._cached_album:
keys += self._cached_album.keys(computed=computed)
return keys

def get(self, key, default=None, with_album=True):
"""Get the value for a given key or `default` if it does not
exist. Set `with_album` to false to skip album fallback.
"""
try:
return self._get(key, default, raise_=with_album)
except KeyError:
if self._cached_album:
return self._cached_album.get(key, default)
return default

def update(self, values):
"""Set all key/value pairs in the mapping. If mtime is
specified, it is not reset (as it might otherwise be).
Expand Down
9 changes: 7 additions & 2 deletions beets/ui/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,8 +1155,13 @@ def _setup(options, lib=None):
plugins.send("library_opened", lib=lib)

# Add types and queries defined by plugins.
library.Item._types.update(plugins.types(library.Item))
library.Album._types.update(plugins.types(library.Album))
plugin_types_album = plugins.types(library.Album)
library.Album._types.update(plugin_types_album)
item_types = plugin_types_album.copy()
item_types.update(library.Item._types)
item_types.update(plugins.types(library.Item))
library.Item._types = item_types

library.Item._queries.update(plugins.named_queries(library.Item))
library.Album._queries.update(plugins.named_queries(library.Album))

Expand Down
2 changes: 1 addition & 1 deletion beetsplug/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def convert_item(self, dest_dir, keep_new, path_formats, fmt,
item.store() # Store new path and audio data.

if self.config['embed'] and not linked:
album = item.get_album()
album = item._cached_album
if album and album.artpath:
self._log.debug(u'embedding album art from {}',
util.displayable_path(album.artpath))
Expand Down
11 changes: 11 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ New features:
* :doc:`/plugins/replaygain` now does its analysis in parallel when using
the ``command`` or ``ffmpeg`` backends.
:bug:`3478`
* Fields in queries now fall back to an item's album and check its fields too.
Notably, this allows querying items by an album flex attribute, also in path
configuration.
Thanks to :user:`FichteFoll`.
:bug:`2797` :bug:`2988`
* Removes usage of the bs1770gain replaygain backend.
Thanks to :user:`SamuelCook`.
* Added ``trackdisambig`` which stores the recording disambiguation from
Expand Down Expand Up @@ -344,6 +349,12 @@ For plugin developers:
:bug:`3355`
* The autotag hooks have been modified such that they now take 'bpm',
'musical_key' and a per-track based 'genre' as attributes.
* Item (and attribute) access on an item now falls back to the album's
attributes as well. If you specifically want to access an item's attributes,
use ``Item.get(key, with_album=False)``. :bug:`2988`
* ``Item.keys`` also has a ``with_album`` argument now, defaulting to ``True``.
* A ``revision`` attribute has been added to ``Database``. It is increased on
every transaction that mutates it. :bug:`2988`

For packagers:

Expand Down
Loading

0 comments on commit 3e82613

Please sign in to comment.