#!/usr/bin/python3 # a HTTP server for sharing files # Copyright 2019-2022 Anton Khirnov # # fshare is free software: you can redistribute it and/or modify it under the # terms of the GNU Affero General Public License as published by the Free # Software Foundation, either version 3 of the License, or (at your option) any # later version. # # fshare is distributed in the hope that it will be useful, but WITHOUT ANY # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. See the GNU Affero General Public License for more # details. # # You should have received a copy of the GNU Affero General Public License # along with fshare. If not, see . import argparse import contextlib import fcntl import hmac import itertools import json import os import os.path from http import HTTPStatus import http.server as hs import logging import logging.handlers import secrets import shutil import socket import sys import tempfile import threading from urllib import parse as urlparse # TODO: detect and store mime types def _open_dirfd(path, mode = 'r', perms = 0o644, dir_fd = None, **kwargs): """ Same as the built-in open(), but with support for file permissions and operation with respect to a directory FD. """ flags = 0 if '+' in mode: flags = os.O_RDWR elif 'r' in mode: flags = os.O_RDONLY else: flags = os.O_WRONLY if 'w' in mode: flags |= os.O_CREAT | os.O_TRUNC elif 'a' in mode: flags |= os.O_CREAT | os.O_APPEND fdopen_mode = mode if 'x' in mode: flags |= os.O_EXCL fdopen_mode = fdopen_mode.replace('x', '') fd = os.open(path, flags, mode = perms, dir_fd = dir_fd) try: return os.fdopen(fd, fdopen_mode, **kwargs) except: os.close(fd) raise class UrlEncoder(object): """ Author: Michael Fogleman License: MIT Link: http://code.activestate.com/recipes/576918/ """ _alphabet = 'mn6j2c4rv8bpygw95z7hsdaetxuk3fq' _block_size = 8 _min_length = 1 def __init__(self, block_size = None): self._mask = (1 << self._block_size) - 1 self._mapping = range(self._block_size) def encode_url(self, n): return self.enbase(self.encode(n)) def decode_url(self, n): return self.decode(self.debase(n)) def encode(self, n): return (n & ~self._mask) | self._encode(n & self._mask) def _encode(self, n): result = 0 for i, b in enumerate(reversed(self._mapping)): if n & (1 << i): result |= (1 << b) return result def decode(self, n): return (n & ~self._mask) | self._decode(n & self._mask) def _decode(self, n): result = 0 for i, b in enumerate(reversed(self._mapping)): if n & (1 << b): result |= (1 << i) return result def enbase(self, x): result = self._enbase(x) padding = self._alphabet[0] * (self._min_length - len(result)) return '%s%s' % (padding, result) def _enbase(self, x): n = len(self._alphabet) if x < n: return self._alphabet[x] return self._enbase(int(x // n)) + self._alphabet[int(x % n)] def debase(self, x): n = len(self._alphabet) result = 0 for i, c in enumerate(reversed(x)): result += self._alphabet.index(c) * (n ** i) return result class StateCorruptError(Exception): pass class URLMap: _lock = None _dir_fd = None _fname = None _file = None _next_id = None _enc = None _full_to_short = None _short_to_full = None def __init__(self, state_dir_fd, fname, logger): self._fname = fname self._dir_fd = os.dup(state_dir_fd) self._enc = UrlEncoder(block_size = 16) self._lock = threading.Lock() def close(self): if self._file is not None: self._file.close() self._file = None if self._dir_fd is not None: os.close(self._dir_fd) self._dir_fd = None def open(self): if self._file is not None: raise RuntimeError('Tried to open an already opened URL map') # create the file if it does not exist try: with _open_dirfd(self._fname, 'w+x', perms = 0o600, dir_fd = self._dir_fd): pass except FileExistsError: pass try: self._file = _open_dirfd(self._fname, 'r+', dir_fd = self._dir_fd) data = [l.strip().split() for l in self._file.readlines()] self._short_to_full = dict(data) self._full_to_short = dict(((b, a) for (a, b) in data)) try: self._next_id = self._enc.decode_url(data[-1][0]) except (IndexError, ValueError): # data is empty or the mapping uses characters not in alphabet self._next_id = 0 except: self.close() raise def __enter__(self): self.open() return self def __exit__(self, exc_type, exc_value, tb): self.close() def short_to_full(self, short): with self._lock: return self._short_to_full[short] def add(self, url): with self._lock: # mapping already exists, just return it if url in self._full_to_short: return self._full_to_short[url] # find the next non-conflicting short id short = None for n in itertools.count(self._next_id): short = self._enc.encode_url(n) if not short in self._short_to_full: self._next_id = n + 1 break self._file.write('%s %s\n' % (short, url)) self._file.flush() self._short_to_full[short] = url self._full_to_short[url] = short return short class PersistentState: key = None urlmap = None _fname_state = 'state' _fname_map = 'map' def __init__(self, state_dir_fd, public, logger): try: with _open_dirfd(self._fname_state, 'w+x', perms = 0o600, dir_fd = state_dir_fd) as f: data = { 'key' : secrets.token_hex(16) } json.dump(data, f, indent = 4) logger.info('Generated a new state file') except FileExistsError: pass with _open_dirfd(self._fname_state, 'r', dir_fd = state_dir_fd) as f: try: data = json.load(f) except json.decoder.JSONDecodeError: raise StateCorruptError self.key = bytes.fromhex(data['key']) if public: self.urlmap = URLMap(state_dir_fd, self._fname_map, logger) def close(self): if self.urlmap: self.urlmap.close() def __enter__(self): if self.urlmap: self.urlmap.open() return self def __exit__(self, exc_type, exc_value, tb): self.close() class HTTPChunkedRequestReader: _stream = None _eof = False _logger = None def __init__(self, stream, logger): self._stream = stream self._logger = logger def read(self): if self._eof: return bytes() l = self._stream.readline().decode('ascii', errors = 'replace') self._logger.debug('reading chunk: chunksize %s', l) try: chunk_size = int(l.split(';')[0], 16) except ValueError: raise IOError('Invalid chunksize line: %s' % l) if chunk_size < 0: raise IOError('Invalid negative chunksize: %d' % chunk_size) if chunk_size == 0: self._eof = True return bytes() data = bytes() remainder = chunk_size while remainder > 0: read = self._stream.read(remainder) if len(read) == 0: raise IOError('Premature EOF') data += read remainder -= len(read) term_line = self._stream.readline().decode('ascii', errors = 'replace') if term_line != '\r\n': raise IOError('Invalid chunk terminator: %s' % term_line) return data class HTTPRequestReader: _stream = None _remainder = 0 _eof = False def __init__(self, stream, request_size): self._stream = stream self._remainder = request_size self._eof = request_size == 0 def read(self): if self._eof: return bytes() read = self._stream.read1(self._remainder) if len(read) == 0: raise IOError('Premature EOF') self._remainder -= len(read) self._eof = self._remainder <= 0 return read class FShareRequestHandler(hs.BaseHTTPRequestHandler): # required for chunked transfer protocol_version = "HTTP/1.1" _logger = None def __init__(self, *args, **kwargs): server = args[2] self._logger = server._logger.getChild('requesthandler') super().__init__(*args, **kwargs) def _process_path(self, encoded_path): # decode percent-encoding path = urlparse.unquote(encoded_path, encoding = 'ascii') # normalize the path path = os.path.normpath(path) # make sure the path is absolute if not path.startswith('/'): raise PermissionError('Invalid path') # drop the leading '/', take the first path component path = path[1:].partition('/')[0] # discard any extension path = os.path.splitext(path)[0] if not path: raise PermissionError('Empty path') if self.server.state.urlmap: short = path try: path = self.server.state.urlmap.short_to_full(short) except KeyError: raise PermissionError('No such short URL: ', short) self._logger.info('%s->%s', path, short) return '/'.join((self.server.data_dir, path)) def _log_request(self): self._logger.info('%s: %s', str(self.client_address), self.requestline) self._logger.debug('headers:\n%s', self.headers) def do_GET(self): self._log_request() try: path = self._process_path(self.path) except PermissionError as e: self._logger.error('Invalid request: %s', str(e)) return self.send_error(HTTPStatus.NOT_FOUND) self._logger.info('serve file: %s', path) try: infile = open(path, 'rb') except OSError: return self.send_error(HTTPStatus.NOT_FOUND) try: stat = os.fstat(infile.fileno()) self.send_response(HTTPStatus.OK) self.send_header('Content-Length', str(stat.st_size)) self.end_headers() shutil.copyfileobj(infile, self.wfile) finally: infile.close() def do_POST(self): self._log_request() src_fname = os.path.basename(urlparse.unquote(self.path)) if '/' in src_fname or src_fname in ('.', '..'): src_fname = '' if 'Transfer-Encoding' in self.headers: if self.headers['Transfer-Encoding'] != 'chunked': return self.send_error(HTTPStatus.NOT_IMPLEMENTED, 'Unsupported Transfer-Encoding: %s' % self.headers['Transfer-Encoding']) infile = HTTPChunkedRequestReader(self.rfile, self._logger.getChild('chreader')) elif 'Content-Length' in self.headers: infile = HTTPRequestReader(self.rfile, int(self.headers['Content-Length'])) else: return self.send_error(HTTPStatus.BAD_REQUEST) h = hmac.new(self.server.state.key, digestmod = 'SHA256') temp_fd, temp_path = tempfile.mkstemp(suffix = '.tmp', dir = self.server.data_dir) try: while True: data = infile.read() if len(data) == 0: self._logger.debug('Finished reading') break written = os.write(temp_fd, data) if written < len(data): raise IOError('partial write: %d < %d' % (written, len(data))) h.update(data) self._logger.debug('streamed %d bytes', len(data)) os.close(temp_fd) dst_fname = h.hexdigest() self._logger.info('Received file: %s', dst_fname) outpath = '/'.join((self.server.data_dir, dst_fname)) if os.path.exists(outpath): retcode = HTTPStatus.OK os.remove(temp_path) else: retcode = HTTPStatus.CREATED os.replace(temp_path, outpath) finally: if os.path.exists(temp_path): os.remove(temp_path) try: host = self.headers['host'] except KeyError: host = 'host.missing' if self.server.state.urlmap: # public server: resulting URL is generated short URL + original extension path = self.server.state.urlmap.add(dst_fname) self._logger.info('%s->%s', dst_fname, path) path += os.path.splitext(src_fname)[1] else: # private server: resulting URL is the secret HMAC + original basename path = dst_fname if src_fname: path += '/' + src_fname path = urlparse.quote(path) reply = ('https://%s/%s' % (host, path)).encode('ascii') self.send_response(retcode) self.send_header('Content-Type', 'text/plain') self.send_header('Content-Length', '%d' % len(reply)) self.end_headers() self.wfile.write(reply) def do_PUT(self): return self.do_POST() def do_DELETE(self): self._log_request() try: local_path = self._process_path(self.path) except PermissionError as e: self._logger.error('Invalid request: %s', str(e)) return self.send_error(HTTPStatus.NOT_FOUND) try: os.remove(local_path) except FileNotFoundError: self._logger.error('DELETE request for non-existing file: %s', local_path) self.send_error(HTTPStatus.NOT_FOUND) return self.send_response(HTTPStatus.NO_CONTENT) self.send_header('Content-Length', '0') self.end_headers() class FShareServer(hs.ThreadingHTTPServer): data_dir = None state = None _logger = None def __init__(self, address, force_v4, force_v6, state, data_dir, logger): self.data_dir = data_dir self.state = state self._logger = logger family = None if force_v4: family = socket.AF_INET elif force_v6: family = socket.AF_INET6 if family is None and len(address[0]): try: family, _, _, _, _ = socket.getaddrinfo(*address)[0] except IndexError: pass if family is None: family = socket.AF_INET6 self.address_family = family self._logger.info('Binding to: %s:%d/%s', address[0], address[1], family.name) super().__init__(address, FShareRequestHandler) # parse commandline arguments parser = argparse.ArgumentParser(description = 'fshare server') parser.add_argument('-a', '--address', default = 'localhost') parser.add_argument('-p', '--port', type = int, default = 5400) parser.add_argument('-P', '--public', action = 'store_true', help = 'Generate public (short and guessable) URLs') group = parser.add_mutually_exclusive_group() group.add_argument('-4', '--ipv4', action = 'store_true') group.add_argument('-6', '--ipv6', action = 'store_true') parser.add_argument('-l', '--loglevel', default = 'WARNING') parser.add_argument('-d', '--debug', action = 'store_true', help = 'log to stderr') parser.add_argument('state_dir') parser.add_argument('data_dir') args = parser.parse_args(sys.argv[1:]) # configure logging progname = os.path.basename(sys.argv[0]) logging.basicConfig(stream = sys.stderr, level = args.loglevel) logger = logging.getLogger(progname) formatter = logging.Formatter(fmt = progname + ': %(message)s') syslog = logging.handlers.SysLogHandler('/dev/log') handlers = [syslog] if args.debug: handlers.append(logging.StreamHandler()) for h in handlers: h.setFormatter(formatter) logger.addHandler(h) # log uncaught top-level exception def excepthook(t, v, tb, logger = logger): logger.error('Uncaught top-level exception', exc_info = (t, v, tb)) sys.excepthook = excepthook with contextlib.ExitStack() as stack: # open the state dir try: state_dir_fd = os.open(args.state_dir, os.O_RDONLY | os.O_DIRECTORY) except (FileNotFoundError, NotADirectoryError) as e: logger.error('The state directory "%s" is not an existing directory: %s', args.state_dir, e) sys.exit(1) stack.callback(os.close, state_dir_fd) # lock the state dir try: fcntl.flock(state_dir_fd, fcntl.LOCK_EX | fcntl.LOCK_NB) except BlockingIOError: logger.error('The state directory is already locked by another process') sys.exit(1) stack.callback(fcntl.flock, state_dir_fd, fcntl.LOCK_UN) try: # read the state file state = stack.enter_context(PersistentState(state_dir_fd, args.public, logger)) except StateCorruptError: logger.error('Corrupted state file') sys.exit(1) try: # launch the server server = FShareServer((args.address, args.port), args.ipv4, args.ipv6, state, args.data_dir, logger) server.serve_forever() except KeyboardInterrupt: logger.info('Interrupted, exiting') sys.exit(0)