From df98562922d401d4357099cdc21f9178a01ca1f9 Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Wed, 27 Dec 2023 10:25:20 +0100 Subject: Initial commit. --- naros.py | 474 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 474 insertions(+) create mode 100755 naros.py diff --git a/naros.py b/naros.py new file mode 100755 index 0000000..170752c --- /dev/null +++ b/naros.py @@ -0,0 +1,474 @@ +#!/usr/bin/python3 + +import argparse +from contextlib import closing, ExitStack +from datetime import datetime, timedelta, timezone +from copy import deepcopy +import gzip +import ipaddress +import json +import logging +import logging.handlers +import os +import re +import selectors +import signal +import sys +import subprocess +import time + +CMD_ADD = 'add' +CMD_DELETE = 'delete' + +OBJ_SET = 'set' +OBJ_ELEMENT = 'element' + +def nft_json_pop(data): + v = iter(data.items()) + try: + key, val = next(v) + except StopIteration: + raise ValueError('JSON data is empty') + + if next(v, None) is not None: + raise ValueError('JSON data has more than one element') + + return key, val + +class NftSetElem: + + TYPE_ADDR = 'addr' + TYPE_PREFIX = 'prefix' + TYPE_RANGE = 'range' + + _type = None + _data = None + _expires = None + + def __init__(self, data): + try: + elem = data['elem'] + except (TypeError, KeyError): + elem = data + + if isinstance(elem, dict) and 'val' in elem: + if 'expires' in elem: + expires = timedelta(seconds = elem['expires']) + self._expires = datetime.now(timezone.utc) + expires + elif 'expires_abs' in elem: + # not in nft, but used by our on-disk format + self._expires = datetime.fromisoformat(elem['expires_abs']) + + elem = elem['val'] + + if isinstance(elem, str): + self._type = self.TYPE_ADDR + self._data = elem + elif 'prefix' in elem: + p = elem['prefix'] + self._type = self.TYPE_PREFIX + self._data = (p['addr'], p['len']) + elif 'range' in elem: + start, end = elem['range'] + self._type = self.TYPE_RANGE + self._data = (start, end) + else: + raise ValueError('Unsupported set element: %s', elem) + + def expired(self, t): + return self._expires is not None and self._expires <= t + + def __eq__(self, other): + return self._type == other._type and self._data == other._data + + def __hash__(self): + return hash((self._type, self._data)) + + def to_string(self, t): + if self.expired(t): + return '' + + if self._type == self.TYPE_ADDR: + ret = self._data + elif self._type == self.TYPE_PREFIX: + ret = '%s/%d' % (self._data[0], self._data[1]) + elif self._type == self.TYPE_RANGE: + ret = '%s-%s' % (self._data[0], self._data[1]) + else: + raise ValueError + + if self._expires: + ret += ' timeout %ds' % int((self._expires - t).total_seconds()) + + return ret + + def to_tree(self, t): + if self.expired(t): + return None + + ret = {} + + if self._type == self.TYPE_ADDR: + d = self._data + elif self._type == self.TYPE_PREFIX: + d = { 'prefix' : { 'addr' : self._data[0], 'len' : self._data[1] } } + elif self._type == self.TYPE_RANGE: + d = { 'range' : (self._data[0], self._data[1]) } + else: + raise ValueError + + ret['val'] = d + + if self._expires: + ret['expires_abs'] = self._expires.isoformat() + + return ret + +class NftSet: + + # True when the set is active in the kernel + active = None + + _logger = None + _data = None + _elems = None + + def __init__(self, logger, data): + self._data = deepcopy(data) + self._data.pop('elem', None) + + if self.type not in ('ipv4_addr', 'ipv6_addr'): + raise ValueError('Persistent sets of type %s not supported', self.type) + + self.active = False + self._elems = set() + self._logger = logger.getChild(str(self)) + + # add initial set elements + self.update(data) + self._logger.info('Added %d initial elements', len(self._elems)) + + def __getattr__(self, key): + return self._data[key] + + def __str__(self): + return '%s(%s, %s, %s, %s)(%d elems)' % \ + (self.__class__.__name__, self.family, self.table, self.name, + self.type, len(self._elems)) + + def update(self, data): + if not 'elem' in data: + return + + for elem in data['elem']: + self._elems.add(NftSetElem(elem)) + + def gc(self): + before = len(self._elems) + now = datetime.now(timezone.utc) + self._elems = { e for e in self._elems if not e.expired(now) } + + self._logger.debug('gc, %d elements expired', before - len(self._elems)) + + return now + + def add(self, elem): + self._logger.debug('Adding element: %s', elem) + self._elems.add(NftSetElem(elem)) + + def remove(self, elem): + self._logger.debug('Removing element: %s', elem) + self._elems.discard(NftSetElem(elem)) + + def to_string(self): + now = self.gc() + return ','.join(filter(None, (e.to_string(now) for e in self._elems))) + + def to_tree(self): + now = self.gc() + + data = deepcopy(self._data) + data['elem'] = tuple(filter(None, (e.to_tree(now) for e in self._elems))) + + return data + +class NftPersistentSets: + + _STATE_FNAME = 'persistent_sets.json.gz' + + _sets = None + _logger = None + _regex = None + _fw_modify = None + _state_dir = None + + _flush_period = None + _flush_next = None + + def __init__(self, logger, state_dir, persistent_set_regex, + flush_period, fw_modify = True): + self._logger = logger.getChild(str(self)) + self._state_dir = state_dir + self._fw_modify = fw_modify + self._flush_period = flush_period + self._regex = re.compile(persistent_set_regex) + + self._sets = {} + self._load() + + # read initial state from the kernel + self._logger.info('Reading initial sets from the kernel') + res = subprocess.run(['nft', '-j', 'list', 'sets'], capture_output = True, + check = True) + data = json.loads(res.stdout) + + count = 0 + for it in data['nftables']: + key, val = next(iter(it.items())) + if key == OBJ_SET and self._regex.search(val['name']): + uid = self._set_key(val) + self._logger.info('Parsing set from the kernel: %s', uid) + self._set_added(self._set_key(val), val) + count += 1 + + self._logger.info('Read %d initial sets from the kernel', count) + + if self._flush_period > 0: + self._flush_next = time.clock_gettime(time.CLOCK_BOOTTIME) + self._flush_period + + def __str__(self): + return self.__class__.__name__ + + def __enter__(self): + return self + def __exit__(self, exc_type, exc_value, traceback): + self.flush() + + def _load(self): + src_path = os.path.join(self._state_dir, self._STATE_FNAME) + + try: + with closing(gzip.open(src_path, 'rt')) as f: + data = json.load(f) + sets = data['sets'] + except (FileNotFoundError, KeyError): + return + + self._logger.info('Loading sets from storage') + + for s in sets: + uid = self._set_key(s) + if uid in self._sets: + raise ValueError('Duplicate set in stored data: %s', s) + + self._logger.info('Loading set from storage: %s', uid) + self._sets[uid] = NftSet(self._logger, s) + + self._logger.info('Loaded %d sets from storage', len(self._sets)) + + def flush(self): + self._logger.info('Writing persistent sets to disk...') + + data = { 'sets' : [s.to_tree() for s in self._sets.values()] } + dst_path = os.path.join(self._state_dir, self._STATE_FNAME) + prev_path = dst_path + '.prev' + + if os.path.exists(dst_path): + os.replace(dst_path, prev_path) + + with closing(gzip.open(dst_path, 'wt', encoding = 'utf-8')) as f: + json.dump(data, f, indent = 1) + + if self._flush_period > 0: + self._flush_next = time.clock_gettime(time.CLOCK_BOOTTIME) + self._flush_period + + def flush_timeout(self): + if self._flush_period > 0: + return self._flush_next - time.clock_gettime(time.CLOCK_BOOTTIME) + return None + + def auto_flush(self): + t = self.flush_timeout() + if t is not None and t <= 0: + self.flush() + + def _set_key(self, d): + return '@'.join((d['family'], d['table'], d['name'])) + + def process(self, cmd, data): + if cmd not in (CMD_ADD, CMD_DELETE): + return + + obj, val = nft_json_pop(data) + + if (obj not in (OBJ_SET, OBJ_ELEMENT) or + not self._regex.search(val['name'])): + return + + uid = self._set_key(val) + func = (self._set_added if (cmd == CMD_ADD and obj == OBJ_SET) else + self._set_removed if (cmd == CMD_DELETE and obj == OBJ_SET) else + self._elem_added if (cmd == CMD_ADD and obj == OBJ_ELEMENT) else + self._elem_removed if (cmd == CMD_DELETE and obj == OBJ_ELEMENT) else None) + func(uid, val) + + def _set_added(self, uid, data): + if uid in self._sets and self._sets[uid].type != data['type']: + self._logger.warning('Set %s created with a mismatching type %s', + self._sets[uid], data['type']) + del self._sets[uid] + + existing = False + try: + s = self._sets[uid] + s.update(data) + existing = True + except KeyError: + s = NftSet(self._logger, data) + self._sets[uid] = s + + self._logger.info('%s set %s', 'Updated' if existing else 'Created', s) + + elems = s.to_string() + if existing and elems: + cmd = 'add element %s %s %s { %s }' % (s.family, s.table, s.name, elems) + self._logger.info('%sdding elements to new set: %s', + 'A' if self._fw_modify else 'Not a', + cmd * self._logger.isEnabledFor(logging.DEBUG)) + + if self._fw_modify: + subprocess.run(['nft', '-f', '-'], input = cmd, check = True, text = True) + + if not s.active: + self._logger.debug('Activating set %s', s) + s.active = True + + def _set_removed(self, uid, data): + if not uid in self._sets: + return + + s = self._sets[uid] + + # FIXME: is this useful for anything? + self._logger.debug('Deactivating set %s', s) + s.active = False + + def _elem_added(self, uid, data): + if not uid in self._sets: + self._logger.warning('Element added to non-tracked set %s', uid) + return + + s = self._sets[uid] + for elem in data['elem']['set']: + self._logger.debug('Element added to %s: %s', s, elem) + s.add(elem) + def _elem_removed(self, uid, data): + if not uid in self._sets: + self._logger.warning('Element removed from non-tracked set %s', uid) + return + + s = self._sets[uid] + for elem in data['elem']['set']: + self._logger.debug('Element removed from %s: %s', s, elem) + s.remove(elem) + +class NftMonitor: + + _logger = None + _child = None + + def __init__(self, logger, state_dir): + self._logger = logger.getChild(str(self)) + + self._child = subprocess.Popen(['nft', '-j', 'monitor'], stdout = subprocess.PIPE) + os.set_blocking(self._child.stdout.fileno(), False) + + def __str__(self): + return self.__class__.__name__ + + def fileno(self): + return self._child.stdout.fileno() + + def __enter__(self): + return self + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def close(self): + self._child.terminate() + + try: + self._child.wait(timeout = 5) + except subprocess.TimeoutExpired: + self._child.kill() + + def process(self, persistent_sets): + for line in self._child.stdout: + cmd, val = nft_json_pop(json.loads(line)) + + self._logger.debug('Nft event: %s %s', cmd, val) + + persistent_sets.process(cmd, val) + +parser = argparse.ArgumentParser('Firewall management daemon') + +parser.add_argument('-d', '--debug', action = 'store_true') +parser.add_argument('-v', '--verbose', action = 'count', default = 0) +parser.add_argument('-q', '--quiet', action = 'count', default = 0) +parser.add_argument('-n', '--dry-run', action = 'store_true') +parser.add_argument('-s', '--syslog', action = 'store_true') +parser.add_argument('-f', '--flush-period', type = int, default = 60 * 60) +parser.add_argument('--state-dir', default = '/var/lib/naros') +parser.add_argument('--persistent-set-regex', default = '^set_persist_') + +args = parser.parse_args(sys.argv[1:]) + +progname = os.path.splitext(os.path.basename(sys.argv[0]))[0] + +logger = logging.getLogger(progname) +# default to 30 (WARNING), every -q goes a level up, every -v a level down +loglevel = max(10 * (3 + args.quiet - args.verbose), 0) +logger.setLevel(loglevel) + +formatter = logging.Formatter(fmt = progname + ' %(name)s: %(message)s') + +handlers = [logging.StreamHandler()] + +if args.syslog: + handlers.append(logging.handlers.SysLogHandler('/dev/log')) + +for h in handlers: + h.setFormatter(formatter) + logger.addHandler(h) + +# log uncaught top-level exception +def excepthook(t, v, tb, logger = logger): + logger.error('Uncaught top-level exception', exc_info = (t, v, tb)) +sys.excepthook = excepthook + +if not os.path.isdir(args.state_dir): + logger.error('State dir "%s" does not exist', args.state_dir) + sys.exit(1) + +with ExitStack() as stack: + sel = selectors.DefaultSelector() + + nft_monitor = stack.enter_context(NftMonitor(logger, args.state_dir)) + sel.register(nft_monitor, selectors.EVENT_READ) + + psets = stack.enter_context(NftPersistentSets(logger, args.state_dir, + args.persistent_set_regex, + args.flush_period, + not args.dry_run)) + # use SIGHUP to force flush + signal.signal(signal.SIGHUP, lambda sig, stack: psets.flush()) + + while True: + logger.debug('polling...') + events = sel.select(psets.flush_timeout()) + logger.debug('got events') + + for key, mask in events: + key.fileobj.process(psets) + + psets.auto_flush() -- cgit v1.2.3