summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnton Khirnov <anton@khirnov.net>2021-02-07 08:44:45 +0100
committerAnton Khirnov <anton@khirnov.net>2021-02-07 08:44:45 +0100
commit2f112621568fa7f8e90e59176dea905603986609 (patch)
tree4cba711ca48d300fc88958dea8c44c7d325e20d1
parent9137c13fe863840a83ed5afd8da10a6aa9582bef (diff)
db/manager: rewrite db write operation API
Allows to apply multiple operations together as a single unit.
-rw-r--r--alot/db/manager.py158
1 files changed, 71 insertions, 87 deletions
diff --git a/alot/db/manager.py b/alot/db/manager.py
index 56f71bd1..148cfe2a 100644
--- a/alot/db/manager.py
+++ b/alot/db/manager.py
@@ -2,11 +2,10 @@
# This file is released under the GNU GPL, version 3 or a later revision.
# For further details see the COPYING file
-import abc
import asyncio
from concurrent.futures import ThreadPoolExecutor
from contextlib import closing
-from functools import partialmethod
+from functools import partial, partialmethod
import logging
import os
@@ -28,19 +27,20 @@ def _is_subdir_of(subpath, superpath):
return os.path.commonprefix([subpath, superpath]) == superpath
# DB write operations
-class _DBOperation(abc.ABC):
- _dbman = None
- _tags = None
+class _DBWriteList:
+ """
+ A list of database write operations.
+ """
- future = None
+ _dbman = None
+ _ops = None
+ _future = None
- def __init__(self, dbman, tags):
+ def __init__(self, dbman):
self._dbman = dbman
- self._tags = tags
-
- self.future = dbman._loop.create_future()
+ self._ops = []
- def apply(self):
+ def _do_apply(self):
logging.debug('Performing DB write: %s', self)
try:
@@ -48,68 +48,67 @@ class _DBOperation(abc.ABC):
logging.debug('got writeable DB')
with db.atomic():
- self._apply(db)
+ for op in self._ops:
+ op(db)
except Exception as e:
logging.exception(e)
- self.future.set_exception(e)
+ self._future.set_exception(e)
else:
logging.debug('DB write completed: %s', self)
- self.future.set_result(True)
-
- @abc.abstractmethod
- def _apply(self, db):
- pass
-
- def __str__(self):
- return '%s:%s' % (self.__class__.__name__, self._tags)
-
-class _DBOperationTagAdd(_DBOperation):
- _query = None
-
- def __init__(self, dbman, tags, query):
- self._query = query
+ self._future.set_result(True)
- super().__init__(dbman, tags)
-
- def _apply(self, db):
- for msg in db.messages(self._query):
+ def _do_tag_add(self, db, tags, query):
+ for msg in db.messages(query):
msg_tags = msg.tags
- msg_tags |= self._tags
-
- def __str__(self):
- return '%s:%s' % (super().__str__(), self._query)
+ msg_tags |= tags
-class _DBOperationTagRemove(_DBOperationTagAdd):
- def _apply(self, db):
- for msg in db.messages(self._query):
+ def _do_tag_remove(self, db, tags, query):
+ for msg in db.messages(query):
msg_tags = msg.tags
- msg_tags -= self._tags
+ msg_tags -= tags
-class _DBOperationTagSet(_DBOperationTagAdd):
- def _apply(self, db):
- for msg in db.messages(self._query):
+ def _do_tag_set(self, db, tags, query):
+ for msg in db.messages(query):
msg_tags = msg.tags
property_tags = msg_tags & self._dbman._property_tags
msg_tags.clear()
- msg_tags |= self._tags | property_tags
+ msg_tags |= tags | property_tags
+
+ def _do_msg_add(self, db, tags, path):
+ msg, _ = db.add(path, sync_flags = self._dbman._sync_flags)
-class _DBOperationMsgAdd(_DBOperation):
- _path = None
+ msg_tags = msg.tags
+ msg_tags |= tags
+
+ def queue_tag_add(self, tags, query):
+ if tags:
+ self._ops.append(partial(self._do_tag_add, query = query, tags = tags))
+ def queue_tag_remove(self, tags, query):
+ if tags:
+ self._ops.append(partial(self._do_tag_remove, query = query, tags = tags))
+ def queue_tag_set(self, tags, query):
+ if tags:
+ self._ops.append(partial(self._do_tag_set, query = query, tags = tags))
+ def queue_msg_add(self, path, tags):
+ self._ops.append(partial(self._do_msg_add, path = path, tags = tags))
- def __init__(self, dbman, tags, path):
- self._path = path
+ def apply(self):
+ if self._future:
+ raise ValueError('Multiple applies on a write list')
- super().__init__(dbman, tags)
+ self._ops = tuple(self._ops)
+ self._future = self._dbman._loop.create_future()
- def _apply(self, db):
- msg, _ = db.add(self._path, sync_flags = self._dbman._sync_flags)
+ if self._ops:
+ self._dbman._write_queue.put_nowait(self)
+ else:
+ self._future.set_result(True)
- msg_tags = msg.tags
- msg_tags |= self._tags
+ return self._future
def __str__(self):
- return '%s:%s' % (super().__str__(), self._path)
+ return '%s:%s' % (self.__class__.__name__, self._ops)
class DBManager:
"""
@@ -243,9 +242,15 @@ class DBManager:
logging.debug('submitting write task: %s', cur_item)
- await self._loop.run_in_executor(executor, cur_item.apply)
+ await self._loop.run_in_executor(executor, cur_item._do_apply)
self._write_queue.task_done()
+ def db_write_create(self):
+ if self.ro:
+ raise DatabaseROError()
+
+ return _DBWriteList(self)
+
def tags_add(self, query, tags):
"""
Asynchronously add tags to messages matching `querystring`.
@@ -255,17 +260,9 @@ class DBManager:
:param tags: a set of tags to be added
:type tags: set of str
"""
- if self.ro:
- raise DatabaseROError()
-
- if not tags:
- ret = self._loop.create_future()
- ret.set_result(True)
- return ret
-
- op = _DBOperationTagAdd(self, tags, query)
- self._write_queue.put_nowait(op)
- return op.future
+ ret = self.db_write_create()
+ ret.queue_tag_add(tags, query)
+ return ret.apply()
def tags_remove(self, query, tags):
"""
@@ -276,17 +273,9 @@ class DBManager:
:param tags: a set of tags to be added
:type tags: set of str
"""
- if self.ro:
- raise DatabaseROError()
-
- if not tags:
- ret = self._loop.create_future()
- ret.set_result(True)
- return ret
-
- op = _DBOperationTagRemove(self, tags, query)
- self._write_queue.put_nowait(op)
- return op.future
+ ret = self.db_write_create()
+ ret.queue_tag_remove(tags, query)
+ return ret.apply()
def tags_set(self, query, tags):
"""
@@ -297,12 +286,9 @@ class DBManager:
:param tags: a set of tags to be added
:type tags: set of str
"""
- if self.ro:
- raise DatabaseROError()
-
- op = _DBOperationTagSet(self, tags, query)
- self._write_queue.put_nowait(op)
- return op.future
+ ret = self.db_write_create()
+ ret.queue_tag_set(tags, query)
+ return ret.apply()
def msg_add(self, path, tags):
"""
@@ -313,14 +299,12 @@ class DBManager:
:param tags: tagstrings to add
:type tags: list of str
"""
- if self.ro:
- raise DatabaseROError()
if not _is_subdir_of(path, self.path):
msg = 'message path %s ' % path
msg += ' is not below notmuchs '
msg += 'root path (%s)' % self.path
raise DatabaseError(msg)
- op = _DBOperationMsgAdd(self, tags, path)
- self._write_queue.put_nowait(op)
- return op.future
+ ret = self.db_write_create()
+ ret.queue_msg_add(path, tags)
+ return ret.apply()