diff --git a/apns.py b/apns.py index 755abf6..81d5a3b 100644 --- a/apns.py +++ b/apns.py @@ -8,11 +8,7 @@ from hashlib import sha1 from base64 import b64encode, b64decode import logging logger = logging.getLogger("apns") - -import tlslite -if tlslite.__version__ != "0.8.0-alpha43": - logger.warning("tlslite-ng is not the correct version!") - logger.warning("Please install tlslite-ng==0.8.0a43 or you will experience issues!") +import ssl import albert import bags @@ -21,24 +17,23 @@ import bags # Pick a random courier server from 01 to APNSCourierHostcount COURIER_HOST = f"{random.randint(1, bags.apns_init_bag()['APNSCourierHostcount'])}-{bags.apns_init_bag()['APNSCourierHostname']}" COURIER_PORT = 5223 -ALPN = [b"apns-security-v2"] +ALPN = [b"apns-security-v3"] # Connect to the courier server -def _connect(private_key: str, cert: str) -> tlslite.TLSConnection: +def _connect(private_key: str, cert: str) -> ssl.SSLSocket: # Connect to the courier server sock = socket.create_connection((COURIER_HOST, COURIER_PORT)) + context = ssl.SSLContext(ssl.PROTOCOL_TLS) + context.set_alpn_protocols(["apns-security-v3"]) # Wrap the socket in TLS - sock = tlslite.TLSConnection(sock) - # Parse the certificate and private key - cert = tlslite.X509CertChain([tlslite.X509().parse(cert)]) - private_key = tlslite.parsePEMKey(private_key, private=True) + ssock = context.wrap_socket(sock) # Handshake with the server - sock.handshakeClientCert(cert, private_key, alpn=ALPN) + ssock.do_handshake() logger.info(f"Connected to APNs ({COURIER_HOST})") - return sock + return ssock class IncomingQueue: @@ -95,7 +90,7 @@ class APNSConnection: self.incoming_queue = IncomingQueue() def _queue_filler(self): - while True and not self.sock.closed: + while True: payload = _deserialize_payload(self.sock) if payload is not None: @@ -108,7 +103,7 @@ class APNSConnection: logger.debug(f"Queue length: {len(self.incoming_queue)}") def _keep_alive_loop(self): - while True and not self.sock.closed: + while True: time.sleep(300) self._keep_alive() @@ -134,6 +129,13 @@ class APNSConnection: def connect(self, root: bool = True, token: bytes = None): + # Parse self.certificate + from cryptography import x509 + cert = x509.load_pem_x509_certificate(self.cert.encode()) + # Parse private key + from cryptography.hazmat.primitives import serialization + private_key = serialization.load_pem_private_key(self.private_key.encode(), password=None) + if token is None: logger.debug(f"Sending connect message without token (root={root})") else: @@ -142,9 +144,19 @@ class APNSConnection: if root: flags |= 0b0100 + # 1 byte fixed 00, 8 bytes timestamp (milliseconds since Unix epoch), 8 bytes random + cert = cert.public_bytes(serialization.Encoding.DER) + nonce = b"\x00" + int(time.time() * 1000).to_bytes(8, "big") + random.randbytes(8) + #signature = private_key.sign(nonce, signature_algorithm=serialization.NoEncryption()) + # RSASSA-PKCS1-SHA1 + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives.asymmetric import padding + signature = b"\x01\x01" + private_key.sign(nonce, padding.PKCS1v15(), hashes.SHA1()) + + if token is None: payload = _serialize_payload( - 7, [(2, 0x01.to_bytes(1, "big")), (5, flags.to_bytes(4, "big"))] + 7, [(2, 0x01.to_bytes(1, "big")), (5, flags.to_bytes(4, "big")), (0x0c, cert), (0x0d, nonce), (0x0e, signature)] ) else: payload = _serialize_payload( @@ -153,10 +165,15 @@ class APNSConnection: (1, token), (2, 0x01.to_bytes(1, "big")), (5, flags.to_bytes(4, "big")), + (0x0c, cert), + (0x0d, nonce), + (0x0e, signature), + ], ) - self.sock.write(payload) + #self.sock.write(payload) + self.sock.sendall(payload) payload = self.incoming_queue.wait_pop_find(lambda i: i[0] == 8) @@ -188,7 +205,7 @@ class APNSConnection: payload = _serialize_payload(9, fields) - self.sock.write(payload) + self.sock.sendall(payload) def send_message(self, topic: str, payload: str, id=None): logger.debug(f"Sending message to topic {topic} with payload {payload}") @@ -205,7 +222,7 @@ class APNSConnection: ], ) - self.sock.write(payload) + self.sock.sendall(payload) # Wait for ACK payload = self.incoming_queue.wait_pop_find(lambda i: i[0] == 0x0B) @@ -215,7 +232,7 @@ class APNSConnection: def set_state(self, state: int): logger.debug(f"Sending state message with state {state}") - self.sock.write( + self.sock.sendall( _serialize_payload( 0x14, [(1, state.to_bytes(1, "big")), (2, 0x7FFFFFFF.to_bytes(4, "big"))], @@ -224,7 +241,7 @@ class APNSConnection: def _keep_alive(self): logger.debug("Sending keep alive message") - self.sock.write(_serialize_payload(0x0C, [])) + self.sock.sendall(_serialize_payload(0x0C, [])) # Remove any keep alive responses we have or missed self.incoming_queue.remove_all(0x0D) @@ -232,7 +249,7 @@ class APNSConnection: def _send_ack(self, id: bytes): logger.debug(f"Sending ACK for message {id}") payload = _serialize_payload(0x0B, [(1, self.token), (4, id), (8, b"\x00")]) - self.sock.write(payload) + self.sock.sendall(payload) # #self.sock.write(_serialize_payload(0x0B, [(4, id)]) # #pass @@ -270,7 +287,7 @@ def _deserialize_field(stream: bytes) -> tuple[int, bytes]: # Note: Takes a stream, not a buffer, as we do not know the length of the payload # WILL BLOCK IF THE STREAM IS EMPTY -def _deserialize_payload(stream) -> tuple[int, list[tuple[int, bytes]]] | None: +def _deserialize_payload(stream: ssl.SSLSocket) -> tuple[int, list[tuple[int, bytes]]] | None: id = int.from_bytes(stream.read(1), "big") if id == 0x0: diff --git a/requirements.txt b/requirements.txt index 3d7c2f4..d4f987a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ requests cryptography wheel -tlslite-ng==0.8.0a43 srp pbkdf2 unicorn