mirror of
https://github.com/Sneed-Group/pypush-plus-plus
synced 2025-01-09 17:33:47 +00:00
start async rewrite
This commit is contained in:
parent
35e4254ded
commit
e49fe9f916
1 changed files with 138 additions and 149 deletions
289
apns.py
289
apns.py
|
@ -10,6 +10,8 @@ import logging
|
|||
logger = logging.getLogger("apns")
|
||||
import ssl
|
||||
|
||||
import trio
|
||||
|
||||
import albert
|
||||
import bags
|
||||
|
||||
|
@ -19,122 +21,134 @@ COURIER_HOST = f"{random.randint(1, bags.apns_init_bag()['APNSCourierHostcount']
|
|||
COURIER_PORT = 5223
|
||||
ALPN = [b"apns-security-v3"]
|
||||
|
||||
async def apns_test():
|
||||
async with APNSConnection.start() as connection:
|
||||
print(b64encode(connection.credentials.token).decode())
|
||||
while True:
|
||||
await trio.sleep(1)
|
||||
print(".")
|
||||
#await connection.set_state(1)
|
||||
|
||||
|
||||
print("Finished")
|
||||
def main():
|
||||
from rich.logging import RichHandler
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.NOTSET, format="%(message)s", datefmt="[%X]", handlers=[RichHandler()]
|
||||
)
|
||||
|
||||
# Set sane log levels
|
||||
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
||||
logging.getLogger("py.warnings").setLevel(logging.ERROR) # Ignore warnings from urllib3
|
||||
logging.getLogger("asyncio").setLevel(logging.WARNING)
|
||||
logging.getLogger("jelly").setLevel(logging.INFO)
|
||||
logging.getLogger("nac").setLevel(logging.INFO)
|
||||
logging.getLogger("apns").setLevel(logging.DEBUG)
|
||||
logging.getLogger("albert").setLevel(logging.INFO)
|
||||
logging.getLogger("ids").setLevel(logging.DEBUG)
|
||||
logging.getLogger("bags").setLevel(logging.INFO)
|
||||
logging.getLogger("imessage").setLevel(logging.DEBUG)
|
||||
|
||||
logging.captureWarnings(True)
|
||||
print("APNs Test:")
|
||||
trio.run(apns_test)
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class PushCredentials:
|
||||
private_key: str
|
||||
cert: str
|
||||
token: bytes
|
||||
|
||||
class APNSConnection:
|
||||
_incoming_queue: list = [] # We don't need a lock because this is trio and we only have one thread
|
||||
_queue_park: trio.Event = trio.Event()
|
||||
|
||||
async def _send(self, id: int, fields: list[tuple[int, bytes]]):
|
||||
payload = _serialize_payload(id, fields)
|
||||
await self.sock.send_all(payload)
|
||||
|
||||
async def _receive(self, id: int):
|
||||
# Check if anything currently in the queue matches the id
|
||||
for payload in self._incoming_queue:
|
||||
if payload[0] == id:
|
||||
return payload
|
||||
while True:
|
||||
await self._queue_park.wait() # Wait for a new payload to be added to the queue
|
||||
logger.debug(f"Woken by event, checking for {id}")
|
||||
# Check if the new payload matches the id
|
||||
if self._incoming_queue[-1][0] == id:
|
||||
return self._incoming_queue.pop()
|
||||
# Otherwise, wait for another payload to be added to the queue
|
||||
|
||||
async def _queue_filler(self):
|
||||
while True:
|
||||
payload = await _deserialize_payload(self.sock)
|
||||
|
||||
logger.debug(f"Received payload: {payload}")
|
||||
self._incoming_queue.append(payload)
|
||||
# Signal to any waiting tasks that we have a new payload
|
||||
self._queue_park.set()
|
||||
self._queue_park = trio.Event() # Reset the event
|
||||
logger.debug(f"Queue length: {len(self._incoming_queue)}")
|
||||
|
||||
async def _keep_alive(self):
|
||||
while True:
|
||||
#await trio.sleep(300)
|
||||
await trio.sleep(1)
|
||||
logger.debug("Sending keep alive message")
|
||||
await self._send(0x0C, [])
|
||||
await self._receive(0x0D)
|
||||
logger.debug("Got keep alive response")
|
||||
|
||||
@asynccontextmanager
|
||||
async def start(credentials: PushCredentials | None = None):
|
||||
"""Sets up a nursery and connection and yields the connection"""
|
||||
async with trio.open_nursery() as nursery:
|
||||
connection = APNSConnection(nursery, credentials)
|
||||
await connection.connect()
|
||||
yield connection
|
||||
nursery.cancel_scope.cancel() # Cancel heartbeat and queue filler tasks
|
||||
await connection.sock.aclose() # Close the socket
|
||||
|
||||
def __init__(self, nursery: trio.Nursery, credentials: PushCredentials | None = None):
|
||||
self._nursery = nursery
|
||||
self.credentials = credentials
|
||||
|
||||
async def connect(self):
|
||||
"""Connects to the APNs server and starts the keep alive and queue filler tasks"""
|
||||
sock = await trio.open_tcp_stream(COURIER_HOST, COURIER_PORT)
|
||||
|
||||
# Connect to the courier server
|
||||
def _connect(private_key: str, cert: str) -> ssl.SSLSocket:
|
||||
# Connect to the courier server
|
||||
sock = socket.create_connection((COURIER_HOST, COURIER_PORT))
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS)
|
||||
context.set_alpn_protocols(["apns-security-v3"])
|
||||
# Wrap the socket in TLS
|
||||
ssock = context.wrap_socket(sock)
|
||||
# Handshake with the server
|
||||
ssock.do_handshake()
|
||||
|
||||
self.sock = trio.SSLStream(sock, context, server_hostname=COURIER_HOST)
|
||||
|
||||
await self.sock.do_handshake()
|
||||
|
||||
logger.info(f"Connected to APNs ({COURIER_HOST})")
|
||||
|
||||
return ssock
|
||||
if self.credentials is None:
|
||||
self.credentials = PushCredentials(*albert.generate_push_cert(), None)
|
||||
|
||||
# Start the queue filler and keep alive tasks
|
||||
self._nursery.start_soon(self._queue_filler)
|
||||
self._nursery.start_soon(self._keep_alive)
|
||||
|
||||
class IncomingQueue:
|
||||
def __init__(self):
|
||||
self.queue = []
|
||||
self.lock = threading.Lock()
|
||||
self.credentials.token = await self._connect(self.credentials.token)
|
||||
|
||||
def append(self, item):
|
||||
with self.lock:
|
||||
self.queue.append(item)
|
||||
|
||||
def pop(self, index = -1):
|
||||
with self.lock:
|
||||
return self.queue.pop(index)
|
||||
|
||||
def __getitem__(self, index):
|
||||
with self.lock:
|
||||
return self.queue[index]
|
||||
|
||||
def __len__(self):
|
||||
with self.lock:
|
||||
return len(self.queue)
|
||||
|
||||
def find(self, finder):
|
||||
with self.lock:
|
||||
return next((i for i in self.queue if finder(i)), None)
|
||||
|
||||
def pop_find(self, finder):
|
||||
with self.lock:
|
||||
found = next((i for i in self.queue if finder(i)), None)
|
||||
if found is not None:
|
||||
# We have the lock, so we can safely remove it
|
||||
self.queue.remove(found)
|
||||
return found
|
||||
|
||||
def remove_all(self, id):
|
||||
with self.lock:
|
||||
self.queue = [i for i in self.queue if i[0] != id]
|
||||
|
||||
def wait_pop_find(self, finder, delay=0.1):
|
||||
found = None
|
||||
while found is None:
|
||||
found = self.pop_find(finder)
|
||||
if found is None:
|
||||
time.sleep(delay)
|
||||
return found
|
||||
|
||||
|
||||
class APNSConnection:
|
||||
incoming_queue = IncomingQueue()
|
||||
|
||||
# Sink everything in the queue
|
||||
def sink(self):
|
||||
self.incoming_queue = IncomingQueue()
|
||||
|
||||
def _queue_filler(self):
|
||||
while True:
|
||||
payload = _deserialize_payload(self.sock)
|
||||
|
||||
if payload is not None:
|
||||
# Automatically ACK incoming notifications to prevent APNs from getting mad at us
|
||||
if payload[0] == 0x0A:
|
||||
logger.debug("Sending automatic ACK")
|
||||
self._send_ack(_get_field(payload[1], 4))
|
||||
logger.debug(f"Received payload: {payload}")
|
||||
self.incoming_queue.append(payload)
|
||||
logger.debug(f"Queue length: {len(self.incoming_queue)}")
|
||||
|
||||
def _keep_alive_loop(self):
|
||||
while True:
|
||||
time.sleep(300)
|
||||
self._keep_alive()
|
||||
|
||||
def __init__(self, private_key=None, cert=None):
|
||||
# Generate the private key and certificate if they're not provided
|
||||
if private_key is None or cert is None:
|
||||
logger.debug("APNs needs a new push certificate")
|
||||
self.private_key, self.cert = albert.generate_push_cert()
|
||||
else:
|
||||
self.private_key, self.cert = private_key, cert
|
||||
|
||||
self.sock = _connect(self.private_key, self.cert)
|
||||
|
||||
self.queue_filler_thread = threading.Thread(
|
||||
target=self._queue_filler, daemon=True
|
||||
)
|
||||
self.queue_filler_thread.start()
|
||||
|
||||
self.keep_alive_thread = threading.Thread(
|
||||
target=self._keep_alive_loop, daemon=True
|
||||
)
|
||||
self.keep_alive_thread.start()
|
||||
|
||||
|
||||
def connect(self, root: bool = True, token: bytes = None):
|
||||
async def _connect(self, token: bytes | None = None, root: bool = False) -> bytes:
|
||||
"""Sends the APNs connect message"""
|
||||
# Parse self.certificate
|
||||
from cryptography import x509
|
||||
cert = x509.load_pem_x509_certificate(self.cert.encode())
|
||||
cert = x509.load_pem_x509_certificate(self.credentials.cert.encode())
|
||||
# Parse private key
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
private_key = serialization.load_pem_private_key(self.private_key.encode(), password=None)
|
||||
private_key = serialization.load_pem_private_key(self.credentials.private_key.encode(), password=None)
|
||||
|
||||
if token is None:
|
||||
logger.debug(f"Sending connect message without token (root={root})")
|
||||
|
@ -153,48 +167,28 @@ class APNSConnection:
|
|||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
signature = b"\x01\x01" + private_key.sign(nonce, padding.PKCS1v15(), hashes.SHA1())
|
||||
|
||||
|
||||
if token is None:
|
||||
payload = _serialize_payload(
|
||||
7, [(2, 0x01.to_bytes(1, "big")), (5, flags.to_bytes(4, "big")), (0x0c, cert), (0x0d, nonce), (0x0e, signature)]
|
||||
)
|
||||
else:
|
||||
payload = _serialize_payload(
|
||||
7,
|
||||
[
|
||||
(1, token),
|
||||
(2, 0x01.to_bytes(1, "big")),
|
||||
payload = [
|
||||
(2, b"\x01"),
|
||||
(5, flags.to_bytes(4, "big")),
|
||||
(0x0c, cert),
|
||||
(0x0d, nonce),
|
||||
(0x0e, signature),
|
||||
]
|
||||
if token is not None:
|
||||
payload.insert(0, (1, token))
|
||||
|
||||
],
|
||||
)
|
||||
await self._send(7, payload)
|
||||
|
||||
#self.sock.write(payload)
|
||||
self.sock.sendall(payload)
|
||||
payload = await self._receive(8)
|
||||
|
||||
payload = self.incoming_queue.wait_pop_find(lambda i: i[0] == 8)
|
||||
|
||||
if (
|
||||
payload == None
|
||||
or payload[0] != 8
|
||||
or _get_field(payload[1], 1) != 0x00.to_bytes(1, "big")
|
||||
):
|
||||
if _get_field(payload[1], 1) != b"\x00":
|
||||
raise Exception("Failed to connect")
|
||||
|
||||
new_token = _get_field(payload[1], 3)
|
||||
if new_token is not None:
|
||||
self.token = new_token
|
||||
elif token is not None:
|
||||
self.token = token
|
||||
else:
|
||||
raise Exception("No token")
|
||||
|
||||
logger.debug(f"Recieved connect response with token {b64encode(self.token).decode()}")
|
||||
logger.debug(f"Recieved connect response with token {b64encode(new_token).decode()}")
|
||||
|
||||
return self.token
|
||||
return new_token
|
||||
|
||||
def filter(self, topics: list[str]):
|
||||
logger.debug(f"Sending filter message with topics {topics}")
|
||||
|
@ -230,20 +224,9 @@ class APNSConnection:
|
|||
if payload[1][0][1] != 0x00.to_bytes(1, "big"):
|
||||
raise Exception("Failed to send message")
|
||||
|
||||
def set_state(self, state: int):
|
||||
async def set_state(self, state: int):
|
||||
logger.debug(f"Sending state message with state {state}")
|
||||
self.sock.sendall(
|
||||
_serialize_payload(
|
||||
0x14,
|
||||
[(1, state.to_bytes(1, "big")), (2, 0x7FFFFFFF.to_bytes(4, "big"))],
|
||||
)
|
||||
)
|
||||
|
||||
def _keep_alive(self):
|
||||
logger.debug("Sending keep alive message")
|
||||
self.sock.sendall(_serialize_payload(0x0C, []))
|
||||
# Remove any keep alive responses we have or missed
|
||||
self.incoming_queue.remove_all(0x0D)
|
||||
await self._send(0x14, [(1, b"\x01"), (2, 0x7FFFFFFF.to_bytes(4, "big"))])
|
||||
|
||||
|
||||
def _send_ack(self, id: bytes):
|
||||
|
@ -287,15 +270,18 @@ def _deserialize_field(stream: bytes) -> tuple[int, bytes]:
|
|||
|
||||
# Note: Takes a stream, not a buffer, as we do not know the length of the payload
|
||||
# WILL BLOCK IF THE STREAM IS EMPTY
|
||||
def _deserialize_payload(stream: ssl.SSLSocket) -> tuple[int, list[tuple[int, bytes]]] | None:
|
||||
id = int.from_bytes(stream.read(1), "big")
|
||||
async def _deserialize_payload(stream: trio.SSLStream) -> tuple[int, list[tuple[int, bytes]]] | None:
|
||||
id = int.from_bytes(await stream.receive_some(1), "big")
|
||||
|
||||
if id == 0x0:
|
||||
return None
|
||||
|
||||
length = int.from_bytes(stream.read(4), "big")
|
||||
length = int.from_bytes(await stream.receive_some(4), "big")
|
||||
|
||||
buffer = stream.read(length)
|
||||
if length == 0:
|
||||
return id, []
|
||||
|
||||
buffer = await stream.receive_some(length)
|
||||
|
||||
fields = []
|
||||
|
||||
|
@ -338,3 +324,6 @@ def _get_field(fields: list[tuple[int, bytes]], id: int) -> bytes:
|
|||
if field_id == id:
|
||||
return value
|
||||
return None
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in a new issue