diff --git a/apns.py b/apns.py index 3d4b137..4ee9c38 100644 --- a/apns.py +++ b/apns.py @@ -109,6 +109,8 @@ class APNSConnection: await self._queue_park.wait() # Wait for a new payload to be added to the queue logger.debug(f"Woken by event, checking for {id}") # Check if the new payload matches the id + if len(self._incoming_queue) == 0: + continue # all payloads have been removed by someone else if self._incoming_queue[-1].id != id: continue if filter is not None: @@ -294,11 +296,16 @@ class APNSConnection: # TODO: Check ACK code - async def expect_notification(self, topic: str, filter: Callable | None = None): + async def expect_notification(self, topics: str | list[str], filter: Callable | None = None): """Waits for a notification to be received, and acks it""" + if isinstance(topics, list): + topic_hashes = [sha1(topic.encode()).digest() for topic in topics] + else: + topic_hashes = [sha1(topics.encode()).digest()] + def f(payload: APNSPayload): - if payload.fields_with_id(2)[0].value != sha1(topic.encode()).digest(): + if payload.fields_with_id(2)[0].value not in topic_hashes: return False if filter is not None: return filter(payload) diff --git a/demo.py b/demo.py index 571a724..e224acf 100644 --- a/demo.py +++ b/demo.py @@ -13,6 +13,8 @@ import apns import ids import imessage +import trio + logging.basicConfig( level=logging.NOTSET, format="%(message)s", datefmt="[%X]", handlers=[RichHandler()] ) @@ -58,13 +60,6 @@ def safe_b64decode(s): return None async def main(): - # Try and load config.json - try: - with open("config.json", "r") as f: - CONFIG = json.load(f) - except FileNotFoundError: - CONFIG = {} - token = CONFIG.get("push", {}).get("token") if token is not None: token = b64decode(token) @@ -138,9 +133,21 @@ async def main(): im = imessage.iMessageUser(conn, user) # Send a message to myself - await im.send(imessage.iMessage.create(im, "Hello, world!", [user.current_handle])) - print(await im.receive()) + async with trio.open_nursery() as nursery: + nursery.start_soon(input_task, im) + nursery.start_soon(output_task, im) + +async def input_task(im: imessage.iMessageUser): + while True: + cmd = await trio.to_thread.run_sync(input, "> ", cancellable=True) + if cmd != "": + await im.send(imessage.iMessage.create(im, cmd, [im.user.current_handle])) + +async def output_task(im: imessage.iMessageUser): + while True: + msg = await im.receive() + print(str(msg)) + if __name__ == "__main__": - import trio trio.run(main) \ No newline at end of file diff --git a/development/proxy/proxy_async.py b/development/proxy/proxy_async.py index cd279db..021bf30 100644 --- a/development/proxy/proxy_async.py +++ b/development/proxy/proxy_async.py @@ -21,7 +21,7 @@ logging.basicConfig( ) async def main(): - apns.COURIER_HOST = "windows.courier.push.apple.com" + apns.COURIER_HOST = "windows.courier.push.apple.com" # Use windows courier so that /etc/hosts override doesn't affect it context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH) context.set_alpn_protocols(["apns-security-v3"]) @@ -72,8 +72,8 @@ class APNSProxy: logging.info(f"<- {payload}") def tamper(self, payload: apns.APNSPayload, to_server) -> apns.APNSPayload: - if not to_server: - payload = self.tamper_lookup_keys(payload) + #if not to_server: + # payload = self.tamper_lookup_keys(payload) return payload diff --git a/imessage.py b/imessage.py index 9f37bae..58816a5 100644 --- a/imessage.py +++ b/imessage.py @@ -261,7 +261,8 @@ class iMessage(Message): """Creates a basic outgoing `iMessage` from the given text and participants""" sender = user.user.current_handle - participants += [sender] + if sender not in participants: + participants += [sender] return iMessage( text=text, @@ -492,18 +493,12 @@ class iMessageUser: except: return False - async def receive(self) -> Message | None: + async def receive(self) -> Message: """ Will return the next iMessage in the queue, or None if there are no messages """ - for type, (topic, cls) in MESSAGE_TYPES.items(): - body = await self._receive_raw(type, topic) - if body is not None: - t = cls - break - else: - return None - + body = await self._receive_raw([t for t, _ in MESSAGE_TYPES.items()], [t[0] for _, t in MESSAGE_TYPES.items()]) + t = MESSAGE_TYPES[body["c"]][1] if not await self._verify_payload(body["P"], body["sP"], body["t"]): raise Exception("Failed to verify payload") @@ -516,7 +511,7 @@ class iMessageUser: return t.from_raw(decrypted, body["sP"]) except Exception as e: logger.error(f"Failed to parse message : {e}") - return None + return Message(text="Failed to parse message", sender="System", participants=[], id=uuid.uuid4(), _raw=body) KEY_CACHE_HANDLE: str = "" KEY_CACHE: dict[bytes, dict[str, tuple[bytes, bytes]]] = {} @@ -545,8 +540,7 @@ class iMessageUser: logger.warning(f"Participant {key} has no identities, this is probably not a real account") for key, participant in lookup.items(): - if not key in self.USER_CACHE: - self.USER_CACHE[key] = [] + self.USER_CACHE[key] = [] # Clear so that we don't keep appending multiple times for identity in participant["identities"]: if not "client-data" in identity: @@ -626,18 +620,22 @@ class iMessageUser: await self.connection.send_notification(topic, body, message_id) - async def _receive_raw(self, c: int, topic: str) -> dict: + async def _receive_raw(self, c: int | list[int], topics: str | list[str]) -> dict: def check(payload: apns.APNSPayload): # Check if the "c" key matches body = payload.fields_with_id(3)[0].value if body is None: return False body = plistlib.loads(body) - if not "c" in body or body["c"] != c: + if not "c" in body: + return False + if isinstance(c, int) and body["c"] != c: + return False + elif isinstance(c, list) and body["c"] not in c: return False return True - payload = await self.connection.expect_notification(topic, check) + payload = await self.connection.expect_notification(topics, check) body = payload.fields_with_id(3)[0].value body = plistlib.loads(body)