summaryrefslogtreecommitdiff
path: root/alot/db/manager.py
diff options
context:
space:
mode:
authorAnton Khirnov <anton@khirnov.net>2021-01-20 20:31:49 +0100
committerAnton Khirnov <anton@khirnov.net>2021-01-20 20:31:49 +0100
commit180d6c5e439bfca2af479b445afff9b18e28df5a (patch)
treee87a716cf1ae05a3aec4f6c635edf7bfb27da7d6 /alot/db/manager.py
parentcc1ee21e0704d7f02552f1617737f6cf7471fa52 (diff)
db: make sure to close the read-only database instances
Diffstat (limited to 'alot/db/manager.py')
-rw-r--r--alot/db/manager.py34
1 files changed, 20 insertions, 14 deletions
diff --git a/alot/db/manager.py b/alot/db/manager.py
index ba9271fc..728dcddb 100644
--- a/alot/db/manager.py
+++ b/alot/db/manager.py
@@ -4,6 +4,7 @@
import abc
import asyncio
+from contextlib import closing
import logging
import os
@@ -151,37 +152,41 @@ class DBManager:
self._property_tags = frozenset(settings.get('property_tags'))
def _db_ro(self):
- return Database(path = self.path, mode = Database.MODE.READ_ONLY)
+ return closing(Database(path = self.path, mode = Database.MODE.READ_ONLY))
def count_messages(self, querystring):
"""returns number of messages that match `querystring`"""
- return self._db_ro().count_messages(querystring, exclude_tags = self._exclude_tags)
+ with self._db_ro() as db:
+ return db.count_messages(querystring, exclude_tags = self._exclude_tags)
def count_threads(self, querystring):
"""returns number of threads that match `querystring`"""
- return self._db_ro().count_threads(querystring, exclude_tags = self._exclude_tags)
+ with self._db_ro() as db:
+ return db.count_threads(querystring, exclude_tags = self._exclude_tags)
- def _get_notmuch_thread(self, tid):
+ def _get_notmuch_thread(self, db, tid):
"""returns :class:`notmuch.database.Thread` with given id"""
querystr = 'thread:' + tid
try:
- return next(self._db_ro().threads(querystr, exclude_tags = self._exclude_tags))
+ return next(db.threads(querystr, exclude_tags = self._exclude_tags))
except StopIteration:
errmsg = 'no thread with id %s exists!' % tid
raise NonexistantObjectError(errmsg)
def get_thread(self, tid):
"""returns :class:`Thread` with given thread id (str)"""
- return Thread(self, self._get_notmuch_thread(tid))
+ with self._db_ro() as db:
+ return Thread(self, self._get_notmuch_thread(db, tid))
def get_all_tags(self):
"""
returns all tagsstrings used in the database
:rtype: set of str
"""
- return self._db_ro().tags
+ with self._db_ro() as db:
+ return set(db.tags)
def get_named_queries(self):
"""
@@ -190,9 +195,9 @@ class DBManager:
"""
q_prefix = 'query.'
- db = self._db_ro()
- queries = filter(lambda k: k.startswith(q_prefix), db.config)
- return { q[len(q_prefix):] : db.config[q] for q in queries }
+ with self._db_ro() as db:
+ queries = filter(lambda k: k.startswith(q_prefix), db.config)
+ return { q[len(q_prefix):] : db.config[q] for q in queries }
def get_threads(self, querystring, sort='newest_first', exclude_tags = frozenset()):
"""
@@ -211,11 +216,12 @@ class DBManager:
# TODO: use a symbolic constant for this
assert sort in self._sort_orders
- db = self._db_ro()
- sort = self._sort_orders[sort]
- exclude_tags = self._exclude_tags | exclude_tags
+ with self._db_ro() as db:
+ sort = self._sort_orders[sort]
+ exclude_tags = self._exclude_tags | exclude_tags
- return (t.threadid for t in db.threads(querystring, sort = sort, exclude_tags = exclude_tags))
+ for t in db.threads(querystring, sort = sort, exclude_tags = exclude_tags):
+ yield t.threadid
async def startup(self):
self._write_queue = asyncio.Queue()