From fb5f59d5897c1cf3f21322afd7367c738fa79d12 Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Thu, 14 Mar 2024 19:13:01 +0100 Subject: Add host sets. --- naros.py | 306 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 288 insertions(+), 18 deletions(-) diff --git a/naros.py b/naros.py index 170752c..ae9b844 100755 --- a/naros.py +++ b/naros.py @@ -13,6 +13,7 @@ import os import re import selectors import signal +from socket import getaddrinfo, gaierror, AF_INET, AF_INET6 import sys import subprocess import time @@ -35,6 +36,9 @@ def nft_json_pop(data): return key, val +def nft_set_uid(d): + return '@'.join((d['family'], d['table'], d['name'])) + class NftSetElem: TYPE_ADDR = 'addr' @@ -205,7 +209,7 @@ class NftPersistentSets: _flush_period = None _flush_next = None - def __init__(self, logger, state_dir, persistent_set_regex, + def __init__(self, logger, init_data, state_dir, persistent_set_regex, flush_period, fw_modify = True): self._logger = logger.getChild(str(self)) self._state_dir = state_dir @@ -218,17 +222,14 @@ class NftPersistentSets: # read initial state from the kernel self._logger.info('Reading initial sets from the kernel') - res = subprocess.run(['nft', '-j', 'list', 'sets'], capture_output = True, - check = True) - data = json.loads(res.stdout) count = 0 - for it in data['nftables']: + for it in init_data['nftables']: key, val = next(iter(it.items())) if key == OBJ_SET and self._regex.search(val['name']): - uid = self._set_key(val) + uid = nft_set_uid(val) self._logger.info('Parsing set from the kernel: %s', uid) - self._set_added(self._set_key(val), val) + self._set_added(uid, val) count += 1 self._logger.info('Read %d initial sets from the kernel', count) @@ -257,7 +258,7 @@ class NftPersistentSets: self._logger.info('Loading sets from storage') for s in sets: - uid = self._set_key(s) + uid = nft_set_uid(s) if uid in self._sets: raise ValueError('Duplicate set in stored data: %s', s) @@ -292,9 +293,6 @@ class NftPersistentSets: if t is not None and t <= 0: self.flush() - def _set_key(self, d): - return '@'.join((d['family'], d['table'], d['name'])) - def process(self, cmd, data): if cmd not in (CMD_ADD, CMD_DELETE): return @@ -305,7 +303,7 @@ class NftPersistentSets: not self._regex.search(val['name'])): return - uid = self._set_key(val) + uid = nft_set_uid(val) func = (self._set_added if (cmd == CMD_ADD and obj == OBJ_SET) else self._set_removed if (cmd == CMD_DELETE and obj == OBJ_SET) else self._elem_added if (cmd == CMD_ADD and obj == OBJ_ELEMENT) else @@ -372,6 +370,257 @@ class NftPersistentSets: self._logger.debug('Element removed from %s: %s', s, elem) s.remove(elem) +class NftHostSet: + data = None + hostname = None + addrs = None + active = None + + _logger = None + + def __init__(self, logger, data, hostname): + if data['type'] not in ('ipv4_addr', 'ipv6_addr'): + raise ValueError('Unexpected type for a host set: %s', data['type']) + + self.data = data + self.hostname = hostname + self.addrs = set() + self.active = False + self._logger = logger.getChild(str(self)) + + if 'elem' in data: + self._logger.info('Loading initial host addresses: %s', data['elem']) + self.addrs.update(data['elem']) + + def __str__(self): + return '%s(%s, %s, %s, %s):%s' % \ + (self.__class__.__name__, self.family, self.table, self.name, + self.type, self.hostname) + + def refresh(self): + family = AF_INET if self.type == 'ipv4_addr' else AF_INET6 + + try: + info = getaddrinfo(self.hostname, None, family) + except gaierror as e: + self._logger.info('Error resolving %s: %s', self.hostname, str(e)) + return False + + addrs = set((r[4][0] for r in info)) + ret = addrs != self.addrs + + self._logger.debug('Resolved %s: %s', self.hostname, ' '.join(addrs)) + + self.addrs = addrs + + return ret + + def __getattr__(self, key): + return self.data[key] + + def to_dict(self): + data = deepcopy(self.data) + data['elem'] = sorted(self.addrs) + + return data + +class NftHostSets: + + _STATE_FNAME = 'host_sets.json.gz' + + _logger = None + _regex = None + _fw_modify = None + _state_dir = None + + _refresh_next = None + _refresh_period = None + + def __init__(self, logger, init_data, state_dir, host_set_regex, + refresh_period, fw_modify = True): + self._logger = logger.getChild(str(self)) + self._state_dir = state_dir + self._refresh_period = refresh_period + self._fw_modify = fw_modify + self._regex = re.compile(host_set_regex) + if self._regex.groups < 1: + raise ValueError('Host set regex must have a capturing group for the encoded hostname') + + self._hosts = {} + self._load() + + # read initial state from the kernel + self._logger.info('Reading initial host sets from the kernel') + + count = 0 + for it in init_data['nftables']: + key, val = next(iter(it.items())) + + if key != OBJ_SET: + continue + + hostname = self._hostname_decode(val['name']) + if not hostname: + continue + + uid = nft_set_uid(val) + self._logger.info('Parsing host set from the kernel: %s', uid) + + h = NftHostSet(self._logger, val, hostname) + h.active = True + + self._hosts[uid] = h + + count += 1 + + self._logger.info('Read %d initial host sets from the kernel', count) + + self._refresh_next = time.clock_gettime(time.CLOCK_BOOTTIME) + self._refresh_period + self.refresh() + + def __str__(self): + return self.__class__.__name__ + + def __enter__(self): + return self + def __exit__(self, exc_type, exc_value, traceback): + self.flush() + + def _hostname_decode(self, set_name): + m = self._regex.search(set_name) + if not m: + return None + + to_decode = m.group(1) + + chars = [] + + i = 0 + while i < len(to_decode): + n = to_decode[i] + if n != '_': + chars.append(n) + i += 1 + continue + + if i == len(to_decode) - 1: + raise ValueError('Trailing _ in host name:', to_decode) + + n1 = to_decode[i + 1] + if n1 == '_': + chars.append('_') + i += 2 + continue + + m = re.match('[0-9]+', to_decode[i + 1:]) + if m: + chars.append(chr(int(m[0]))) + i += 1 + len(m[0]) + continue + + raise ValueError('Underscore followed by unhandled character:', to_decode[i + 1:]) + + return ''.join(chars) + + + def _load(self): + src_path = os.path.join(self._state_dir, self._STATE_FNAME) + + try: + with closing(gzip.open(src_path, 'rt')) as f: + data = json.load(f) + hosts = data['sets'] + except (FileNotFoundError, KeyError): + return + + self._logger.info('Loading host sets from storage') + + for h in hosts: + uid = nft_set_uid(h) + if uid in self._hosts: + raise ValueError('Duplicate host set in stored data: %s', h) + + self._logger.info('Loading host set from storage: %s', uid) + + hostname = self._hostname_decode(h['name']) + self._hosts[uid] = NftHostSet(self._logger, h, hostname) + + self._logger.info('Loaded %d host sets from storage', len(self._hosts)) + + def _fw_set_update(self, h): + cmd = 'flush set %s %s %s;' % (h.family, h.table, h.name) + + if h.addrs: + cmd += 'add element %s %s %s { %s };' % \ + (h.family, h.table, h.name, ','.join(h.addrs)) + + self._logger.info('%seplacing set: %s', + 'R' if self._fw_modify else 'Not r', + cmd * self._logger.isEnabledFor(logging.DEBUG)) + + if self._fw_modify: + subprocess.run(['nft', '-f', '-'], input = cmd, check = True, text = True) + + def refresh(self): + need_flush = False + + for h in self._hosts.values(): + if h.active and h.refresh(): + self._fw_set_update(h) + need_flush = True + + if need_flush: + self.flush() + + self._refresh_next = time.clock_gettime(time.CLOCK_BOOTTIME) + self._refresh_period + + def flush(self): + self._logger.info('Writing persistent sets to disk...') + + data = { 'sets' : [h.to_dict() for h in self._hosts.values()] } + dst_path = os.path.join(self._state_dir, self._STATE_FNAME) + prev_path = dst_path + '.prev' + + if os.path.exists(dst_path): + os.replace(dst_path, prev_path) + + with closing(gzip.open(dst_path, 'wt', encoding = 'utf-8')) as f: + json.dump(data, f, indent = 1) + + def refresh_timeout(self): + return self._refresh_next - time.clock_gettime(time.CLOCK_BOOTTIME) + + def auto_refresh(self): + t = self.refresh_timeout() + if t <= 0: + self.refresh() + + def process(self, cmd, data): + if cmd not in (CMD_ADD, CMD_DELETE): + return + + obj, val = nft_json_pop(data) + + if obj != OBJ_SET: + return + + hostname = self._hostname_decode(val['name']) + if not hostname: + return + + uid = nft_set_uid(val) + if cmd == CMD_DELETE and uid in self._hosts: + self._hosts[uid].active = False + else: + h = self._hosts[uid] if uid in self._hosts else \ + NftHostSet(self._logger, val, hostname) + + h.active = True + self._hosts[uid] = h + + h.refresh() + self._fw_set_update(h) + class NftMonitor: _logger = None @@ -402,13 +651,13 @@ class NftMonitor: except subprocess.TimeoutExpired: self._child.kill() - def process(self, persistent_sets): + def process(self, process_cb): for line in self._child.stdout: cmd, val = nft_json_pop(json.loads(line)) self._logger.debug('Nft event: %s %s', cmd, val) - persistent_sets.process(cmd, val) + process_cb(cmd, val) parser = argparse.ArgumentParser('Firewall management daemon') @@ -420,6 +669,8 @@ parser.add_argument('-s', '--syslog', action = 'store_true') parser.add_argument('-f', '--flush-period', type = int, default = 60 * 60) parser.add_argument('--state-dir', default = '/var/lib/naros') parser.add_argument('--persistent-set-regex', default = '^set_persist_') +parser.add_argument('--host-set-regex', default = 'set_host.*__(.*)') +parser.add_argument('--hosts-refresh-period', type = int, default = 60 * 60 * 24) args = parser.parse_args(sys.argv[1:]) @@ -456,19 +707,38 @@ with ExitStack() as stack: nft_monitor = stack.enter_context(NftMonitor(logger, args.state_dir)) sel.register(nft_monitor, selectors.EVENT_READ) - psets = stack.enter_context(NftPersistentSets(logger, args.state_dir, + # read initial state from the kernel + init_data = json.loads(subprocess.run(['nft', '-j', 'list', 'sets'], + capture_output = True, check = True).stdout) + + psets = stack.enter_context(NftPersistentSets(logger, init_data, args.state_dir, args.persistent_set_regex, args.flush_period, not args.dry_run)) + + hsets = stack.enter_context(NftHostSets(logger, init_data, args.state_dir, + args.host_set_regex, + args.hosts_refresh_period, + not args.dry_run)) # use SIGHUP to force flush - signal.signal(signal.SIGHUP, lambda sig, stack: psets.flush()) + def sighup_handler(sig, stack): + psets.flush() + hsets.flush() + signal.signal(signal.SIGHUP, sighup_handler) + + def process_cb(cmd, val): + psets.process(cmd, val) + hsets.process(cmd, val) while True: + timeout = min(psets.flush_timeout(), hsets.refresh_timeout()) + logger.debug('polling...') - events = sel.select(psets.flush_timeout()) + events = sel.select(timeout) logger.debug('got events') for key, mask in events: - key.fileobj.process(psets) + key.fileobj.process(process_cb) psets.auto_flush() + hsets.auto_refresh() -- cgit v1.2.3