From 180d6c5e439bfca2af479b445afff9b18e28df5a Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Wed, 20 Jan 2021 20:31:49 +0100 Subject: db: make sure to close the read-only database instances --- alot/db/manager.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) (limited to 'alot/db/manager.py') diff --git a/alot/db/manager.py b/alot/db/manager.py index ba9271fc..728dcddb 100644 --- a/alot/db/manager.py +++ b/alot/db/manager.py @@ -4,6 +4,7 @@ import abc import asyncio +from contextlib import closing import logging import os @@ -151,37 +152,41 @@ class DBManager: self._property_tags = frozenset(settings.get('property_tags')) def _db_ro(self): - return Database(path = self.path, mode = Database.MODE.READ_ONLY) + return closing(Database(path = self.path, mode = Database.MODE.READ_ONLY)) def count_messages(self, querystring): """returns number of messages that match `querystring`""" - return self._db_ro().count_messages(querystring, exclude_tags = self._exclude_tags) + with self._db_ro() as db: + return db.count_messages(querystring, exclude_tags = self._exclude_tags) def count_threads(self, querystring): """returns number of threads that match `querystring`""" - return self._db_ro().count_threads(querystring, exclude_tags = self._exclude_tags) + with self._db_ro() as db: + return db.count_threads(querystring, exclude_tags = self._exclude_tags) - def _get_notmuch_thread(self, tid): + def _get_notmuch_thread(self, db, tid): """returns :class:`notmuch.database.Thread` with given id""" querystr = 'thread:' + tid try: - return next(self._db_ro().threads(querystr, exclude_tags = self._exclude_tags)) + return next(db.threads(querystr, exclude_tags = self._exclude_tags)) except StopIteration: errmsg = 'no thread with id %s exists!' % tid raise NonexistantObjectError(errmsg) def get_thread(self, tid): """returns :class:`Thread` with given thread id (str)""" - return Thread(self, self._get_notmuch_thread(tid)) + with self._db_ro() as db: + return Thread(self, self._get_notmuch_thread(db, tid)) def get_all_tags(self): """ returns all tagsstrings used in the database :rtype: set of str """ - return self._db_ro().tags + with self._db_ro() as db: + return set(db.tags) def get_named_queries(self): """ @@ -190,9 +195,9 @@ class DBManager: """ q_prefix = 'query.' - db = self._db_ro() - queries = filter(lambda k: k.startswith(q_prefix), db.config) - return { q[len(q_prefix):] : db.config[q] for q in queries } + with self._db_ro() as db: + queries = filter(lambda k: k.startswith(q_prefix), db.config) + return { q[len(q_prefix):] : db.config[q] for q in queries } def get_threads(self, querystring, sort='newest_first', exclude_tags = frozenset()): """ @@ -211,11 +216,12 @@ class DBManager: # TODO: use a symbolic constant for this assert sort in self._sort_orders - db = self._db_ro() - sort = self._sort_orders[sort] - exclude_tags = self._exclude_tags | exclude_tags + with self._db_ro() as db: + sort = self._sort_orders[sort] + exclude_tags = self._exclude_tags | exclude_tags - return (t.threadid for t in db.threads(querystring, sort = sort, exclude_tags = exclude_tags)) + for t in db.threads(querystring, sort = sort, exclude_tags = exclude_tags): + yield t.threadid async def startup(self): self._write_queue = asyncio.Queue() -- cgit v1.2.3