diff --git a/apns.py b/apns.py index ee88642..45f1042 100644 --- a/apns.py +++ b/apns.py @@ -215,7 +215,7 @@ class APNSConnection: ], ) - if token != b"": + if token != b"" and token is not None: payload.fields.insert(0, APNSField(0x1, token)) await self._send(payload) @@ -225,7 +225,12 @@ class APNSConnection: if payload.fields_with_id(1)[0].value != b"\x00": raise Exception("Failed to connect") - new_token = payload.fields_with_id(3)[0].value + if len(payload.fields_with_id(3)) > 0: + new_token = payload.fields_with_id(3)[0].value + else: + if token is None: + raise Exception("No token received") + new_token = token logger.debug( f"Received connect response with token {b64encode(new_token).decode()}" @@ -292,7 +297,7 @@ class APNSConnection: APNSPayload( 0x14, [ - APNSField(1, state.to_bytes(4, "big")), + APNSField(1, state.to_bytes(1, "big")), APNSField(2, 0x7FFFFFFF.to_bytes(4, "big")), ], ) diff --git a/demo.py b/demo.py index 319028b..f6c0797 100644 --- a/demo.py +++ b/demo.py @@ -23,11 +23,11 @@ logging.getLogger("py.warnings").setLevel(logging.ERROR) # Ignore warnings from logging.getLogger("asyncio").setLevel(logging.WARNING) logging.getLogger("jelly").setLevel(logging.INFO) logging.getLogger("nac").setLevel(logging.INFO) -logging.getLogger("apns").setLevel(logging.INFO) +logging.getLogger("apns").setLevel(logging.DEBUG) logging.getLogger("albert").setLevel(logging.INFO) logging.getLogger("ids").setLevel(logging.DEBUG) logging.getLogger("bags").setLevel(logging.INFO) -logging.getLogger("imessage").setLevel(logging.INFO) +logging.getLogger("imessage").setLevel(logging.DEBUG) logging.captureWarnings(True) @@ -65,13 +65,18 @@ async def main(): except FileNotFoundError: CONFIG = {} + token = CONFIG.get("push", {}).get("token") + if token is not None: + token = b64decode(token) + else: + token = b"" + push_creds = apns.PushCredentials( - CONFIG.get("push", {}).get("key"), CONFIG.get("push", {}).get("cert"), CONFIG.get("push", {}).get("token") - ) + CONFIG.get("push", {}).get("key", ""), CONFIG.get("push", {}).get("cert", ""), token) async with apns.APNSConnection.start(push_creds) as conn: - conn.set_state(1) - conn.filter(["com.apple.madrid"]) + await conn.set_state(1) + await conn.filter(["com.apple.madrid"]) user = ids.IDSUser(conn) @@ -130,4 +135,11 @@ async def main(): with open("config.json", "w") as f: json.dump(CONFIG, f, indent=4) - im = imessage.iMessageUser(conn, user) \ No newline at end of file + im = imessage.iMessageUser(conn, user) + + # Send a message to myself + await im.send(imessage.iMessage.create(im, "Hello, world!", [user.current_handle])) + +if __name__ == "__main__": + import trio + trio.run(main) \ No newline at end of file diff --git a/ids/__init__.py b/ids/__init__.py index 7e561c4..5c78b55 100644 --- a/ids/__init__.py +++ b/ids/__init__.py @@ -3,12 +3,12 @@ from base64 import b64encode import apns from . import _helpers, identity, profile, query - +from typing import Callable, Any class IDSUser: # Sets self.user_id and self._auth_token def _authenticate_for_token( - self, username: str, password: str, factor_callback: callable = None + self, username: str, password: str, factor_callback: Callable | None = None ): self.user_id, self._auth_token = profile.get_auth_token( username, password, factor_callback @@ -25,22 +25,22 @@ class IDSUser: ): self.push_connection = push_connection self._push_keypair = _helpers.KeyPair( - self.push_connection.private_key, self.push_connection.cert + self.push_connection.credentials.private_key, self.push_connection.credentials.cert ) self.ec_key = self.rsa_key = None def __str__(self): - return f"IDSUser(user_id={self.user_id}, handles={self.handles}, push_token={b64encode(self.push_connection.token).decode()})" + return f"IDSUser(user_id={self.user_id}, handles={self.handles}, push_token={b64encode(self.push_connection.credentials.token).decode()})" # Authenticates with a username and password, to create a brand new authentication keypair def authenticate( - self, username: str, password: str, factor_callback: callable = None + self, username: str, password: str, factor_callback: Callable | None = None ): self._authenticate_for_token(username, password, factor_callback) self._authenticate_for_cert() self.handles = profile.get_handles( - b64encode(self.push_connection.token), + b64encode(self.push_connection.credentials.token), self.user_id, self._auth_keypair, self._push_keypair, @@ -68,7 +68,7 @@ class IDSUser: cert = identity.register( - b64encode(self.push_connection.token), + b64encode(self.push_connection.credentials.token), self.handles, self.user_id, self._auth_keypair, @@ -81,6 +81,6 @@ class IDSUser: def restore_identity(self, id_keypair: _helpers.KeyPair): self._id_keypair = id_keypair - def lookup(self, uris: list[str], topic: str = "com.apple.madrid") -> any: - return query.lookup(self.push_connection, self.current_handle, self._id_keypair, uris, topic) + async def lookup(self, uris: list[str], topic: str = "com.apple.madrid") -> Any: + return await query.lookup(self.push_connection, self.current_handle, self._id_keypair, uris, topic) diff --git a/ids/identity.py b/ids/identity.py index a918f55..f7ee5e3 100644 --- a/ids/identity.py +++ b/ids/identity.py @@ -17,27 +17,28 @@ class IDSIdentity: def __init__(self, signing_key: str | None = None, encryption_key: str | None = None, signing_public_key: str | None = None, encryption_public_key: str | None = None): if signing_key is not None: self.signing_key = signing_key - self.signing_public_key = serialize_key(parse_key(signing_key).public_key()) + self.signing_public_key = serialize_key(parse_key(signing_key).public_key())# type: ignore elif signing_public_key is not None: self.signing_key = None self.signing_public_key = signing_public_key else: # Generate a new key self.signing_key = serialize_key(ec.generate_private_key(ec.SECP256R1())) - self.signing_public_key = serialize_key(parse_key(self.signing_key).public_key()) + self.signing_public_key = serialize_key(parse_key(self.signing_key).public_key())# type: ignore if encryption_key is not None: self.encryption_key = encryption_key - self.encryption_public_key = serialize_key(parse_key(encryption_key).public_key()) + self.encryption_public_key = serialize_key(parse_key(encryption_key).public_key())# type: ignore elif encryption_public_key is not None: self.encryption_key = None self.encryption_public_key = encryption_public_key else: self.encryption_key = serialize_key(rsa.generate_private_key(65537, 1280)) - self.encryption_public_key = serialize_key(parse_key(self.encryption_key).public_key()) - - def decode(input: bytes) -> 'IDSIdentity': - input = BytesIO(input) + self.encryption_public_key = serialize_key(parse_key(self.encryption_key).public_key())# type: ignore + + @staticmethod + def decode(inp: bytes) -> 'IDSIdentity': + input = BytesIO(inp) assert input.read(5) == b'\x30\x81\xF6\x81\x43' # DER header raw_ecdsa = input.read(67) @@ -75,13 +76,13 @@ class IDSIdentity: raw_rsa.write(b'\x00\xAC') raw_rsa.write(b'\x30\x81\xA9') raw_rsa.write(b'\x02\x81\xA1') - raw_rsa.write(parse_key(self.encryption_public_key).public_numbers().n.to_bytes(161, "big")) + raw_rsa.write(parse_key(self.encryption_public_key).public_numbers().n.to_bytes(161, "big")) # type: ignore raw_rsa.write(b'\x02\x03\x01\x00\x01') # Hardcode the exponent output.write(b'\x30\x81\xF6\x81\x43') output.write(b'\x00\x41\x04') - output.write(parse_key(self.signing_public_key).public_numbers().x.to_bytes(32, "big")) - output.write(parse_key(self.signing_public_key).public_numbers().y.to_bytes(32, "big")) + output.write(parse_key(self.signing_public_key).public_numbers().x.to_bytes(32, "big"))# type: ignore + output.write(parse_key(self.signing_public_key).public_numbers().y.to_bytes(32, "big"))# type: ignore output.write(b'\x82\x81\xAE') output.write(raw_rsa.getvalue()) diff --git a/ids/profile.py b/ids/profile.py index 3af9d73..f5ae576 100644 --- a/ids/profile.py +++ b/ids/profile.py @@ -18,8 +18,10 @@ from ._helpers import PROTOCOL_VERSION, USER_AGENT, KeyPair import logging logger = logging.getLogger("ids") +from typing import Any, Callable -def _auth_token_request(username: str, password: str) -> any: + +def _auth_token_request(username: str, password: str) -> Any: # Turn the PET into an auth token data = { "username": username, @@ -46,7 +48,7 @@ def _auth_token_request(username: str, password: str) -> any: # If factor_gen is not None, it will be called to get the 2FA code, otherwise it will be prompted # Returns (realm user id, auth token) def get_auth_token( - username: str, password: str, factor_gen: callable = None + username: str, password: str, factor_gen: Callable | None = None ) -> tuple[str, str]: from sys import platform @@ -154,7 +156,7 @@ def get_handles(push_token, user_id: str, auth_key: KeyPair, push_key: KeyPair): "x-auth-user-id": user_id, } signing.add_auth_signature( - headers, None, BAG_KEY, auth_key, push_key, push_token + headers, b"", BAG_KEY, auth_key, push_key, push_token ) r = requests.get( diff --git a/imessage.py b/imessage.py index 8951cde..ac76369 100644 --- a/imessage.py +++ b/imessage.py @@ -50,7 +50,7 @@ class MMCSFile(AttachmentFile): logger.info( requests.get( - url=self.url, + url=self.url, # type: ignore headers={ "User-Agent": f"IMTransferAgent/900 CFNetwork/596.2.3 Darwin/12.2.0 (x86_64) (Macmini5,1)", # "MMCS-Url": self.url, @@ -79,8 +79,8 @@ class Attachment: def __init__(self, message_raw_content: dict, xml_element: ElementTree.Element): attrib = xml_element.attrib - self.name = attrib["name"] if "name" in attrib else None - self.mime_type = attrib["mime-type"] if "mime-type" in attrib else None + self.name = attrib["name"] if "name" in attrib else None # type: ignore + self.mime_type = attrib["mime-type"] if "mime-type" in attrib else None # type: ignore if "inline-attachment" in attrib: # just grab the inline attachment ! @@ -121,7 +121,7 @@ class Attachment: # case "decryption-key": # versions[index].decryption_key = base64.b16decode(val)[1:] - self.versions = versions + self.versions = versions # type: ignore def __repr__(self): return f'' @@ -136,16 +136,23 @@ class Message: _compressed: bool = True xml: str | None = None + @staticmethod def from_raw(message: bytes, sender: str | None = None) -> "Message": """Create a `Message` from raw message bytes""" raise NotImplementedError() - def __str__(): + def to_raw(self) -> bytes: + """Convert a `Message` to raw message bytes""" + + raise NotImplementedError() + + def __str__(self): raise NotImplementedError() @dataclass class SMSReflectedMessage(Message): + @staticmethod def from_raw(message: bytes, sender: str | None = None) -> "SMSReflectedMessage": """Create a `SMSReflectedMessage` from raw message bytes""" @@ -161,11 +168,11 @@ class SMSReflectedMessage(Message): logger.info(f"Decoding SMSReflectedMessage: {message}") return SMSReflectedMessage( - text=message["mD"]["plain-body"], - sender=sender, - participants=[re["id"] for re in message["re"]] + [sender], - id=uuid.UUID(message["mD"]["guid"]), - _raw=message, + text=message["mD"]["plain-body"], # type: ignore + sender=sender, # type: ignore + participants=[re["id"] for re in message["re"]] + [sender], # type: ignore + id=uuid.UUID(message["mD"]["guid"]), # type: ignore + _raw=message, # type: ignore _compressed=compressed, ) @@ -209,6 +216,7 @@ class SMSReflectedMessage(Message): @dataclass class SMSIncomingMessage(Message): + @staticmethod def from_raw(message: bytes, sender: str | None = None) -> "SMSIncomingMessage": """Create a `SMSIncomingMessage` from raw message bytes""" @@ -224,11 +232,11 @@ class SMSIncomingMessage(Message): logger.debug(f"Decoding SMSIncomingMessage: {message}") return SMSIncomingMessage( - text=message["k"][0]["data"].decode(), - sender=message["h"], # Don't use sender parameter, that is the phone that forwarded the message - participants=[message["h"], message["co"]], - id=uuid.UUID(message["g"]), - _raw=message, + text=message["k"][0]["data"].decode(), # type: ignore + sender=message["h"], # Don't use sender parameter, that is the phone that forwarded the message # type: ignore + participants=[message["h"], message["co"]], # type: ignore + id=uuid.UUID(message["g"]), # type: ignore + _raw=message, # type: ignore _compressed=compressed, ) @@ -237,16 +245,18 @@ class SMSIncomingMessage(Message): @dataclass class SMSIncomingImage(Message): + @staticmethod def from_raw(message: bytes, sender: str | None = None) -> "SMSIncomingImage": """Create a `SMSIncomingImage` from raw message bytes""" # TODO: Implement this - return "SMSIncomingImage" + return "SMSIncomingImage" # type: ignore @dataclass class iMessage(Message): effect: str | None = None + @staticmethod def create(user: "iMessageUser", text: str, participants: list[str]) -> "iMessage": """Creates a basic outgoing `iMessage` from the given text and participants""" @@ -260,6 +270,7 @@ class iMessage(Message): id=uuid.uuid4(), ) + @staticmethod def from_raw(message: bytes, sender: str | None = None) -> "iMessage": """Create a `iMessage` from raw message bytes""" @@ -275,14 +286,14 @@ class iMessage(Message): logger.debug(f"Decoding iMessage: {message}") return iMessage( - text=message["t"], - participants=message["p"], - sender=sender, - id=uuid.UUID(message["r"]) if "r" in message else None, - xml=message["x"] if "x" in message else None, - _raw=message, + text=message["t"], # type: ignore + participants=message["p"], # type: ignore + sender=sender, # type: ignore + id=uuid.UUID(message["r"]) if "r" in message else None, # type: ignore + xml=message["x"] if "x" in message else None, # type: ignore + _raw=message, # type: ignore _compressed=compressed, - effect=message["iid"] if "iid" in message else None, + effect=message["iid"] if "iid" in message else None, # type: ignore ) def to_raw(self) -> bytes: @@ -338,8 +349,9 @@ class iMessageUser: self.connection = connection self.user = user - def _parse_payload(payload: bytes) -> tuple[bytes, bytes]: - payload = BytesIO(payload) + @staticmethod + def _parse_payload(p: bytes) -> tuple[bytes, bytes]: + payload = BytesIO(p) tag = payload.read(1) # print("TAG", tag) @@ -351,6 +363,7 @@ class iMessageUser: return (body, signature) + @staticmethod def _construct_payload(body: bytes, signature: bytes) -> bytes: payload = ( b"\x02" @@ -361,6 +374,7 @@ class iMessageUser: ) return payload + @staticmethod def _hash_identity(id: bytes) -> bytes: iden = ids.identity.IDSIdentity.decode(id) @@ -369,13 +383,11 @@ class iMessageUser: output.write(b"\x00\x41\x04") output.write( ids._helpers.parse_key(iden.signing_public_key) - .public_numbers() - .x.to_bytes(32, "big") + .public_numbers().x.to_bytes(32, "big") # type: ignore ) output.write( ids._helpers.parse_key(iden.signing_public_key) - .public_numbers() - .y.to_bytes(32, "big") + .public_numbers().y.to_bytes(32, "big") # type: ignore ) output.write(b"\x00\xAC") @@ -383,8 +395,7 @@ class iMessageUser: output.write(b"\x02\x81\xA1") output.write( ids._helpers.parse_key(iden.encryption_public_key) - .public_numbers() - .n.to_bytes(161, "big") + .public_numbers().n.to_bytes(161, "big") # type: ignore ) output.write(b"\x02\x03\x01\x00\x01") @@ -417,7 +428,7 @@ class iMessageUser: # Encrypt the AES key with the public key of the recipient recipient_key = ids._helpers.parse_key(key.encryption_public_key) - rsa_body = recipient_key.encrypt( + rsa_body = recipient_key.encrypt( # type: ignore aes_key + encrypted[:100], padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA1()), @@ -428,20 +439,20 @@ class iMessageUser: # Construct the payload body = rsa_body + encrypted[100:] - sig = ids._helpers.parse_key(self.user.encryption_identity.signing_key).sign( - body, ec.ECDSA(hashes.SHA1()) + sig = ids._helpers.parse_key(self.user.encryption_identity.signing_key).sign( # type: ignore + body, ec.ECDSA(hashes.SHA1()) # type: ignore ) payload = iMessageUser._construct_payload(body, sig) return payload - def _decrypt_payload(self, payload: bytes) -> dict: - payload = iMessageUser._parse_payload(payload) + def _decrypt_payload(self, p: bytes) -> bytes: + payload = iMessageUser._parse_payload(p) body = BytesIO(payload[0]) rsa_body = ids._helpers.parse_key( - self.user.encryption_identity.encryption_key - ).decrypt( + self.user.encryption_identity.encryption_key # type: ignore + ).decrypt( # type: ignore body.read(160), padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA1()), @@ -455,9 +466,9 @@ class iMessageUser: return decrypted - def _verify_payload(self, payload: bytes, sender: str, sender_token: str) -> bool: + async def _verify_payload(self, p: bytes, sender: str, sender_token: str) -> bool: # Get the public key for the sender - self._cache_keys([sender], "com.apple.madrid") + await self._cache_keys([sender], "com.apple.madrid") if not sender_token in self.KEY_CACHE: logger.warning("Unable to find the public key of the sender, cannot verify") @@ -468,25 +479,25 @@ class iMessageUser: ) sender_ec_key = ids._helpers.parse_key(identity_keys.signing_public_key) - payload = iMessageUser._parse_payload(payload) + payload = iMessageUser._parse_payload(p) try: # Verify the signature (will throw an exception if it fails) - sender_ec_key.verify( + sender_ec_key.verify( # type: ignore payload[1], payload[0], - ec.ECDSA(hashes.SHA1()), + ec.ECDSA(hashes.SHA1()), # type: ignore ) return True except: return False - def receive(self) -> Message | None: + async def receive(self) -> Message | None: """ Will return the next iMessage in the queue, or None if there are no messages """ for type, (topic, cls) in MESSAGE_TYPES.items(): - body = self._receive_raw(type, topic) + body = await self._receive_raw(type, topic) if body is not None: t = cls break @@ -494,7 +505,7 @@ class iMessageUser: return None - if not self._verify_payload(body["P"], body["sP"], body["t"]): + if not await self._verify_payload(body["P"], body["sP"], body["t"]): raise Exception("Failed to verify payload") logger.debug(f"Encrypted body : {body}") @@ -513,7 +524,7 @@ class iMessageUser: USER_CACHE: dict[str, list[bytes]] = {} """Mapping of handle : [push tokens]""" - def _cache_keys(self, participants: list[str], topic: str): + async def _cache_keys(self, participants: list[str], topic: str): # Clear the cache if the handle has changed if self.KEY_CACHE_HANDLE != self.user.current_handle: self.KEY_CACHE_HANDLE = self.user.current_handle @@ -526,7 +537,7 @@ class iMessageUser: # TODO: This doesn't work since it doesn't check if they are cached for all topics # Look up the public keys for the participants, and cache a token : public key mapping - lookup = self.user.lookup(participants, topic=topic) + lookup = await self.user.lookup(participants, topic=topic) logger.debug(f"Lookup response : {lookup}") for key, participant in lookup.items(): @@ -559,7 +570,7 @@ class iMessageUser: identity["session-token"], ) - def _send_raw( + async def _send_raw( self, type: int, participants: list[str], @@ -568,12 +579,12 @@ class iMessageUser: id: uuid.UUID | None = None, extra: dict = {}, ): - self._cache_keys(participants, topic) + await self._cache_keys(participants, topic) dtl = [] for participant in participants: for push_token in self.USER_CACHE[participant]: - if push_token == self.connection.token: + if push_token == self.connection.credentials.token: continue # Don't send to ourselves identity_keys = ids.identity.IDSIdentity.decode( @@ -613,52 +624,35 @@ class iMessageUser: body = plistlib.dumps(body, fmt=plistlib.FMT_BINARY) - self.connection.send_message(topic, body, message_id) + await self.connection.send_notification(topic, body, message_id) - def _receive_raw(self, c: int | list[int], topic: str | list[str]) -> dict | None: - def check_response(x): - if x[0] != 0x0A: + async def _receive_raw(self, c: int, topic: str) -> dict | None: + def check(payload: apns.APNSPayload): + # Check if the "c" key matches + body = payload.fields_with_id(3)[0].value + if body is None: return False - # Check if it matches any of the topics - if isinstance(topic, list): - for t in topic: - if apns._get_field(x[1], 2) == sha1(t.encode()).digest(): - break - else: - return False - else: - if apns._get_field(x[1], 2) != sha1(topic.encode()).digest(): - return False - - resp_body = apns._get_field(x[1], 3) - if resp_body is None: - return False - resp_body = plistlib.loads(resp_body) - - #logger.info(f"See type {resp_body['c']}") - - if isinstance(c, list): - if not resp_body["c"] in c: - return False - elif resp_body["c"] != c: + body = plistlib.loads(body) + if not "c" in body or "c" != c: return False return True + + payload = await self.connection.expect_notification(topic, check) - payload = self.connection.incoming_queue.pop_find(check_response) if payload is None: return None - body = apns._get_field(payload[1], 3) + body = payload.fields_with_id(3)[0].value body = plistlib.loads(body) return body - def activate_sms(self) -> bool: + async def activate_sms(self): """ Try to activate SMS forwarding Returns True if we are able to perform SMS forwarding, False otherwise Call repeatedly until it returns True """ - act_message = self._receive_raw(145, "com.apple.private.alloy.sms") + act_message = await self._receive_raw(145, "com.apple.private.alloy.sms") if act_message is None: return False @@ -672,7 +666,7 @@ class iMessageUser: else: logger.info("SMS forwarding de-activated, sending response") - self._send_raw( + await self._send_raw( 147, [self.user.current_handle], "com.apple.private.alloy.sms", @@ -681,7 +675,7 @@ class iMessageUser: } ) - def send(self, message: Message): + async def send(self, message: Message): # Check what type of message we are sending for t, (topic, cls) in MESSAGE_TYPES.items(): if isinstance(message, cls): @@ -691,9 +685,9 @@ class iMessageUser: send_to = message.participants if isinstance(message, iMessage) else [self.user.current_handle] - self._cache_keys(send_to, topic) + await self._cache_keys(send_to, topic) - self._send_raw( + await self._send_raw( t, send_to, topic, @@ -713,12 +707,12 @@ class iMessageUser: for p in send_to: for t in self.USER_CACHE[p]: - if t == self.connection.token: + if t == self.connection.credentials.token: continue total += 1 while count < total and time.time() - start < 2: - resp = self._receive_raw(255, topic) + resp = await self._receive_raw(255, topic) if resp is None: continue count += 1