summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnton Khirnov <anton@khirnov.net>2024-03-14 19:13:01 +0100
committerAnton Khirnov <anton@khirnov.net>2024-03-14 19:13:01 +0100
commitfb5f59d5897c1cf3f21322afd7367c738fa79d12 (patch)
tree7ede2fc988874674e7461b8e83aae27b918ddd2b
parentdf98562922d401d4357099cdc21f9178a01ca1f9 (diff)
Add host sets.
-rwxr-xr-xnaros.py306
1 files changed, 288 insertions, 18 deletions
diff --git a/naros.py b/naros.py
index 170752c..ae9b844 100755
--- a/naros.py
+++ b/naros.py
@@ -13,6 +13,7 @@ import os
import re
import selectors
import signal
+from socket import getaddrinfo, gaierror, AF_INET, AF_INET6
import sys
import subprocess
import time
@@ -35,6 +36,9 @@ def nft_json_pop(data):
return key, val
+def nft_set_uid(d):
+ return '@'.join((d['family'], d['table'], d['name']))
+
class NftSetElem:
TYPE_ADDR = 'addr'
@@ -205,7 +209,7 @@ class NftPersistentSets:
_flush_period = None
_flush_next = None
- def __init__(self, logger, state_dir, persistent_set_regex,
+ def __init__(self, logger, init_data, state_dir, persistent_set_regex,
flush_period, fw_modify = True):
self._logger = logger.getChild(str(self))
self._state_dir = state_dir
@@ -218,17 +222,14 @@ class NftPersistentSets:
# read initial state from the kernel
self._logger.info('Reading initial sets from the kernel')
- res = subprocess.run(['nft', '-j', 'list', 'sets'], capture_output = True,
- check = True)
- data = json.loads(res.stdout)
count = 0
- for it in data['nftables']:
+ for it in init_data['nftables']:
key, val = next(iter(it.items()))
if key == OBJ_SET and self._regex.search(val['name']):
- uid = self._set_key(val)
+ uid = nft_set_uid(val)
self._logger.info('Parsing set from the kernel: %s', uid)
- self._set_added(self._set_key(val), val)
+ self._set_added(uid, val)
count += 1
self._logger.info('Read %d initial sets from the kernel', count)
@@ -257,7 +258,7 @@ class NftPersistentSets:
self._logger.info('Loading sets from storage')
for s in sets:
- uid = self._set_key(s)
+ uid = nft_set_uid(s)
if uid in self._sets:
raise ValueError('Duplicate set in stored data: %s', s)
@@ -292,9 +293,6 @@ class NftPersistentSets:
if t is not None and t <= 0:
self.flush()
- def _set_key(self, d):
- return '@'.join((d['family'], d['table'], d['name']))
-
def process(self, cmd, data):
if cmd not in (CMD_ADD, CMD_DELETE):
return
@@ -305,7 +303,7 @@ class NftPersistentSets:
not self._regex.search(val['name'])):
return
- uid = self._set_key(val)
+ uid = nft_set_uid(val)
func = (self._set_added if (cmd == CMD_ADD and obj == OBJ_SET) else
self._set_removed if (cmd == CMD_DELETE and obj == OBJ_SET) else
self._elem_added if (cmd == CMD_ADD and obj == OBJ_ELEMENT) else
@@ -372,6 +370,257 @@ class NftPersistentSets:
self._logger.debug('Element removed from %s: %s', s, elem)
s.remove(elem)
+class NftHostSet:
+ data = None
+ hostname = None
+ addrs = None
+ active = None
+
+ _logger = None
+
+ def __init__(self, logger, data, hostname):
+ if data['type'] not in ('ipv4_addr', 'ipv6_addr'):
+ raise ValueError('Unexpected type for a host set: %s', data['type'])
+
+ self.data = data
+ self.hostname = hostname
+ self.addrs = set()
+ self.active = False
+ self._logger = logger.getChild(str(self))
+
+ if 'elem' in data:
+ self._logger.info('Loading initial host addresses: %s', data['elem'])
+ self.addrs.update(data['elem'])
+
+ def __str__(self):
+ return '%s(%s, %s, %s, %s):%s' % \
+ (self.__class__.__name__, self.family, self.table, self.name,
+ self.type, self.hostname)
+
+ def refresh(self):
+ family = AF_INET if self.type == 'ipv4_addr' else AF_INET6
+
+ try:
+ info = getaddrinfo(self.hostname, None, family)
+ except gaierror as e:
+ self._logger.info('Error resolving %s: %s', self.hostname, str(e))
+ return False
+
+ addrs = set((r[4][0] for r in info))
+ ret = addrs != self.addrs
+
+ self._logger.debug('Resolved %s: %s', self.hostname, ' '.join(addrs))
+
+ self.addrs = addrs
+
+ return ret
+
+ def __getattr__(self, key):
+ return self.data[key]
+
+ def to_dict(self):
+ data = deepcopy(self.data)
+ data['elem'] = sorted(self.addrs)
+
+ return data
+
+class NftHostSets:
+
+ _STATE_FNAME = 'host_sets.json.gz'
+
+ _logger = None
+ _regex = None
+ _fw_modify = None
+ _state_dir = None
+
+ _refresh_next = None
+ _refresh_period = None
+
+ def __init__(self, logger, init_data, state_dir, host_set_regex,
+ refresh_period, fw_modify = True):
+ self._logger = logger.getChild(str(self))
+ self._state_dir = state_dir
+ self._refresh_period = refresh_period
+ self._fw_modify = fw_modify
+ self._regex = re.compile(host_set_regex)
+ if self._regex.groups < 1:
+ raise ValueError('Host set regex must have a capturing group for the encoded hostname')
+
+ self._hosts = {}
+ self._load()
+
+ # read initial state from the kernel
+ self._logger.info('Reading initial host sets from the kernel')
+
+ count = 0
+ for it in init_data['nftables']:
+ key, val = next(iter(it.items()))
+
+ if key != OBJ_SET:
+ continue
+
+ hostname = self._hostname_decode(val['name'])
+ if not hostname:
+ continue
+
+ uid = nft_set_uid(val)
+ self._logger.info('Parsing host set from the kernel: %s', uid)
+
+ h = NftHostSet(self._logger, val, hostname)
+ h.active = True
+
+ self._hosts[uid] = h
+
+ count += 1
+
+ self._logger.info('Read %d initial host sets from the kernel', count)
+
+ self._refresh_next = time.clock_gettime(time.CLOCK_BOOTTIME) + self._refresh_period
+ self.refresh()
+
+ def __str__(self):
+ return self.__class__.__name__
+
+ def __enter__(self):
+ return self
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.flush()
+
+ def _hostname_decode(self, set_name):
+ m = self._regex.search(set_name)
+ if not m:
+ return None
+
+ to_decode = m.group(1)
+
+ chars = []
+
+ i = 0
+ while i < len(to_decode):
+ n = to_decode[i]
+ if n != '_':
+ chars.append(n)
+ i += 1
+ continue
+
+ if i == len(to_decode) - 1:
+ raise ValueError('Trailing _ in host name:', to_decode)
+
+ n1 = to_decode[i + 1]
+ if n1 == '_':
+ chars.append('_')
+ i += 2
+ continue
+
+ m = re.match('[0-9]+', to_decode[i + 1:])
+ if m:
+ chars.append(chr(int(m[0])))
+ i += 1 + len(m[0])
+ continue
+
+ raise ValueError('Underscore followed by unhandled character:', to_decode[i + 1:])
+
+ return ''.join(chars)
+
+
+ def _load(self):
+ src_path = os.path.join(self._state_dir, self._STATE_FNAME)
+
+ try:
+ with closing(gzip.open(src_path, 'rt')) as f:
+ data = json.load(f)
+ hosts = data['sets']
+ except (FileNotFoundError, KeyError):
+ return
+
+ self._logger.info('Loading host sets from storage')
+
+ for h in hosts:
+ uid = nft_set_uid(h)
+ if uid in self._hosts:
+ raise ValueError('Duplicate host set in stored data: %s', h)
+
+ self._logger.info('Loading host set from storage: %s', uid)
+
+ hostname = self._hostname_decode(h['name'])
+ self._hosts[uid] = NftHostSet(self._logger, h, hostname)
+
+ self._logger.info('Loaded %d host sets from storage', len(self._hosts))
+
+ def _fw_set_update(self, h):
+ cmd = 'flush set %s %s %s;' % (h.family, h.table, h.name)
+
+ if h.addrs:
+ cmd += 'add element %s %s %s { %s };' % \
+ (h.family, h.table, h.name, ','.join(h.addrs))
+
+ self._logger.info('%seplacing set: %s',
+ 'R' if self._fw_modify else 'Not r',
+ cmd * self._logger.isEnabledFor(logging.DEBUG))
+
+ if self._fw_modify:
+ subprocess.run(['nft', '-f', '-'], input = cmd, check = True, text = True)
+
+ def refresh(self):
+ need_flush = False
+
+ for h in self._hosts.values():
+ if h.active and h.refresh():
+ self._fw_set_update(h)
+ need_flush = True
+
+ if need_flush:
+ self.flush()
+
+ self._refresh_next = time.clock_gettime(time.CLOCK_BOOTTIME) + self._refresh_period
+
+ def flush(self):
+ self._logger.info('Writing persistent sets to disk...')
+
+ data = { 'sets' : [h.to_dict() for h in self._hosts.values()] }
+ dst_path = os.path.join(self._state_dir, self._STATE_FNAME)
+ prev_path = dst_path + '.prev'
+
+ if os.path.exists(dst_path):
+ os.replace(dst_path, prev_path)
+
+ with closing(gzip.open(dst_path, 'wt', encoding = 'utf-8')) as f:
+ json.dump(data, f, indent = 1)
+
+ def refresh_timeout(self):
+ return self._refresh_next - time.clock_gettime(time.CLOCK_BOOTTIME)
+
+ def auto_refresh(self):
+ t = self.refresh_timeout()
+ if t <= 0:
+ self.refresh()
+
+ def process(self, cmd, data):
+ if cmd not in (CMD_ADD, CMD_DELETE):
+ return
+
+ obj, val = nft_json_pop(data)
+
+ if obj != OBJ_SET:
+ return
+
+ hostname = self._hostname_decode(val['name'])
+ if not hostname:
+ return
+
+ uid = nft_set_uid(val)
+ if cmd == CMD_DELETE and uid in self._hosts:
+ self._hosts[uid].active = False
+ else:
+ h = self._hosts[uid] if uid in self._hosts else \
+ NftHostSet(self._logger, val, hostname)
+
+ h.active = True
+ self._hosts[uid] = h
+
+ h.refresh()
+ self._fw_set_update(h)
+
class NftMonitor:
_logger = None
@@ -402,13 +651,13 @@ class NftMonitor:
except subprocess.TimeoutExpired:
self._child.kill()
- def process(self, persistent_sets):
+ def process(self, process_cb):
for line in self._child.stdout:
cmd, val = nft_json_pop(json.loads(line))
self._logger.debug('Nft event: %s %s', cmd, val)
- persistent_sets.process(cmd, val)
+ process_cb(cmd, val)
parser = argparse.ArgumentParser('Firewall management daemon')
@@ -420,6 +669,8 @@ parser.add_argument('-s', '--syslog', action = 'store_true')
parser.add_argument('-f', '--flush-period', type = int, default = 60 * 60)
parser.add_argument('--state-dir', default = '/var/lib/naros')
parser.add_argument('--persistent-set-regex', default = '^set_persist_')
+parser.add_argument('--host-set-regex', default = 'set_host.*__(.*)')
+parser.add_argument('--hosts-refresh-period', type = int, default = 60 * 60 * 24)
args = parser.parse_args(sys.argv[1:])
@@ -456,19 +707,38 @@ with ExitStack() as stack:
nft_monitor = stack.enter_context(NftMonitor(logger, args.state_dir))
sel.register(nft_monitor, selectors.EVENT_READ)
- psets = stack.enter_context(NftPersistentSets(logger, args.state_dir,
+ # read initial state from the kernel
+ init_data = json.loads(subprocess.run(['nft', '-j', 'list', 'sets'],
+ capture_output = True, check = True).stdout)
+
+ psets = stack.enter_context(NftPersistentSets(logger, init_data, args.state_dir,
args.persistent_set_regex,
args.flush_period,
not args.dry_run))
+
+ hsets = stack.enter_context(NftHostSets(logger, init_data, args.state_dir,
+ args.host_set_regex,
+ args.hosts_refresh_period,
+ not args.dry_run))
# use SIGHUP to force flush
- signal.signal(signal.SIGHUP, lambda sig, stack: psets.flush())
+ def sighup_handler(sig, stack):
+ psets.flush()
+ hsets.flush()
+ signal.signal(signal.SIGHUP, sighup_handler)
+
+ def process_cb(cmd, val):
+ psets.process(cmd, val)
+ hsets.process(cmd, val)
while True:
+ timeout = min(psets.flush_timeout(), hsets.refresh_timeout())
+
logger.debug('polling...')
- events = sel.select(psets.flush_timeout())
+ events = sel.select(timeout)
logger.debug('got events')
for key, mask in events:
- key.fileobj.process(psets)
+ key.fileobj.process(process_cb)
psets.auto_flush()
+ hsets.auto_refresh()