mirror of
https://github.com/Sneed-Group/pypush-plus-plus
synced 2024-12-23 19:32:29 -06:00
majorly refactor apnsconnection to be more usable
This commit is contained in:
parent
0a901fb7a6
commit
6e3601de58
3 changed files with 100 additions and 163 deletions
208
apns.py
208
apns.py
|
@ -1,149 +1,89 @@
|
|||
from __future__ import annotations
|
||||
|
||||
class Fields:
|
||||
@staticmethod
|
||||
def from_bytes(data: bytes) -> Fields:
|
||||
fields = {}
|
||||
|
||||
while len(data) > 0:
|
||||
field = data[0]
|
||||
length = int.from_bytes(data[1:3], "big")
|
||||
value = data[3:3 + length]
|
||||
|
||||
fields[field] = value
|
||||
|
||||
data = data[3 + length:]
|
||||
|
||||
return Fields(fields)
|
||||
|
||||
def __init__(self, fields: dict[int, bytes]):
|
||||
self.fields = fields
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
buffer = bytearray()
|
||||
|
||||
for field, value in self.fields.items():
|
||||
buffer.append(field)
|
||||
buffer.extend(len(value).to_bytes(2, "big"))
|
||||
buffer.extend(value)
|
||||
|
||||
return buffer
|
||||
|
||||
# Debug formating
|
||||
def __str__(self) -> str:
|
||||
return f"{self.fields}"
|
||||
|
||||
# Define number to command name mapping
|
||||
COMMANDS = {
|
||||
0x7: "Connect",
|
||||
0x8: "ConnectResponse",
|
||||
0x9: "PushTopics",
|
||||
0x0A: "PushNotification",
|
||||
0x0B: "Acknowledge",
|
||||
}
|
||||
|
||||
class Payload:
|
||||
@staticmethod
|
||||
def from_stream(stream) -> Payload|None:
|
||||
command = int.from_bytes(stream.read(1), "big")
|
||||
if command == 0:
|
||||
return None # We reached the end of the stream
|
||||
length = int.from_bytes(stream.read(4), "big")
|
||||
fields = Fields.from_bytes(stream.read(length))
|
||||
|
||||
return Payload(command, fields)
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data: bytes) -> Payload:
|
||||
# Convert it to bytes for cleaner printing
|
||||
data = bytes(data)
|
||||
command = data[0]
|
||||
length = int.from_bytes(data[1:5], "big")
|
||||
fields = Fields.from_bytes(data[5:5 + length])
|
||||
|
||||
return Payload(command, fields)
|
||||
|
||||
def __init__(self, command: int, fields: Fields):
|
||||
self.command = command
|
||||
self.fields = fields
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
buffer = bytearray()
|
||||
|
||||
buffer.append(self.command)
|
||||
|
||||
fields = self.fields.to_bytes()
|
||||
|
||||
buffer.extend(len(fields).to_bytes(4, "big"))
|
||||
buffer.extend(fields)
|
||||
|
||||
return buffer
|
||||
|
||||
# Debug formating
|
||||
def __str__(self) -> str:
|
||||
return f"{COMMANDS[self.command]}: {self.fields}"
|
||||
|
||||
import courier
|
||||
import courier, albert
|
||||
from hashlib import sha1
|
||||
|
||||
class APNSConnection():
|
||||
def __init__(self, token: bytes=None, private_key=None, cert=None):
|
||||
self.sock, self.private_key, self.cert = courier.connect(private_key, cert)
|
||||
self.token = token
|
||||
|
||||
self._connect()
|
||||
|
||||
def _connect(self):
|
||||
if self.token is None:
|
||||
payload = Payload(7, Fields({2: 0x01.to_bytes()}))
|
||||
def _serialize_field(id: int, value: bytes) -> bytes:
|
||||
return id.to_bytes() + len(value).to_bytes(2, "big") + value
|
||||
|
||||
|
||||
def _serialize_payload(id: int, fields: list[(int, bytes)]) -> bytes:
|
||||
payload = b""
|
||||
|
||||
for fid, value in fields:
|
||||
payload += _serialize_field(fid, value)
|
||||
|
||||
return id.to_bytes() + len(payload).to_bytes(4, "big") + payload
|
||||
|
||||
|
||||
def _deserialize_field(stream: bytes) -> tuple[int, bytes]:
|
||||
id = int.from_bytes(stream[:1], "big")
|
||||
length = int.from_bytes(stream[1:3], "big")
|
||||
value = stream[3 : 3 + length]
|
||||
return id, value
|
||||
|
||||
|
||||
# Note: Takes a stream, not a buffer, as we do not know the length of the payload
|
||||
def _deserialize_payload(stream) -> tuple[int, list[tuple[int, bytes]]] | None:
|
||||
id = int.from_bytes(stream.read(1), "big")
|
||||
|
||||
if id == 0x0:
|
||||
return None
|
||||
|
||||
length = int.from_bytes(stream.read(4), "big")
|
||||
|
||||
buffer = stream.read(length)
|
||||
|
||||
fields = []
|
||||
|
||||
while len(buffer) > 0:
|
||||
fid, value = _deserialize_field(buffer)
|
||||
fields.append((fid, value))
|
||||
buffer = buffer[3 + len(value) :]
|
||||
|
||||
return id, fields
|
||||
|
||||
|
||||
# Returns the value of the first field with the given id
|
||||
def _get_field(fields: list[tuple[int, bytes]], id: int) -> bytes:
|
||||
for field_id, value in fields:
|
||||
if field_id == id:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
class APNSConnection:
|
||||
def __init__(self, private_key=None, cert=None):
|
||||
# Generate the private key and certificate if they're not provided
|
||||
if private_key is None or cert is None:
|
||||
self.private_key, self.cert = albert.generate_push_cert()
|
||||
else:
|
||||
payload = Payload(7, Fields({1: self.token, 2: 0x01.to_bytes()}))
|
||||
|
||||
self.sock.write(payload.to_bytes())
|
||||
self.private_key, self.cert = private_key, cert
|
||||
|
||||
resp = Payload.from_stream(self.sock)
|
||||
self.sock = courier.connect(self.private_key, self.cert)
|
||||
|
||||
if resp.command != 8 or resp.fields.fields[1] != 0x00.to_bytes():
|
||||
def connect(self, token: bytes = None):
|
||||
if token is None:
|
||||
payload = _serialize_payload(7, [(2, 0x01.to_bytes())])
|
||||
else:
|
||||
payload = _serialize_payload(7, [(1, token), (2, 0x01.to_bytes())])
|
||||
|
||||
self.sock.write(payload)
|
||||
|
||||
payload = _deserialize_payload(self.sock)
|
||||
|
||||
if payload == None or payload[0] != 8 or _get_field(payload[1], 1) != 0x00.to_bytes():
|
||||
raise Exception("Failed to connect")
|
||||
|
||||
if 3 in resp.fields.fields:
|
||||
self.token = resp.fields.fields[3]
|
||||
self.token = _get_field(payload[1], 3)
|
||||
|
||||
def filter(self, topics: list[str]):
|
||||
payload = Payload(9, Fields({1: self.token, 2: b"".join([sha1(topic.encode()).digest() for topic in topics])}))
|
||||
fields = [(1, self.token)]
|
||||
|
||||
self.sock.write(payload.to_bytes())
|
||||
for topic in topics:
|
||||
fields.append((2, sha1(topic.encode()).digest()))
|
||||
|
||||
|
||||
payload = _serialize_payload(9, fields)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import courier
|
||||
import base64
|
||||
|
||||
sock = courier.connect()
|
||||
|
||||
# Try and read the token from the file
|
||||
try:
|
||||
with open("token", "r") as f:
|
||||
r = f.read()
|
||||
if r == "":
|
||||
raise FileNotFoundError
|
||||
payload = Payload(7, Fields({1: base64.b64decode(r), 2: 0x01.to_bytes()}))
|
||||
except FileNotFoundError:
|
||||
payload = Payload(7, Fields({2: 0x01.to_bytes()}))
|
||||
|
||||
# Send the connect request (with or without the token)
|
||||
sock.write(payload.to_bytes())
|
||||
|
||||
# Read the response
|
||||
resp = Payload.from_stream(sock)
|
||||
# Check if the response is valid
|
||||
if resp.command != 8 or resp.fields.fields[1] != 0x00.to_bytes():
|
||||
raise Exception("Failed to connect")
|
||||
|
||||
# If there's a new token, save it
|
||||
if 3 in resp.fields.fields:
|
||||
with open("token", "wb") as f:
|
||||
f.write(base64.b64encode(resp.fields.fields[3]))
|
||||
|
||||
# Send the push topics request
|
||||
self.sock.write(payload)
|
||||
|
|
19
courier.py
19
courier.py
|
@ -1,28 +1,23 @@
|
|||
import albert
|
||||
import tlslite
|
||||
import socket
|
||||
|
||||
COURIER_HOST = "10-courier.push.apple.com"
|
||||
COURIER_HOST = "10-courier.push.apple.com" # TODO: Get this from config
|
||||
COURIER_PORT = 5223
|
||||
ALPN = [b"apns-security-v2"]
|
||||
#ALPN = None
|
||||
|
||||
def connect(private_key=None, cert=None):
|
||||
# If we don't have a private key or certificate, generate one
|
||||
if private_key is None or cert is None:
|
||||
private_key, cert = albert.generate_push_cert()
|
||||
|
||||
# Connect to the courier server
|
||||
def connect(private_key: str, cert: str) -> tlslite.TLSConnection:
|
||||
# Connect to the courier server
|
||||
sock = socket.create_connection((COURIER_HOST, COURIER_PORT))
|
||||
# Wrap the socket in TLS
|
||||
sock = tlslite.TLSConnection(sock)
|
||||
# Parse the certificate and private key
|
||||
cert_parsed = tlslite.X509CertChain([tlslite.X509().parse(cert)])
|
||||
private_key_parsed = tlslite.parsePEMKey(private_key, private=True)
|
||||
cert = tlslite.X509CertChain([tlslite.X509().parse(cert)])
|
||||
private_key = tlslite.parsePEMKey(private_key, private=True)
|
||||
# Handshake with the server
|
||||
sock.handshakeClientCert(cert_parsed, private_key_parsed, alpn=ALPN)
|
||||
sock.handshakeClientCert(cert, private_key, alpn=ALPN)
|
||||
|
||||
return sock, private_key, cert
|
||||
return sock
|
||||
|
||||
if __name__ == "__main__":
|
||||
sock = connect()
|
||||
|
|
36
demo.py
36
demo.py
|
@ -3,33 +3,35 @@ from base64 import b64decode, b64encode
|
|||
from hashlib import sha1
|
||||
|
||||
conn1 = apns.APNSConnection()
|
||||
conn1.connect()
|
||||
print(f"Push Token 1: {b64encode(conn1.token).decode()}")
|
||||
|
||||
conn2 = apns.APNSConnection()
|
||||
conn2.connect()
|
||||
print(f"Push Token 2: {b64encode(conn2.token).decode()}")
|
||||
|
||||
conn1.filter(["com.apple.madrid"])
|
||||
conn2.filter(["com.apple.madrid"])
|
||||
|
||||
#print(sha1(b"com.apple.madrid").digest())
|
||||
# Send a notification
|
||||
# expiry timestamp in UNIX epoch
|
||||
expiry = 1680761868
|
||||
expiry = expiry.to_bytes(4, "big")
|
||||
# #print(sha1(b"com.apple.madrid").digest())
|
||||
# # Send a notification
|
||||
# # expiry timestamp in UNIX epoch
|
||||
# expiry = 1680761868
|
||||
# expiry = expiry.to_bytes(4, "big")
|
||||
|
||||
# Current time in UNIX nano epoch
|
||||
import time
|
||||
now = int(time.time() * 1000).to_bytes(8, "big")
|
||||
# # Current time in UNIX nano epoch
|
||||
# import time
|
||||
# now = int(time.time() * 1000).to_bytes(8, "big")
|
||||
|
||||
payload = apns.Payload(0x0a, apns.Fields({1: sha1(b"com.apple.madrid").digest(), 2: conn2.token, 3: b"Hello World!", 4: 0x00.to_bytes(), 5: expiry, 6: now, 7: 0x00.to_bytes()}))
|
||||
conn1.sock.write(payload.to_bytes())
|
||||
# payload = apns.Payload(0x0a, apns.Fields({1: sha1(b"com.apple.madrid").digest(), 2: conn2.token, 3: b"Hello World!", 4: 0x00.to_bytes(), 5: expiry, 6: now, 7: 0x00.to_bytes()}))
|
||||
# conn1.sock.write(payload.to_bytes())
|
||||
|
||||
print("Waiting for response...")
|
||||
# print("Waiting for response...")
|
||||
|
||||
# Check if the notification was sent
|
||||
resp = apns.Payload.from_stream(conn1.sock)
|
||||
print(resp)
|
||||
# # Check if the notification was sent
|
||||
# resp = apns.Payload.from_stream(conn1.sock)
|
||||
# print(resp)
|
||||
|
||||
# Read the message from the other connection
|
||||
resp = apns.Payload.from_stream(conn2.sock)
|
||||
print(resp)
|
||||
# # Read the message from the other connection
|
||||
# resp = apns.Payload.from_stream(conn2.sock)
|
||||
# print(resp)
|
Loading…
Reference in a new issue