From 078c6314f8da4676d5e967bb66941c7c531f4470 Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Sun, 27 Oct 2019 09:28:32 +0100 Subject: Implement most of the basic functionality and clean up. --- dash_server.py | 229 +++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 140 insertions(+), 89 deletions(-) diff --git a/dash_server.py b/dash_server.py index 403f75a..1b9e1cd 100755 --- a/dash_server.py +++ b/dash_server.py @@ -26,65 +26,66 @@ import os import os.path from http import HTTPStatus import http.server as hs +import logging import select +import shutil +import socket import sys import threading +# monkey-patch in ThreadingHTTPServer for older python versions +if sys.version_info.minor < 7: + import socketserver + class ThreadingHTTPServer(socketserver.ThreadingMixIn, hs.HTTPServer): + daemon_threads = True + hs.ThreadingHTTPServer = ThreadingHTTPServer + class HTTPChunkedRequestReader: _stream = None _eof = False - _partial_chunk = None - _remainder = 0 + _logger = None - def __init__(self, stream): + def __init__(self, stream, logger): self._stream = stream + self._logger = logger def fileno(self): return self._stream.fileno() - def read(self, size = -1): - if size != -1: - raise ValueError + def read(self): if self._eof: return bytes() - if self._partial_chunk is None: - l = self._stream.readline() - print(b'line: ' + l) - if l is None: - return l - chunk_size = int(l.split(b';')[0], 16) - if chunk_size == 0: - self._eof = True - return bytes() + l = self._stream.readline().decode('ascii', errors = 'replace') + self._logger.debug('reading chunk: chunksize %s', l) - self._partial_chunk = bytes() - self._remainder = chunk_size + try: + chunk_size = int(l.split(';')[0], 16) + except ValueError: + raise IOError('Invalid chunksize line: %s' % l) + if chunk_size < 0: + raise IOError('Invalid negative chunksize: %d' % chunk_size) + if chunk_size == 0: + self._eof = True + return bytes() - while self._remainder > 0: - read = self._stream.read(self._remainder) - if read is None: - return read + data = bytes() + remainder = chunk_size + while remainder > 0: + read = self._stream.read(remainder) if len(read) == 0: raise IOError('Premature EOF') - self._partial_chunk += read - self._remainder -= len(read) + data += read + remainder -= len(read) - term_line = self._stream.readline() - if term_line != b'\r\n': - self._partial_chunk = None - self._remainder = 0 - raise ValueError('Invalid chunk terminator') + term_line = self._stream.readline().decode('ascii', errors = 'replace') + if term_line != '\r\n': + raise IOError('Invalid chunk terminator: %s' % term_line) - ret = self._partial_chunk - sys.stderr.write('finished chunk %d %d\n' % (len(ret), self._remainder)) - self._partial_chunk = None - self._remainder = 0 - - return ret + return data class HTTPRequestReader: @@ -96,25 +97,21 @@ class HTTPRequestReader: self._stream = stream self._remainder = request_size self._eof = request_size == 0 - sys.stderr.write('request of length %d\n' % request_size); def fileno(self): return self._stream.fileno() - def read(self, size = -1): - if size != -1: - raise ValueError + def read(self): if self._eof: return bytes() - read = self._stream.read(self._remainder) - if read is None: - return read + read = self._stream.read1(self._remainder) if len(read) == 0: raise IOError('Premature EOF') self._remainder -= len(read) - self._eof = self._remainder <= 0 + self._eof = self._remainder <= 0 + return read class DataStream: @@ -124,45 +121,56 @@ class DataStream: _eof = False def __init__(self): - self._data = bytes() + self._data = [] self._data_cond = threading.Condition() + def close(self): + with self._data_cond: + self._eof = True + self._data_cond.notify_all() + def write(self, data): with self._data_cond: if len(data) == 0: self._eof = True else: - self._data += data + if self._eof: + raise ValueError('Tried to write data after EOF') + + self._data.append(data) self._data_cond.notify_all() - def read(self, offset): + def read(self, chunk): with self._data_cond: - while self._eof is False and len(self._data) <= offset: + while self._eof is False and len(self._data) <= chunk: self._data_cond.wait() if self._eof: return bytes() - return self._data[offset:] + return self._data[chunk] class StreamCache: _streams = None _lock = None + _logger = None - def __init__(self): + def __init__(self, logger): self._streams = {} self._lock = threading.Lock() + self._logger = logger def __getitem__(self, key): + self._logger.debug('reading from cache: %s', key) with self._lock: return self._streams[key] @contextlib.contextmanager def add_entry(self, key, val): # XXX handle key already present - sys.stderr.write('cache add: %s\n' % key) + self._logger.debug('cache add: %s', key) try: with self._lock: self._streams[key] = val @@ -170,22 +178,41 @@ class StreamCache: finally: with self._lock: del self._streams[key] - sys.stderr.write('cache delete: %s\n' % key) + self._logger.debug('cache delete: %s', key) -class RequestHandler(hs.BaseHTTPRequestHandler): +class DashRequestHandler(hs.BaseHTTPRequestHandler): # required for chunked transfer protocol_version = "HTTP/1.1" + _logger = None + + def __init__(self, *args, **kwargs): + server = args[2] + self._logger = server._logger.getChild('requesthandler') + + super().__init__(*args, **kwargs) + def _decode_path(self, encoded_path): + # FIXME implement unquoting return encoded_path + def _serve_local(self, path): + with open(path, 'rb') as infile: + stat = os.fstat(infile.fileno()) + + self.send_response(HTTPStatus.OK) + self.send_header('Content-Length', str(stat.st_size)) + self.end_headers() + + shutil.copyfileobj(infile, self.wfile) + + def _log_request(self): + self._logger.info('%s: %s', str(self.client_address), self.requestline) + self._logger.debug('headers:\n%s', self.headers) + def do_GET(self): - sys.stderr.write('GET\n') - sys.stderr.write('requestline: %s\n' % self.requestline) - sys.stderr.write('path: %s\n' % self.path) - sys.stderr.write('command: %s\n' % self.command) - sys.stderr.write('headers: %s\n' % self.headers) + self._log_request() local_path = self._decode_path(self.path) outpath = '%s/%s' % (self.server.serve_dir, local_path) @@ -193,8 +220,10 @@ class RequestHandler(hs.BaseHTTPRequestHandler): ds = self.server._streams[local_path] except KeyError: if os.path.exists(outpath): + return self._serve_local(outpath) # we managed to finalize the file after the upstream checked for it and before now - self.send_response('X-Accel-Redirect', self.path) + self.send_response(HTTPStatus.OK) + self.send_header('X-Accel-Redirect', self.path) self.end_headers() else: self.send_error(HTTPStatus.NOT_FOUND) @@ -205,40 +234,34 @@ class RequestHandler(hs.BaseHTTPRequestHandler): self.send_header('Transfer-Encoding', 'chunked') self.end_headers() + chunk = 0 while True: - data = ds.read() + data = ds.read(chunk) if len(data) == 0: - self.wfile.write(b'0\r\n') + self.wfile.write(b'0\r\n\r\n') break - self.wfile.write(hex(len(data))[2:].encode('ascii') + '\r\n') + chunk += 1 + + self.wfile.write(hex(len(data))[2:].encode('ascii') + b'\r\n') self.wfile.write(data) - self.wfile.write('\r\n') + self.wfile.write(b'\r\n') def do_POST(self): - sys.stderr.write('POST\n') - sys.stderr.write('requestline: %s\n' % self.requestline) - sys.stderr.write('path: %s\n' % self.path) - sys.stderr.write('command: %s\n' % self.command) - sys.stderr.write('headers: %s\n' % self.headers) + self._log_request() with contextlib.ExitStack() as stack: local_path = self._decode_path(self.path) - ds = DataStream() + ds = stack.enter_context(contextlib.closing(DataStream())) stack.enter_context(self.server._streams.add_entry(local_path, ds)) - outpath = '%s/%s' % (self.server.serve_dir, local_path) - write_path = outpath + '.tmp' - - os.set_blocking(self.rfile.fileno(), False) - if 'Transfer-Encoding' in self.headers: if self.headers['Transfer-Encoding'] != 'chunked': return self.send_error(HTTPStatus.NOT_IMPLEMENTED, 'Unsupported Transfer-Encoding: %s' % self.headers['Transfer-Encoding']) - infile = HTTPChunkedRequestReader(self.rfile) + infile = HTTPChunkedRequestReader(self.rfile, self._logger.getChild('chreader')) elif 'Content-Length' in self.headers: infile = HTTPRequestReader(self.rfile, int(self.headers['Content-Length'])) else: @@ -247,32 +270,28 @@ class RequestHandler(hs.BaseHTTPRequestHandler): poll = select.poll() poll.register(infile, select.POLLIN) - outfile = stack.enter_context(open(write_path, 'wb')) + outpath = '%s/%s' % (self.server.serve_dir, local_path) + write_path = outpath + '.tmp' + outfile = stack.enter_context(open(write_path, 'wb')) while True: data = infile.read() - if data is None: - sys.stderr.write('would block, sleeping\n') - poll.poll() - continue ds.write(data) if len(data) == 0: - sys.stderr.write('Finished reading\n') + self._logger.debug('Finished reading') break - sys.stderr.write('read %d bytes\n' % (len(data))) - written = outfile.write(data) if written < len(data): - sys.stderr.write('partial write: %d < %d\n' % (written, len(data))) - return self.send_error(HTTPStatus.INTERNAL_SERVER_ERROR) + raise IOError('partial write: %d < %d' % (written, len(data))) - sys.stderr.write('wrote %d bytes\n' % (written)) + self._logger.debug('streamed %d bytes', len(data)) - retcode = HTTPStatus.CREATED if os.path.exists(outpath) else HTTPStatus.NO_CONTENT + retcode = HTTPStatus.NO_CONTENT if os.path.exists(outpath) else HTTPStatus.CREATED os.replace(write_path, outpath) self.send_response(retcode) + self.send_header('Content-Length', '0') self.end_headers() def do_PUT(self): @@ -286,21 +305,53 @@ class DashServer(hs.ThreadingHTTPServer): # should only be accessed by the request instances spawned by this server _streams = None - def __init__(self, address, port, serve_dir): + _logger = None + + def __init__(self, address, force_v4, force_v6, serve_dir, logger): self.serve_dir = serve_dir - self._streams = StreamCache() + self._streams = StreamCache(logger.getChild('streamcache')) + self._logger = logger + + family = None + if force_v4: + family = socket.AF_INET + elif force_v6: + family = socket.AF_INET6 - super().__init__(address, port) + if family is None and len(address[0]): + try: + family, _, _, _, _ = socket.getaddrinfo(*address)[0] + except IndexError: + pass + + if family is None: + family = socket.AF_INET6 + + self.address_family = family + + super().__init__(address, DashRequestHandler) def main(argv): parser = argparse.ArgumentParser('DASH server') + parser.add_argument('-a', '--address', default = '') parser.add_argument('-p', '--port', type = int, default = 8000) + + group = parser.add_mutually_exclusive_group() + group.add_argument('-4', '--ipv4', action = 'store_true') + group.add_argument('-6', '--ipv6', action = 'store_true') + + parser.add_argument('-l', '--loglevel', default = 'WARNING') + parser.add_argument('directory') args = parser.parse_args(argv[1:]) - server = DashServer((args.address, args.port), RequestHandler, args.directory) + logging.basicConfig(stream = sys.stderr, level = args.loglevel) + logger = logging.getLogger('DashServer') + + server = DashServer((args.address, args.port), args.ipv4, args.ipv6, + args.directory, logger) server.serve_forever() if __name__ == '__main__': -- cgit v1.2.3