majorly refactor, fix types

This commit is contained in:
JJTech0130 2023-08-17 19:48:09 -04:00
parent e49fe9f916
commit b5f644b989
No known key found for this signature in database
GPG key ID: 23C92EBCCF8F93D6

420
apns.py
View file

@ -1,47 +1,57 @@
from __future__ import annotations from __future__ import annotations
import random
import socket
import threading
import time
from hashlib import sha1
from base64 import b64encode, b64decode
import logging import logging
logger = logging.getLogger("apns") import random
import ssl import ssl
import time
from base64 import b64encode
from contextlib import asynccontextmanager
from dataclasses import dataclass
from hashlib import sha1
from typing import Callable
import trio import trio
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding
import albert import albert
import bags import bags
#COURIER_HOST = "windows.courier.push.apple.com" # TODO: Get this from config logger = logging.getLogger("apns")
# Pick a random courier server from 01 to APNSCourierHostcount # Pick a random courier server from 01 to APNSCourierHostcount
COURIER_HOST = f"{random.randint(1, bags.apns_init_bag()['APNSCourierHostcount'])}-{bags.apns_init_bag()['APNSCourierHostname']}" COURIER_HOST = f"{random.randint(1, bags.apns_init_bag()['APNSCourierHostcount'])}-{bags.apns_init_bag()['APNSCourierHostname']}"
COURIER_PORT = 5223 COURIER_PORT = 5223
ALPN = [b"apns-security-v3"] ALPN = [b"apns-security-v3"]
async def apns_test(): async def apns_test():
async with APNSConnection.start() as connection: async with APNSConnection.start() as connection:
print(b64encode(connection.credentials.token).decode()) print(b64encode(connection.credentials.token).decode())
while True: while True:
await trio.sleep(1) await trio.sleep(1)
print(".") print(".")
#await connection.set_state(1) # await connection.set_state(1)
print("Finished") print("Finished")
def main(): def main():
from rich.logging import RichHandler from rich.logging import RichHandler
logging.basicConfig( logging.basicConfig(
level=logging.NOTSET, format="%(message)s", datefmt="[%X]", handlers=[RichHandler()] level=logging.NOTSET,
format="%(message)s",
datefmt="[%X]",
handlers=[RichHandler()],
) )
# Set sane log levels # Set sane log levels
logging.getLogger("urllib3").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("py.warnings").setLevel(logging.ERROR) # Ignore warnings from urllib3 logging.getLogger("py.warnings").setLevel(
logging.ERROR
) # Ignore warnings from urllib3
logging.getLogger("asyncio").setLevel(logging.WARNING) logging.getLogger("asyncio").setLevel(logging.WARNING)
logging.getLogger("jelly").setLevel(logging.INFO) logging.getLogger("jelly").setLevel(logging.INFO)
logging.getLogger("nac").setLevel(logging.INFO) logging.getLogger("nac").setLevel(logging.INFO)
@ -52,70 +62,96 @@ def main():
logging.getLogger("imessage").setLevel(logging.DEBUG) logging.getLogger("imessage").setLevel(logging.DEBUG)
logging.captureWarnings(True) logging.captureWarnings(True)
print("APNs Test:") print("APNs Test:")
trio.run(apns_test) trio.run(apns_test)
from contextlib import asynccontextmanager
from dataclasses import dataclass
@dataclass @dataclass
class PushCredentials: class PushCredentials:
private_key: str private_key: str = ""
cert: str cert: str = ""
token: bytes token: bytes = b""
class APNSConnection: class APNSConnection:
_incoming_queue: list = [] # We don't need a lock because this is trio and we only have one thread """A connection to the APNs server"""
_queue_park: trio.Event = trio.Event()
_incoming_queue: list[APNSPayload] = []
"""A queue of payloads that have been received from the APNs server"""
_queue_park: trio.Event = trio.Event()
"""An event that is set when a new payload is added to the queue"""
async def _send(self, payload: APNSPayload):
"""Sends a payload to the APNs server"""
await payload.write_to_stream(self.sock)
async def _receive(self, id: int, filter: Callable | None = None):
"""
Waits for a payload with the given id to be added to the queue, then returns it.
If filter is not None, it will be called with the payload as an argument, and if it returns False,
the payload will be ignored and another will be waited for.
NOTE: It is not defined what happens if receive is called twice with the same id and filter,
as the first payload will be removed from the queue, so the second call might never return
"""
async def _send(self, id: int, fields: list[tuple[int, bytes]]):
payload = _serialize_payload(id, fields)
await self.sock.send_all(payload)
async def _receive(self, id: int):
# Check if anything currently in the queue matches the id # Check if anything currently in the queue matches the id
for payload in self._incoming_queue: for payload in self._incoming_queue:
if payload[0] == id: if payload.id == id:
return payload return payload
while True: while True:
await self._queue_park.wait() # Wait for a new payload to be added to the queue await self._queue_park.wait() # Wait for a new payload to be added to the queue
logger.debug(f"Woken by event, checking for {id}") logger.debug(f"Woken by event, checking for {id}")
# Check if the new payload matches the id # Check if the new payload matches the id
if self._incoming_queue[-1][0] == id: if self._incoming_queue[-1].id != id:
return self._incoming_queue.pop() continue
if filter is not None:
if not filter(self._incoming_queue[-1]):
continue
return self._incoming_queue.pop()
# Otherwise, wait for another payload to be added to the queue # Otherwise, wait for another payload to be added to the queue
async def _queue_filler(self): async def _queue_filler(self):
"""Fills the queue with payloads from the APNs socket"""
while True: while True:
payload = await _deserialize_payload(self.sock) payload = await APNSPayload.read_from_stream(self.sock)
logger.debug(f"Received payload: {payload}") logger.debug(f"Received payload: {payload}")
self._incoming_queue.append(payload) self._incoming_queue.append(payload)
# Signal to any waiting tasks that we have a new payload # Signal to any waiting tasks that we have a new payload
self._queue_park.set() self._queue_park.set()
self._queue_park = trio.Event() # Reset the event self._queue_park = trio.Event() # Reset the event
logger.debug(f"Queue length: {len(self._incoming_queue)}") logger.debug(f"Queue length: {len(self._incoming_queue)}")
async def _keep_alive(self): async def _keep_alive(self):
"""Sends keep alive messages to the APNs server every 5 minutes"""
while True: while True:
#await trio.sleep(300) await trio.sleep(300)
await trio.sleep(1)
logger.debug("Sending keep alive message") logger.debug("Sending keep alive message")
await self._send(0x0C, []) await self._send(APNSPayload(0x0C, []))
await self._receive(0x0D) await self._receive(0x0D)
logger.debug("Got keep alive response") logger.debug("Got keep alive response")
@asynccontextmanager @asynccontextmanager
async def start(credentials: PushCredentials | None = None): @staticmethod
async def start(credentials: PushCredentials = PushCredentials()):
"""Sets up a nursery and connection and yields the connection""" """Sets up a nursery and connection and yields the connection"""
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
connection = APNSConnection(nursery, credentials) connection = APNSConnection(nursery, credentials)
await connection.connect() await connection.connect()
yield connection yield connection
nursery.cancel_scope.cancel() # Cancel heartbeat and queue filler tasks nursery.cancel_scope.cancel() # Cancel heartbeat and queue filler tasks
await connection.sock.aclose() # Close the socket await connection.sock.aclose() # Close the socket
def __init__(self, nursery: trio.Nursery, credentials: PushCredentials | None = None): def __init__(
self, nursery: trio.Nursery, credentials: PushCredentials = PushCredentials()
):
"""Creates a raw APNSConnection. Make sure to call aclose() on the socket and cancel the nursery when you're done with it"""
self._nursery = nursery self._nursery = nursery
self.credentials = credentials self.credentials = credentials
@ -132,8 +168,11 @@ class APNSConnection:
logger.info(f"Connected to APNs ({COURIER_HOST})") logger.info(f"Connected to APNs ({COURIER_HOST})")
if self.credentials is None: if self.credentials.cert == "" or self.credentials.private_key == "":
self.credentials = PushCredentials(*albert.generate_push_cert(), None) (
self.credentials.private_key,
self.credentials.cert,
) = albert.generate_push_cert()
# Start the queue filler and keep alive tasks # Start the queue filler and keep alive tasks
self._nursery.start_soon(self._queue_filler) self._nursery.start_soon(self._queue_filler)
@ -143,187 +182,204 @@ class APNSConnection:
async def _connect(self, token: bytes | None = None, root: bool = False) -> bytes: async def _connect(self, token: bytes | None = None, root: bool = False) -> bytes:
"""Sends the APNs connect message""" """Sends the APNs connect message"""
# Parse self.certificate
from cryptography import x509
cert = x509.load_pem_x509_certificate(self.credentials.cert.encode()) cert = x509.load_pem_x509_certificate(self.credentials.cert.encode())
# Parse private key private_key = serialization.load_pem_private_key(
from cryptography.hazmat.primitives import serialization self.credentials.private_key.encode(), password=None
private_key = serialization.load_pem_private_key(self.credentials.private_key.encode(), password=None) )
if token is None: if token is None:
logger.debug(f"Sending connect message without token (root={root})") logger.debug(f"Sending connect message without token (root={root})")
else: else:
logger.debug(f"Sending connect message with token {b64encode(token).decode()} (root={root})") logger.debug(
f"Sending connect message with token {b64encode(token).decode()} (root={root})"
)
flags = 0b01000001 flags = 0b01000001
if root: if root:
flags |= 0b0100 flags |= 0b0100
# 1 byte fixed 00, 8 bytes timestamp (milliseconds since Unix epoch), 8 bytes random
cert = cert.public_bytes(serialization.Encoding.DER) cert = cert.public_bytes(serialization.Encoding.DER)
nonce = b"\x00" + int(time.time() * 1000).to_bytes(8, "big") + random.randbytes(8) nonce = (
#signature = private_key.sign(nonce, signature_algorithm=serialization.NoEncryption()) b"\x00" + int(time.time() * 1000).to_bytes(8, "big") + random.randbytes(8)
# RSASSA-PKCS1-SHA1 )
from cryptography.hazmat.primitives import hashes signature = b"\x01\x01" + private_key.sign(nonce, padding.PKCS1v15(), hashes.SHA1()) # type: ignore
from cryptography.hazmat.primitives.asymmetric import padding
signature = b"\x01\x01" + private_key.sign(nonce, padding.PKCS1v15(), hashes.SHA1())
payload = [
(2, b"\x01"),
(5, flags.to_bytes(4, "big")),
(0x0c, cert),
(0x0d, nonce),
(0x0e, signature),
]
if token is not None:
payload.insert(0, (1, token))
await self._send(7, payload)
payload = await self._receive(8)
if _get_field(payload[1], 1) != b"\x00": payload = APNSPayload(
raise Exception("Failed to connect") 7,
new_token = _get_field(payload[1], 3)
logger.debug(f"Recieved connect response with token {b64encode(new_token).decode()}")
return new_token
def filter(self, topics: list[str]):
logger.debug(f"Sending filter message with topics {topics}")
fields = [(1, self.token)]
for topic in topics:
fields.append((2, sha1(topic.encode()).digest()))
payload = _serialize_payload(9, fields)
self.sock.sendall(payload)
def send_message(self, topic: str, payload: str, id=None):
logger.debug(f"Sending message to topic {topic} with payload {payload}")
if id is None:
id = random.randbytes(4)
payload = _serialize_payload(
0x0A,
[ [
(4, id), APNSField(0x2, b"\x01"),
(1, sha1(topic.encode()).digest()), APNSField(0x5, flags.to_bytes(4, "big")),
(2, self.token), APNSField(0xC, cert),
(3, payload), APNSField(0xD, nonce),
APNSField(0xE, signature),
], ],
) )
self.sock.sendall(payload) if token is not None:
payload.fields.insert(0, APNSField(0x1, token))
await self._send(payload)
payload = await self._receive(8)
if payload.fields_with_id(1)[0].value != b"\x00":
raise Exception("Failed to connect")
new_token = payload.fields_with_id(3)[0].value
logger.debug(
f"Received connect response with token {b64encode(new_token).decode()}"
)
return new_token
async def filter(self, topics: list[str]):
"""Sends the APNs filter message"""
logger.debug(f"Sending filter message with topics {topics}")
payload = APNSPayload(9, [APNSField(1, self.credentials.token)])
for topic in topics:
payload.fields.append(APNSField(2, sha1(topic.encode()).digest()))
await payload.write_to_stream(self.sock)
async def send_notification(self, topic: str, payload: bytes, id=None):
"""Sends a notification to the APNs server"""
if id is None:
id = random.randbytes(4)
p = APNSPayload(
0xA,
[
APNSField(4, id),
APNSField(1, sha1(topic.encode()).digest()),
APNSField(2, self.credentials.token),
APNSField(3, payload),
],
)
await self._send(p)
# Wait for ACK # Wait for ACK
payload = self.incoming_queue.wait_pop_find(lambda i: i[0] == 0x0B) r = await self._receive(0xB)
if payload[1][0][1] != 0x00.to_bytes(1, "big"): # TODO: Check ACK code
raise Exception("Failed to send message")
async def expect_notification(self, topic: str, filter: Callable | None = None):
"""Waits for a notification to be received, and acks it"""
def f(payload: APNSPayload):
if payload.fields_with_id(1)[0].value != sha1(topic.encode()).digest():
return False
if filter is not None:
return filter(payload)
return True
r = await self._receive(0xA, f)
await self._send_ack(r.fields_with_id(4)[0].value)
return r
async def set_state(self, state: int): async def set_state(self, state: int):
"""Sends the APNs state message"""
logger.debug(f"Sending state message with state {state}") logger.debug(f"Sending state message with state {state}")
await self._send(0x14, [(1, b"\x01"), (2, 0x7FFFFFFF.to_bytes(4, "big"))]) await self._send(
APNSPayload(
0x14,
[
APNSField(1, state.to_bytes(4, "big")),
APNSField(2, 0x7FFFFFFF.to_bytes(4, "big")),
],
)
)
def _send_ack(self, id: bytes): async def _send_ack(self, id: bytes):
"""Sends an ACK for a notification with the given id"""
logger.debug(f"Sending ACK for message {id}") logger.debug(f"Sending ACK for message {id}")
payload = _serialize_payload(0x0B, [(1, self.token), (4, id), (8, b"\x00")]) payload = APNSPayload(
self.sock.sendall(payload) 0xB,
# #self.sock.write(_serialize_payload(0x0B, [(4, id)]) [
# #pass APNSField(1, self.credentials.token),
APNSField(4, id),
# def recieve_message(self): APNSField(8, b"\x00"),
# payload = self.incoming_queue.wait_pop_find(lambda i: i[0] == 0x0A) ],
# # Send ACK )
# self._send_ack(_get_field(payload[1], 4)) await payload.write_to_stream(self.sock)
# return _get_field(payload[1], 3)
# TODO: Find a way to make this non-blocking
# def expect_message(self) -> tuple[int, list[tuple[int, bytes]]] | None:
# return _deserialize_payload(self.sock)
def _serialize_field(id: int, value: bytes) -> bytes: @dataclass
return id.to_bytes(1, "big") + len(value).to_bytes(2, "big") + value class APNSField:
"""A field in an APNS payload"""
id: int
value: bytes
@staticmethod
def from_buffer(stream: bytes) -> APNSField:
id = int.from_bytes(stream[:1], "big")
length = int.from_bytes(stream[1:3], "big")
value = stream[3 : 3 + length]
return APNSField(id, value)
def to_buffer(self) -> bytes:
return (
self.id.to_bytes(1, "big") + len(self.value).to_bytes(2, "big") + self.value
)
def _serialize_payload(id: int, fields: list[(int, bytes)]) -> bytes: @dataclass
payload = b"" class APNSPayload:
"""An APNS payload"""
for fid, value in fields: id: int
if fid is not None: fields: list[APNSField]
payload += _serialize_field(fid, value)
return id.to_bytes(1, "big") + len(payload).to_bytes(4, "big") + payload @staticmethod
async def read_from_stream(stream: trio.abc.Stream) -> APNSPayload:
"""Reads a payload from the given stream"""
id = await stream.receive_some(1)
if id is None:
raise Exception("Unable to read payload id from stream")
id = int.from_bytes(id, "big")
if id == 0x0:
raise Exception("Received id 0x0, which is not valid")
def _deserialize_field(stream: bytes) -> tuple[int, bytes]: length = await stream.receive_some(4)
id = int.from_bytes(stream[:1], "big") if length is None:
length = int.from_bytes(stream[1:3], "big") raise Exception("Unable to read payload length from stream")
value = stream[3 : 3 + length] length = int.from_bytes(length, "big")
return id, value
if length == 0:
return APNSPayload(id, [])
# Note: Takes a stream, not a buffer, as we do not know the length of the payload buffer = await stream.receive_some(length)
# WILL BLOCK IF THE STREAM IS EMPTY if buffer is None:
async def _deserialize_payload(stream: trio.SSLStream) -> tuple[int, list[tuple[int, bytes]]] | None: raise Exception("Unable to read payload from stream")
id = int.from_bytes(await stream.receive_some(1), "big") fields = []
if id == 0x0: while len(buffer) > 0:
return None field = APNSField.from_buffer(buffer)
fields.append(field)
buffer = buffer[3 + len(field.value) :]
length = int.from_bytes(await stream.receive_some(4), "big") return APNSPayload(id, fields)
if length == 0: async def write_to_stream(self, stream: trio.abc.Stream):
return id, [] """Writes the payload to the given stream"""
payload = b""
buffer = await stream.receive_some(length) for field in self.fields:
payload += field.to_buffer()
fields = [] buffer = self.id.to_bytes(1, "big") + len(payload).to_bytes(4, "big") + payload
while len(buffer) > 0: await stream.send_all(buffer)
fid, value = _deserialize_field(buffer)
fields.append((fid, value))
buffer = buffer[3 + len(value) :]
return id, fields def fields_with_id(self, id: int):
"""Returns all fields with the given id"""
return [field for field in self.fields if field.id == id]
def _deserialize_payload_from_buffer(
buffer: bytes,
) -> tuple[int, list[tuple[int, bytes]]] | None:
id = int.from_bytes(buffer[:1], "big")
if id == 0x0:
return None
length = int.from_bytes(buffer[1:5], "big")
buffer = buffer[5:]
if len(buffer) < length:
raise Exception("Buffer is too short")
fields = []
while len(buffer) > 0:
fid, value = _deserialize_field(buffer)
fields.append((fid, value))
buffer = buffer[3 + len(value) :]
return id, fields
# Returns the value of the first field with the given id
def _get_field(fields: list[tuple[int, bytes]], id: int) -> bytes:
for field_id, value in fields:
if field_id == id:
return value
return None
if __name__ == "__main__": if __name__ == "__main__":
main() main()