#!/usr/bin/python3 import argparse import logging import logging.handlers import os import re import select import signal import shlex import sys import subprocess import time ACT_NOTHING = "nothing" ACT_BAN_SHORT = "ban_short" ACT_BAN_MEDIUM = "ban_medium" ACT_BAN_LONG = "ban_long" IFF_EVIL = 0 IFF_GOOD = 1 IFF_GRAY = 2 MINUTE = 60 # seconds HOUR = 60 * MINUTE DAY = 24 * HOUR regexes = { IFF_GOOD : [ r'^Accepted publickey .* from (\S+)', ], IFF_EVIL : [ r'^Invalid user .* from (\S+)', r'^Failed password .* from (\S+)', r'^PAM .* authentication failure .* rhost=(\S+)', r'^error: maximum authentication attempts exceeded\b.* from (\S+)', r'^banner exchange: Connection from (\S+) port \d+: invalid format', r'^ssh_dispatch_run_fatal: Connection from (\S+) port \d+: message authentication code incorrect \[preauth\]', r'^ssh_dispatch_run_fatal: Connection from (\S+) port \d+: Connection corrupted \[preauth\]', ], IFF_GRAY : [ r'^Received disconnect from (\S+)', r'^Connection reset by (\S+) port \d+', r'^Connection closed by authenticating user \S+ (\S+) port \d+ \[preauth\]', r'^Connection closed by (\S+) port \d+ \[preauth\]', r'^Unable to negotiate with (\S+) port \d+: no matching key exchange method found.', ] } def process_msg(ts, msg): for iff, rr in regexes.items(): for r in rr: m = re.search(r, msg) if m is None: continue return (iff, m.group(1)) return None class ExpiringCounter: default_timeout = None _data = None _gc_counter = None def __init__(self, default_timeout): self._data = {} self.default_timeout = default_timeout self._gc_counter = 0 def __str__(self): self._gc() now = self._now() ret = '' for key, (ts, count) in self._data.items(): ret += '%s(%d): %gs, %gs remaining\n' % (key, count, now - ts, self.default_timeout - (now - ts)) return ret def __contains__(self, key): if not key in self._data: return False now = self._now() ts, val = self._data[key] if now - ts > self.default_timeout: del self._data[key] return False return True def __delitem__(self, key): del self._data[key] def _now(self): return time.clock_gettime(time.CLOCK_BOOTTIME) def _gc(self): to_remove = [] now = self._now() for key, (ts, count) in self._data.items(): if now - ts > self.default_timeout: to_remove.append(key) for key in to_remove: del self._data[key] def inc(self, key, count = 1): now = self._now() oldval = self._data[key][1] if key in self else 0 newval = max(0, oldval + count) if newval > 0: self._data[key] = (now, newval) elif key in self: del self[key] self._gc_counter += 1 if (self._gc_counter & ((1 << 10) - 1)) == 0: self._gc() return newval def dec(self, key, count = 1): return self.inc(item, -count) class Judge: # FIXME: arbitrary constants _whitelist = None _blacklists = None _graylist = None _gray_threshold = None _black_thresholds = None def __init__(self, thresh): self._whitelist = ExpiringCounter(DAY) self._graylist = ExpiringCounter(DAY) self._blacklists = {} self._blacklists[ACT_BAN_SHORT] = ExpiringCounter(MINUTE) self._blacklists[ACT_BAN_MEDIUM] = ExpiringCounter(HOUR) self._blacklists[ACT_BAN_LONG] = ExpiringCounter(DAY) self._black_thresholds = thresh self._gray_threshold = 8 * thresh[ACT_BAN_MEDIUM] def process(self, iff, host): if iff == IFF_GOOD: # add to whitelist self._whitelist.inc(host) # remove from graylist if host in self._graylist: del self._graylist[host] # reduce blacklist entries for bl in self._blacklists: if host in bl: bl.dec(host, 4) elif iff == IFF_GRAY: if not host in self._whitelist: count = self._graylist.inc(host) if count > self._gray_threshold: return ACT_BAN_MEDIUM elif iff == IFF_EVIL: for bl_id in (ACT_BAN_LONG, ACT_BAN_MEDIUM, ACT_BAN_SHORT): bl = self._blacklists[bl_id] thresh = self._black_thresholds[bl_id] count = bl.inc(host) if count > thresh: return bl_id return ACT_NOTHING def __str__(self): ret = 'Judge:\n wl: %s\n gl: %s\n' % (str(self._whitelist), str(self._graylist)) for key, val in self._blacklists.items(): ret += ' bl %s: %s\n' % (key, str(val)) return ret parser = argparse.ArgumentParser('Parse logs and ban SSH abusers') parser.add_argument('-s', '--thresh-short', type = int, default = 8, help = 'Maximum number of abuses per minute to get banned') parser.add_argument('-m', '--thresh-medium', type = int, default = 16, help = 'Maximum number of abuses per hour to get banned') parser.add_argument('-l', '--thresh-long', type = int, default = 32, help = 'Maximum number of abuses per day to get banned') parser.add_argument('-d', '--debug', action = 'store_true') parser.add_argument('inputfifo', help = 'FIFO from which the log lines will be read') parser.add_argument('action', help = 'Executable to run. It will get two parameters:' ' the action to take and the hostname/address of the offender') args = parser.parse_args(sys.argv[1:]) progname = os.path.basename(sys.argv[0]) action = shlex.split(args.action) logger = logging.getLogger(progname) loglevel = logging.DEBUG if args.debug else logging.INFO logger.setLevel(loglevel) formatter = logging.Formatter(fmt = progname + ': %(message)s') syslog = logging.handlers.SysLogHandler('/dev/log') handlers = [syslog] if args.debug: handlers.append(logging.StreamHandler()) for h in handlers: h.setFormatter(formatter) logger.addHandler(h) # log uncaught top-level exception def excepthook(t, v, tb, logger = logger): logger.error('Uncaught top-level exception', exc_info = (t, v, tb)) sys.excepthook = excepthook judge = Judge({ ACT_BAN_SHORT : args.thresh_short, ACT_BAN_MEDIUM : args.thresh_medium, ACT_BAN_LONG : args.thresh_long }) # use SIGUSR1 to print state def log_state(sig, stack): state = str(judge).splitlines() for l in state: logger.info(l) signal.signal(signal.SIGUSR1, log_state) # open FIFO read-write so poll() won't return HUP endlessly if the writer dies fifofd = os.open(args.inputfifo, os.O_RDWR | os.O_NONBLOCK) with open(fifofd) as fifo: poll = select.epoll() poll.register(fifofd, select.EPOLLIN) while True: for line in fifo: line = line.strip() if len(line) == 0: continue parts = line.rstrip().split(maxsplit = 1) if len(parts) != 2: logger.error('Invalid log line: %s' % line) continue ts, msg = parts logger.debug('processing message: %s' % msg) res = process_msg(ts, msg) if res is None: logger.debug('message not matched') continue iff, host = res verdict = judge.process(iff, host) if verdict == ACT_NOTHING: continue logger.info('Action %s for: %s' % (verdict, host)) # TODO: rate-limit actions? cmdline = action + [verdict, host] res = subprocess.run(cmdline, capture_output = True, text = True) if res.returncode != 0: logger.error('Error running action "%s": return code %d' % (str(cmdline), res.returncode)) if res.stderr: logger.error('stderr: ' + res.stderr) logger.debug('polling input') poll.poll()