Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 91 additions & 11 deletions emmett_core/http/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from .._io import loop_open_file


class _RangeNotSatisfiable(Exception):
def __init__(self, max_size: int):
self.max_size = max_size


class HTTPResponse(Exception):
def __init__(
self,
Expand Down Expand Up @@ -45,7 +50,7 @@ async def asgi(self, scope, send):
await self._send_headers(send)
await self._send_body(send)

def rsgi(self, protocol):
def rsgi(self, scope, protocol):
protocol.response_empty(self.status_code, list(self.rsgi_headers()))


Expand All @@ -63,7 +68,7 @@ def __init__(
async def _send_body(self, send):
await send({"type": "http.response.body", "body": self.body, "more_body": False})

def rsgi(self, protocol):
def rsgi(self, scope, protocol):
protocol.response_bytes(self.status_code, list(self.rsgi_headers()), self.body)


Expand All @@ -85,7 +90,7 @@ def encoded_body(self):
async def _send_body(self, send):
await send({"type": "http.response.body", "body": self.encoded_body, "more_body": False})

def rsgi(self, protocol):
def rsgi(self, scope, protocol):
protocol.response_str(self.status_code, list(self.rsgi_headers()), self.body)


Expand Down Expand Up @@ -121,6 +126,60 @@ def _get_stat_headers(self, stat_data):
"etag": etag,
}

def _if_range_feasible(self, http_if_range: str) -> bool:
return http_if_range == self._headers["last-modified"] or http_if_range == self._headers["etag"]

@classmethod
def _parse_range_header(cls, http_range: str, file_size: int) -> list[tuple[int, int]]:
units, hrange_val = http_range.split("=", 1)
units = units.strip().lower()
if units != "bytes":
raise ValueError

ranges = cls._parse_ranges(hrange_val, file_size)
if len(ranges) == 0:
raise ValueError("Range header: range must be requested")
if any(not (0 <= start < file_size) for start, _ in ranges):
raise _RangeNotSatisfiable(file_size)
if any(start > end for start, end in ranges):
raise ValueError("Range header: start must be less than end")

if len(ranges) == 1:
return ranges

#: sort and merge overlapping ranges
ranges.sort()
res = [ranges[0]]
for start, end in ranges[1:]:
last_start, last_end = res[-1]
if start <= last_end:
res[-1] = (last_start, max(last_end, end))
else:
res.append((start, end))
return ranges

@classmethod
def _parse_ranges(cls, hrange: str, file_size: int) -> list[tuple[int, int]]:
ret = []
for part in hrange.split(","):
part = part.strip()
if not part or part == "-":
continue
if "-" not in part:
continue

start_str, end_str = part.split("-", 1)
start_str = start_str.strip()
end_str = end_str.strip()
try:
start = int(start_str) if start_str else file_size - int(end_str)
end = int(end_str) + 1 if start_str and end_str and int(end_str) < file_size else file_size
ret.append((start, end))
except ValueError:
continue

return ret

async def asgi(self, scope, send):
try:
stat_data = os.stat(self.file_path)
Expand Down Expand Up @@ -153,17 +212,38 @@ async def _send_body(self, send):
}
)

def rsgi(self, protocol):
def rsgi(self, scope, protocol):
try:
stat_data = os.stat(self.file_path)
if not stat.S_ISREG(stat_data.st_mode):
return HTTPResponse(403).rsgi(protocol)
return HTTPResponse(403).rsgi(scope, protocol)
self._headers.update(self._get_stat_headers(stat_data))
except OSError as e:
if e.errno == errno.EACCES:
return HTTPResponse(403).rsgi(protocol)
return HTTPResponse(404).rsgi(protocol)

return HTTPResponse(403).rsgi(scope, protocol)
return HTTPResponse(404).rsgi(scope, protocol)

self._headers["accept-ranges"] = "bytes"
empty_res = scope.method.lower() == "head"
h_range = scope.headers.get("range")
h_if_range = scope.headers.get("if-range")
if h_range or (h_if_range and self._if_range_feasible(h_if_range)):
try:
ranges = self._parse_range_header(h_range, stat_data.st_size)
except _RangeNotSatisfiable as exc:
return protocol.response_empty(416, [("content-range", f"bytes */{exc.max_size}")])
except Exception:
return protocol.response_empty(400)
# FIXME: support multiple ranges in RSGI
range_start, range_end = ranges[0]
self._headers["content-range"] = f"bytes {range_start}-{range_end - 1}/{stat_data.st_size}"
self._headers["content-length"] = str(range_end - range_start)
if empty_res:
return protocol.response_empty(206, list(self.rsgi_headers()))
return protocol.response_file_range(206, list(self.rsgi_headers()), self.file_path, range_start, range_end)

if empty_res:
return protocol.response_empty(self.status_code, list(self.rsgi_headers()))
protocol.response_file(self.status_code, list(self.rsgi_headers()), self.file_path)


Expand Down Expand Up @@ -202,7 +282,7 @@ async def _send_body(self, send):
}
)

def rsgi(self, protocol):
def rsgi(self, scope, protocol):
protocol.response_bytes(self.status_code, list(self.rsgi_headers()), self.io_stream.read())


Expand All @@ -218,7 +298,7 @@ async def _send_body(self, send):
await send({"type": "http.response.body", "body": chunk, "more_body": True})
await send({"type": "http.response.body", "body": b"", "more_body": False})

async def rsgi(self, protocol):
async def rsgi(self, scope, protocol):
trx = protocol.response_stream(self.status_code, list(self.rsgi_headers()))
for chunk in self.iter:
await trx.send_bytes(chunk)
Expand All @@ -240,7 +320,7 @@ async def _send_body(self, send):
await send({"type": "http.response.body", "body": chunk, "more_body": True})
await send({"type": "http.response.body", "body": b"", "more_body": False})

async def rsgi(self, protocol):
async def rsgi(self, scope, protocol):
trx = protocol.response_stream(self.status_code, list(self.rsgi_headers()))
async for chunk in self.iter:
await trx.send_bytes(chunk)
2 changes: 1 addition & 1 deletion emmett_core/protocols/rsgi/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _configure_methods(self):

async def __call__(self, scope, protocol):
http = await self.pre_handler(scope, protocol, scope.path)
if coro := http.rsgi(protocol):
if coro := http.rsgi(scope, protocol):
if self.app.config.response_timeout is None:
await coro
return
Expand Down
4 changes: 2 additions & 2 deletions tests/http/test_http_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def iterator():
yield b"test"

http = HTTPIterResponse(iterator())
await http.rsgi(rsgi_proto)
await http.rsgi(None, rsgi_proto)
rsgi_proto.data.seek(0)

assert rsgi_proto.code == 200
Expand All @@ -95,7 +95,7 @@ async def iterator():
yield b"test"

http = HTTPAsyncIterResponse(iterator())
await http.rsgi(rsgi_proto)
await http.rsgi(None, rsgi_proto)
rsgi_proto.data.seek(0)

assert rsgi_proto.code == 200
Expand Down