From 2f112621568fa7f8e90e59176dea905603986609 Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Sun, 7 Feb 2021 08:44:45 +0100 Subject: db/manager: rewrite db write operation API Allows to apply multiple operations together as a single unit. --- alot/db/manager.py | 158 ++++++++++++++++++++++++----------------------------- 1 file changed, 71 insertions(+), 87 deletions(-) diff --git a/alot/db/manager.py b/alot/db/manager.py index 56f71bd1..148cfe2a 100644 --- a/alot/db/manager.py +++ b/alot/db/manager.py @@ -2,11 +2,10 @@ # This file is released under the GNU GPL, version 3 or a later revision. # For further details see the COPYING file -import abc import asyncio from concurrent.futures import ThreadPoolExecutor from contextlib import closing -from functools import partialmethod +from functools import partial, partialmethod import logging import os @@ -28,19 +27,20 @@ def _is_subdir_of(subpath, superpath): return os.path.commonprefix([subpath, superpath]) == superpath # DB write operations -class _DBOperation(abc.ABC): - _dbman = None - _tags = None +class _DBWriteList: + """ + A list of database write operations. + """ - future = None + _dbman = None + _ops = None + _future = None - def __init__(self, dbman, tags): + def __init__(self, dbman): self._dbman = dbman - self._tags = tags - - self.future = dbman._loop.create_future() + self._ops = [] - def apply(self): + def _do_apply(self): logging.debug('Performing DB write: %s', self) try: @@ -48,68 +48,67 @@ class _DBOperation(abc.ABC): logging.debug('got writeable DB') with db.atomic(): - self._apply(db) + for op in self._ops: + op(db) except Exception as e: logging.exception(e) - self.future.set_exception(e) + self._future.set_exception(e) else: logging.debug('DB write completed: %s', self) - self.future.set_result(True) - - @abc.abstractmethod - def _apply(self, db): - pass - - def __str__(self): - return '%s:%s' % (self.__class__.__name__, self._tags) - -class _DBOperationTagAdd(_DBOperation): - _query = None - - def __init__(self, dbman, tags, query): - self._query = query + self._future.set_result(True) - super().__init__(dbman, tags) - - def _apply(self, db): - for msg in db.messages(self._query): + def _do_tag_add(self, db, tags, query): + for msg in db.messages(query): msg_tags = msg.tags - msg_tags |= self._tags - - def __str__(self): - return '%s:%s' % (super().__str__(), self._query) + msg_tags |= tags -class _DBOperationTagRemove(_DBOperationTagAdd): - def _apply(self, db): - for msg in db.messages(self._query): + def _do_tag_remove(self, db, tags, query): + for msg in db.messages(query): msg_tags = msg.tags - msg_tags -= self._tags + msg_tags -= tags -class _DBOperationTagSet(_DBOperationTagAdd): - def _apply(self, db): - for msg in db.messages(self._query): + def _do_tag_set(self, db, tags, query): + for msg in db.messages(query): msg_tags = msg.tags property_tags = msg_tags & self._dbman._property_tags msg_tags.clear() - msg_tags |= self._tags | property_tags + msg_tags |= tags | property_tags + + def _do_msg_add(self, db, tags, path): + msg, _ = db.add(path, sync_flags = self._dbman._sync_flags) -class _DBOperationMsgAdd(_DBOperation): - _path = None + msg_tags = msg.tags + msg_tags |= tags + + def queue_tag_add(self, tags, query): + if tags: + self._ops.append(partial(self._do_tag_add, query = query, tags = tags)) + def queue_tag_remove(self, tags, query): + if tags: + self._ops.append(partial(self._do_tag_remove, query = query, tags = tags)) + def queue_tag_set(self, tags, query): + if tags: + self._ops.append(partial(self._do_tag_set, query = query, tags = tags)) + def queue_msg_add(self, path, tags): + self._ops.append(partial(self._do_msg_add, path = path, tags = tags)) - def __init__(self, dbman, tags, path): - self._path = path + def apply(self): + if self._future: + raise ValueError('Multiple applies on a write list') - super().__init__(dbman, tags) + self._ops = tuple(self._ops) + self._future = self._dbman._loop.create_future() - def _apply(self, db): - msg, _ = db.add(self._path, sync_flags = self._dbman._sync_flags) + if self._ops: + self._dbman._write_queue.put_nowait(self) + else: + self._future.set_result(True) - msg_tags = msg.tags - msg_tags |= self._tags + return self._future def __str__(self): - return '%s:%s' % (super().__str__(), self._path) + return '%s:%s' % (self.__class__.__name__, self._ops) class DBManager: """ @@ -243,9 +242,15 @@ class DBManager: logging.debug('submitting write task: %s', cur_item) - await self._loop.run_in_executor(executor, cur_item.apply) + await self._loop.run_in_executor(executor, cur_item._do_apply) self._write_queue.task_done() + def db_write_create(self): + if self.ro: + raise DatabaseROError() + + return _DBWriteList(self) + def tags_add(self, query, tags): """ Asynchronously add tags to messages matching `querystring`. @@ -255,17 +260,9 @@ class DBManager: :param tags: a set of tags to be added :type tags: set of str """ - if self.ro: - raise DatabaseROError() - - if not tags: - ret = self._loop.create_future() - ret.set_result(True) - return ret - - op = _DBOperationTagAdd(self, tags, query) - self._write_queue.put_nowait(op) - return op.future + ret = self.db_write_create() + ret.queue_tag_add(tags, query) + return ret.apply() def tags_remove(self, query, tags): """ @@ -276,17 +273,9 @@ class DBManager: :param tags: a set of tags to be added :type tags: set of str """ - if self.ro: - raise DatabaseROError() - - if not tags: - ret = self._loop.create_future() - ret.set_result(True) - return ret - - op = _DBOperationTagRemove(self, tags, query) - self._write_queue.put_nowait(op) - return op.future + ret = self.db_write_create() + ret.queue_tag_remove(tags, query) + return ret.apply() def tags_set(self, query, tags): """ @@ -297,12 +286,9 @@ class DBManager: :param tags: a set of tags to be added :type tags: set of str """ - if self.ro: - raise DatabaseROError() - - op = _DBOperationTagSet(self, tags, query) - self._write_queue.put_nowait(op) - return op.future + ret = self.db_write_create() + ret.queue_tag_set(tags, query) + return ret.apply() def msg_add(self, path, tags): """ @@ -313,14 +299,12 @@ class DBManager: :param tags: tagstrings to add :type tags: list of str """ - if self.ro: - raise DatabaseROError() if not _is_subdir_of(path, self.path): msg = 'message path %s ' % path msg += ' is not below notmuchs ' msg += 'root path (%s)' % self.path raise DatabaseError(msg) - op = _DBOperationMsgAdd(self, tags, path) - self._write_queue.put_nowait(op) - return op.future + ret = self.db_write_create() + ret.queue_msg_add(path, tags) + return ret.apply() -- cgit v1.2.3