From b5f644b989cc8a4c467de8193b280455bccd54ca Mon Sep 17 00:00:00 2001 From: JJTech0130 Date: Thu, 17 Aug 2023 19:48:09 -0400 Subject: [PATCH] majorly refactor, fix types --- apns.py | 420 ++++++++++++++++++++++++++++++++------------------------ 1 file changed, 238 insertions(+), 182 deletions(-) diff --git a/apns.py b/apns.py index e5f1862..a2489f0 100644 --- a/apns.py +++ b/apns.py @@ -1,47 +1,57 @@ from __future__ import annotations -import random -import socket -import threading -import time -from hashlib import sha1 -from base64 import b64encode, b64decode import logging -logger = logging.getLogger("apns") +import random import ssl +import time +from base64 import b64encode +from contextlib import asynccontextmanager +from dataclasses import dataclass +from hashlib import sha1 +from typing import Callable import trio +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding import albert import bags -#COURIER_HOST = "windows.courier.push.apple.com" # TODO: Get this from config +logger = logging.getLogger("apns") + # Pick a random courier server from 01 to APNSCourierHostcount COURIER_HOST = f"{random.randint(1, bags.apns_init_bag()['APNSCourierHostcount'])}-{bags.apns_init_bag()['APNSCourierHostname']}" COURIER_PORT = 5223 ALPN = [b"apns-security-v3"] + async def apns_test(): async with APNSConnection.start() as connection: print(b64encode(connection.credentials.token).decode()) while True: await trio.sleep(1) print(".") - #await connection.set_state(1) - + # await connection.set_state(1) print("Finished") + + def main(): from rich.logging import RichHandler - logging.basicConfig( - level=logging.NOTSET, format="%(message)s", datefmt="[%X]", handlers=[RichHandler()] + level=logging.NOTSET, + format="%(message)s", + datefmt="[%X]", + handlers=[RichHandler()], ) # Set sane log levels logging.getLogger("urllib3").setLevel(logging.WARNING) - logging.getLogger("py.warnings").setLevel(logging.ERROR) # Ignore warnings from urllib3 + logging.getLogger("py.warnings").setLevel( + logging.ERROR + ) # Ignore warnings from urllib3 logging.getLogger("asyncio").setLevel(logging.WARNING) logging.getLogger("jelly").setLevel(logging.INFO) logging.getLogger("nac").setLevel(logging.INFO) @@ -52,70 +62,96 @@ def main(): logging.getLogger("imessage").setLevel(logging.DEBUG) logging.captureWarnings(True) + print("APNs Test:") + trio.run(apns_test) -from contextlib import asynccontextmanager -from dataclasses import dataclass @dataclass class PushCredentials: - private_key: str - cert: str - token: bytes + private_key: str = "" + cert: str = "" + token: bytes = b"" + class APNSConnection: - _incoming_queue: list = [] # We don't need a lock because this is trio and we only have one thread - _queue_park: trio.Event = trio.Event() + """A connection to the APNs server""" + + _incoming_queue: list[APNSPayload] = [] + """A queue of payloads that have been received from the APNs server""" + _queue_park: trio.Event = trio.Event() + """An event that is set when a new payload is added to the queue""" + + async def _send(self, payload: APNSPayload): + """Sends a payload to the APNs server""" + await payload.write_to_stream(self.sock) + + async def _receive(self, id: int, filter: Callable | None = None): + """ + Waits for a payload with the given id to be added to the queue, then returns it. + If filter is not None, it will be called with the payload as an argument, and if it returns False, + the payload will be ignored and another will be waited for. + + NOTE: It is not defined what happens if receive is called twice with the same id and filter, + as the first payload will be removed from the queue, so the second call might never return + """ - async def _send(self, id: int, fields: list[tuple[int, bytes]]): - payload = _serialize_payload(id, fields) - await self.sock.send_all(payload) - - async def _receive(self, id: int): # Check if anything currently in the queue matches the id for payload in self._incoming_queue: - if payload[0] == id: + if payload.id == id: return payload while True: - await self._queue_park.wait() # Wait for a new payload to be added to the queue + await self._queue_park.wait() # Wait for a new payload to be added to the queue logger.debug(f"Woken by event, checking for {id}") # Check if the new payload matches the id - if self._incoming_queue[-1][0] == id: - return self._incoming_queue.pop() + if self._incoming_queue[-1].id != id: + continue + if filter is not None: + if not filter(self._incoming_queue[-1]): + continue + return self._incoming_queue.pop() # Otherwise, wait for another payload to be added to the queue async def _queue_filler(self): + """Fills the queue with payloads from the APNs socket""" while True: - payload = await _deserialize_payload(self.sock) + payload = await APNSPayload.read_from_stream(self.sock) logger.debug(f"Received payload: {payload}") + self._incoming_queue.append(payload) + # Signal to any waiting tasks that we have a new payload self._queue_park.set() - self._queue_park = trio.Event() # Reset the event + self._queue_park = trio.Event() # Reset the event + logger.debug(f"Queue length: {len(self._incoming_queue)}") - + async def _keep_alive(self): + """Sends keep alive messages to the APNs server every 5 minutes""" while True: - #await trio.sleep(300) - await trio.sleep(1) + await trio.sleep(300) logger.debug("Sending keep alive message") - await self._send(0x0C, []) + await self._send(APNSPayload(0x0C, [])) await self._receive(0x0D) logger.debug("Got keep alive response") @asynccontextmanager - async def start(credentials: PushCredentials | None = None): + @staticmethod + async def start(credentials: PushCredentials = PushCredentials()): """Sets up a nursery and connection and yields the connection""" async with trio.open_nursery() as nursery: connection = APNSConnection(nursery, credentials) await connection.connect() yield connection - nursery.cancel_scope.cancel() # Cancel heartbeat and queue filler tasks - await connection.sock.aclose() # Close the socket + nursery.cancel_scope.cancel() # Cancel heartbeat and queue filler tasks + await connection.sock.aclose() # Close the socket - def __init__(self, nursery: trio.Nursery, credentials: PushCredentials | None = None): + def __init__( + self, nursery: trio.Nursery, credentials: PushCredentials = PushCredentials() + ): + """Creates a raw APNSConnection. Make sure to call aclose() on the socket and cancel the nursery when you're done with it""" self._nursery = nursery self.credentials = credentials @@ -132,8 +168,11 @@ class APNSConnection: logger.info(f"Connected to APNs ({COURIER_HOST})") - if self.credentials is None: - self.credentials = PushCredentials(*albert.generate_push_cert(), None) + if self.credentials.cert == "" or self.credentials.private_key == "": + ( + self.credentials.private_key, + self.credentials.cert, + ) = albert.generate_push_cert() # Start the queue filler and keep alive tasks self._nursery.start_soon(self._queue_filler) @@ -143,187 +182,204 @@ class APNSConnection: async def _connect(self, token: bytes | None = None, root: bool = False) -> bytes: """Sends the APNs connect message""" - # Parse self.certificate - from cryptography import x509 + cert = x509.load_pem_x509_certificate(self.credentials.cert.encode()) - # Parse private key - from cryptography.hazmat.primitives import serialization - private_key = serialization.load_pem_private_key(self.credentials.private_key.encode(), password=None) + private_key = serialization.load_pem_private_key( + self.credentials.private_key.encode(), password=None + ) if token is None: logger.debug(f"Sending connect message without token (root={root})") else: - logger.debug(f"Sending connect message with token {b64encode(token).decode()} (root={root})") + logger.debug( + f"Sending connect message with token {b64encode(token).decode()} (root={root})" + ) flags = 0b01000001 if root: flags |= 0b0100 - # 1 byte fixed 00, 8 bytes timestamp (milliseconds since Unix epoch), 8 bytes random cert = cert.public_bytes(serialization.Encoding.DER) - nonce = b"\x00" + int(time.time() * 1000).to_bytes(8, "big") + random.randbytes(8) - #signature = private_key.sign(nonce, signature_algorithm=serialization.NoEncryption()) - # RSASSA-PKCS1-SHA1 - from cryptography.hazmat.primitives import hashes - from cryptography.hazmat.primitives.asymmetric import padding - signature = b"\x01\x01" + private_key.sign(nonce, padding.PKCS1v15(), hashes.SHA1()) - - payload = [ - (2, b"\x01"), - (5, flags.to_bytes(4, "big")), - (0x0c, cert), - (0x0d, nonce), - (0x0e, signature), - ] - if token is not None: - payload.insert(0, (1, token)) - - await self._send(7, payload) - - payload = await self._receive(8) + nonce = ( + b"\x00" + int(time.time() * 1000).to_bytes(8, "big") + random.randbytes(8) + ) + signature = b"\x01\x01" + private_key.sign(nonce, padding.PKCS1v15(), hashes.SHA1()) # type: ignore - if _get_field(payload[1], 1) != b"\x00": - raise Exception("Failed to connect") - - new_token = _get_field(payload[1], 3) - - logger.debug(f"Recieved connect response with token {b64encode(new_token).decode()}") - - return new_token - - def filter(self, topics: list[str]): - logger.debug(f"Sending filter message with topics {topics}") - fields = [(1, self.token)] - - for topic in topics: - fields.append((2, sha1(topic.encode()).digest())) - - payload = _serialize_payload(9, fields) - - self.sock.sendall(payload) - - def send_message(self, topic: str, payload: str, id=None): - logger.debug(f"Sending message to topic {topic} with payload {payload}") - if id is None: - id = random.randbytes(4) - - payload = _serialize_payload( - 0x0A, + payload = APNSPayload( + 7, [ - (4, id), - (1, sha1(topic.encode()).digest()), - (2, self.token), - (3, payload), + APNSField(0x2, b"\x01"), + APNSField(0x5, flags.to_bytes(4, "big")), + APNSField(0xC, cert), + APNSField(0xD, nonce), + APNSField(0xE, signature), ], ) - self.sock.sendall(payload) + if token is not None: + payload.fields.insert(0, APNSField(0x1, token)) + + await self._send(payload) + + payload = await self._receive(8) + + if payload.fields_with_id(1)[0].value != b"\x00": + raise Exception("Failed to connect") + + new_token = payload.fields_with_id(3)[0].value + + logger.debug( + f"Received connect response with token {b64encode(new_token).decode()}" + ) + + return new_token + + async def filter(self, topics: list[str]): + """Sends the APNs filter message""" + logger.debug(f"Sending filter message with topics {topics}") + + payload = APNSPayload(9, [APNSField(1, self.credentials.token)]) + + for topic in topics: + payload.fields.append(APNSField(2, sha1(topic.encode()).digest())) + + await payload.write_to_stream(self.sock) + + async def send_notification(self, topic: str, payload: bytes, id=None): + """Sends a notification to the APNs server""" + if id is None: + id = random.randbytes(4) + + p = APNSPayload( + 0xA, + [ + APNSField(4, id), + APNSField(1, sha1(topic.encode()).digest()), + APNSField(2, self.credentials.token), + APNSField(3, payload), + ], + ) + + await self._send(p) # Wait for ACK - payload = self.incoming_queue.wait_pop_find(lambda i: i[0] == 0x0B) + r = await self._receive(0xB) - if payload[1][0][1] != 0x00.to_bytes(1, "big"): - raise Exception("Failed to send message") + # TODO: Check ACK code + + async def expect_notification(self, topic: str, filter: Callable | None = None): + """Waits for a notification to be received, and acks it""" + + def f(payload: APNSPayload): + if payload.fields_with_id(1)[0].value != sha1(topic.encode()).digest(): + return False + if filter is not None: + return filter(payload) + return True + + r = await self._receive(0xA, f) + await self._send_ack(r.fields_with_id(4)[0].value) + return r async def set_state(self, state: int): + """Sends the APNs state message""" logger.debug(f"Sending state message with state {state}") - await self._send(0x14, [(1, b"\x01"), (2, 0x7FFFFFFF.to_bytes(4, "big"))]) - + await self._send( + APNSPayload( + 0x14, + [ + APNSField(1, state.to_bytes(4, "big")), + APNSField(2, 0x7FFFFFFF.to_bytes(4, "big")), + ], + ) + ) - def _send_ack(self, id: bytes): + async def _send_ack(self, id: bytes): + """Sends an ACK for a notification with the given id""" logger.debug(f"Sending ACK for message {id}") - payload = _serialize_payload(0x0B, [(1, self.token), (4, id), (8, b"\x00")]) - self.sock.sendall(payload) - # #self.sock.write(_serialize_payload(0x0B, [(4, id)]) - # #pass - - # def recieve_message(self): - # payload = self.incoming_queue.wait_pop_find(lambda i: i[0] == 0x0A) - # # Send ACK - # self._send_ack(_get_field(payload[1], 4)) - # return _get_field(payload[1], 3) - - # TODO: Find a way to make this non-blocking - # def expect_message(self) -> tuple[int, list[tuple[int, bytes]]] | None: - # return _deserialize_payload(self.sock) + payload = APNSPayload( + 0xB, + [ + APNSField(1, self.credentials.token), + APNSField(4, id), + APNSField(8, b"\x00"), + ], + ) + await payload.write_to_stream(self.sock) -def _serialize_field(id: int, value: bytes) -> bytes: - return id.to_bytes(1, "big") + len(value).to_bytes(2, "big") + value +@dataclass +class APNSField: + """A field in an APNS payload""" + + id: int + value: bytes + + @staticmethod + def from_buffer(stream: bytes) -> APNSField: + id = int.from_bytes(stream[:1], "big") + length = int.from_bytes(stream[1:3], "big") + value = stream[3 : 3 + length] + return APNSField(id, value) + + def to_buffer(self) -> bytes: + return ( + self.id.to_bytes(1, "big") + len(self.value).to_bytes(2, "big") + self.value + ) -def _serialize_payload(id: int, fields: list[(int, bytes)]) -> bytes: - payload = b"" +@dataclass +class APNSPayload: + """An APNS payload""" - for fid, value in fields: - if fid is not None: - payload += _serialize_field(fid, value) + id: int + fields: list[APNSField] - return id.to_bytes(1, "big") + len(payload).to_bytes(4, "big") + payload + @staticmethod + async def read_from_stream(stream: trio.abc.Stream) -> APNSPayload: + """Reads a payload from the given stream""" + id = await stream.receive_some(1) + if id is None: + raise Exception("Unable to read payload id from stream") + id = int.from_bytes(id, "big") + if id == 0x0: + raise Exception("Received id 0x0, which is not valid") -def _deserialize_field(stream: bytes) -> tuple[int, bytes]: - id = int.from_bytes(stream[:1], "big") - length = int.from_bytes(stream[1:3], "big") - value = stream[3 : 3 + length] - return id, value + length = await stream.receive_some(4) + if length is None: + raise Exception("Unable to read payload length from stream") + length = int.from_bytes(length, "big") + if length == 0: + return APNSPayload(id, []) -# Note: Takes a stream, not a buffer, as we do not know the length of the payload -# WILL BLOCK IF THE STREAM IS EMPTY -async def _deserialize_payload(stream: trio.SSLStream) -> tuple[int, list[tuple[int, bytes]]] | None: - id = int.from_bytes(await stream.receive_some(1), "big") + buffer = await stream.receive_some(length) + if buffer is None: + raise Exception("Unable to read payload from stream") + fields = [] - if id == 0x0: - return None + while len(buffer) > 0: + field = APNSField.from_buffer(buffer) + fields.append(field) + buffer = buffer[3 + len(field.value) :] - length = int.from_bytes(await stream.receive_some(4), "big") + return APNSPayload(id, fields) - if length == 0: - return id, [] + async def write_to_stream(self, stream: trio.abc.Stream): + """Writes the payload to the given stream""" + payload = b"" - buffer = await stream.receive_some(length) + for field in self.fields: + payload += field.to_buffer() - fields = [] + buffer = self.id.to_bytes(1, "big") + len(payload).to_bytes(4, "big") + payload - while len(buffer) > 0: - fid, value = _deserialize_field(buffer) - fields.append((fid, value)) - buffer = buffer[3 + len(value) :] + await stream.send_all(buffer) - return id, fields + def fields_with_id(self, id: int): + """Returns all fields with the given id""" + return [field for field in self.fields if field.id == id] -def _deserialize_payload_from_buffer( - buffer: bytes, -) -> tuple[int, list[tuple[int, bytes]]] | None: - id = int.from_bytes(buffer[:1], "big") - - if id == 0x0: - return None - - length = int.from_bytes(buffer[1:5], "big") - - buffer = buffer[5:] - - if len(buffer) < length: - raise Exception("Buffer is too short") - - fields = [] - - while len(buffer) > 0: - fid, value = _deserialize_field(buffer) - fields.append((fid, value)) - buffer = buffer[3 + len(value) :] - - return id, fields - - -# Returns the value of the first field with the given id -def _get_field(fields: list[tuple[int, bytes]], id: int) -> bytes: - for field_id, value in fields: - if field_id == id: - return value - return None - if __name__ == "__main__": - main() \ No newline at end of file + main()