mirror of
https://github.com/Sneed-Group/pypush-plus-plus
synced 2024-10-30 08:27:52 +00:00
majorly refactor apnsconnection to be more usable
This commit is contained in:
parent
0a901fb7a6
commit
6e3601de58
3 changed files with 100 additions and 163 deletions
208
apns.py
208
apns.py
|
@ -1,149 +1,89 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
class Fields:
|
import courier, albert
|
||||||
@staticmethod
|
|
||||||
def from_bytes(data: bytes) -> Fields:
|
|
||||||
fields = {}
|
|
||||||
|
|
||||||
while len(data) > 0:
|
|
||||||
field = data[0]
|
|
||||||
length = int.from_bytes(data[1:3], "big")
|
|
||||||
value = data[3:3 + length]
|
|
||||||
|
|
||||||
fields[field] = value
|
|
||||||
|
|
||||||
data = data[3 + length:]
|
|
||||||
|
|
||||||
return Fields(fields)
|
|
||||||
|
|
||||||
def __init__(self, fields: dict[int, bytes]):
|
|
||||||
self.fields = fields
|
|
||||||
|
|
||||||
def to_bytes(self) -> bytes:
|
|
||||||
buffer = bytearray()
|
|
||||||
|
|
||||||
for field, value in self.fields.items():
|
|
||||||
buffer.append(field)
|
|
||||||
buffer.extend(len(value).to_bytes(2, "big"))
|
|
||||||
buffer.extend(value)
|
|
||||||
|
|
||||||
return buffer
|
|
||||||
|
|
||||||
# Debug formating
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return f"{self.fields}"
|
|
||||||
|
|
||||||
# Define number to command name mapping
|
|
||||||
COMMANDS = {
|
|
||||||
0x7: "Connect",
|
|
||||||
0x8: "ConnectResponse",
|
|
||||||
0x9: "PushTopics",
|
|
||||||
0x0A: "PushNotification",
|
|
||||||
0x0B: "Acknowledge",
|
|
||||||
}
|
|
||||||
|
|
||||||
class Payload:
|
|
||||||
@staticmethod
|
|
||||||
def from_stream(stream) -> Payload|None:
|
|
||||||
command = int.from_bytes(stream.read(1), "big")
|
|
||||||
if command == 0:
|
|
||||||
return None # We reached the end of the stream
|
|
||||||
length = int.from_bytes(stream.read(4), "big")
|
|
||||||
fields = Fields.from_bytes(stream.read(length))
|
|
||||||
|
|
||||||
return Payload(command, fields)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_bytes(data: bytes) -> Payload:
|
|
||||||
# Convert it to bytes for cleaner printing
|
|
||||||
data = bytes(data)
|
|
||||||
command = data[0]
|
|
||||||
length = int.from_bytes(data[1:5], "big")
|
|
||||||
fields = Fields.from_bytes(data[5:5 + length])
|
|
||||||
|
|
||||||
return Payload(command, fields)
|
|
||||||
|
|
||||||
def __init__(self, command: int, fields: Fields):
|
|
||||||
self.command = command
|
|
||||||
self.fields = fields
|
|
||||||
|
|
||||||
def to_bytes(self) -> bytes:
|
|
||||||
buffer = bytearray()
|
|
||||||
|
|
||||||
buffer.append(self.command)
|
|
||||||
|
|
||||||
fields = self.fields.to_bytes()
|
|
||||||
|
|
||||||
buffer.extend(len(fields).to_bytes(4, "big"))
|
|
||||||
buffer.extend(fields)
|
|
||||||
|
|
||||||
return buffer
|
|
||||||
|
|
||||||
# Debug formating
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return f"{COMMANDS[self.command]}: {self.fields}"
|
|
||||||
|
|
||||||
import courier
|
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
|
|
||||||
class APNSConnection():
|
|
||||||
def __init__(self, token: bytes=None, private_key=None, cert=None):
|
|
||||||
self.sock, self.private_key, self.cert = courier.connect(private_key, cert)
|
|
||||||
self.token = token
|
|
||||||
|
|
||||||
self._connect()
|
def _serialize_field(id: int, value: bytes) -> bytes:
|
||||||
|
return id.to_bytes() + len(value).to_bytes(2, "big") + value
|
||||||
def _connect(self):
|
|
||||||
if self.token is None:
|
|
||||||
payload = Payload(7, Fields({2: 0x01.to_bytes()}))
|
def _serialize_payload(id: int, fields: list[(int, bytes)]) -> bytes:
|
||||||
|
payload = b""
|
||||||
|
|
||||||
|
for fid, value in fields:
|
||||||
|
payload += _serialize_field(fid, value)
|
||||||
|
|
||||||
|
return id.to_bytes() + len(payload).to_bytes(4, "big") + payload
|
||||||
|
|
||||||
|
|
||||||
|
def _deserialize_field(stream: bytes) -> tuple[int, bytes]:
|
||||||
|
id = int.from_bytes(stream[:1], "big")
|
||||||
|
length = int.from_bytes(stream[1:3], "big")
|
||||||
|
value = stream[3 : 3 + length]
|
||||||
|
return id, value
|
||||||
|
|
||||||
|
|
||||||
|
# Note: Takes a stream, not a buffer, as we do not know the length of the payload
|
||||||
|
def _deserialize_payload(stream) -> tuple[int, list[tuple[int, bytes]]] | None:
|
||||||
|
id = int.from_bytes(stream.read(1), "big")
|
||||||
|
|
||||||
|
if id == 0x0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
length = int.from_bytes(stream.read(4), "big")
|
||||||
|
|
||||||
|
buffer = stream.read(length)
|
||||||
|
|
||||||
|
fields = []
|
||||||
|
|
||||||
|
while len(buffer) > 0:
|
||||||
|
fid, value = _deserialize_field(buffer)
|
||||||
|
fields.append((fid, value))
|
||||||
|
buffer = buffer[3 + len(value) :]
|
||||||
|
|
||||||
|
return id, fields
|
||||||
|
|
||||||
|
|
||||||
|
# Returns the value of the first field with the given id
|
||||||
|
def _get_field(fields: list[tuple[int, bytes]], id: int) -> bytes:
|
||||||
|
for field_id, value in fields:
|
||||||
|
if field_id == id:
|
||||||
|
return value
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class APNSConnection:
|
||||||
|
def __init__(self, private_key=None, cert=None):
|
||||||
|
# Generate the private key and certificate if they're not provided
|
||||||
|
if private_key is None or cert is None:
|
||||||
|
self.private_key, self.cert = albert.generate_push_cert()
|
||||||
else:
|
else:
|
||||||
payload = Payload(7, Fields({1: self.token, 2: 0x01.to_bytes()}))
|
self.private_key, self.cert = private_key, cert
|
||||||
|
|
||||||
self.sock.write(payload.to_bytes())
|
|
||||||
|
|
||||||
resp = Payload.from_stream(self.sock)
|
self.sock = courier.connect(self.private_key, self.cert)
|
||||||
|
|
||||||
if resp.command != 8 or resp.fields.fields[1] != 0x00.to_bytes():
|
def connect(self, token: bytes = None):
|
||||||
|
if token is None:
|
||||||
|
payload = _serialize_payload(7, [(2, 0x01.to_bytes())])
|
||||||
|
else:
|
||||||
|
payload = _serialize_payload(7, [(1, token), (2, 0x01.to_bytes())])
|
||||||
|
|
||||||
|
self.sock.write(payload)
|
||||||
|
|
||||||
|
payload = _deserialize_payload(self.sock)
|
||||||
|
|
||||||
|
if payload == None or payload[0] != 8 or _get_field(payload[1], 1) != 0x00.to_bytes():
|
||||||
raise Exception("Failed to connect")
|
raise Exception("Failed to connect")
|
||||||
|
|
||||||
if 3 in resp.fields.fields:
|
self.token = _get_field(payload[1], 3)
|
||||||
self.token = resp.fields.fields[3]
|
|
||||||
|
|
||||||
def filter(self, topics: list[str]):
|
def filter(self, topics: list[str]):
|
||||||
payload = Payload(9, Fields({1: self.token, 2: b"".join([sha1(topic.encode()).digest() for topic in topics])}))
|
fields = [(1, self.token)]
|
||||||
|
|
||||||
self.sock.write(payload.to_bytes())
|
for topic in topics:
|
||||||
|
fields.append((2, sha1(topic.encode()).digest()))
|
||||||
|
|
||||||
|
payload = _serialize_payload(9, fields)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
self.sock.write(payload)
|
||||||
import courier
|
|
||||||
import base64
|
|
||||||
|
|
||||||
sock = courier.connect()
|
|
||||||
|
|
||||||
# Try and read the token from the file
|
|
||||||
try:
|
|
||||||
with open("token", "r") as f:
|
|
||||||
r = f.read()
|
|
||||||
if r == "":
|
|
||||||
raise FileNotFoundError
|
|
||||||
payload = Payload(7, Fields({1: base64.b64decode(r), 2: 0x01.to_bytes()}))
|
|
||||||
except FileNotFoundError:
|
|
||||||
payload = Payload(7, Fields({2: 0x01.to_bytes()}))
|
|
||||||
|
|
||||||
# Send the connect request (with or without the token)
|
|
||||||
sock.write(payload.to_bytes())
|
|
||||||
|
|
||||||
# Read the response
|
|
||||||
resp = Payload.from_stream(sock)
|
|
||||||
# Check if the response is valid
|
|
||||||
if resp.command != 8 or resp.fields.fields[1] != 0x00.to_bytes():
|
|
||||||
raise Exception("Failed to connect")
|
|
||||||
|
|
||||||
# If there's a new token, save it
|
|
||||||
if 3 in resp.fields.fields:
|
|
||||||
with open("token", "wb") as f:
|
|
||||||
f.write(base64.b64encode(resp.fields.fields[3]))
|
|
||||||
|
|
||||||
# Send the push topics request
|
|
||||||
|
|
19
courier.py
19
courier.py
|
@ -1,28 +1,23 @@
|
||||||
import albert
|
|
||||||
import tlslite
|
import tlslite
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
COURIER_HOST = "10-courier.push.apple.com"
|
COURIER_HOST = "10-courier.push.apple.com" # TODO: Get this from config
|
||||||
COURIER_PORT = 5223
|
COURIER_PORT = 5223
|
||||||
ALPN = [b"apns-security-v2"]
|
ALPN = [b"apns-security-v2"]
|
||||||
#ALPN = None
|
|
||||||
|
|
||||||
def connect(private_key=None, cert=None):
|
|
||||||
# If we don't have a private key or certificate, generate one
|
|
||||||
if private_key is None or cert is None:
|
|
||||||
private_key, cert = albert.generate_push_cert()
|
|
||||||
|
|
||||||
|
# Connect to the courier server
|
||||||
|
def connect(private_key: str, cert: str) -> tlslite.TLSConnection:
|
||||||
# Connect to the courier server
|
# Connect to the courier server
|
||||||
sock = socket.create_connection((COURIER_HOST, COURIER_PORT))
|
sock = socket.create_connection((COURIER_HOST, COURIER_PORT))
|
||||||
# Wrap the socket in TLS
|
# Wrap the socket in TLS
|
||||||
sock = tlslite.TLSConnection(sock)
|
sock = tlslite.TLSConnection(sock)
|
||||||
# Parse the certificate and private key
|
# Parse the certificate and private key
|
||||||
cert_parsed = tlslite.X509CertChain([tlslite.X509().parse(cert)])
|
cert = tlslite.X509CertChain([tlslite.X509().parse(cert)])
|
||||||
private_key_parsed = tlslite.parsePEMKey(private_key, private=True)
|
private_key = tlslite.parsePEMKey(private_key, private=True)
|
||||||
# Handshake with the server
|
# Handshake with the server
|
||||||
sock.handshakeClientCert(cert_parsed, private_key_parsed, alpn=ALPN)
|
sock.handshakeClientCert(cert, private_key, alpn=ALPN)
|
||||||
|
|
||||||
return sock, private_key, cert
|
return sock
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
sock = connect()
|
sock = connect()
|
||||||
|
|
36
demo.py
36
demo.py
|
@ -3,33 +3,35 @@ from base64 import b64decode, b64encode
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
|
|
||||||
conn1 = apns.APNSConnection()
|
conn1 = apns.APNSConnection()
|
||||||
|
conn1.connect()
|
||||||
print(f"Push Token 1: {b64encode(conn1.token).decode()}")
|
print(f"Push Token 1: {b64encode(conn1.token).decode()}")
|
||||||
|
|
||||||
conn2 = apns.APNSConnection()
|
conn2 = apns.APNSConnection()
|
||||||
|
conn2.connect()
|
||||||
print(f"Push Token 2: {b64encode(conn2.token).decode()}")
|
print(f"Push Token 2: {b64encode(conn2.token).decode()}")
|
||||||
|
|
||||||
conn1.filter(["com.apple.madrid"])
|
conn1.filter(["com.apple.madrid"])
|
||||||
conn2.filter(["com.apple.madrid"])
|
conn2.filter(["com.apple.madrid"])
|
||||||
|
|
||||||
#print(sha1(b"com.apple.madrid").digest())
|
# #print(sha1(b"com.apple.madrid").digest())
|
||||||
# Send a notification
|
# # Send a notification
|
||||||
# expiry timestamp in UNIX epoch
|
# # expiry timestamp in UNIX epoch
|
||||||
expiry = 1680761868
|
# expiry = 1680761868
|
||||||
expiry = expiry.to_bytes(4, "big")
|
# expiry = expiry.to_bytes(4, "big")
|
||||||
|
|
||||||
# Current time in UNIX nano epoch
|
# # Current time in UNIX nano epoch
|
||||||
import time
|
# import time
|
||||||
now = int(time.time() * 1000).to_bytes(8, "big")
|
# now = int(time.time() * 1000).to_bytes(8, "big")
|
||||||
|
|
||||||
payload = apns.Payload(0x0a, apns.Fields({1: sha1(b"com.apple.madrid").digest(), 2: conn2.token, 3: b"Hello World!", 4: 0x00.to_bytes(), 5: expiry, 6: now, 7: 0x00.to_bytes()}))
|
# payload = apns.Payload(0x0a, apns.Fields({1: sha1(b"com.apple.madrid").digest(), 2: conn2.token, 3: b"Hello World!", 4: 0x00.to_bytes(), 5: expiry, 6: now, 7: 0x00.to_bytes()}))
|
||||||
conn1.sock.write(payload.to_bytes())
|
# conn1.sock.write(payload.to_bytes())
|
||||||
|
|
||||||
print("Waiting for response...")
|
# print("Waiting for response...")
|
||||||
|
|
||||||
# Check if the notification was sent
|
# # Check if the notification was sent
|
||||||
resp = apns.Payload.from_stream(conn1.sock)
|
# resp = apns.Payload.from_stream(conn1.sock)
|
||||||
print(resp)
|
# print(resp)
|
||||||
|
|
||||||
# Read the message from the other connection
|
# # Read the message from the other connection
|
||||||
resp = apns.Payload.from_stream(conn2.sock)
|
# resp = apns.Payload.from_stream(conn2.sock)
|
||||||
print(resp)
|
# print(resp)
|
Loading…
Reference in a new issue