summaryrefslogtreecommitdiff
path: root/alot/db
diff options
context:
space:
mode:
authorAnton Khirnov <anton@khirnov.net>2021-01-20 20:31:49 +0100
committerAnton Khirnov <anton@khirnov.net>2021-01-20 20:31:49 +0100
commit180d6c5e439bfca2af479b445afff9b18e28df5a (patch)
treee87a716cf1ae05a3aec4f6c635edf7bfb27da7d6 /alot/db
parentcc1ee21e0704d7f02552f1617737f6cf7471fa52 (diff)
db: make sure to close the read-only database instances
Diffstat (limited to 'alot/db')
-rw-r--r--alot/db/manager.py34
-rw-r--r--alot/db/thread.py12
2 files changed, 27 insertions, 19 deletions
diff --git a/alot/db/manager.py b/alot/db/manager.py
index ba9271fc..728dcddb 100644
--- a/alot/db/manager.py
+++ b/alot/db/manager.py
@@ -4,6 +4,7 @@
import abc
import asyncio
+from contextlib import closing
import logging
import os
@@ -151,37 +152,41 @@ class DBManager:
self._property_tags = frozenset(settings.get('property_tags'))
def _db_ro(self):
- return Database(path = self.path, mode = Database.MODE.READ_ONLY)
+ return closing(Database(path = self.path, mode = Database.MODE.READ_ONLY))
def count_messages(self, querystring):
"""returns number of messages that match `querystring`"""
- return self._db_ro().count_messages(querystring, exclude_tags = self._exclude_tags)
+ with self._db_ro() as db:
+ return db.count_messages(querystring, exclude_tags = self._exclude_tags)
def count_threads(self, querystring):
"""returns number of threads that match `querystring`"""
- return self._db_ro().count_threads(querystring, exclude_tags = self._exclude_tags)
+ with self._db_ro() as db:
+ return db.count_threads(querystring, exclude_tags = self._exclude_tags)
- def _get_notmuch_thread(self, tid):
+ def _get_notmuch_thread(self, db, tid):
"""returns :class:`notmuch.database.Thread` with given id"""
querystr = 'thread:' + tid
try:
- return next(self._db_ro().threads(querystr, exclude_tags = self._exclude_tags))
+ return next(db.threads(querystr, exclude_tags = self._exclude_tags))
except StopIteration:
errmsg = 'no thread with id %s exists!' % tid
raise NonexistantObjectError(errmsg)
def get_thread(self, tid):
"""returns :class:`Thread` with given thread id (str)"""
- return Thread(self, self._get_notmuch_thread(tid))
+ with self._db_ro() as db:
+ return Thread(self, self._get_notmuch_thread(db, tid))
def get_all_tags(self):
"""
returns all tagsstrings used in the database
:rtype: set of str
"""
- return self._db_ro().tags
+ with self._db_ro() as db:
+ return set(db.tags)
def get_named_queries(self):
"""
@@ -190,9 +195,9 @@ class DBManager:
"""
q_prefix = 'query.'
- db = self._db_ro()
- queries = filter(lambda k: k.startswith(q_prefix), db.config)
- return { q[len(q_prefix):] : db.config[q] for q in queries }
+ with self._db_ro() as db:
+ queries = filter(lambda k: k.startswith(q_prefix), db.config)
+ return { q[len(q_prefix):] : db.config[q] for q in queries }
def get_threads(self, querystring, sort='newest_first', exclude_tags = frozenset()):
"""
@@ -211,11 +216,12 @@ class DBManager:
# TODO: use a symbolic constant for this
assert sort in self._sort_orders
- db = self._db_ro()
- sort = self._sort_orders[sort]
- exclude_tags = self._exclude_tags | exclude_tags
+ with self._db_ro() as db:
+ sort = self._sort_orders[sort]
+ exclude_tags = self._exclude_tags | exclude_tags
- return (t.threadid for t in db.threads(querystring, sort = sort, exclude_tags = exclude_tags))
+ for t in db.threads(querystring, sort = sort, exclude_tags = exclude_tags):
+ yield t.threadid
async def startup(self):
self._write_queue = asyncio.Queue()
diff --git a/alot/db/thread.py b/alot/db/thread.py
index 1f9ae7f0..37cc5fc7 100644
--- a/alot/db/thread.py
+++ b/alot/db/thread.py
@@ -61,13 +61,10 @@ class Thread:
self.message_list = []
self.messages = {}
- self.refresh(thread)
+ self._refresh(thread)
- def refresh(self, thread = None):
+ def _refresh(self, thread):
"""refresh thread metadata from the index"""
- if not thread:
- thread = self._dbman._get_notmuch_thread(self.id)
-
self.total_messages = len(thread)
self._notmuch_authors_string = thread.authors
@@ -97,6 +94,11 @@ class Thread:
self.messages, self.toplevel_messages, self.message_list = self._gather_messages(thread)
+ def refresh(self):
+ with self._dbman._db_ro() as db:
+ thread = self._dbman._get_notmuch_thread(db, self.id)
+ self._refresh(thread)
+
def _gather_messages(self, thread):
msgs = {}
msg_tree = []