From 6e3601de58dd8bb740485045622a15e26345c907 Mon Sep 17 00:00:00 2001 From: JJTech0130 Date: Thu, 6 Apr 2023 13:19:13 -0400 Subject: [PATCH] majorly refactor apnsconnection to be more usable --- apns.py | 208 +++++++++++++++++++---------------------------------- courier.py | 19 ++--- demo.py | 36 +++++----- 3 files changed, 100 insertions(+), 163 deletions(-) diff --git a/apns.py b/apns.py index ac1a9b2..19a2cfc 100644 --- a/apns.py +++ b/apns.py @@ -1,149 +1,89 @@ from __future__ import annotations -class Fields: - @staticmethod - def from_bytes(data: bytes) -> Fields: - fields = {} - - while len(data) > 0: - field = data[0] - length = int.from_bytes(data[1:3], "big") - value = data[3:3 + length] - - fields[field] = value - - data = data[3 + length:] - - return Fields(fields) - - def __init__(self, fields: dict[int, bytes]): - self.fields = fields - - def to_bytes(self) -> bytes: - buffer = bytearray() - - for field, value in self.fields.items(): - buffer.append(field) - buffer.extend(len(value).to_bytes(2, "big")) - buffer.extend(value) - - return buffer - - # Debug formating - def __str__(self) -> str: - return f"{self.fields}" - -# Define number to command name mapping -COMMANDS = { - 0x7: "Connect", - 0x8: "ConnectResponse", - 0x9: "PushTopics", - 0x0A: "PushNotification", - 0x0B: "Acknowledge", -} - -class Payload: - @staticmethod - def from_stream(stream) -> Payload|None: - command = int.from_bytes(stream.read(1), "big") - if command == 0: - return None # We reached the end of the stream - length = int.from_bytes(stream.read(4), "big") - fields = Fields.from_bytes(stream.read(length)) - - return Payload(command, fields) - - @staticmethod - def from_bytes(data: bytes) -> Payload: - # Convert it to bytes for cleaner printing - data = bytes(data) - command = data[0] - length = int.from_bytes(data[1:5], "big") - fields = Fields.from_bytes(data[5:5 + length]) - - return Payload(command, fields) - - def __init__(self, command: int, fields: Fields): - self.command = command - self.fields = fields - - def to_bytes(self) -> bytes: - buffer = bytearray() - - buffer.append(self.command) - - fields = self.fields.to_bytes() - - buffer.extend(len(fields).to_bytes(4, "big")) - buffer.extend(fields) - - return buffer - - # Debug formating - def __str__(self) -> str: - return f"{COMMANDS[self.command]}: {self.fields}" - -import courier +import courier, albert from hashlib import sha1 -class APNSConnection(): - def __init__(self, token: bytes=None, private_key=None, cert=None): - self.sock, self.private_key, self.cert = courier.connect(private_key, cert) - self.token = token - self._connect() - - def _connect(self): - if self.token is None: - payload = Payload(7, Fields({2: 0x01.to_bytes()})) +def _serialize_field(id: int, value: bytes) -> bytes: + return id.to_bytes() + len(value).to_bytes(2, "big") + value + + +def _serialize_payload(id: int, fields: list[(int, bytes)]) -> bytes: + payload = b"" + + for fid, value in fields: + payload += _serialize_field(fid, value) + + return id.to_bytes() + len(payload).to_bytes(4, "big") + payload + + +def _deserialize_field(stream: bytes) -> tuple[int, bytes]: + id = int.from_bytes(stream[:1], "big") + length = int.from_bytes(stream[1:3], "big") + value = stream[3 : 3 + length] + return id, value + + +# Note: Takes a stream, not a buffer, as we do not know the length of the payload +def _deserialize_payload(stream) -> tuple[int, list[tuple[int, bytes]]] | None: + id = int.from_bytes(stream.read(1), "big") + + if id == 0x0: + return None + + length = int.from_bytes(stream.read(4), "big") + + buffer = stream.read(length) + + fields = [] + + while len(buffer) > 0: + fid, value = _deserialize_field(buffer) + fields.append((fid, value)) + buffer = buffer[3 + len(value) :] + + return id, fields + + +# Returns the value of the first field with the given id +def _get_field(fields: list[tuple[int, bytes]], id: int) -> bytes: + for field_id, value in fields: + if field_id == id: + return value + return None + + +class APNSConnection: + def __init__(self, private_key=None, cert=None): + # Generate the private key and certificate if they're not provided + if private_key is None or cert is None: + self.private_key, self.cert = albert.generate_push_cert() else: - payload = Payload(7, Fields({1: self.token, 2: 0x01.to_bytes()})) - - self.sock.write(payload.to_bytes()) + self.private_key, self.cert = private_key, cert - resp = Payload.from_stream(self.sock) + self.sock = courier.connect(self.private_key, self.cert) - if resp.command != 8 or resp.fields.fields[1] != 0x00.to_bytes(): + def connect(self, token: bytes = None): + if token is None: + payload = _serialize_payload(7, [(2, 0x01.to_bytes())]) + else: + payload = _serialize_payload(7, [(1, token), (2, 0x01.to_bytes())]) + + self.sock.write(payload) + + payload = _deserialize_payload(self.sock) + + if payload == None or payload[0] != 8 or _get_field(payload[1], 1) != 0x00.to_bytes(): raise Exception("Failed to connect") - if 3 in resp.fields.fields: - self.token = resp.fields.fields[3] + self.token = _get_field(payload[1], 3) def filter(self, topics: list[str]): - payload = Payload(9, Fields({1: self.token, 2: b"".join([sha1(topic.encode()).digest() for topic in topics])})) + fields = [(1, self.token)] - self.sock.write(payload.to_bytes()) + for topic in topics: + fields.append((2, sha1(topic.encode()).digest())) - + payload = _serialize_payload(9, fields) -if __name__ == "__main__": - import courier - import base64 - - sock = courier.connect() - - # Try and read the token from the file - try: - with open("token", "r") as f: - r = f.read() - if r == "": - raise FileNotFoundError - payload = Payload(7, Fields({1: base64.b64decode(r), 2: 0x01.to_bytes()})) - except FileNotFoundError: - payload = Payload(7, Fields({2: 0x01.to_bytes()})) - - # Send the connect request (with or without the token) - sock.write(payload.to_bytes()) - - # Read the response - resp = Payload.from_stream(sock) - # Check if the response is valid - if resp.command != 8 or resp.fields.fields[1] != 0x00.to_bytes(): - raise Exception("Failed to connect") - - # If there's a new token, save it - if 3 in resp.fields.fields: - with open("token", "wb") as f: - f.write(base64.b64encode(resp.fields.fields[3])) - - # Send the push topics request + self.sock.write(payload) diff --git a/courier.py b/courier.py index dcfd94d..846eb8f 100644 --- a/courier.py +++ b/courier.py @@ -1,28 +1,23 @@ -import albert import tlslite import socket -COURIER_HOST = "10-courier.push.apple.com" +COURIER_HOST = "10-courier.push.apple.com" # TODO: Get this from config COURIER_PORT = 5223 ALPN = [b"apns-security-v2"] -#ALPN = None - -def connect(private_key=None, cert=None): - # If we don't have a private key or certificate, generate one - if private_key is None or cert is None: - private_key, cert = albert.generate_push_cert() +# Connect to the courier server +def connect(private_key: str, cert: str) -> tlslite.TLSConnection: # Connect to the courier server sock = socket.create_connection((COURIER_HOST, COURIER_PORT)) # Wrap the socket in TLS sock = tlslite.TLSConnection(sock) # Parse the certificate and private key - cert_parsed = tlslite.X509CertChain([tlslite.X509().parse(cert)]) - private_key_parsed = tlslite.parsePEMKey(private_key, private=True) + cert = tlslite.X509CertChain([tlslite.X509().parse(cert)]) + private_key = tlslite.parsePEMKey(private_key, private=True) # Handshake with the server - sock.handshakeClientCert(cert_parsed, private_key_parsed, alpn=ALPN) + sock.handshakeClientCert(cert, private_key, alpn=ALPN) - return sock, private_key, cert + return sock if __name__ == "__main__": sock = connect() diff --git a/demo.py b/demo.py index 4b82754..a10383f 100644 --- a/demo.py +++ b/demo.py @@ -3,33 +3,35 @@ from base64 import b64decode, b64encode from hashlib import sha1 conn1 = apns.APNSConnection() +conn1.connect() print(f"Push Token 1: {b64encode(conn1.token).decode()}") conn2 = apns.APNSConnection() +conn2.connect() print(f"Push Token 2: {b64encode(conn2.token).decode()}") conn1.filter(["com.apple.madrid"]) conn2.filter(["com.apple.madrid"]) -#print(sha1(b"com.apple.madrid").digest()) -# Send a notification -# expiry timestamp in UNIX epoch -expiry = 1680761868 -expiry = expiry.to_bytes(4, "big") +# #print(sha1(b"com.apple.madrid").digest()) +# # Send a notification +# # expiry timestamp in UNIX epoch +# expiry = 1680761868 +# expiry = expiry.to_bytes(4, "big") -# Current time in UNIX nano epoch -import time -now = int(time.time() * 1000).to_bytes(8, "big") +# # Current time in UNIX nano epoch +# import time +# now = int(time.time() * 1000).to_bytes(8, "big") -payload = apns.Payload(0x0a, apns.Fields({1: sha1(b"com.apple.madrid").digest(), 2: conn2.token, 3: b"Hello World!", 4: 0x00.to_bytes(), 5: expiry, 6: now, 7: 0x00.to_bytes()})) -conn1.sock.write(payload.to_bytes()) +# payload = apns.Payload(0x0a, apns.Fields({1: sha1(b"com.apple.madrid").digest(), 2: conn2.token, 3: b"Hello World!", 4: 0x00.to_bytes(), 5: expiry, 6: now, 7: 0x00.to_bytes()})) +# conn1.sock.write(payload.to_bytes()) -print("Waiting for response...") +# print("Waiting for response...") -# Check if the notification was sent -resp = apns.Payload.from_stream(conn1.sock) -print(resp) +# # Check if the notification was sent +# resp = apns.Payload.from_stream(conn1.sock) +# print(resp) -# Read the message from the other connection -resp = apns.Payload.from_stream(conn2.sock) -print(resp) \ No newline at end of file +# # Read the message from the other connection +# resp = apns.Payload.from_stream(conn2.sock) +# print(resp) \ No newline at end of file