summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnton Khirnov <anton@khirnov.net>2023-12-27 10:25:20 +0100
committerAnton Khirnov <anton@khirnov.net>2023-12-27 10:25:20 +0100
commitdf98562922d401d4357099cdc21f9178a01ca1f9 (patch)
tree897371d95148d9ae083ebd249ae9ba9ce10011f5
Initial commit.
-rwxr-xr-xnaros.py474
1 files changed, 474 insertions, 0 deletions
diff --git a/naros.py b/naros.py
new file mode 100755
index 0000000..170752c
--- /dev/null
+++ b/naros.py
@@ -0,0 +1,474 @@
+#!/usr/bin/python3
+
+import argparse
+from contextlib import closing, ExitStack
+from datetime import datetime, timedelta, timezone
+from copy import deepcopy
+import gzip
+import ipaddress
+import json
+import logging
+import logging.handlers
+import os
+import re
+import selectors
+import signal
+import sys
+import subprocess
+import time
+
+CMD_ADD = 'add'
+CMD_DELETE = 'delete'
+
+OBJ_SET = 'set'
+OBJ_ELEMENT = 'element'
+
+def nft_json_pop(data):
+ v = iter(data.items())
+ try:
+ key, val = next(v)
+ except StopIteration:
+ raise ValueError('JSON data is empty')
+
+ if next(v, None) is not None:
+ raise ValueError('JSON data has more than one element')
+
+ return key, val
+
+class NftSetElem:
+
+ TYPE_ADDR = 'addr'
+ TYPE_PREFIX = 'prefix'
+ TYPE_RANGE = 'range'
+
+ _type = None
+ _data = None
+ _expires = None
+
+ def __init__(self, data):
+ try:
+ elem = data['elem']
+ except (TypeError, KeyError):
+ elem = data
+
+ if isinstance(elem, dict) and 'val' in elem:
+ if 'expires' in elem:
+ expires = timedelta(seconds = elem['expires'])
+ self._expires = datetime.now(timezone.utc) + expires
+ elif 'expires_abs' in elem:
+ # not in nft, but used by our on-disk format
+ self._expires = datetime.fromisoformat(elem['expires_abs'])
+
+ elem = elem['val']
+
+ if isinstance(elem, str):
+ self._type = self.TYPE_ADDR
+ self._data = elem
+ elif 'prefix' in elem:
+ p = elem['prefix']
+ self._type = self.TYPE_PREFIX
+ self._data = (p['addr'], p['len'])
+ elif 'range' in elem:
+ start, end = elem['range']
+ self._type = self.TYPE_RANGE
+ self._data = (start, end)
+ else:
+ raise ValueError('Unsupported set element: %s', elem)
+
+ def expired(self, t):
+ return self._expires is not None and self._expires <= t
+
+ def __eq__(self, other):
+ return self._type == other._type and self._data == other._data
+
+ def __hash__(self):
+ return hash((self._type, self._data))
+
+ def to_string(self, t):
+ if self.expired(t):
+ return ''
+
+ if self._type == self.TYPE_ADDR:
+ ret = self._data
+ elif self._type == self.TYPE_PREFIX:
+ ret = '%s/%d' % (self._data[0], self._data[1])
+ elif self._type == self.TYPE_RANGE:
+ ret = '%s-%s' % (self._data[0], self._data[1])
+ else:
+ raise ValueError
+
+ if self._expires:
+ ret += ' timeout %ds' % int((self._expires - t).total_seconds())
+
+ return ret
+
+ def to_tree(self, t):
+ if self.expired(t):
+ return None
+
+ ret = {}
+
+ if self._type == self.TYPE_ADDR:
+ d = self._data
+ elif self._type == self.TYPE_PREFIX:
+ d = { 'prefix' : { 'addr' : self._data[0], 'len' : self._data[1] } }
+ elif self._type == self.TYPE_RANGE:
+ d = { 'range' : (self._data[0], self._data[1]) }
+ else:
+ raise ValueError
+
+ ret['val'] = d
+
+ if self._expires:
+ ret['expires_abs'] = self._expires.isoformat()
+
+ return ret
+
+class NftSet:
+
+ # True when the set is active in the kernel
+ active = None
+
+ _logger = None
+ _data = None
+ _elems = None
+
+ def __init__(self, logger, data):
+ self._data = deepcopy(data)
+ self._data.pop('elem', None)
+
+ if self.type not in ('ipv4_addr', 'ipv6_addr'):
+ raise ValueError('Persistent sets of type %s not supported', self.type)
+
+ self.active = False
+ self._elems = set()
+ self._logger = logger.getChild(str(self))
+
+ # add initial set elements
+ self.update(data)
+ self._logger.info('Added %d initial elements', len(self._elems))
+
+ def __getattr__(self, key):
+ return self._data[key]
+
+ def __str__(self):
+ return '%s(%s, %s, %s, %s)(%d elems)' % \
+ (self.__class__.__name__, self.family, self.table, self.name,
+ self.type, len(self._elems))
+
+ def update(self, data):
+ if not 'elem' in data:
+ return
+
+ for elem in data['elem']:
+ self._elems.add(NftSetElem(elem))
+
+ def gc(self):
+ before = len(self._elems)
+ now = datetime.now(timezone.utc)
+ self._elems = { e for e in self._elems if not e.expired(now) }
+
+ self._logger.debug('gc, %d elements expired', before - len(self._elems))
+
+ return now
+
+ def add(self, elem):
+ self._logger.debug('Adding element: %s', elem)
+ self._elems.add(NftSetElem(elem))
+
+ def remove(self, elem):
+ self._logger.debug('Removing element: %s', elem)
+ self._elems.discard(NftSetElem(elem))
+
+ def to_string(self):
+ now = self.gc()
+ return ','.join(filter(None, (e.to_string(now) for e in self._elems)))
+
+ def to_tree(self):
+ now = self.gc()
+
+ data = deepcopy(self._data)
+ data['elem'] = tuple(filter(None, (e.to_tree(now) for e in self._elems)))
+
+ return data
+
+class NftPersistentSets:
+
+ _STATE_FNAME = 'persistent_sets.json.gz'
+
+ _sets = None
+ _logger = None
+ _regex = None
+ _fw_modify = None
+ _state_dir = None
+
+ _flush_period = None
+ _flush_next = None
+
+ def __init__(self, logger, state_dir, persistent_set_regex,
+ flush_period, fw_modify = True):
+ self._logger = logger.getChild(str(self))
+ self._state_dir = state_dir
+ self._fw_modify = fw_modify
+ self._flush_period = flush_period
+ self._regex = re.compile(persistent_set_regex)
+
+ self._sets = {}
+ self._load()
+
+ # read initial state from the kernel
+ self._logger.info('Reading initial sets from the kernel')
+ res = subprocess.run(['nft', '-j', 'list', 'sets'], capture_output = True,
+ check = True)
+ data = json.loads(res.stdout)
+
+ count = 0
+ for it in data['nftables']:
+ key, val = next(iter(it.items()))
+ if key == OBJ_SET and self._regex.search(val['name']):
+ uid = self._set_key(val)
+ self._logger.info('Parsing set from the kernel: %s', uid)
+ self._set_added(self._set_key(val), val)
+ count += 1
+
+ self._logger.info('Read %d initial sets from the kernel', count)
+
+ if self._flush_period > 0:
+ self._flush_next = time.clock_gettime(time.CLOCK_BOOTTIME) + self._flush_period
+
+ def __str__(self):
+ return self.__class__.__name__
+
+ def __enter__(self):
+ return self
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.flush()
+
+ def _load(self):
+ src_path = os.path.join(self._state_dir, self._STATE_FNAME)
+
+ try:
+ with closing(gzip.open(src_path, 'rt')) as f:
+ data = json.load(f)
+ sets = data['sets']
+ except (FileNotFoundError, KeyError):
+ return
+
+ self._logger.info('Loading sets from storage')
+
+ for s in sets:
+ uid = self._set_key(s)
+ if uid in self._sets:
+ raise ValueError('Duplicate set in stored data: %s', s)
+
+ self._logger.info('Loading set from storage: %s', uid)
+ self._sets[uid] = NftSet(self._logger, s)
+
+ self._logger.info('Loaded %d sets from storage', len(self._sets))
+
+ def flush(self):
+ self._logger.info('Writing persistent sets to disk...')
+
+ data = { 'sets' : [s.to_tree() for s in self._sets.values()] }
+ dst_path = os.path.join(self._state_dir, self._STATE_FNAME)
+ prev_path = dst_path + '.prev'
+
+ if os.path.exists(dst_path):
+ os.replace(dst_path, prev_path)
+
+ with closing(gzip.open(dst_path, 'wt', encoding = 'utf-8')) as f:
+ json.dump(data, f, indent = 1)
+
+ if self._flush_period > 0:
+ self._flush_next = time.clock_gettime(time.CLOCK_BOOTTIME) + self._flush_period
+
+ def flush_timeout(self):
+ if self._flush_period > 0:
+ return self._flush_next - time.clock_gettime(time.CLOCK_BOOTTIME)
+ return None
+
+ def auto_flush(self):
+ t = self.flush_timeout()
+ if t is not None and t <= 0:
+ self.flush()
+
+ def _set_key(self, d):
+ return '@'.join((d['family'], d['table'], d['name']))
+
+ def process(self, cmd, data):
+ if cmd not in (CMD_ADD, CMD_DELETE):
+ return
+
+ obj, val = nft_json_pop(data)
+
+ if (obj not in (OBJ_SET, OBJ_ELEMENT) or
+ not self._regex.search(val['name'])):
+ return
+
+ uid = self._set_key(val)
+ func = (self._set_added if (cmd == CMD_ADD and obj == OBJ_SET) else
+ self._set_removed if (cmd == CMD_DELETE and obj == OBJ_SET) else
+ self._elem_added if (cmd == CMD_ADD and obj == OBJ_ELEMENT) else
+ self._elem_removed if (cmd == CMD_DELETE and obj == OBJ_ELEMENT) else None)
+ func(uid, val)
+
+ def _set_added(self, uid, data):
+ if uid in self._sets and self._sets[uid].type != data['type']:
+ self._logger.warning('Set %s created with a mismatching type %s',
+ self._sets[uid], data['type'])
+ del self._sets[uid]
+
+ existing = False
+ try:
+ s = self._sets[uid]
+ s.update(data)
+ existing = True
+ except KeyError:
+ s = NftSet(self._logger, data)
+ self._sets[uid] = s
+
+ self._logger.info('%s set %s', 'Updated' if existing else 'Created', s)
+
+ elems = s.to_string()
+ if existing and elems:
+ cmd = 'add element %s %s %s { %s }' % (s.family, s.table, s.name, elems)
+ self._logger.info('%sdding elements to new set: %s',
+ 'A' if self._fw_modify else 'Not a',
+ cmd * self._logger.isEnabledFor(logging.DEBUG))
+
+ if self._fw_modify:
+ subprocess.run(['nft', '-f', '-'], input = cmd, check = True, text = True)
+
+ if not s.active:
+ self._logger.debug('Activating set %s', s)
+ s.active = True
+
+ def _set_removed(self, uid, data):
+ if not uid in self._sets:
+ return
+
+ s = self._sets[uid]
+
+ # FIXME: is this useful for anything?
+ self._logger.debug('Deactivating set %s', s)
+ s.active = False
+
+ def _elem_added(self, uid, data):
+ if not uid in self._sets:
+ self._logger.warning('Element added to non-tracked set %s', uid)
+ return
+
+ s = self._sets[uid]
+ for elem in data['elem']['set']:
+ self._logger.debug('Element added to %s: %s', s, elem)
+ s.add(elem)
+ def _elem_removed(self, uid, data):
+ if not uid in self._sets:
+ self._logger.warning('Element removed from non-tracked set %s', uid)
+ return
+
+ s = self._sets[uid]
+ for elem in data['elem']['set']:
+ self._logger.debug('Element removed from %s: %s', s, elem)
+ s.remove(elem)
+
+class NftMonitor:
+
+ _logger = None
+ _child = None
+
+ def __init__(self, logger, state_dir):
+ self._logger = logger.getChild(str(self))
+
+ self._child = subprocess.Popen(['nft', '-j', 'monitor'], stdout = subprocess.PIPE)
+ os.set_blocking(self._child.stdout.fileno(), False)
+
+ def __str__(self):
+ return self.__class__.__name__
+
+ def fileno(self):
+ return self._child.stdout.fileno()
+
+ def __enter__(self):
+ return self
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.close()
+
+ def close(self):
+ self._child.terminate()
+
+ try:
+ self._child.wait(timeout = 5)
+ except subprocess.TimeoutExpired:
+ self._child.kill()
+
+ def process(self, persistent_sets):
+ for line in self._child.stdout:
+ cmd, val = nft_json_pop(json.loads(line))
+
+ self._logger.debug('Nft event: %s %s', cmd, val)
+
+ persistent_sets.process(cmd, val)
+
+parser = argparse.ArgumentParser('Firewall management daemon')
+
+parser.add_argument('-d', '--debug', action = 'store_true')
+parser.add_argument('-v', '--verbose', action = 'count', default = 0)
+parser.add_argument('-q', '--quiet', action = 'count', default = 0)
+parser.add_argument('-n', '--dry-run', action = 'store_true')
+parser.add_argument('-s', '--syslog', action = 'store_true')
+parser.add_argument('-f', '--flush-period', type = int, default = 60 * 60)
+parser.add_argument('--state-dir', default = '/var/lib/naros')
+parser.add_argument('--persistent-set-regex', default = '^set_persist_')
+
+args = parser.parse_args(sys.argv[1:])
+
+progname = os.path.splitext(os.path.basename(sys.argv[0]))[0]
+
+logger = logging.getLogger(progname)
+# default to 30 (WARNING), every -q goes a level up, every -v a level down
+loglevel = max(10 * (3 + args.quiet - args.verbose), 0)
+logger.setLevel(loglevel)
+
+formatter = logging.Formatter(fmt = progname + ' %(name)s: %(message)s')
+
+handlers = [logging.StreamHandler()]
+
+if args.syslog:
+ handlers.append(logging.handlers.SysLogHandler('/dev/log'))
+
+for h in handlers:
+ h.setFormatter(formatter)
+ logger.addHandler(h)
+
+# log uncaught top-level exception
+def excepthook(t, v, tb, logger = logger):
+ logger.error('Uncaught top-level exception', exc_info = (t, v, tb))
+sys.excepthook = excepthook
+
+if not os.path.isdir(args.state_dir):
+ logger.error('State dir "%s" does not exist', args.state_dir)
+ sys.exit(1)
+
+with ExitStack() as stack:
+ sel = selectors.DefaultSelector()
+
+ nft_monitor = stack.enter_context(NftMonitor(logger, args.state_dir))
+ sel.register(nft_monitor, selectors.EVENT_READ)
+
+ psets = stack.enter_context(NftPersistentSets(logger, args.state_dir,
+ args.persistent_set_regex,
+ args.flush_period,
+ not args.dry_run))
+ # use SIGHUP to force flush
+ signal.signal(signal.SIGHUP, lambda sig, stack: psets.flush())
+
+ while True:
+ logger.debug('polling...')
+ events = sel.select(psets.flush_timeout())
+ logger.debug('got events')
+
+ for key, mask in events:
+ key.fileobj.process(psets)
+
+ psets.auto_flush()