diff --git a/apns.py b/apns.py index 09106a4..8b69b48 100644 --- a/apns.py +++ b/apns.py @@ -375,6 +375,16 @@ class APNSField: ) +async def receive_exact(stream: trio.abc.Stream, amount: int): + """Reads exactly the given amount of bytes from the given stream""" + buffer = b"" + while len(buffer) < amount: + # Check for EOF + if (b := await stream.receive_some(1)) == b"": + return None # None is how EOF's were represented in the old code, so we'll keep it that way + buffer += b + return buffer + @dataclass class APNSPayload: """An APNS payload""" @@ -385,18 +395,18 @@ class APNSPayload: @staticmethod async def read_from_stream(stream: trio.abc.Stream) -> APNSPayload: """Reads a payload from the given stream""" - if not (id_bytes := await stream.receive_some(1)): + if not (id_bytes := await receive_exact(stream, 1)): raise Exception("Unable to read payload id from stream") id: int = int.from_bytes(id_bytes, "big") - if (length := await stream.receive_some(4)) is None: + if (length := await receive_exact(stream, 4)) is None: raise Exception("Unable to read payload length from stream") length = int.from_bytes(length, "big") if length == 0: return APNSPayload(id, []) - buffer = await stream.receive_some(length) + buffer = await receive_exact(stream, length) if buffer is None: raise Exception("Unable to read payload from stream") fields = []