pypush-plus-plus/apns.py

386 lines
12 KiB
Python
Raw Normal View History

2023-04-05 18:52:14 -05:00
from __future__ import annotations
2023-08-17 18:48:09 -05:00
import logging
2023-04-11 11:23:04 -05:00
import random
2023-08-17 18:48:09 -05:00
import ssl
2023-04-07 21:32:00 -05:00
import time
2023-08-17 18:48:09 -05:00
from base64 import b64encode
from contextlib import asynccontextmanager
from dataclasses import dataclass
2023-04-11 11:23:04 -05:00
from hashlib import sha1
2023-08-17 18:48:09 -05:00
from typing import Callable
2023-08-17 16:19:48 -05:00
import trio
2023-08-17 18:48:09 -05:00
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding
2023-08-17 16:19:48 -05:00
2023-04-11 11:23:04 -05:00
import albert
2023-05-09 19:01:22 -05:00
import bags
2023-08-17 18:48:09 -05:00
logger = logging.getLogger("apns")
2023-05-09 19:01:22 -05:00
# Pick a random courier server from 01 to APNSCourierHostcount
COURIER_HOST = f"{random.randint(1, bags.apns_init_bag()['APNSCourierHostcount'])}-{bags.apns_init_bag()['APNSCourierHostname']}"
COURIER_PORT = 5223
2023-08-17 14:05:45 -05:00
ALPN = [b"apns-security-v3"]
2023-08-17 18:48:09 -05:00
2023-08-17 16:19:48 -05:00
async def apns_test():
async with APNSConnection.start() as connection:
print(b64encode(connection.credentials.token).decode())
while True:
await trio.sleep(1)
print(".")
2023-08-17 18:48:09 -05:00
# await connection.set_state(1)
2023-08-17 16:19:48 -05:00
print("Finished")
2023-08-17 18:48:09 -05:00
2023-08-17 16:19:48 -05:00
def main():
from rich.logging import RichHandler
2023-07-24 08:18:21 -05:00
2023-08-17 16:19:48 -05:00
logging.basicConfig(
2023-08-17 18:48:09 -05:00
level=logging.NOTSET,
format="%(message)s",
datefmt="[%X]",
handlers=[RichHandler()],
2023-08-17 16:19:48 -05:00
)
2023-05-02 19:53:18 -05:00
2023-08-17 16:19:48 -05:00
# Set sane log levels
logging.getLogger("urllib3").setLevel(logging.WARNING)
2023-08-17 18:48:09 -05:00
logging.getLogger("py.warnings").setLevel(
logging.ERROR
) # Ignore warnings from urllib3
2023-08-17 16:19:48 -05:00
logging.getLogger("asyncio").setLevel(logging.WARNING)
logging.getLogger("jelly").setLevel(logging.INFO)
logging.getLogger("nac").setLevel(logging.INFO)
logging.getLogger("apns").setLevel(logging.DEBUG)
logging.getLogger("albert").setLevel(logging.INFO)
logging.getLogger("ids").setLevel(logging.DEBUG)
logging.getLogger("bags").setLevel(logging.INFO)
logging.getLogger("imessage").setLevel(logging.DEBUG)
2023-05-02 19:53:18 -05:00
2023-08-17 16:19:48 -05:00
logging.captureWarnings(True)
2023-08-17 18:48:09 -05:00
2023-08-17 16:19:48 -05:00
print("APNs Test:")
2023-08-17 18:48:09 -05:00
2023-08-17 16:19:48 -05:00
trio.run(apns_test)
2023-05-02 19:53:18 -05:00
2023-08-17 16:19:48 -05:00
@dataclass
class PushCredentials:
2023-08-17 18:48:09 -05:00
private_key: str = ""
cert: str = ""
token: bytes = b""
2023-05-02 19:53:18 -05:00
2023-08-17 16:19:48 -05:00
class APNSConnection:
2023-08-17 18:48:09 -05:00
"""A connection to the APNs server"""
_incoming_queue: list[APNSPayload] = []
"""A queue of payloads that have been received from the APNs server"""
2023-08-17 16:19:48 -05:00
_queue_park: trio.Event = trio.Event()
2023-08-17 18:48:09 -05:00
"""An event that is set when a new payload is added to the queue"""
async def _send(self, payload: APNSPayload):
"""Sends a payload to the APNs server"""
await payload.write_to_stream(self.sock)
async def _receive(self, id: int, filter: Callable | None = None):
"""
Waits for a payload with the given id to be added to the queue, then returns it.
If filter is not None, it will be called with the payload as an argument, and if it returns False,
the payload will be ignored and another will be waited for.
NOTE: It is not defined what happens if receive is called twice with the same id and filter,
as the first payload will be removed from the queue, so the second call might never return
"""
2023-08-17 16:19:48 -05:00
# Check if anything currently in the queue matches the id
for payload in self._incoming_queue:
2023-08-17 18:48:09 -05:00
if payload.id == id:
2023-08-17 16:19:48 -05:00
return payload
while True:
2023-08-17 18:48:09 -05:00
await self._queue_park.wait() # Wait for a new payload to be added to the queue
2023-08-17 16:19:48 -05:00
logger.debug(f"Woken by event, checking for {id}")
# Check if the new payload matches the id
2023-08-17 18:48:09 -05:00
if self._incoming_queue[-1].id != id:
continue
if filter is not None:
if not filter(self._incoming_queue[-1]):
continue
return self._incoming_queue.pop()
2023-08-17 16:19:48 -05:00
# Otherwise, wait for another payload to be added to the queue
async def _queue_filler(self):
2023-08-17 18:48:09 -05:00
"""Fills the queue with payloads from the APNs socket"""
2023-08-17 16:19:48 -05:00
while True:
2023-08-17 18:48:09 -05:00
payload = await APNSPayload.read_from_stream(self.sock)
2023-08-17 16:19:48 -05:00
logger.debug(f"Received payload: {payload}")
2023-08-17 18:48:09 -05:00
2023-08-17 16:19:48 -05:00
self._incoming_queue.append(payload)
2023-08-17 18:48:09 -05:00
2023-08-17 16:19:48 -05:00
# Signal to any waiting tasks that we have a new payload
self._queue_park.set()
2023-08-17 18:48:09 -05:00
self._queue_park = trio.Event() # Reset the event
2023-08-17 16:19:48 -05:00
logger.debug(f"Queue length: {len(self._incoming_queue)}")
2023-08-17 18:48:09 -05:00
2023-08-17 16:19:48 -05:00
async def _keep_alive(self):
2023-08-17 18:48:09 -05:00
"""Sends keep alive messages to the APNs server every 5 minutes"""
2023-08-17 16:19:48 -05:00
while True:
2023-08-17 18:48:09 -05:00
await trio.sleep(300)
2023-08-17 16:19:48 -05:00
logger.debug("Sending keep alive message")
2023-08-17 18:48:09 -05:00
await self._send(APNSPayload(0x0C, []))
2023-08-17 16:19:48 -05:00
await self._receive(0x0D)
logger.debug("Got keep alive response")
2023-05-02 19:53:18 -05:00
2023-08-17 16:19:48 -05:00
@asynccontextmanager
2023-08-17 18:48:09 -05:00
@staticmethod
async def start(credentials: PushCredentials = PushCredentials()):
2023-08-17 16:19:48 -05:00
"""Sets up a nursery and connection and yields the connection"""
async with trio.open_nursery() as nursery:
connection = APNSConnection(nursery, credentials)
await connection.connect()
yield connection
2023-08-17 18:48:09 -05:00
nursery.cancel_scope.cancel() # Cancel heartbeat and queue filler tasks
await connection.sock.aclose() # Close the socket
2023-05-02 19:53:18 -05:00
2023-08-17 18:48:09 -05:00
def __init__(
self, nursery: trio.Nursery, credentials: PushCredentials = PushCredentials()
):
"""Creates a raw APNSConnection. Make sure to call aclose() on the socket and cancel the nursery when you're done with it"""
2023-08-17 16:19:48 -05:00
self._nursery = nursery
self.credentials = credentials
2023-05-02 19:53:18 -05:00
2023-08-17 16:19:48 -05:00
async def connect(self):
"""Connects to the APNs server and starts the keep alive and queue filler tasks"""
sock = await trio.open_tcp_stream(COURIER_HOST, COURIER_PORT)
2023-08-17 16:19:48 -05:00
context = ssl.SSLContext(ssl.PROTOCOL_TLS)
context.set_alpn_protocols(["apns-security-v3"])
2023-05-02 19:53:18 -05:00
2023-08-17 16:19:48 -05:00
self.sock = trio.SSLStream(sock, context, server_hostname=COURIER_HOST)
2023-04-07 18:53:21 -05:00
2023-08-17 16:19:48 -05:00
await self.sock.do_handshake()
2023-08-17 16:19:48 -05:00
logger.info(f"Connected to APNs ({COURIER_HOST})")
2023-04-05 20:01:07 -05:00
2023-08-17 18:48:09 -05:00
if self.credentials.cert == "" or self.credentials.private_key == "":
(
self.credentials.private_key,
self.credentials.cert,
) = albert.generate_push_cert()
2023-04-05 20:01:07 -05:00
2023-08-17 16:19:48 -05:00
# Start the queue filler and keep alive tasks
self._nursery.start_soon(self._queue_filler)
self._nursery.start_soon(self._keep_alive)
2023-07-28 10:53:13 -05:00
2023-08-17 16:19:48 -05:00
self.credentials.token = await self._connect(self.credentials.token)
2023-07-28 10:53:13 -05:00
2023-08-17 16:19:48 -05:00
async def _connect(self, token: bytes | None = None, root: bool = False) -> bytes:
"""Sends the APNs connect message"""
2023-08-17 18:48:09 -05:00
2023-08-17 16:19:48 -05:00
cert = x509.load_pem_x509_certificate(self.credentials.cert.encode())
2023-08-17 18:48:09 -05:00
private_key = serialization.load_pem_private_key(
self.credentials.private_key.encode(), password=None
)
2023-08-17 14:05:45 -05:00
2023-07-24 08:18:21 -05:00
if token is None:
logger.debug(f"Sending connect message without token (root={root})")
else:
2023-08-17 18:48:09 -05:00
logger.debug(
f"Sending connect message with token {b64encode(token).decode()} (root={root})"
)
2023-04-07 15:24:05 -05:00
flags = 0b01000001
if root:
flags |= 0b0100
2023-04-11 11:23:04 -05:00
2023-08-17 14:05:45 -05:00
cert = cert.public_bytes(serialization.Encoding.DER)
2023-08-17 18:48:09 -05:00
nonce = (
b"\x00" + int(time.time() * 1000).to_bytes(8, "big") + random.randbytes(8)
)
signature = b"\x01\x01" + private_key.sign(nonce, padding.PKCS1v15(), hashes.SHA1()) # type: ignore
payload = APNSPayload(
7,
[
APNSField(0x2, b"\x01"),
APNSField(0x5, flags.to_bytes(4, "big")),
APNSField(0xC, cert),
APNSField(0xD, nonce),
APNSField(0xE, signature),
],
)
2023-08-17 16:19:48 -05:00
if token is not None:
2023-08-17 18:48:09 -05:00
payload.fields.insert(0, APNSField(0x1, token))
await self._send(payload)
2023-08-17 16:19:48 -05:00
payload = await self._receive(8)
2023-08-17 14:05:45 -05:00
2023-08-17 18:48:09 -05:00
if payload.fields_with_id(1)[0].value != b"\x00":
2023-04-05 20:01:07 -05:00
raise Exception("Failed to connect")
2023-08-17 18:48:09 -05:00
new_token = payload.fields_with_id(3)[0].value
logger.debug(
f"Received connect response with token {b64encode(new_token).decode()}"
)
2023-08-17 16:19:48 -05:00
return new_token
2023-04-07 15:24:05 -05:00
2023-08-17 18:48:09 -05:00
async def filter(self, topics: list[str]):
"""Sends the APNs filter message"""
2023-07-24 08:18:21 -05:00
logger.debug(f"Sending filter message with topics {topics}")
2023-08-17 18:48:09 -05:00
payload = APNSPayload(9, [APNSField(1, self.credentials.token)])
2023-08-17 18:48:09 -05:00
for topic in topics:
payload.fields.append(APNSField(2, sha1(topic.encode()).digest()))
2023-08-17 18:48:09 -05:00
await payload.write_to_stream(self.sock)
2023-04-11 11:23:04 -05:00
2023-08-17 18:48:09 -05:00
async def send_notification(self, topic: str, payload: bytes, id=None):
"""Sends a notification to the APNs server"""
if id is None:
id = random.randbytes(4)
2023-04-11 11:23:04 -05:00
2023-08-17 18:48:09 -05:00
p = APNSPayload(
0xA,
2023-04-11 11:23:04 -05:00
[
2023-08-17 18:48:09 -05:00
APNSField(4, id),
APNSField(1, sha1(topic.encode()).digest()),
APNSField(2, self.credentials.token),
APNSField(3, payload),
2023-04-11 11:23:04 -05:00
],
)
2023-08-17 18:48:09 -05:00
await self._send(p)
2023-04-07 00:48:07 -05:00
2023-05-02 19:51:02 -05:00
# Wait for ACK
2023-08-17 18:48:09 -05:00
r = await self._receive(0xB)
2023-05-02 19:51:02 -05:00
2023-08-17 18:48:09 -05:00
# TODO: Check ACK code
2023-04-07 21:32:00 -05:00
2023-08-17 18:48:09 -05:00
async def expect_notification(self, topic: str, filter: Callable | None = None):
"""Waits for a notification to be received, and acks it"""
2023-04-07 21:32:00 -05:00
2023-08-17 18:48:09 -05:00
def f(payload: APNSPayload):
if payload.fields_with_id(1)[0].value != sha1(topic.encode()).digest():
return False
if filter is not None:
return filter(payload)
return True
2023-04-07 21:32:00 -05:00
2023-08-17 18:48:09 -05:00
r = await self._receive(0xA, f)
await self._send_ack(r.fields_with_id(4)[0].value)
return r
2023-04-07 21:32:00 -05:00
2023-08-17 18:48:09 -05:00
async def set_state(self, state: int):
"""Sends the APNs state message"""
logger.debug(f"Sending state message with state {state}")
await self._send(
APNSPayload(
0x14,
[
APNSField(1, state.to_bytes(4, "big")),
APNSField(2, 0x7FFFFFFF.to_bytes(4, "big")),
],
)
)
2023-04-11 11:23:04 -05:00
2023-08-17 18:48:09 -05:00
async def _send_ack(self, id: bytes):
"""Sends an ACK for a notification with the given id"""
logger.debug(f"Sending ACK for message {id}")
payload = APNSPayload(
0xB,
[
APNSField(1, self.credentials.token),
APNSField(4, id),
APNSField(8, b"\x00"),
],
)
await payload.write_to_stream(self.sock)
2023-04-07 21:32:00 -05:00
2023-08-17 18:48:09 -05:00
@dataclass
class APNSField:
"""A field in an APNS payload"""
id: int
value: bytes
@staticmethod
def from_buffer(stream: bytes) -> APNSField:
id = int.from_bytes(stream[:1], "big")
length = int.from_bytes(stream[1:3], "big")
value = stream[3 : 3 + length]
return APNSField(id, value)
def to_buffer(self) -> bytes:
return (
self.id.to_bytes(1, "big") + len(self.value).to_bytes(2, "big") + self.value
)
2023-08-17 16:19:48 -05:00
2023-04-07 21:32:00 -05:00
2023-08-17 18:48:09 -05:00
@dataclass
class APNSPayload:
"""An APNS payload"""
2023-04-07 21:32:00 -05:00
2023-08-17 18:48:09 -05:00
id: int
fields: list[APNSField]
2023-04-07 21:32:00 -05:00
2023-08-17 18:48:09 -05:00
@staticmethod
async def read_from_stream(stream: trio.abc.Stream) -> APNSPayload:
"""Reads a payload from the given stream"""
id = await stream.receive_some(1)
if id is None:
raise Exception("Unable to read payload id from stream")
id = int.from_bytes(id, "big")
2023-04-07 21:32:00 -05:00
2023-08-17 18:48:09 -05:00
if id == 0x0:
raise Exception("Received id 0x0, which is not valid")
2023-04-07 21:32:00 -05:00
2023-08-17 18:48:09 -05:00
length = await stream.receive_some(4)
if length is None:
raise Exception("Unable to read payload length from stream")
length = int.from_bytes(length, "big")
2023-04-11 11:23:04 -05:00
2023-08-17 18:48:09 -05:00
if length == 0:
return APNSPayload(id, [])
2023-04-07 21:32:00 -05:00
2023-08-17 18:48:09 -05:00
buffer = await stream.receive_some(length)
if buffer is None:
raise Exception("Unable to read payload from stream")
fields = []
2023-04-07 21:32:00 -05:00
2023-08-17 18:48:09 -05:00
while len(buffer) > 0:
field = APNSField.from_buffer(buffer)
fields.append(field)
buffer = buffer[3 + len(field.value) :]
2023-04-07 21:32:00 -05:00
2023-08-17 18:48:09 -05:00
return APNSPayload(id, fields)
2023-04-07 21:32:00 -05:00
2023-08-17 18:48:09 -05:00
async def write_to_stream(self, stream: trio.abc.Stream):
"""Writes the payload to the given stream"""
payload = b""
2023-04-07 21:32:00 -05:00
2023-08-17 18:48:09 -05:00
for field in self.fields:
payload += field.to_buffer()
2023-04-07 21:32:00 -05:00
2023-08-17 18:48:09 -05:00
buffer = self.id.to_bytes(1, "big") + len(payload).to_bytes(4, "big") + payload
2023-04-07 21:32:00 -05:00
2023-08-17 18:48:09 -05:00
await stream.send_all(buffer)
2023-04-07 21:32:00 -05:00
2023-08-17 18:48:09 -05:00
def fields_with_id(self, id: int):
"""Returns all fields with the given id"""
return [field for field in self.fields if field.id == id]
2023-04-07 21:32:00 -05:00
2023-08-17 16:19:48 -05:00
if __name__ == "__main__":
2023-08-17 18:48:09 -05:00
main()