typing, async, kinda works most of the time

This commit is contained in:
JJTech0130 2023-08-17 21:14:44 -04:00
parent a0f88b0d91
commit 60459cf1bc
No known key found for this signature in database
GPG key ID: 23C92EBCCF8F93D6
6 changed files with 133 additions and 119 deletions

11
apns.py
View file

@ -215,7 +215,7 @@ class APNSConnection:
], ],
) )
if token != b"": if token != b"" and token is not None:
payload.fields.insert(0, APNSField(0x1, token)) payload.fields.insert(0, APNSField(0x1, token))
await self._send(payload) await self._send(payload)
@ -225,7 +225,12 @@ class APNSConnection:
if payload.fields_with_id(1)[0].value != b"\x00": if payload.fields_with_id(1)[0].value != b"\x00":
raise Exception("Failed to connect") raise Exception("Failed to connect")
new_token = payload.fields_with_id(3)[0].value if len(payload.fields_with_id(3)) > 0:
new_token = payload.fields_with_id(3)[0].value
else:
if token is None:
raise Exception("No token received")
new_token = token
logger.debug( logger.debug(
f"Received connect response with token {b64encode(new_token).decode()}" f"Received connect response with token {b64encode(new_token).decode()}"
@ -292,7 +297,7 @@ class APNSConnection:
APNSPayload( APNSPayload(
0x14, 0x14,
[ [
APNSField(1, state.to_bytes(4, "big")), APNSField(1, state.to_bytes(1, "big")),
APNSField(2, 0x7FFFFFFF.to_bytes(4, "big")), APNSField(2, 0x7FFFFFFF.to_bytes(4, "big")),
], ],
) )

24
demo.py
View file

@ -23,11 +23,11 @@ logging.getLogger("py.warnings").setLevel(logging.ERROR) # Ignore warnings from
logging.getLogger("asyncio").setLevel(logging.WARNING) logging.getLogger("asyncio").setLevel(logging.WARNING)
logging.getLogger("jelly").setLevel(logging.INFO) logging.getLogger("jelly").setLevel(logging.INFO)
logging.getLogger("nac").setLevel(logging.INFO) logging.getLogger("nac").setLevel(logging.INFO)
logging.getLogger("apns").setLevel(logging.INFO) logging.getLogger("apns").setLevel(logging.DEBUG)
logging.getLogger("albert").setLevel(logging.INFO) logging.getLogger("albert").setLevel(logging.INFO)
logging.getLogger("ids").setLevel(logging.DEBUG) logging.getLogger("ids").setLevel(logging.DEBUG)
logging.getLogger("bags").setLevel(logging.INFO) logging.getLogger("bags").setLevel(logging.INFO)
logging.getLogger("imessage").setLevel(logging.INFO) logging.getLogger("imessage").setLevel(logging.DEBUG)
logging.captureWarnings(True) logging.captureWarnings(True)
@ -65,13 +65,18 @@ async def main():
except FileNotFoundError: except FileNotFoundError:
CONFIG = {} CONFIG = {}
token = CONFIG.get("push", {}).get("token")
if token is not None:
token = b64decode(token)
else:
token = b""
push_creds = apns.PushCredentials( push_creds = apns.PushCredentials(
CONFIG.get("push", {}).get("key"), CONFIG.get("push", {}).get("cert"), CONFIG.get("push", {}).get("token") CONFIG.get("push", {}).get("key", ""), CONFIG.get("push", {}).get("cert", ""), token)
)
async with apns.APNSConnection.start(push_creds) as conn: async with apns.APNSConnection.start(push_creds) as conn:
conn.set_state(1) await conn.set_state(1)
conn.filter(["com.apple.madrid"]) await conn.filter(["com.apple.madrid"])
user = ids.IDSUser(conn) user = ids.IDSUser(conn)
@ -131,3 +136,10 @@ async def main():
json.dump(CONFIG, f, indent=4) json.dump(CONFIG, f, indent=4)
im = imessage.iMessageUser(conn, user) im = imessage.iMessageUser(conn, user)
# Send a message to myself
await im.send(imessage.iMessage.create(im, "Hello, world!", [user.current_handle]))
if __name__ == "__main__":
import trio
trio.run(main)

View file

@ -3,12 +3,12 @@ from base64 import b64encode
import apns import apns
from . import _helpers, identity, profile, query from . import _helpers, identity, profile, query
from typing import Callable, Any
class IDSUser: class IDSUser:
# Sets self.user_id and self._auth_token # Sets self.user_id and self._auth_token
def _authenticate_for_token( def _authenticate_for_token(
self, username: str, password: str, factor_callback: callable = None self, username: str, password: str, factor_callback: Callable | None = None
): ):
self.user_id, self._auth_token = profile.get_auth_token( self.user_id, self._auth_token = profile.get_auth_token(
username, password, factor_callback username, password, factor_callback
@ -25,22 +25,22 @@ class IDSUser:
): ):
self.push_connection = push_connection self.push_connection = push_connection
self._push_keypair = _helpers.KeyPair( self._push_keypair = _helpers.KeyPair(
self.push_connection.private_key, self.push_connection.cert self.push_connection.credentials.private_key, self.push_connection.credentials.cert
) )
self.ec_key = self.rsa_key = None self.ec_key = self.rsa_key = None
def __str__(self): def __str__(self):
return f"IDSUser(user_id={self.user_id}, handles={self.handles}, push_token={b64encode(self.push_connection.token).decode()})" return f"IDSUser(user_id={self.user_id}, handles={self.handles}, push_token={b64encode(self.push_connection.credentials.token).decode()})"
# Authenticates with a username and password, to create a brand new authentication keypair # Authenticates with a username and password, to create a brand new authentication keypair
def authenticate( def authenticate(
self, username: str, password: str, factor_callback: callable = None self, username: str, password: str, factor_callback: Callable | None = None
): ):
self._authenticate_for_token(username, password, factor_callback) self._authenticate_for_token(username, password, factor_callback)
self._authenticate_for_cert() self._authenticate_for_cert()
self.handles = profile.get_handles( self.handles = profile.get_handles(
b64encode(self.push_connection.token), b64encode(self.push_connection.credentials.token),
self.user_id, self.user_id,
self._auth_keypair, self._auth_keypair,
self._push_keypair, self._push_keypair,
@ -68,7 +68,7 @@ class IDSUser:
cert = identity.register( cert = identity.register(
b64encode(self.push_connection.token), b64encode(self.push_connection.credentials.token),
self.handles, self.handles,
self.user_id, self.user_id,
self._auth_keypair, self._auth_keypair,
@ -81,6 +81,6 @@ class IDSUser:
def restore_identity(self, id_keypair: _helpers.KeyPair): def restore_identity(self, id_keypair: _helpers.KeyPair):
self._id_keypair = id_keypair self._id_keypair = id_keypair
def lookup(self, uris: list[str], topic: str = "com.apple.madrid") -> any: async def lookup(self, uris: list[str], topic: str = "com.apple.madrid") -> Any:
return query.lookup(self.push_connection, self.current_handle, self._id_keypair, uris, topic) return await query.lookup(self.push_connection, self.current_handle, self._id_keypair, uris, topic)

View file

@ -17,27 +17,28 @@ class IDSIdentity:
def __init__(self, signing_key: str | None = None, encryption_key: str | None = None, signing_public_key: str | None = None, encryption_public_key: str | None = None): def __init__(self, signing_key: str | None = None, encryption_key: str | None = None, signing_public_key: str | None = None, encryption_public_key: str | None = None):
if signing_key is not None: if signing_key is not None:
self.signing_key = signing_key self.signing_key = signing_key
self.signing_public_key = serialize_key(parse_key(signing_key).public_key()) self.signing_public_key = serialize_key(parse_key(signing_key).public_key())# type: ignore
elif signing_public_key is not None: elif signing_public_key is not None:
self.signing_key = None self.signing_key = None
self.signing_public_key = signing_public_key self.signing_public_key = signing_public_key
else: else:
# Generate a new key # Generate a new key
self.signing_key = serialize_key(ec.generate_private_key(ec.SECP256R1())) self.signing_key = serialize_key(ec.generate_private_key(ec.SECP256R1()))
self.signing_public_key = serialize_key(parse_key(self.signing_key).public_key()) self.signing_public_key = serialize_key(parse_key(self.signing_key).public_key())# type: ignore
if encryption_key is not None: if encryption_key is not None:
self.encryption_key = encryption_key self.encryption_key = encryption_key
self.encryption_public_key = serialize_key(parse_key(encryption_key).public_key()) self.encryption_public_key = serialize_key(parse_key(encryption_key).public_key())# type: ignore
elif encryption_public_key is not None: elif encryption_public_key is not None:
self.encryption_key = None self.encryption_key = None
self.encryption_public_key = encryption_public_key self.encryption_public_key = encryption_public_key
else: else:
self.encryption_key = serialize_key(rsa.generate_private_key(65537, 1280)) self.encryption_key = serialize_key(rsa.generate_private_key(65537, 1280))
self.encryption_public_key = serialize_key(parse_key(self.encryption_key).public_key()) self.encryption_public_key = serialize_key(parse_key(self.encryption_key).public_key())# type: ignore
def decode(input: bytes) -> 'IDSIdentity': @staticmethod
input = BytesIO(input) def decode(inp: bytes) -> 'IDSIdentity':
input = BytesIO(inp)
assert input.read(5) == b'\x30\x81\xF6\x81\x43' # DER header assert input.read(5) == b'\x30\x81\xF6\x81\x43' # DER header
raw_ecdsa = input.read(67) raw_ecdsa = input.read(67)
@ -75,13 +76,13 @@ class IDSIdentity:
raw_rsa.write(b'\x00\xAC') raw_rsa.write(b'\x00\xAC')
raw_rsa.write(b'\x30\x81\xA9') raw_rsa.write(b'\x30\x81\xA9')
raw_rsa.write(b'\x02\x81\xA1') raw_rsa.write(b'\x02\x81\xA1')
raw_rsa.write(parse_key(self.encryption_public_key).public_numbers().n.to_bytes(161, "big")) raw_rsa.write(parse_key(self.encryption_public_key).public_numbers().n.to_bytes(161, "big")) # type: ignore
raw_rsa.write(b'\x02\x03\x01\x00\x01') # Hardcode the exponent raw_rsa.write(b'\x02\x03\x01\x00\x01') # Hardcode the exponent
output.write(b'\x30\x81\xF6\x81\x43') output.write(b'\x30\x81\xF6\x81\x43')
output.write(b'\x00\x41\x04') output.write(b'\x00\x41\x04')
output.write(parse_key(self.signing_public_key).public_numbers().x.to_bytes(32, "big")) output.write(parse_key(self.signing_public_key).public_numbers().x.to_bytes(32, "big"))# type: ignore
output.write(parse_key(self.signing_public_key).public_numbers().y.to_bytes(32, "big")) output.write(parse_key(self.signing_public_key).public_numbers().y.to_bytes(32, "big"))# type: ignore
output.write(b'\x82\x81\xAE') output.write(b'\x82\x81\xAE')
output.write(raw_rsa.getvalue()) output.write(raw_rsa.getvalue())

View file

@ -18,8 +18,10 @@ from ._helpers import PROTOCOL_VERSION, USER_AGENT, KeyPair
import logging import logging
logger = logging.getLogger("ids") logger = logging.getLogger("ids")
from typing import Any, Callable
def _auth_token_request(username: str, password: str) -> any:
def _auth_token_request(username: str, password: str) -> Any:
# Turn the PET into an auth token # Turn the PET into an auth token
data = { data = {
"username": username, "username": username,
@ -46,7 +48,7 @@ def _auth_token_request(username: str, password: str) -> any:
# If factor_gen is not None, it will be called to get the 2FA code, otherwise it will be prompted # If factor_gen is not None, it will be called to get the 2FA code, otherwise it will be prompted
# Returns (realm user id, auth token) # Returns (realm user id, auth token)
def get_auth_token( def get_auth_token(
username: str, password: str, factor_gen: callable = None username: str, password: str, factor_gen: Callable | None = None
) -> tuple[str, str]: ) -> tuple[str, str]:
from sys import platform from sys import platform
@ -154,7 +156,7 @@ def get_handles(push_token, user_id: str, auth_key: KeyPair, push_key: KeyPair):
"x-auth-user-id": user_id, "x-auth-user-id": user_id,
} }
signing.add_auth_signature( signing.add_auth_signature(
headers, None, BAG_KEY, auth_key, push_key, push_token headers, b"", BAG_KEY, auth_key, push_key, push_token
) )
r = requests.get( r = requests.get(

View file

@ -50,7 +50,7 @@ class MMCSFile(AttachmentFile):
logger.info( logger.info(
requests.get( requests.get(
url=self.url, url=self.url, # type: ignore
headers={ headers={
"User-Agent": f"IMTransferAgent/900 CFNetwork/596.2.3 Darwin/12.2.0 (x86_64) (Macmini5,1)", "User-Agent": f"IMTransferAgent/900 CFNetwork/596.2.3 Darwin/12.2.0 (x86_64) (Macmini5,1)",
# "MMCS-Url": self.url, # "MMCS-Url": self.url,
@ -79,8 +79,8 @@ class Attachment:
def __init__(self, message_raw_content: dict, xml_element: ElementTree.Element): def __init__(self, message_raw_content: dict, xml_element: ElementTree.Element):
attrib = xml_element.attrib attrib = xml_element.attrib
self.name = attrib["name"] if "name" in attrib else None self.name = attrib["name"] if "name" in attrib else None # type: ignore
self.mime_type = attrib["mime-type"] if "mime-type" in attrib else None self.mime_type = attrib["mime-type"] if "mime-type" in attrib else None # type: ignore
if "inline-attachment" in attrib: if "inline-attachment" in attrib:
# just grab the inline attachment ! # just grab the inline attachment !
@ -121,7 +121,7 @@ class Attachment:
# case "decryption-key": # case "decryption-key":
# versions[index].decryption_key = base64.b16decode(val)[1:] # versions[index].decryption_key = base64.b16decode(val)[1:]
self.versions = versions self.versions = versions # type: ignore
def __repr__(self): def __repr__(self):
return f'<Attachment name="{self.name}" type="{self.mime_type}">' return f'<Attachment name="{self.name}" type="{self.mime_type}">'
@ -136,16 +136,23 @@ class Message:
_compressed: bool = True _compressed: bool = True
xml: str | None = None xml: str | None = None
@staticmethod
def from_raw(message: bytes, sender: str | None = None) -> "Message": def from_raw(message: bytes, sender: str | None = None) -> "Message":
"""Create a `Message` from raw message bytes""" """Create a `Message` from raw message bytes"""
raise NotImplementedError() raise NotImplementedError()
def __str__(): def to_raw(self) -> bytes:
"""Convert a `Message` to raw message bytes"""
raise NotImplementedError()
def __str__(self):
raise NotImplementedError() raise NotImplementedError()
@dataclass @dataclass
class SMSReflectedMessage(Message): class SMSReflectedMessage(Message):
@staticmethod
def from_raw(message: bytes, sender: str | None = None) -> "SMSReflectedMessage": def from_raw(message: bytes, sender: str | None = None) -> "SMSReflectedMessage":
"""Create a `SMSReflectedMessage` from raw message bytes""" """Create a `SMSReflectedMessage` from raw message bytes"""
@ -161,11 +168,11 @@ class SMSReflectedMessage(Message):
logger.info(f"Decoding SMSReflectedMessage: {message}") logger.info(f"Decoding SMSReflectedMessage: {message}")
return SMSReflectedMessage( return SMSReflectedMessage(
text=message["mD"]["plain-body"], text=message["mD"]["plain-body"], # type: ignore
sender=sender, sender=sender, # type: ignore
participants=[re["id"] for re in message["re"]] + [sender], participants=[re["id"] for re in message["re"]] + [sender], # type: ignore
id=uuid.UUID(message["mD"]["guid"]), id=uuid.UUID(message["mD"]["guid"]), # type: ignore
_raw=message, _raw=message, # type: ignore
_compressed=compressed, _compressed=compressed,
) )
@ -209,6 +216,7 @@ class SMSReflectedMessage(Message):
@dataclass @dataclass
class SMSIncomingMessage(Message): class SMSIncomingMessage(Message):
@staticmethod
def from_raw(message: bytes, sender: str | None = None) -> "SMSIncomingMessage": def from_raw(message: bytes, sender: str | None = None) -> "SMSIncomingMessage":
"""Create a `SMSIncomingMessage` from raw message bytes""" """Create a `SMSIncomingMessage` from raw message bytes"""
@ -224,11 +232,11 @@ class SMSIncomingMessage(Message):
logger.debug(f"Decoding SMSIncomingMessage: {message}") logger.debug(f"Decoding SMSIncomingMessage: {message}")
return SMSIncomingMessage( return SMSIncomingMessage(
text=message["k"][0]["data"].decode(), text=message["k"][0]["data"].decode(), # type: ignore
sender=message["h"], # Don't use sender parameter, that is the phone that forwarded the message sender=message["h"], # Don't use sender parameter, that is the phone that forwarded the message # type: ignore
participants=[message["h"], message["co"]], participants=[message["h"], message["co"]], # type: ignore
id=uuid.UUID(message["g"]), id=uuid.UUID(message["g"]), # type: ignore
_raw=message, _raw=message, # type: ignore
_compressed=compressed, _compressed=compressed,
) )
@ -237,16 +245,18 @@ class SMSIncomingMessage(Message):
@dataclass @dataclass
class SMSIncomingImage(Message): class SMSIncomingImage(Message):
@staticmethod
def from_raw(message: bytes, sender: str | None = None) -> "SMSIncomingImage": def from_raw(message: bytes, sender: str | None = None) -> "SMSIncomingImage":
"""Create a `SMSIncomingImage` from raw message bytes""" """Create a `SMSIncomingImage` from raw message bytes"""
# TODO: Implement this # TODO: Implement this
return "SMSIncomingImage" return "SMSIncomingImage" # type: ignore
@dataclass @dataclass
class iMessage(Message): class iMessage(Message):
effect: str | None = None effect: str | None = None
@staticmethod
def create(user: "iMessageUser", text: str, participants: list[str]) -> "iMessage": def create(user: "iMessageUser", text: str, participants: list[str]) -> "iMessage":
"""Creates a basic outgoing `iMessage` from the given text and participants""" """Creates a basic outgoing `iMessage` from the given text and participants"""
@ -260,6 +270,7 @@ class iMessage(Message):
id=uuid.uuid4(), id=uuid.uuid4(),
) )
@staticmethod
def from_raw(message: bytes, sender: str | None = None) -> "iMessage": def from_raw(message: bytes, sender: str | None = None) -> "iMessage":
"""Create a `iMessage` from raw message bytes""" """Create a `iMessage` from raw message bytes"""
@ -275,14 +286,14 @@ class iMessage(Message):
logger.debug(f"Decoding iMessage: {message}") logger.debug(f"Decoding iMessage: {message}")
return iMessage( return iMessage(
text=message["t"], text=message["t"], # type: ignore
participants=message["p"], participants=message["p"], # type: ignore
sender=sender, sender=sender, # type: ignore
id=uuid.UUID(message["r"]) if "r" in message else None, id=uuid.UUID(message["r"]) if "r" in message else None, # type: ignore
xml=message["x"] if "x" in message else None, xml=message["x"] if "x" in message else None, # type: ignore
_raw=message, _raw=message, # type: ignore
_compressed=compressed, _compressed=compressed,
effect=message["iid"] if "iid" in message else None, effect=message["iid"] if "iid" in message else None, # type: ignore
) )
def to_raw(self) -> bytes: def to_raw(self) -> bytes:
@ -338,8 +349,9 @@ class iMessageUser:
self.connection = connection self.connection = connection
self.user = user self.user = user
def _parse_payload(payload: bytes) -> tuple[bytes, bytes]: @staticmethod
payload = BytesIO(payload) def _parse_payload(p: bytes) -> tuple[bytes, bytes]:
payload = BytesIO(p)
tag = payload.read(1) tag = payload.read(1)
# print("TAG", tag) # print("TAG", tag)
@ -351,6 +363,7 @@ class iMessageUser:
return (body, signature) return (body, signature)
@staticmethod
def _construct_payload(body: bytes, signature: bytes) -> bytes: def _construct_payload(body: bytes, signature: bytes) -> bytes:
payload = ( payload = (
b"\x02" b"\x02"
@ -361,6 +374,7 @@ class iMessageUser:
) )
return payload return payload
@staticmethod
def _hash_identity(id: bytes) -> bytes: def _hash_identity(id: bytes) -> bytes:
iden = ids.identity.IDSIdentity.decode(id) iden = ids.identity.IDSIdentity.decode(id)
@ -369,13 +383,11 @@ class iMessageUser:
output.write(b"\x00\x41\x04") output.write(b"\x00\x41\x04")
output.write( output.write(
ids._helpers.parse_key(iden.signing_public_key) ids._helpers.parse_key(iden.signing_public_key)
.public_numbers() .public_numbers().x.to_bytes(32, "big") # type: ignore
.x.to_bytes(32, "big")
) )
output.write( output.write(
ids._helpers.parse_key(iden.signing_public_key) ids._helpers.parse_key(iden.signing_public_key)
.public_numbers() .public_numbers().y.to_bytes(32, "big") # type: ignore
.y.to_bytes(32, "big")
) )
output.write(b"\x00\xAC") output.write(b"\x00\xAC")
@ -383,8 +395,7 @@ class iMessageUser:
output.write(b"\x02\x81\xA1") output.write(b"\x02\x81\xA1")
output.write( output.write(
ids._helpers.parse_key(iden.encryption_public_key) ids._helpers.parse_key(iden.encryption_public_key)
.public_numbers() .public_numbers().n.to_bytes(161, "big") # type: ignore
.n.to_bytes(161, "big")
) )
output.write(b"\x02\x03\x01\x00\x01") output.write(b"\x02\x03\x01\x00\x01")
@ -417,7 +428,7 @@ class iMessageUser:
# Encrypt the AES key with the public key of the recipient # Encrypt the AES key with the public key of the recipient
recipient_key = ids._helpers.parse_key(key.encryption_public_key) recipient_key = ids._helpers.parse_key(key.encryption_public_key)
rsa_body = recipient_key.encrypt( rsa_body = recipient_key.encrypt( # type: ignore
aes_key + encrypted[:100], aes_key + encrypted[:100],
padding.OAEP( padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA1()), mgf=padding.MGF1(algorithm=hashes.SHA1()),
@ -428,20 +439,20 @@ class iMessageUser:
# Construct the payload # Construct the payload
body = rsa_body + encrypted[100:] body = rsa_body + encrypted[100:]
sig = ids._helpers.parse_key(self.user.encryption_identity.signing_key).sign( sig = ids._helpers.parse_key(self.user.encryption_identity.signing_key).sign( # type: ignore
body, ec.ECDSA(hashes.SHA1()) body, ec.ECDSA(hashes.SHA1()) # type: ignore
) )
payload = iMessageUser._construct_payload(body, sig) payload = iMessageUser._construct_payload(body, sig)
return payload return payload
def _decrypt_payload(self, payload: bytes) -> dict: def _decrypt_payload(self, p: bytes) -> bytes:
payload = iMessageUser._parse_payload(payload) payload = iMessageUser._parse_payload(p)
body = BytesIO(payload[0]) body = BytesIO(payload[0])
rsa_body = ids._helpers.parse_key( rsa_body = ids._helpers.parse_key(
self.user.encryption_identity.encryption_key self.user.encryption_identity.encryption_key # type: ignore
).decrypt( ).decrypt( # type: ignore
body.read(160), body.read(160),
padding.OAEP( padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA1()), mgf=padding.MGF1(algorithm=hashes.SHA1()),
@ -455,9 +466,9 @@ class iMessageUser:
return decrypted return decrypted
def _verify_payload(self, payload: bytes, sender: str, sender_token: str) -> bool: async def _verify_payload(self, p: bytes, sender: str, sender_token: str) -> bool:
# Get the public key for the sender # Get the public key for the sender
self._cache_keys([sender], "com.apple.madrid") await self._cache_keys([sender], "com.apple.madrid")
if not sender_token in self.KEY_CACHE: if not sender_token in self.KEY_CACHE:
logger.warning("Unable to find the public key of the sender, cannot verify") logger.warning("Unable to find the public key of the sender, cannot verify")
@ -468,25 +479,25 @@ class iMessageUser:
) )
sender_ec_key = ids._helpers.parse_key(identity_keys.signing_public_key) sender_ec_key = ids._helpers.parse_key(identity_keys.signing_public_key)
payload = iMessageUser._parse_payload(payload) payload = iMessageUser._parse_payload(p)
try: try:
# Verify the signature (will throw an exception if it fails) # Verify the signature (will throw an exception if it fails)
sender_ec_key.verify( sender_ec_key.verify( # type: ignore
payload[1], payload[1],
payload[0], payload[0],
ec.ECDSA(hashes.SHA1()), ec.ECDSA(hashes.SHA1()), # type: ignore
) )
return True return True
except: except:
return False return False
def receive(self) -> Message | None: async def receive(self) -> Message | None:
""" """
Will return the next iMessage in the queue, or None if there are no messages Will return the next iMessage in the queue, or None if there are no messages
""" """
for type, (topic, cls) in MESSAGE_TYPES.items(): for type, (topic, cls) in MESSAGE_TYPES.items():
body = self._receive_raw(type, topic) body = await self._receive_raw(type, topic)
if body is not None: if body is not None:
t = cls t = cls
break break
@ -494,7 +505,7 @@ class iMessageUser:
return None return None
if not self._verify_payload(body["P"], body["sP"], body["t"]): if not await self._verify_payload(body["P"], body["sP"], body["t"]):
raise Exception("Failed to verify payload") raise Exception("Failed to verify payload")
logger.debug(f"Encrypted body : {body}") logger.debug(f"Encrypted body : {body}")
@ -513,7 +524,7 @@ class iMessageUser:
USER_CACHE: dict[str, list[bytes]] = {} USER_CACHE: dict[str, list[bytes]] = {}
"""Mapping of handle : [push tokens]""" """Mapping of handle : [push tokens]"""
def _cache_keys(self, participants: list[str], topic: str): async def _cache_keys(self, participants: list[str], topic: str):
# Clear the cache if the handle has changed # Clear the cache if the handle has changed
if self.KEY_CACHE_HANDLE != self.user.current_handle: if self.KEY_CACHE_HANDLE != self.user.current_handle:
self.KEY_CACHE_HANDLE = self.user.current_handle self.KEY_CACHE_HANDLE = self.user.current_handle
@ -526,7 +537,7 @@ class iMessageUser:
# TODO: This doesn't work since it doesn't check if they are cached for all topics # TODO: This doesn't work since it doesn't check if they are cached for all topics
# Look up the public keys for the participants, and cache a token : public key mapping # Look up the public keys for the participants, and cache a token : public key mapping
lookup = self.user.lookup(participants, topic=topic) lookup = await self.user.lookup(participants, topic=topic)
logger.debug(f"Lookup response : {lookup}") logger.debug(f"Lookup response : {lookup}")
for key, participant in lookup.items(): for key, participant in lookup.items():
@ -559,7 +570,7 @@ class iMessageUser:
identity["session-token"], identity["session-token"],
) )
def _send_raw( async def _send_raw(
self, self,
type: int, type: int,
participants: list[str], participants: list[str],
@ -568,12 +579,12 @@ class iMessageUser:
id: uuid.UUID | None = None, id: uuid.UUID | None = None,
extra: dict = {}, extra: dict = {},
): ):
self._cache_keys(participants, topic) await self._cache_keys(participants, topic)
dtl = [] dtl = []
for participant in participants: for participant in participants:
for push_token in self.USER_CACHE[participant]: for push_token in self.USER_CACHE[participant]:
if push_token == self.connection.token: if push_token == self.connection.credentials.token:
continue # Don't send to ourselves continue # Don't send to ourselves
identity_keys = ids.identity.IDSIdentity.decode( identity_keys = ids.identity.IDSIdentity.decode(
@ -613,52 +624,35 @@ class iMessageUser:
body = plistlib.dumps(body, fmt=plistlib.FMT_BINARY) body = plistlib.dumps(body, fmt=plistlib.FMT_BINARY)
self.connection.send_message(topic, body, message_id) await self.connection.send_notification(topic, body, message_id)
def _receive_raw(self, c: int | list[int], topic: str | list[str]) -> dict | None: async def _receive_raw(self, c: int, topic: str) -> dict | None:
def check_response(x): def check(payload: apns.APNSPayload):
if x[0] != 0x0A: # Check if the "c" key matches
body = payload.fields_with_id(3)[0].value
if body is None:
return False return False
# Check if it matches any of the topics body = plistlib.loads(body)
if isinstance(topic, list): if not "c" in body or "c" != c:
for t in topic:
if apns._get_field(x[1], 2) == sha1(t.encode()).digest():
break
else:
return False
else:
if apns._get_field(x[1], 2) != sha1(topic.encode()).digest():
return False
resp_body = apns._get_field(x[1], 3)
if resp_body is None:
return False
resp_body = plistlib.loads(resp_body)
#logger.info(f"See type {resp_body['c']}")
if isinstance(c, list):
if not resp_body["c"] in c:
return False
elif resp_body["c"] != c:
return False return False
return True return True
payload = self.connection.incoming_queue.pop_find(check_response) payload = await self.connection.expect_notification(topic, check)
if payload is None: if payload is None:
return None return None
body = apns._get_field(payload[1], 3) body = payload.fields_with_id(3)[0].value
body = plistlib.loads(body) body = plistlib.loads(body)
return body return body
def activate_sms(self) -> bool: async def activate_sms(self):
""" """
Try to activate SMS forwarding Try to activate SMS forwarding
Returns True if we are able to perform SMS forwarding, False otherwise Returns True if we are able to perform SMS forwarding, False otherwise
Call repeatedly until it returns True Call repeatedly until it returns True
""" """
act_message = self._receive_raw(145, "com.apple.private.alloy.sms") act_message = await self._receive_raw(145, "com.apple.private.alloy.sms")
if act_message is None: if act_message is None:
return False return False
@ -672,7 +666,7 @@ class iMessageUser:
else: else:
logger.info("SMS forwarding de-activated, sending response") logger.info("SMS forwarding de-activated, sending response")
self._send_raw( await self._send_raw(
147, 147,
[self.user.current_handle], [self.user.current_handle],
"com.apple.private.alloy.sms", "com.apple.private.alloy.sms",
@ -681,7 +675,7 @@ class iMessageUser:
} }
) )
def send(self, message: Message): async def send(self, message: Message):
# Check what type of message we are sending # Check what type of message we are sending
for t, (topic, cls) in MESSAGE_TYPES.items(): for t, (topic, cls) in MESSAGE_TYPES.items():
if isinstance(message, cls): if isinstance(message, cls):
@ -691,9 +685,9 @@ class iMessageUser:
send_to = message.participants if isinstance(message, iMessage) else [self.user.current_handle] send_to = message.participants if isinstance(message, iMessage) else [self.user.current_handle]
self._cache_keys(send_to, topic) await self._cache_keys(send_to, topic)
self._send_raw( await self._send_raw(
t, t,
send_to, send_to,
topic, topic,
@ -713,12 +707,12 @@ class iMessageUser:
for p in send_to: for p in send_to:
for t in self.USER_CACHE[p]: for t in self.USER_CACHE[p]:
if t == self.connection.token: if t == self.connection.credentials.token:
continue continue
total += 1 total += 1
while count < total and time.time() - start < 2: while count < total and time.time() - start < 2:
resp = self._receive_raw(255, topic) resp = await self._receive_raw(255, topic)
if resp is None: if resp is None:
continue continue
count += 1 count += 1