aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnton Khirnov <anton@khirnov.net>2019-10-27 09:28:32 +0100
committerAnton Khirnov <anton@khirnov.net>2019-10-27 09:28:32 +0100
commit078c6314f8da4676d5e967bb66941c7c531f4470 (patch)
tree4a873ec0d6809d275f01a1901f07e77d2e85660c
parent86f99936ecbbd71e2b203638837ae4c18bb022b1 (diff)
Implement most of the basic functionality and clean up.
-rwxr-xr-xdash_server.py229
1 files changed, 140 insertions, 89 deletions
diff --git a/dash_server.py b/dash_server.py
index 403f75a..1b9e1cd 100755
--- a/dash_server.py
+++ b/dash_server.py
@@ -26,65 +26,66 @@ import os
import os.path
from http import HTTPStatus
import http.server as hs
+import logging
import select
+import shutil
+import socket
import sys
import threading
+# monkey-patch in ThreadingHTTPServer for older python versions
+if sys.version_info.minor < 7:
+ import socketserver
+ class ThreadingHTTPServer(socketserver.ThreadingMixIn, hs.HTTPServer):
+ daemon_threads = True
+ hs.ThreadingHTTPServer = ThreadingHTTPServer
+
class HTTPChunkedRequestReader:
_stream = None
_eof = False
- _partial_chunk = None
- _remainder = 0
+ _logger = None
- def __init__(self, stream):
+ def __init__(self, stream, logger):
self._stream = stream
+ self._logger = logger
def fileno(self):
return self._stream.fileno()
- def read(self, size = -1):
- if size != -1:
- raise ValueError
+ def read(self):
if self._eof:
return bytes()
- if self._partial_chunk is None:
- l = self._stream.readline()
- print(b'line: ' + l)
- if l is None:
- return l
- chunk_size = int(l.split(b';')[0], 16)
- if chunk_size == 0:
- self._eof = True
- return bytes()
+ l = self._stream.readline().decode('ascii', errors = 'replace')
+ self._logger.debug('reading chunk: chunksize %s', l)
- self._partial_chunk = bytes()
- self._remainder = chunk_size
+ try:
+ chunk_size = int(l.split(';')[0], 16)
+ except ValueError:
+ raise IOError('Invalid chunksize line: %s' % l)
+ if chunk_size < 0:
+ raise IOError('Invalid negative chunksize: %d' % chunk_size)
+ if chunk_size == 0:
+ self._eof = True
+ return bytes()
- while self._remainder > 0:
- read = self._stream.read(self._remainder)
- if read is None:
- return read
+ data = bytes()
+ remainder = chunk_size
+ while remainder > 0:
+ read = self._stream.read(remainder)
if len(read) == 0:
raise IOError('Premature EOF')
- self._partial_chunk += read
- self._remainder -= len(read)
+ data += read
+ remainder -= len(read)
- term_line = self._stream.readline()
- if term_line != b'\r\n':
- self._partial_chunk = None
- self._remainder = 0
- raise ValueError('Invalid chunk terminator')
+ term_line = self._stream.readline().decode('ascii', errors = 'replace')
+ if term_line != '\r\n':
+ raise IOError('Invalid chunk terminator: %s' % term_line)
- ret = self._partial_chunk
- sys.stderr.write('finished chunk %d %d\n' % (len(ret), self._remainder))
- self._partial_chunk = None
- self._remainder = 0
-
- return ret
+ return data
class HTTPRequestReader:
@@ -96,25 +97,21 @@ class HTTPRequestReader:
self._stream = stream
self._remainder = request_size
self._eof = request_size == 0
- sys.stderr.write('request of length %d\n' % request_size);
def fileno(self):
return self._stream.fileno()
- def read(self, size = -1):
- if size != -1:
- raise ValueError
+ def read(self):
if self._eof:
return bytes()
- read = self._stream.read(self._remainder)
- if read is None:
- return read
+ read = self._stream.read1(self._remainder)
if len(read) == 0:
raise IOError('Premature EOF')
self._remainder -= len(read)
- self._eof = self._remainder <= 0
+ self._eof = self._remainder <= 0
+
return read
class DataStream:
@@ -124,45 +121,56 @@ class DataStream:
_eof = False
def __init__(self):
- self._data = bytes()
+ self._data = []
self._data_cond = threading.Condition()
+ def close(self):
+ with self._data_cond:
+ self._eof = True
+ self._data_cond.notify_all()
+
def write(self, data):
with self._data_cond:
if len(data) == 0:
self._eof = True
else:
- self._data += data
+ if self._eof:
+ raise ValueError('Tried to write data after EOF')
+
+ self._data.append(data)
self._data_cond.notify_all()
- def read(self, offset):
+ def read(self, chunk):
with self._data_cond:
- while self._eof is False and len(self._data) <= offset:
+ while self._eof is False and len(self._data) <= chunk:
self._data_cond.wait()
if self._eof:
return bytes()
- return self._data[offset:]
+ return self._data[chunk]
class StreamCache:
_streams = None
_lock = None
+ _logger = None
- def __init__(self):
+ def __init__(self, logger):
self._streams = {}
self._lock = threading.Lock()
+ self._logger = logger
def __getitem__(self, key):
+ self._logger.debug('reading from cache: %s', key)
with self._lock:
return self._streams[key]
@contextlib.contextmanager
def add_entry(self, key, val):
# XXX handle key already present
- sys.stderr.write('cache add: %s\n' % key)
+ self._logger.debug('cache add: %s', key)
try:
with self._lock:
self._streams[key] = val
@@ -170,22 +178,41 @@ class StreamCache:
finally:
with self._lock:
del self._streams[key]
- sys.stderr.write('cache delete: %s\n' % key)
+ self._logger.debug('cache delete: %s', key)
-class RequestHandler(hs.BaseHTTPRequestHandler):
+class DashRequestHandler(hs.BaseHTTPRequestHandler):
# required for chunked transfer
protocol_version = "HTTP/1.1"
+ _logger = None
+
+ def __init__(self, *args, **kwargs):
+ server = args[2]
+ self._logger = server._logger.getChild('requesthandler')
+
+ super().__init__(*args, **kwargs)
+
def _decode_path(self, encoded_path):
+ # FIXME implement unquoting
return encoded_path
+ def _serve_local(self, path):
+ with open(path, 'rb') as infile:
+ stat = os.fstat(infile.fileno())
+
+ self.send_response(HTTPStatus.OK)
+ self.send_header('Content-Length', str(stat.st_size))
+ self.end_headers()
+
+ shutil.copyfileobj(infile, self.wfile)
+
+ def _log_request(self):
+ self._logger.info('%s: %s', str(self.client_address), self.requestline)
+ self._logger.debug('headers:\n%s', self.headers)
+
def do_GET(self):
- sys.stderr.write('GET\n')
- sys.stderr.write('requestline: %s\n' % self.requestline)
- sys.stderr.write('path: %s\n' % self.path)
- sys.stderr.write('command: %s\n' % self.command)
- sys.stderr.write('headers: %s\n' % self.headers)
+ self._log_request()
local_path = self._decode_path(self.path)
outpath = '%s/%s' % (self.server.serve_dir, local_path)
@@ -193,8 +220,10 @@ class RequestHandler(hs.BaseHTTPRequestHandler):
ds = self.server._streams[local_path]
except KeyError:
if os.path.exists(outpath):
+ return self._serve_local(outpath)
# we managed to finalize the file after the upstream checked for it and before now
- self.send_response('X-Accel-Redirect', self.path)
+ self.send_response(HTTPStatus.OK)
+ self.send_header('X-Accel-Redirect', self.path)
self.end_headers()
else:
self.send_error(HTTPStatus.NOT_FOUND)
@@ -205,40 +234,34 @@ class RequestHandler(hs.BaseHTTPRequestHandler):
self.send_header('Transfer-Encoding', 'chunked')
self.end_headers()
+ chunk = 0
while True:
- data = ds.read()
+ data = ds.read(chunk)
if len(data) == 0:
- self.wfile.write(b'0\r\n')
+ self.wfile.write(b'0\r\n\r\n')
break
- self.wfile.write(hex(len(data))[2:].encode('ascii') + '\r\n')
+ chunk += 1
+
+ self.wfile.write(hex(len(data))[2:].encode('ascii') + b'\r\n')
self.wfile.write(data)
- self.wfile.write('\r\n')
+ self.wfile.write(b'\r\n')
def do_POST(self):
- sys.stderr.write('POST\n')
- sys.stderr.write('requestline: %s\n' % self.requestline)
- sys.stderr.write('path: %s\n' % self.path)
- sys.stderr.write('command: %s\n' % self.command)
- sys.stderr.write('headers: %s\n' % self.headers)
+ self._log_request()
with contextlib.ExitStack() as stack:
local_path = self._decode_path(self.path)
- ds = DataStream()
+ ds = stack.enter_context(contextlib.closing(DataStream()))
stack.enter_context(self.server._streams.add_entry(local_path, ds))
- outpath = '%s/%s' % (self.server.serve_dir, local_path)
- write_path = outpath + '.tmp'
-
- os.set_blocking(self.rfile.fileno(), False)
-
if 'Transfer-Encoding' in self.headers:
if self.headers['Transfer-Encoding'] != 'chunked':
return self.send_error(HTTPStatus.NOT_IMPLEMENTED,
'Unsupported Transfer-Encoding: %s' %
self.headers['Transfer-Encoding'])
- infile = HTTPChunkedRequestReader(self.rfile)
+ infile = HTTPChunkedRequestReader(self.rfile, self._logger.getChild('chreader'))
elif 'Content-Length' in self.headers:
infile = HTTPRequestReader(self.rfile, int(self.headers['Content-Length']))
else:
@@ -247,32 +270,28 @@ class RequestHandler(hs.BaseHTTPRequestHandler):
poll = select.poll()
poll.register(infile, select.POLLIN)
- outfile = stack.enter_context(open(write_path, 'wb'))
+ outpath = '%s/%s' % (self.server.serve_dir, local_path)
+ write_path = outpath + '.tmp'
+ outfile = stack.enter_context(open(write_path, 'wb'))
while True:
data = infile.read()
- if data is None:
- sys.stderr.write('would block, sleeping\n')
- poll.poll()
- continue
ds.write(data)
if len(data) == 0:
- sys.stderr.write('Finished reading\n')
+ self._logger.debug('Finished reading')
break
- sys.stderr.write('read %d bytes\n' % (len(data)))
-
written = outfile.write(data)
if written < len(data):
- sys.stderr.write('partial write: %d < %d\n' % (written, len(data)))
- return self.send_error(HTTPStatus.INTERNAL_SERVER_ERROR)
+ raise IOError('partial write: %d < %d' % (written, len(data)))
- sys.stderr.write('wrote %d bytes\n' % (written))
+ self._logger.debug('streamed %d bytes', len(data))
- retcode = HTTPStatus.CREATED if os.path.exists(outpath) else HTTPStatus.NO_CONTENT
+ retcode = HTTPStatus.NO_CONTENT if os.path.exists(outpath) else HTTPStatus.CREATED
os.replace(write_path, outpath)
self.send_response(retcode)
+ self.send_header('Content-Length', '0')
self.end_headers()
def do_PUT(self):
@@ -286,21 +305,53 @@ class DashServer(hs.ThreadingHTTPServer):
# should only be accessed by the request instances spawned by this server
_streams = None
- def __init__(self, address, port, serve_dir):
+ _logger = None
+
+ def __init__(self, address, force_v4, force_v6, serve_dir, logger):
self.serve_dir = serve_dir
- self._streams = StreamCache()
+ self._streams = StreamCache(logger.getChild('streamcache'))
+ self._logger = logger
+
+ family = None
+ if force_v4:
+ family = socket.AF_INET
+ elif force_v6:
+ family = socket.AF_INET6
- super().__init__(address, port)
+ if family is None and len(address[0]):
+ try:
+ family, _, _, _, _ = socket.getaddrinfo(*address)[0]
+ except IndexError:
+ pass
+
+ if family is None:
+ family = socket.AF_INET6
+
+ self.address_family = family
+
+ super().__init__(address, DashRequestHandler)
def main(argv):
parser = argparse.ArgumentParser('DASH server')
+
parser.add_argument('-a', '--address', default = '')
parser.add_argument('-p', '--port', type = int, default = 8000)
+
+ group = parser.add_mutually_exclusive_group()
+ group.add_argument('-4', '--ipv4', action = 'store_true')
+ group.add_argument('-6', '--ipv6', action = 'store_true')
+
+ parser.add_argument('-l', '--loglevel', default = 'WARNING')
+
parser.add_argument('directory')
args = parser.parse_args(argv[1:])
- server = DashServer((args.address, args.port), RequestHandler, args.directory)
+ logging.basicConfig(stream = sys.stderr, level = args.loglevel)
+ logger = logging.getLogger('DashServer')
+
+ server = DashServer((args.address, args.port), args.ipv4, args.ipv6,
+ args.directory, logger)
server.serve_forever()
if __name__ == '__main__':