# Copyright (C) 2011-2012 Patrick Totzke # This file is released under the GNU GPL, version 3 or a later revision. # For further details see the COPYING file import asyncio from concurrent.futures import ThreadPoolExecutor from contextlib import closing from functools import partial, partialmethod import logging from notmuch2 import Database, NotmuchError from .errors import (DatabaseROError, NonexistantObjectError, QueryError) from .sort import ORDER from .thread import Thread from ..settings.const import settings # DB write operations class _DBWriteList: """ A list of database write operations. """ _dbman = None _ops = None _future = None def __init__(self, dbman): self._dbman = dbman self._ops = [] def _do_apply(self): logging.debug('Performing DB write: %s', self) try: with Database(path = self._dbman._db_path, config = self._dbman._config_path, mode = Database.MODE.READ_WRITE) as db: logging.debug('got writeable DB') with db.atomic(): for op in self._ops: op(db) except Exception as e: logging.exception(e) self._future.set_exception(e) else: logging.debug('DB write completed: %s', self) self._future.set_result(True) def _do_tag_add(self, db, tags, query): for msg in db.messages(query): msg_tags = msg.tags op_tags = tags - msg_tags msg_tags |= op_tags def _do_tag_remove(self, db, tags, query): for msg in db.messages(query): msg_tags = msg.tags op_tags = tags & msg_tags msg_tags -= op_tags def _do_tag_set(self, db, tags, query): for msg in db.messages(query): msg_tags = msg.tags property_tags = msg_tags & self._dbman._property_tags msg_tags.clear() msg_tags |= tags | property_tags def _do_msg_add(self, db, tags, path): msg, _ = db.add(path, sync_flags = self._dbman._sync_flags) msg_tags = msg.tags msg_tags |= tags def queue_tag_add(self, tags, query): if tags: self._ops.append(partial(self._do_tag_add, query = query, tags = tags)) def queue_tag_remove(self, tags, query): if tags: self._ops.append(partial(self._do_tag_remove, query = query, tags = tags)) def queue_tag_set(self, tags, query): if tags: self._ops.append(partial(self._do_tag_set, query = query, tags = tags)) def queue_msg_add(self, path, tags): self._ops.append(partial(self._do_msg_add, path = path, tags = tags)) def apply(self): if self._future: raise ValueError('Multiple applies on a write list') self._ops = tuple(self._ops) self._future = self._dbman._loop.create_future() if self._ops: self._dbman._write_queue.put_nowait(self) else: self._future.set_result(True) return self._future def __str__(self): return '%s:%s' % (self.__class__.__name__, self._ops) class DBManager: """ Keeps track of your index parameters, maintains a write-queue and lets you look up threads and messages directly to the persistent wrapper classes. """ _db_path = None _config_path = None _loop = None _sync_flags = None _exclude_tags = None _property_tags = None _write_task = None _write_queue = None def __init__(self, loop, ro = False, db_path = None, config_path = Database.CONFIG.SEARCH): self.ro = ro self._db_path = db_path self._config_path = config_path self._loop = loop # read notmuch's config regarding imap flag synchronization self._sync_flags = settings.get_notmuch_setting('maildir', 'synchronize_flags') self._exclude_tags = frozenset(settings.get('exclude_tags')) self._property_tags = frozenset(settings.get('property_tags')) def _db_ro(self): return closing(Database(path = self._db_path, config = self._config_path, mode = Database.MODE.READ_ONLY)) def _count(self, what, querystring): try: with self._db_ro() as db: func = getattr(db, 'count_' +what) return func(querystring, exclude_tags = self._exclude_tags) except NotmuchError: return -1 count_messages = partialmethod(_count, 'messages') """returns number of messages that match `querystring`""" count_threads = partialmethod(_count, 'threads') """returns number of threads that match `querystring`""" def _get_notmuch_thread(self, db, tid): """returns :class:`notmuch.database.Thread` with given id""" querystr = 'thread:' + tid try: return next(db.threads(querystr, exclude_tags = self._exclude_tags)) except StopIteration: errmsg = 'no thread with id %s exists!' % tid raise NonexistantObjectError(errmsg) def get_thread(self, tid): """returns :class:`Thread` with given thread id (str)""" with self._db_ro() as db: return Thread(self, self._get_notmuch_thread(db, tid)) def get_all_tags(self): """ returns all tagsstrings used in the database :rtype: set of str """ with self._db_ro() as db: return set(db.tags) def get_named_queries(self): """ returns the named queries stored in the database. :rtype: dict (str -> str) mapping alias to full query string """ q_prefix = 'query.' with self._db_ro() as db: queries = filter(lambda k: k.startswith(q_prefix), db.config) return { q[len(q_prefix):] : db.config[q] for q in queries } def get_threads(self, querystring, sort = ORDER.NEWEST_FIRST, exclude_tags = frozenset()): """ asynchronously look up thread ids matching `querystring`. :param querystring: The query string to use for the lookup :type querystring: str. :param sort: Sort order. :type query: alot.db.sort.ORDER :param exclude_tags: Tags to exclude by default unless included in the search :type exclude_tags: set of str :returns: iterator over thread ids """ with self._db_ro() as db: exclude_tags = self._exclude_tags | exclude_tags try: for t in db.threads(querystring, sort = sort, exclude_tags = exclude_tags): yield t.threadid except NotmuchError as e: raise QueryError from e async def startup(self): self._write_queue = asyncio.Queue() self._write_task = asyncio.create_task(self._db_write_task()) async def shutdown(self): if self._write_task: await self._write_queue.put(None) await self._write_task async def _db_write_task(self): # this task serialises write operations on the database and # sends them off to a thread so they do not block the event loop # one workers, as there can be only one DB writer at any moment with ThreadPoolExecutor(max_workers = 1) as executor: while True: cur_item = await self._write_queue.get() if cur_item is None: self._write_queue.task_done() break logging.debug('submitting write task: %s', cur_item) await self._loop.run_in_executor(executor, cur_item._do_apply) self._write_queue.task_done() def db_write_create(self): if self.ro: raise DatabaseROError() return _DBWriteList(self) def tags_add(self, query, tags): """ Asynchronously add tags to messages matching `querystring`. :param querystring: notmuch search string :type querystring: str :param tags: a set of tags to be added :type tags: set of str """ ret = self.db_write_create() ret.queue_tag_add(tags, query) return ret.apply() def tags_remove(self, query, tags): """ Asynchronously remove tags to messages matching `querystring`. :param querystring: notmuch search string :type querystring: str :param tags: a set of tags to be added :type tags: set of str """ ret = self.db_write_create() ret.queue_tag_remove(tags, query) return ret.apply() def tags_set(self, query, tags): """ Asynchronously set tags to messages matching `querystring`. :param querystring: notmuch search string :type querystring: str :param tags: a set of tags to be added :type tags: set of str """ ret = self.db_write_create() ret.queue_tag_set(tags, query) return ret.apply() def msg_add(self, path, tags): """ Asynchronously add a file to the notmuch index. :param path: path to the file :type path: str :param tags: tagstrings to add :type tags: list of str """ ret = self.db_write_create() ret.queue_msg_add(path, tags) return ret.apply()