pypush-plus-plus/apns.py

180 lines
5.5 KiB
Python
Raw Normal View History

2023-04-05 18:52:14 -05:00
from __future__ import annotations
import courier, albert
2023-04-06 09:38:29 -05:00
from hashlib import sha1
2023-04-07 21:32:00 -05:00
import threading
import time
import random
class APNSConnection:
2023-04-07 18:53:21 -05:00
incoming_queue = []
def _queue_filler(self):
while True and not self.sock.closed:
#print(self.sock.closed)
#print("QUEUE: Waiting for payload...")
#self.sock.read(1)
print("QUEUE: Got payload?")
payload = _deserialize_payload(self.sock)
#print("QUEUE: Got payload?")
if payload is not None:
print("QUEUE: Received payload: " + str(payload))
self.incoming_queue.append(payload)
print("QUEUE: Thread ended")
def _pop_by_id(self, id: int) -> tuple[int, list[tuple[int, bytes]]] | None:
print("QUEUE: Looking for id " + str(id) + " in " + str(self.incoming_queue))
for i in range(len(self.incoming_queue)):
if self.incoming_queue[i][0] == id:
return self.incoming_queue.pop(i)
return None
2023-04-07 21:32:00 -05:00
def wait_for_packet(self, id: int) -> tuple[int, list[tuple[int, bytes]]]:
payload = self._pop_by_id(id)
while payload is None:
payload = self._pop_by_id(id)
time.sleep(0.1)
return payload
def __init__(self, private_key=None, cert=None):
# Generate the private key and certificate if they're not provided
if private_key is None or cert is None:
self.private_key, self.cert = albert.generate_push_cert()
2023-04-05 20:01:07 -05:00
else:
self.private_key, self.cert = private_key, cert
2023-04-05 20:01:07 -05:00
self.sock = courier.connect(self.private_key, self.cert)
2023-04-05 20:01:07 -05:00
2023-04-07 18:53:21 -05:00
# Start the queue filler thread
2023-04-07 21:32:00 -05:00
self.queue_filler_thread = threading.Thread(target=self._queue_filler, daemon=True)
2023-04-07 18:53:21 -05:00
self.queue_filler_thread.start()
2023-04-07 15:24:05 -05:00
def connect(self, root: bool = True, token: bytes = None):
flags = 0b01000001
if root:
flags |= 0b0100
if token is None:
2023-04-07 15:24:05 -05:00
payload = _serialize_payload(7, [(2, 0x01.to_bytes()), (5, flags.to_bytes(4))])
else:
2023-04-07 15:24:05 -05:00
payload = _serialize_payload(7, [(1, token), (2, 0x01.to_bytes()), (5, flags.to_bytes(4))])
self.sock.write(payload)
2023-04-07 21:32:00 -05:00
payload = self.wait_for_packet(8)
if payload == None or payload[0] != 8 or _get_field(payload[1], 1) != 0x00.to_bytes():
2023-04-05 20:01:07 -05:00
raise Exception("Failed to connect")
self.token = _get_field(payload[1], 3)
2023-04-05 20:01:07 -05:00
2023-04-07 15:24:05 -05:00
return self.token
2023-04-06 09:38:29 -05:00
def filter(self, topics: list[str]):
fields = [(1, self.token)]
for topic in topics:
fields.append((2, sha1(topic.encode()).digest()))
payload = _serialize_payload(9, fields)
self.sock.write(payload)
2023-04-07 21:32:00 -05:00
def send_message(self, topic: str, payload: str):
2023-04-07 15:24:05 -05:00
payload = _serialize_payload(0x0a,
2023-04-07 21:32:00 -05:00
[(4, random.randbytes(4)),
2023-04-07 15:24:05 -05:00
(1, sha1(topic.encode()).digest()),
2023-04-07 21:32:00 -05:00
(2, self.token),
(3, payload)])
#print(payload)
2023-04-07 00:48:07 -05:00
self.sock.write(payload)
2023-04-07 21:32:00 -05:00
payload = self.wait_for_packet(0x0b)
#payload = _deserialize_payload(self.sock)
2023-04-07 00:48:07 -05:00
print(payload)
2023-04-07 18:53:21 -05:00
def set_state(self, state: int):
self.sock.write(_serialize_payload(0x14, [(1, state.to_bytes(1)), (2, 0x7FFFFFFF.to_bytes(4))]))
def keep_alive(self):
self.sock.write(_serialize_payload(0x0c, []))
2023-04-07 00:48:07 -05:00
# TODO: Find a way to make this non-blocking
2023-04-07 21:32:00 -05:00
#def expect_message(self) -> tuple[int, list[tuple[int, bytes]]] | None:
# return _deserialize_payload(self.sock)
def _serialize_field(id: int, value: bytes) -> bytes:
return id.to_bytes() + len(value).to_bytes(2, "big") + value
def _serialize_payload(id: int, fields: list[(int, bytes)]) -> bytes:
payload = b""
for fid, value in fields:
if fid is not None:
payload += _serialize_field(fid, value)
return id.to_bytes() + len(payload).to_bytes(4, "big") + payload
def _deserialize_field(stream: bytes) -> tuple[int, bytes]:
id = int.from_bytes(stream[:1], "big")
length = int.from_bytes(stream[1:3], "big")
value = stream[3 : 3 + length]
return id, value
# Note: Takes a stream, not a buffer, as we do not know the length of the payload
# WILL BLOCK IF THE STREAM IS EMPTY
def _deserialize_payload(stream) -> tuple[int, list[tuple[int, bytes]]] | None:
id = int.from_bytes(stream.read(1), "big")
if id == 0x0:
return None
length = int.from_bytes(stream.read(4), "big")
buffer = stream.read(length)
fields = []
while len(buffer) > 0:
fid, value = _deserialize_field(buffer)
fields.append((fid, value))
buffer = buffer[3 + len(value) :]
return id, fields
def _deserialize_payload_from_buffer(buffer: bytes) -> tuple[int, list[tuple[int, bytes]]] | None:
id = int.from_bytes(buffer[:1], "big")
if id == 0x0:
return None
length = int.from_bytes(buffer[1:5], "big")
buffer = buffer[5:]
if len(buffer) < length:
raise Exception("Buffer is too short")
fields = []
while len(buffer) > 0:
fid, value = _deserialize_field(buffer)
fields.append((fid, value))
buffer = buffer[3 + len(value) :]
return id, fields
# Returns the value of the first field with the given id
def _get_field(fields: list[tuple[int, bytes]], id: int) -> bytes:
for field_id, value in fields:
if field_id == id:
return value
return None