diff --git a/apns.py b/apns.py index 81d5a3b..e5f1862 100644 --- a/apns.py +++ b/apns.py @@ -10,6 +10,8 @@ import logging logger = logging.getLogger("apns") import ssl +import trio + import albert import bags @@ -19,122 +21,134 @@ COURIER_HOST = f"{random.randint(1, bags.apns_init_bag()['APNSCourierHostcount'] COURIER_PORT = 5223 ALPN = [b"apns-security-v3"] - -# Connect to the courier server -def _connect(private_key: str, cert: str) -> ssl.SSLSocket: - # Connect to the courier server - sock = socket.create_connection((COURIER_HOST, COURIER_PORT)) - context = ssl.SSLContext(ssl.PROTOCOL_TLS) - context.set_alpn_protocols(["apns-security-v3"]) - # Wrap the socket in TLS - ssock = context.wrap_socket(sock) - # Handshake with the server - ssock.do_handshake() - - logger.info(f"Connected to APNs ({COURIER_HOST})") - - return ssock +async def apns_test(): + async with APNSConnection.start() as connection: + print(b64encode(connection.credentials.token).decode()) + while True: + await trio.sleep(1) + print(".") + #await connection.set_state(1) -class IncomingQueue: - def __init__(self): - self.queue = [] - self.lock = threading.Lock() + print("Finished") +def main(): + from rich.logging import RichHandler - def append(self, item): - with self.lock: - self.queue.append(item) - def pop(self, index = -1): - with self.lock: - return self.queue.pop(index) + logging.basicConfig( + level=logging.NOTSET, format="%(message)s", datefmt="[%X]", handlers=[RichHandler()] + ) - def __getitem__(self, index): - with self.lock: - return self.queue[index] + # Set sane log levels + logging.getLogger("urllib3").setLevel(logging.WARNING) + logging.getLogger("py.warnings").setLevel(logging.ERROR) # Ignore warnings from urllib3 + logging.getLogger("asyncio").setLevel(logging.WARNING) + logging.getLogger("jelly").setLevel(logging.INFO) + logging.getLogger("nac").setLevel(logging.INFO) + logging.getLogger("apns").setLevel(logging.DEBUG) + logging.getLogger("albert").setLevel(logging.INFO) + logging.getLogger("ids").setLevel(logging.DEBUG) + logging.getLogger("bags").setLevel(logging.INFO) + logging.getLogger("imessage").setLevel(logging.DEBUG) - def __len__(self): - with self.lock: - return len(self.queue) + logging.captureWarnings(True) + print("APNs Test:") + trio.run(apns_test) - def find(self, finder): - with self.lock: - return next((i for i in self.queue if finder(i)), None) - - def pop_find(self, finder): - with self.lock: - found = next((i for i in self.queue if finder(i)), None) - if found is not None: - # We have the lock, so we can safely remove it - self.queue.remove(found) - return found - - def remove_all(self, id): - with self.lock: - self.queue = [i for i in self.queue if i[0] != id] - - def wait_pop_find(self, finder, delay=0.1): - found = None - while found is None: - found = self.pop_find(finder) - if found is None: - time.sleep(delay) - return found +from contextlib import asynccontextmanager +from dataclasses import dataclass +@dataclass +class PushCredentials: + private_key: str + cert: str + token: bytes class APNSConnection: - incoming_queue = IncomingQueue() + _incoming_queue: list = [] # We don't need a lock because this is trio and we only have one thread + _queue_park: trio.Event = trio.Event() - # Sink everything in the queue - def sink(self): - self.incoming_queue = IncomingQueue() - - def _queue_filler(self): + async def _send(self, id: int, fields: list[tuple[int, bytes]]): + payload = _serialize_payload(id, fields) + await self.sock.send_all(payload) + + async def _receive(self, id: int): + # Check if anything currently in the queue matches the id + for payload in self._incoming_queue: + if payload[0] == id: + return payload while True: - payload = _deserialize_payload(self.sock) + await self._queue_park.wait() # Wait for a new payload to be added to the queue + logger.debug(f"Woken by event, checking for {id}") + # Check if the new payload matches the id + if self._incoming_queue[-1][0] == id: + return self._incoming_queue.pop() + # Otherwise, wait for another payload to be added to the queue - if payload is not None: - # Automatically ACK incoming notifications to prevent APNs from getting mad at us - if payload[0] == 0x0A: - logger.debug("Sending automatic ACK") - self._send_ack(_get_field(payload[1], 4)) - logger.debug(f"Received payload: {payload}") - self.incoming_queue.append(payload) - logger.debug(f"Queue length: {len(self.incoming_queue)}") - - def _keep_alive_loop(self): + async def _queue_filler(self): while True: - time.sleep(300) - self._keep_alive() + payload = await _deserialize_payload(self.sock) - def __init__(self, private_key=None, cert=None): - # Generate the private key and certificate if they're not provided - if private_key is None or cert is None: - logger.debug("APNs needs a new push certificate") - self.private_key, self.cert = albert.generate_push_cert() - else: - self.private_key, self.cert = private_key, cert + logger.debug(f"Received payload: {payload}") + self._incoming_queue.append(payload) + # Signal to any waiting tasks that we have a new payload + self._queue_park.set() + self._queue_park = trio.Event() # Reset the event + logger.debug(f"Queue length: {len(self._incoming_queue)}") + + async def _keep_alive(self): + while True: + #await trio.sleep(300) + await trio.sleep(1) + logger.debug("Sending keep alive message") + await self._send(0x0C, []) + await self._receive(0x0D) + logger.debug("Got keep alive response") - self.sock = _connect(self.private_key, self.cert) + @asynccontextmanager + async def start(credentials: PushCredentials | None = None): + """Sets up a nursery and connection and yields the connection""" + async with trio.open_nursery() as nursery: + connection = APNSConnection(nursery, credentials) + await connection.connect() + yield connection + nursery.cancel_scope.cancel() # Cancel heartbeat and queue filler tasks + await connection.sock.aclose() # Close the socket - self.queue_filler_thread = threading.Thread( - target=self._queue_filler, daemon=True - ) - self.queue_filler_thread.start() + def __init__(self, nursery: trio.Nursery, credentials: PushCredentials | None = None): + self._nursery = nursery + self.credentials = credentials - self.keep_alive_thread = threading.Thread( - target=self._keep_alive_loop, daemon=True - ) - self.keep_alive_thread.start() + async def connect(self): + """Connects to the APNs server and starts the keep alive and queue filler tasks""" + sock = await trio.open_tcp_stream(COURIER_HOST, COURIER_PORT) + context = ssl.SSLContext(ssl.PROTOCOL_TLS) + context.set_alpn_protocols(["apns-security-v3"]) - def connect(self, root: bool = True, token: bytes = None): + self.sock = trio.SSLStream(sock, context, server_hostname=COURIER_HOST) + + await self.sock.do_handshake() + + logger.info(f"Connected to APNs ({COURIER_HOST})") + + if self.credentials is None: + self.credentials = PushCredentials(*albert.generate_push_cert(), None) + + # Start the queue filler and keep alive tasks + self._nursery.start_soon(self._queue_filler) + self._nursery.start_soon(self._keep_alive) + + self.credentials.token = await self._connect(self.credentials.token) + + async def _connect(self, token: bytes | None = None, root: bool = False) -> bytes: + """Sends the APNs connect message""" # Parse self.certificate from cryptography import x509 - cert = x509.load_pem_x509_certificate(self.cert.encode()) + cert = x509.load_pem_x509_certificate(self.credentials.cert.encode()) # Parse private key from cryptography.hazmat.primitives import serialization - private_key = serialization.load_pem_private_key(self.private_key.encode(), password=None) + private_key = serialization.load_pem_private_key(self.credentials.private_key.encode(), password=None) if token is None: logger.debug(f"Sending connect message without token (root={root})") @@ -153,48 +167,28 @@ class APNSConnection: from cryptography.hazmat.primitives.asymmetric import padding signature = b"\x01\x01" + private_key.sign(nonce, padding.PKCS1v15(), hashes.SHA1()) - - if token is None: - payload = _serialize_payload( - 7, [(2, 0x01.to_bytes(1, "big")), (5, flags.to_bytes(4, "big")), (0x0c, cert), (0x0d, nonce), (0x0e, signature)] - ) - else: - payload = _serialize_payload( - 7, - [ - (1, token), - (2, 0x01.to_bytes(1, "big")), - (5, flags.to_bytes(4, "big")), - (0x0c, cert), - (0x0d, nonce), - (0x0e, signature), - - ], - ) - - #self.sock.write(payload) - self.sock.sendall(payload) - - payload = self.incoming_queue.wait_pop_find(lambda i: i[0] == 8) - - if ( - payload == None - or payload[0] != 8 - or _get_field(payload[1], 1) != 0x00.to_bytes(1, "big") - ): - raise Exception("Failed to connect") - - new_token = _get_field(payload[1], 3) - if new_token is not None: - self.token = new_token - elif token is not None: - self.token = token - else: - raise Exception("No token") + payload = [ + (2, b"\x01"), + (5, flags.to_bytes(4, "big")), + (0x0c, cert), + (0x0d, nonce), + (0x0e, signature), + ] + if token is not None: + payload.insert(0, (1, token)) - logger.debug(f"Recieved connect response with token {b64encode(self.token).decode()}") + await self._send(7, payload) + + payload = await self._receive(8) - return self.token + if _get_field(payload[1], 1) != b"\x00": + raise Exception("Failed to connect") + + new_token = _get_field(payload[1], 3) + + logger.debug(f"Recieved connect response with token {b64encode(new_token).decode()}") + + return new_token def filter(self, topics: list[str]): logger.debug(f"Sending filter message with topics {topics}") @@ -230,20 +224,9 @@ class APNSConnection: if payload[1][0][1] != 0x00.to_bytes(1, "big"): raise Exception("Failed to send message") - def set_state(self, state: int): + async def set_state(self, state: int): logger.debug(f"Sending state message with state {state}") - self.sock.sendall( - _serialize_payload( - 0x14, - [(1, state.to_bytes(1, "big")), (2, 0x7FFFFFFF.to_bytes(4, "big"))], - ) - ) - - def _keep_alive(self): - logger.debug("Sending keep alive message") - self.sock.sendall(_serialize_payload(0x0C, [])) - # Remove any keep alive responses we have or missed - self.incoming_queue.remove_all(0x0D) + await self._send(0x14, [(1, b"\x01"), (2, 0x7FFFFFFF.to_bytes(4, "big"))]) def _send_ack(self, id: bytes): @@ -287,15 +270,18 @@ def _deserialize_field(stream: bytes) -> tuple[int, bytes]: # Note: Takes a stream, not a buffer, as we do not know the length of the payload # WILL BLOCK IF THE STREAM IS EMPTY -def _deserialize_payload(stream: ssl.SSLSocket) -> tuple[int, list[tuple[int, bytes]]] | None: - id = int.from_bytes(stream.read(1), "big") +async def _deserialize_payload(stream: trio.SSLStream) -> tuple[int, list[tuple[int, bytes]]] | None: + id = int.from_bytes(await stream.receive_some(1), "big") if id == 0x0: return None - length = int.from_bytes(stream.read(4), "big") + length = int.from_bytes(await stream.receive_some(4), "big") - buffer = stream.read(length) + if length == 0: + return id, [] + + buffer = await stream.receive_some(length) fields = [] @@ -338,3 +324,6 @@ def _get_field(fields: list[tuple[int, bytes]], id: int) -> bytes: if field_id == id: return value return None + +if __name__ == "__main__": + main() \ No newline at end of file