majorly refactor apnsconnection to be more usable

This commit is contained in:
JJTech0130 2023-04-06 13:19:13 -04:00
parent 0a901fb7a6
commit 6e3601de58
No known key found for this signature in database
GPG key ID: 23C92EBCCF8F93D6
3 changed files with 100 additions and 163 deletions

208
apns.py
View file

@ -1,149 +1,89 @@
from __future__ import annotations from __future__ import annotations
class Fields: import courier, albert
@staticmethod
def from_bytes(data: bytes) -> Fields:
fields = {}
while len(data) > 0:
field = data[0]
length = int.from_bytes(data[1:3], "big")
value = data[3:3 + length]
fields[field] = value
data = data[3 + length:]
return Fields(fields)
def __init__(self, fields: dict[int, bytes]):
self.fields = fields
def to_bytes(self) -> bytes:
buffer = bytearray()
for field, value in self.fields.items():
buffer.append(field)
buffer.extend(len(value).to_bytes(2, "big"))
buffer.extend(value)
return buffer
# Debug formating
def __str__(self) -> str:
return f"{self.fields}"
# Define number to command name mapping
COMMANDS = {
0x7: "Connect",
0x8: "ConnectResponse",
0x9: "PushTopics",
0x0A: "PushNotification",
0x0B: "Acknowledge",
}
class Payload:
@staticmethod
def from_stream(stream) -> Payload|None:
command = int.from_bytes(stream.read(1), "big")
if command == 0:
return None # We reached the end of the stream
length = int.from_bytes(stream.read(4), "big")
fields = Fields.from_bytes(stream.read(length))
return Payload(command, fields)
@staticmethod
def from_bytes(data: bytes) -> Payload:
# Convert it to bytes for cleaner printing
data = bytes(data)
command = data[0]
length = int.from_bytes(data[1:5], "big")
fields = Fields.from_bytes(data[5:5 + length])
return Payload(command, fields)
def __init__(self, command: int, fields: Fields):
self.command = command
self.fields = fields
def to_bytes(self) -> bytes:
buffer = bytearray()
buffer.append(self.command)
fields = self.fields.to_bytes()
buffer.extend(len(fields).to_bytes(4, "big"))
buffer.extend(fields)
return buffer
# Debug formating
def __str__(self) -> str:
return f"{COMMANDS[self.command]}: {self.fields}"
import courier
from hashlib import sha1 from hashlib import sha1
class APNSConnection():
def __init__(self, token: bytes=None, private_key=None, cert=None):
self.sock, self.private_key, self.cert = courier.connect(private_key, cert)
self.token = token
self._connect() def _serialize_field(id: int, value: bytes) -> bytes:
return id.to_bytes() + len(value).to_bytes(2, "big") + value
def _connect(self):
if self.token is None:
payload = Payload(7, Fields({2: 0x01.to_bytes()})) def _serialize_payload(id: int, fields: list[(int, bytes)]) -> bytes:
payload = b""
for fid, value in fields:
payload += _serialize_field(fid, value)
return id.to_bytes() + len(payload).to_bytes(4, "big") + payload
def _deserialize_field(stream: bytes) -> tuple[int, bytes]:
id = int.from_bytes(stream[:1], "big")
length = int.from_bytes(stream[1:3], "big")
value = stream[3 : 3 + length]
return id, value
# Note: Takes a stream, not a buffer, as we do not know the length of the payload
def _deserialize_payload(stream) -> tuple[int, list[tuple[int, bytes]]] | None:
id = int.from_bytes(stream.read(1), "big")
if id == 0x0:
return None
length = int.from_bytes(stream.read(4), "big")
buffer = stream.read(length)
fields = []
while len(buffer) > 0:
fid, value = _deserialize_field(buffer)
fields.append((fid, value))
buffer = buffer[3 + len(value) :]
return id, fields
# Returns the value of the first field with the given id
def _get_field(fields: list[tuple[int, bytes]], id: int) -> bytes:
for field_id, value in fields:
if field_id == id:
return value
return None
class APNSConnection:
def __init__(self, private_key=None, cert=None):
# Generate the private key and certificate if they're not provided
if private_key is None or cert is None:
self.private_key, self.cert = albert.generate_push_cert()
else: else:
payload = Payload(7, Fields({1: self.token, 2: 0x01.to_bytes()})) self.private_key, self.cert = private_key, cert
self.sock.write(payload.to_bytes())
resp = Payload.from_stream(self.sock) self.sock = courier.connect(self.private_key, self.cert)
if resp.command != 8 or resp.fields.fields[1] != 0x00.to_bytes(): def connect(self, token: bytes = None):
if token is None:
payload = _serialize_payload(7, [(2, 0x01.to_bytes())])
else:
payload = _serialize_payload(7, [(1, token), (2, 0x01.to_bytes())])
self.sock.write(payload)
payload = _deserialize_payload(self.sock)
if payload == None or payload[0] != 8 or _get_field(payload[1], 1) != 0x00.to_bytes():
raise Exception("Failed to connect") raise Exception("Failed to connect")
if 3 in resp.fields.fields: self.token = _get_field(payload[1], 3)
self.token = resp.fields.fields[3]
def filter(self, topics: list[str]): def filter(self, topics: list[str]):
payload = Payload(9, Fields({1: self.token, 2: b"".join([sha1(topic.encode()).digest() for topic in topics])})) fields = [(1, self.token)]
self.sock.write(payload.to_bytes()) for topic in topics:
fields.append((2, sha1(topic.encode()).digest()))
payload = _serialize_payload(9, fields)
if __name__ == "__main__": self.sock.write(payload)
import courier
import base64
sock = courier.connect()
# Try and read the token from the file
try:
with open("token", "r") as f:
r = f.read()
if r == "":
raise FileNotFoundError
payload = Payload(7, Fields({1: base64.b64decode(r), 2: 0x01.to_bytes()}))
except FileNotFoundError:
payload = Payload(7, Fields({2: 0x01.to_bytes()}))
# Send the connect request (with or without the token)
sock.write(payload.to_bytes())
# Read the response
resp = Payload.from_stream(sock)
# Check if the response is valid
if resp.command != 8 or resp.fields.fields[1] != 0x00.to_bytes():
raise Exception("Failed to connect")
# If there's a new token, save it
if 3 in resp.fields.fields:
with open("token", "wb") as f:
f.write(base64.b64encode(resp.fields.fields[3]))
# Send the push topics request

View file

@ -1,28 +1,23 @@
import albert
import tlslite import tlslite
import socket import socket
COURIER_HOST = "10-courier.push.apple.com" COURIER_HOST = "10-courier.push.apple.com" # TODO: Get this from config
COURIER_PORT = 5223 COURIER_PORT = 5223
ALPN = [b"apns-security-v2"] ALPN = [b"apns-security-v2"]
#ALPN = None
def connect(private_key=None, cert=None):
# If we don't have a private key or certificate, generate one
if private_key is None or cert is None:
private_key, cert = albert.generate_push_cert()
# Connect to the courier server
def connect(private_key: str, cert: str) -> tlslite.TLSConnection:
# Connect to the courier server # Connect to the courier server
sock = socket.create_connection((COURIER_HOST, COURIER_PORT)) sock = socket.create_connection((COURIER_HOST, COURIER_PORT))
# Wrap the socket in TLS # Wrap the socket in TLS
sock = tlslite.TLSConnection(sock) sock = tlslite.TLSConnection(sock)
# Parse the certificate and private key # Parse the certificate and private key
cert_parsed = tlslite.X509CertChain([tlslite.X509().parse(cert)]) cert = tlslite.X509CertChain([tlslite.X509().parse(cert)])
private_key_parsed = tlslite.parsePEMKey(private_key, private=True) private_key = tlslite.parsePEMKey(private_key, private=True)
# Handshake with the server # Handshake with the server
sock.handshakeClientCert(cert_parsed, private_key_parsed, alpn=ALPN) sock.handshakeClientCert(cert, private_key, alpn=ALPN)
return sock, private_key, cert return sock
if __name__ == "__main__": if __name__ == "__main__":
sock = connect() sock = connect()

36
demo.py
View file

@ -3,33 +3,35 @@ from base64 import b64decode, b64encode
from hashlib import sha1 from hashlib import sha1
conn1 = apns.APNSConnection() conn1 = apns.APNSConnection()
conn1.connect()
print(f"Push Token 1: {b64encode(conn1.token).decode()}") print(f"Push Token 1: {b64encode(conn1.token).decode()}")
conn2 = apns.APNSConnection() conn2 = apns.APNSConnection()
conn2.connect()
print(f"Push Token 2: {b64encode(conn2.token).decode()}") print(f"Push Token 2: {b64encode(conn2.token).decode()}")
conn1.filter(["com.apple.madrid"]) conn1.filter(["com.apple.madrid"])
conn2.filter(["com.apple.madrid"]) conn2.filter(["com.apple.madrid"])
#print(sha1(b"com.apple.madrid").digest()) # #print(sha1(b"com.apple.madrid").digest())
# Send a notification # # Send a notification
# expiry timestamp in UNIX epoch # # expiry timestamp in UNIX epoch
expiry = 1680761868 # expiry = 1680761868
expiry = expiry.to_bytes(4, "big") # expiry = expiry.to_bytes(4, "big")
# Current time in UNIX nano epoch # # Current time in UNIX nano epoch
import time # import time
now = int(time.time() * 1000).to_bytes(8, "big") # now = int(time.time() * 1000).to_bytes(8, "big")
payload = apns.Payload(0x0a, apns.Fields({1: sha1(b"com.apple.madrid").digest(), 2: conn2.token, 3: b"Hello World!", 4: 0x00.to_bytes(), 5: expiry, 6: now, 7: 0x00.to_bytes()})) # payload = apns.Payload(0x0a, apns.Fields({1: sha1(b"com.apple.madrid").digest(), 2: conn2.token, 3: b"Hello World!", 4: 0x00.to_bytes(), 5: expiry, 6: now, 7: 0x00.to_bytes()}))
conn1.sock.write(payload.to_bytes()) # conn1.sock.write(payload.to_bytes())
print("Waiting for response...") # print("Waiting for response...")
# Check if the notification was sent # # Check if the notification was sent
resp = apns.Payload.from_stream(conn1.sock) # resp = apns.Payload.from_stream(conn1.sock)
print(resp) # print(resp)
# Read the message from the other connection # # Read the message from the other connection
resp = apns.Payload.from_stream(conn2.sock) # resp = apns.Payload.from_stream(conn2.sock)
print(resp) # print(resp)