#!/usr/bin/python3 import argparse from contextlib import closing, ExitStack from datetime import datetime, timedelta, timezone from copy import deepcopy import gzip import ipaddress import json import logging import logging.handlers import os import re import selectors import signal from socket import getaddrinfo, gaierror, AF_INET, AF_INET6 import sys import subprocess import time CMD_ADD = 'add' CMD_DELETE = 'delete' OBJ_SET = 'set' OBJ_ELEMENT = 'element' def nft_json_pop(data): v = iter(data.items()) try: key, val = next(v) except StopIteration: raise ValueError('JSON data is empty') if next(v, None) is not None: raise ValueError('JSON data has more than one element') return key, val def nft_set_uid(d): return '@'.join((d['family'], d['table'], d['name'])) class NftSetElem: TYPE_ADDR = 'addr' TYPE_PREFIX = 'prefix' TYPE_RANGE = 'range' _type = None _data = None _expires = None def __init__(self, data): try: elem = data['elem'] except (TypeError, KeyError): elem = data if isinstance(elem, dict) and 'val' in elem: if 'expires' in elem: expires = timedelta(seconds = elem['expires']) self._expires = datetime.now(timezone.utc) + expires elif 'expires_abs' in elem: # not in nft, but used by our on-disk format self._expires = datetime.fromisoformat(elem['expires_abs']) elem = elem['val'] if isinstance(elem, str): self._type = self.TYPE_ADDR self._data = elem elif 'prefix' in elem: p = elem['prefix'] self._type = self.TYPE_PREFIX self._data = (p['addr'], p['len']) elif 'range' in elem: start, end = elem['range'] self._type = self.TYPE_RANGE self._data = (start, end) else: raise ValueError('Unsupported set element: %s', elem) def expired(self, t): return self._expires is not None and self._expires <= t def __eq__(self, other): return self._type == other._type and self._data == other._data def __hash__(self): return hash((self._type, self._data)) def to_string(self, t): if self.expired(t): return '' if self._type == self.TYPE_ADDR: ret = self._data elif self._type == self.TYPE_PREFIX: ret = '%s/%d' % (self._data[0], self._data[1]) elif self._type == self.TYPE_RANGE: ret = '%s-%s' % (self._data[0], self._data[1]) else: raise ValueError if self._expires: ret += ' timeout %ds' % int((self._expires - t).total_seconds()) return ret def to_tree(self, t): if self.expired(t): return None ret = {} if self._type == self.TYPE_ADDR: d = self._data elif self._type == self.TYPE_PREFIX: d = { 'prefix' : { 'addr' : self._data[0], 'len' : self._data[1] } } elif self._type == self.TYPE_RANGE: d = { 'range' : (self._data[0], self._data[1]) } else: raise ValueError ret['val'] = d if self._expires: ret['expires_abs'] = self._expires.isoformat() return ret class NftSet: # True when the set is active in the kernel active = None _logger = None _data = None _elems = None def __init__(self, logger, data): self._data = deepcopy(data) self._data.pop('elem', None) if self.type not in ('ipv4_addr', 'ipv6_addr'): raise ValueError('Persistent sets of type %s not supported', self.type) self.active = False self._elems = set() self._logger = logger.getChild(str(self)) # add initial set elements self.update(data) self._logger.info('Added %d initial elements', len(self._elems)) def __getattr__(self, key): return self._data[key] def __str__(self): return '%s(%s, %s, %s, %s)(%d elems)' % \ (self.__class__.__name__, self.family, self.table, self.name, self.type, len(self._elems)) def update(self, data): if not 'elem' in data: return for elem in data['elem']: self._elems.add(NftSetElem(elem)) def gc(self): before = len(self._elems) now = datetime.now(timezone.utc) self._elems = { e for e in self._elems if not e.expired(now) } self._logger.debug('gc, %d elements expired', before - len(self._elems)) return now def add(self, elem): self._logger.debug('Adding element: %s', elem) self._elems.add(NftSetElem(elem)) def remove(self, elem): self._logger.debug('Removing element: %s', elem) self._elems.discard(NftSetElem(elem)) def to_string(self): now = self.gc() return ','.join(filter(None, (e.to_string(now) for e in self._elems))) def to_tree(self): now = self.gc() data = deepcopy(self._data) data['elem'] = tuple(filter(None, (e.to_tree(now) for e in self._elems))) return data class NftPersistentSets: _STATE_FNAME = 'persistent_sets.json.gz' _sets = None _logger = None _regex = None _fw_modify = None _state_dir = None _flush_period = None _flush_next = None def __init__(self, logger, init_data, state_dir, persistent_set_regex, flush_period, fw_modify = True): self._logger = logger.getChild(str(self)) self._state_dir = state_dir self._fw_modify = fw_modify self._flush_period = flush_period self._regex = re.compile(persistent_set_regex) self._sets = {} self._load() # read initial state from the kernel self._logger.info('Reading initial sets from the kernel') count = 0 for it in init_data['nftables']: key, val = next(iter(it.items())) if key == OBJ_SET and self._regex.search(val['name']): uid = nft_set_uid(val) self._logger.info('Parsing set from the kernel: %s', uid) self._set_added(uid, val) count += 1 self._logger.info('Read %d initial sets from the kernel', count) if self._flush_period > 0: self._flush_next = time.clock_gettime(time.CLOCK_BOOTTIME) + self._flush_period def __str__(self): return self.__class__.__name__ def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.flush() def _load(self): src_path = os.path.join(self._state_dir, self._STATE_FNAME) try: with closing(gzip.open(src_path, 'rt')) as f: data = json.load(f) sets = data['sets'] except (FileNotFoundError, KeyError): return self._logger.info('Loading sets from storage') for s in sets: uid = nft_set_uid(s) if uid in self._sets: raise ValueError('Duplicate set in stored data: %s', s) self._logger.info('Loading set from storage: %s', uid) self._sets[uid] = NftSet(self._logger, s) self._logger.info('Loaded %d sets from storage', len(self._sets)) def flush(self): self._logger.info('Writing persistent sets to disk...') data = { 'sets' : [s.to_tree() for s in self._sets.values()] } dst_path = os.path.join(self._state_dir, self._STATE_FNAME) prev_path = dst_path + '.prev' if os.path.exists(dst_path): os.replace(dst_path, prev_path) with closing(gzip.open(dst_path, 'wt', encoding = 'utf-8')) as f: json.dump(data, f, indent = 1) if self._flush_period > 0: self._flush_next = time.clock_gettime(time.CLOCK_BOOTTIME) + self._flush_period def flush_timeout(self): if self._flush_period > 0: return self._flush_next - time.clock_gettime(time.CLOCK_BOOTTIME) return None def auto_flush(self): t = self.flush_timeout() if t is not None and t <= 0: self.flush() def process(self, cmd, data): if cmd not in (CMD_ADD, CMD_DELETE): return obj, val = nft_json_pop(data) if (obj not in (OBJ_SET, OBJ_ELEMENT) or not self._regex.search(val['name'])): return uid = nft_set_uid(val) func = (self._set_added if (cmd == CMD_ADD and obj == OBJ_SET) else self._set_removed if (cmd == CMD_DELETE and obj == OBJ_SET) else self._elem_added if (cmd == CMD_ADD and obj == OBJ_ELEMENT) else self._elem_removed if (cmd == CMD_DELETE and obj == OBJ_ELEMENT) else None) func(uid, val) def _set_added(self, uid, data): if uid in self._sets and self._sets[uid].type != data['type']: self._logger.warning('Set %s created with a mismatching type %s', self._sets[uid], data['type']) del self._sets[uid] existing = False try: s = self._sets[uid] s.update(data) existing = True except KeyError: s = NftSet(self._logger, data) self._sets[uid] = s self._logger.info('%s set %s', 'Updated' if existing else 'Created', s) elems = s.to_string() if existing and elems: cmd = 'add element %s %s %s { %s }' % (s.family, s.table, s.name, elems) self._logger.info('%sdding elements to new set: %s', 'A' if self._fw_modify else 'Not a', cmd * self._logger.isEnabledFor(logging.DEBUG)) if self._fw_modify: subprocess.run(['nft', '-f', '-'], input = cmd, check = True, text = True) if not s.active: self._logger.debug('Activating set %s', s) s.active = True def _set_removed(self, uid, data): if not uid in self._sets: return s = self._sets[uid] # FIXME: is this useful for anything? self._logger.debug('Deactivating set %s', s) s.active = False def _elem_added(self, uid, data): if not uid in self._sets: self._logger.warning('Element added to non-tracked set %s', uid) return s = self._sets[uid] for elem in data['elem']['set']: self._logger.debug('Element added to %s: %s', s, elem) s.add(elem) def _elem_removed(self, uid, data): if not uid in self._sets: self._logger.warning('Element removed from non-tracked set %s', uid) return s = self._sets[uid] for elem in data['elem']['set']: self._logger.debug('Element removed from %s: %s', s, elem) s.remove(elem) class NftHostSet: data = None hostname = None addrs = None active = None _logger = None def __init__(self, logger, data, hostname): if data['type'] not in ('ipv4_addr', 'ipv6_addr'): raise ValueError('Unexpected type for a host set: %s', data['type']) self.data = data self.hostname = hostname self.addrs = set() self.active = False self._logger = logger.getChild(str(self)) if 'elem' in data: self._logger.info('Loading initial host addresses: %s', data['elem']) self.addrs.update(data['elem']) def __str__(self): return '%s(%s, %s, %s, %s):%s' % \ (self.__class__.__name__, self.family, self.table, self.name, self.type, self.hostname) def refresh(self): family = AF_INET if self.type == 'ipv4_addr' else AF_INET6 try: info = getaddrinfo(self.hostname, None, family) except gaierror as e: self._logger.info('Error resolving %s: %s', self.hostname, str(e)) return False addrs = set((r[4][0] for r in info)) ret = addrs != self.addrs self._logger.debug('Resolved %s: %s', self.hostname, ' '.join(addrs)) self.addrs = addrs return ret def __getattr__(self, key): return self.data[key] def to_dict(self): data = deepcopy(self.data) data['elem'] = sorted(self.addrs) return data class NftHostSets: _STATE_FNAME = 'host_sets.json.gz' _logger = None _regex = None _fw_modify = None _state_dir = None _refresh_next = None _refresh_period = None def __init__(self, logger, init_data, state_dir, host_set_regex, refresh_period, fw_modify = True): self._logger = logger.getChild(str(self)) self._state_dir = state_dir self._refresh_period = refresh_period self._fw_modify = fw_modify self._regex = re.compile(host_set_regex) if self._regex.groups < 1: raise ValueError('Host set regex must have a capturing group for the encoded hostname') self._hosts = {} self._load() # read initial state from the kernel self._logger.info('Reading initial host sets from the kernel') count = 0 for it in init_data['nftables']: key, val = next(iter(it.items())) if key != OBJ_SET: continue hostname = self._hostname_decode(val['name']) if not hostname: continue uid = nft_set_uid(val) self._logger.info('Parsing host set from the kernel: %s', uid) h = NftHostSet(self._logger, val, hostname) h.active = True self._hosts[uid] = h count += 1 self._logger.info('Read %d initial host sets from the kernel', count) self._refresh_next = time.clock_gettime(time.CLOCK_BOOTTIME) + self._refresh_period self.refresh() def __str__(self): return self.__class__.__name__ def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.flush() def _hostname_decode(self, set_name): m = self._regex.search(set_name) if not m: return None to_decode = m.group(1) chars = [] i = 0 while i < len(to_decode): n = to_decode[i] if n != '_': chars.append(n) i += 1 continue if i == len(to_decode) - 1: raise ValueError('Trailing _ in host name:', to_decode) n1 = to_decode[i + 1] if n1 == '_': chars.append('_') i += 2 continue m = re.match('[0-9]+', to_decode[i + 1:]) if m: chars.append(chr(int(m[0]))) i += 1 + len(m[0]) continue raise ValueError('Underscore followed by unhandled character:', to_decode[i + 1:]) return ''.join(chars) def _load(self): src_path = os.path.join(self._state_dir, self._STATE_FNAME) try: with closing(gzip.open(src_path, 'rt')) as f: data = json.load(f) hosts = data['sets'] except (FileNotFoundError, KeyError): return self._logger.info('Loading host sets from storage') for h in hosts: uid = nft_set_uid(h) if uid in self._hosts: raise ValueError('Duplicate host set in stored data: %s', h) self._logger.info('Loading host set from storage: %s', uid) hostname = self._hostname_decode(h['name']) self._hosts[uid] = NftHostSet(self._logger, h, hostname) self._logger.info('Loaded %d host sets from storage', len(self._hosts)) def _fw_set_update(self, h): cmd = 'flush set %s %s %s;' % (h.family, h.table, h.name) if h.addrs: cmd += 'add element %s %s %s { %s };' % \ (h.family, h.table, h.name, ','.join(h.addrs)) self._logger.info('%seplacing set: %s', 'R' if self._fw_modify else 'Not r', cmd * self._logger.isEnabledFor(logging.DEBUG)) if self._fw_modify: subprocess.run(['nft', '-f', '-'], input = cmd, check = True, text = True) def refresh(self): need_flush = False for h in self._hosts.values(): if h.active and h.refresh(): self._fw_set_update(h) need_flush = True if need_flush: self.flush() self._refresh_next = time.clock_gettime(time.CLOCK_BOOTTIME) + self._refresh_period def flush(self): self._logger.info('Writing persistent sets to disk...') data = { 'sets' : [h.to_dict() for h in self._hosts.values()] } dst_path = os.path.join(self._state_dir, self._STATE_FNAME) prev_path = dst_path + '.prev' if os.path.exists(dst_path): os.replace(dst_path, prev_path) with closing(gzip.open(dst_path, 'wt', encoding = 'utf-8')) as f: json.dump(data, f, indent = 1) def refresh_timeout(self): return self._refresh_next - time.clock_gettime(time.CLOCK_BOOTTIME) def auto_refresh(self): t = self.refresh_timeout() if t <= 0: self.refresh() def process(self, cmd, data): if cmd not in (CMD_ADD, CMD_DELETE): return obj, val = nft_json_pop(data) if obj != OBJ_SET: return hostname = self._hostname_decode(val['name']) if not hostname: return uid = nft_set_uid(val) if cmd == CMD_DELETE and uid in self._hosts: self._hosts[uid].active = False else: h = self._hosts[uid] if uid in self._hosts else \ NftHostSet(self._logger, val, hostname) h.active = True self._hosts[uid] = h h.refresh() self._fw_set_update(h) class NftMonitor: _logger = None _child = None def __init__(self, logger, state_dir): self._logger = logger.getChild(str(self)) self._child = subprocess.Popen(['nft', '-j', 'monitor'], stdout = subprocess.PIPE) os.set_blocking(self._child.stdout.fileno(), False) def __str__(self): return self.__class__.__name__ def fileno(self): return self._child.stdout.fileno() def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close() def close(self): self._child.terminate() try: self._child.wait(timeout = 5) except subprocess.TimeoutExpired: self._child.kill() def process(self, process_cb): for line in self._child.stdout: try: data = json.loads(line) except json.decoder.JSONDecodeError: self._logger.error('Error decoding event: %s', line) continue cmd, val = nft_json_pop(data) self._logger.debug('Nft event: %s %s', cmd, val) process_cb(cmd, val) parser = argparse.ArgumentParser('Firewall management daemon') parser.add_argument('-v', '--verbose', action = 'count', default = 0) parser.add_argument('-q', '--quiet', action = 'count', default = 0) parser.add_argument('-n', '--dry-run', action = 'store_true') parser.add_argument('-s', '--syslog', action = 'store_true') parser.add_argument('-f', '--flush-period', type = int, default = 60 * 60) parser.add_argument('--state-dir', default = '/var/lib/naros') parser.add_argument('--persistent-set-regex', default = '^set_persist_') parser.add_argument('--host-set-regex', default = '^set_host.*__(.*)') parser.add_argument('--hosts-refresh-period', type = int, default = 60 * 60 * 24) args = parser.parse_args(sys.argv[1:]) progname = os.path.splitext(os.path.basename(sys.argv[0]))[0] logger = logging.getLogger(progname) # default to 30 (WARNING), every -q goes a level up, every -v a level down loglevel = max(10 * (3 + args.quiet - args.verbose), 1) logger.setLevel(loglevel) formatter = logging.Formatter(fmt = progname + ' %(name)s: %(message)s') handlers = [logging.StreamHandler()] if args.syslog: handlers.append(logging.handlers.SysLogHandler('/dev/log')) for h in handlers: h.setFormatter(formatter) logger.addHandler(h) # log uncaught top-level exception def excepthook(t, v, tb, logger = logger): logger.error('Uncaught top-level exception', exc_info = (t, v, tb)) sys.excepthook = excepthook if not os.path.isdir(args.state_dir): logger.error('State dir "%s" does not exist', args.state_dir) sys.exit(1) with ExitStack() as stack: sel = selectors.DefaultSelector() nft_monitor = stack.enter_context(NftMonitor(logger, args.state_dir)) sel.register(nft_monitor, selectors.EVENT_READ) # read initial state from the kernel init_data = json.loads(subprocess.run(['nft', '-j', 'list', 'sets'], capture_output = True, check = True).stdout) psets = stack.enter_context(NftPersistentSets(logger, init_data, args.state_dir, args.persistent_set_regex, args.flush_period, not args.dry_run)) hsets = stack.enter_context(NftHostSets(logger, init_data, args.state_dir, args.host_set_regex, args.hosts_refresh_period, not args.dry_run)) # use SIGHUP to force flush def sighup_handler(sig, stack): psets.flush() hsets.flush() signal.signal(signal.SIGHUP, sighup_handler) def process_cb(cmd, val): psets.process(cmd, val) hsets.process(cmd, val) while True: timeout = min(psets.flush_timeout(), hsets.refresh_timeout()) logger.debug('polling...') events = sel.select(timeout) logger.debug('got events') for key, mask in events: key.fileobj.process(process_cb) psets.auto_flush() hsets.auto_refresh()