- Remove unused imports

- extract idsuser authentication functionality to separate function so that it can be reused
- add more typing hints to make lsp happier
- access dictionary values more safely with walrus operator
- simplify some list comprehension and iteration
- print proxy errors in more detail so they're easier to debug
- store apnsconnection in proxy so that we can use it to make a user and decrypt payloads if needed
This commit is contained in:
itsjunetime 2023-08-21 21:10:04 -06:00
parent 5d7fab9cdd
commit f80acd2e09
11 changed files with 124 additions and 117 deletions

38
demo.py
View file

@ -1,10 +1,6 @@
import json import json
import logging import logging
import os
import threading
import time
from base64 import b64decode, b64encode from base64 import b64decode, b64encode
from getpass import getpass
from subprocess import PIPE, Popen from subprocess import PIPE, Popen
from rich.logging import RichHandler from rich.logging import RichHandler
@ -74,37 +70,7 @@ async def main():
await conn.filter(["com.apple.madrid"]) await conn.filter(["com.apple.madrid"])
user = ids.IDSUser(conn) user = ids.IDSUser(conn)
user.auth_and_set_encryption_from_config(CONFIG)
if CONFIG.get("auth", {}).get("cert") is not None:
auth_keypair = ids._helpers.KeyPair(CONFIG["auth"]["key"], CONFIG["auth"]["cert"])
user_id = CONFIG["auth"]["user_id"]
handles = CONFIG["auth"]["handles"]
user.restore_authentication(auth_keypair, user_id, handles)
else:
username = input("Username: ")
password = getpass("Password: ")
user.authenticate(username, password)
user.encryption_identity = ids.identity.IDSIdentity(
encryption_key=CONFIG.get("encryption", {}).get("rsa_key"),
signing_key=CONFIG.get("encryption", {}).get("ec_key"),
)
if (
CONFIG.get("id", {}).get("cert") is not None
and user.encryption_identity is not None
):
id_keypair = ids._helpers.KeyPair(CONFIG["id"]["key"], CONFIG["id"]["cert"])
user.restore_identity(id_keypair)
else:
logging.info("Registering new identity...")
import emulated.nac
vd = emulated.nac.generate_validation_data()
vd = b64encode(vd).decode()
user.register(vd)
# Write config.json # Write config.json
CONFIG["encryption"] = { CONFIG["encryption"] = {
@ -150,4 +116,4 @@ async def output_task(im: imessage.iMessageUser):
if __name__ == "__main__": if __name__ == "__main__":
trio.run(main) trio.run(main)

View file

@ -353,4 +353,4 @@ def c_string(bytes, start: int = 0) -> str:
#print(start) #print(start)
#print(chr(bytes[i])) #print(chr(bytes[i]))
i += 1 i += 1
return out return out

View file

@ -23,12 +23,12 @@ from datetime import datetime
from json import dump from json import dump
from math import exp, log from math import exp, log
from os import SEEK_END from os import SEEK_END
from re import split
from struct import unpack from struct import unpack
from uuid import UUID from uuid import UUID
from typing import Any
#from asn1crypto.cms import ContentInfo from asn1crypto.cms import ContentInfo
#from asn1crypto.x509 import DirectoryString from asn1crypto.x509 import DirectoryString
from plistlib import loads from plistlib import loads
#import mdictionary as mdictionary #import mdictionary as mdictionary
@ -52,7 +52,7 @@ class Parser():
self.__is_64_bit = True # default place-holder self.__is_64_bit = True # default place-holder
self.__is_little_endian = True # ^^ self.__is_little_endian = True # ^^
self.__macho = {} self.__macho = {}
self.__output = { self.__output: dict[str, Any] = {
'name': 'IMDAppleServices' 'name': 'IMDAppleServices'
} }
@ -931,7 +931,7 @@ class Parser():
n_value = self.get_ll() if self.__is_64_bit else self.get_int() n_value = self.get_ll() if self.__is_64_bit else self.get_int()
symbol = { symbol: dict[str, int | str] = {
'n_strx': n_strx, 'n_strx': n_strx,
'n_sect': n_sect, 'n_sect': n_sect,
'n_desc': n_desc, 'n_desc': n_desc,
@ -2298,4 +2298,4 @@ class mdictionary:
2147483648 + 11: 'POWERPC_7450 (LIB64)', 2147483648 + 11: 'POWERPC_7450 (LIB64)',
2147483648 + 100: 'POWERPC_970 (LIB64)' 2147483648 + 100: 'POWERPC_970 (LIB64)'
} }
} }

View file

@ -1,4 +1,6 @@
from base64 import b64encode from base64 import b64encode
from getpass import getpass
import logging
import apns import apns
@ -27,6 +29,9 @@ class IDSUser:
self._push_keypair = _helpers.KeyPair( self._push_keypair = _helpers.KeyPair(
self.push_connection.credentials.private_key, self.push_connection.credentials.cert self.push_connection.credentials.private_key, self.push_connection.credentials.cert
) )
# set the encryption_identity to a default randomized value so that
# it's still valid if we can't pull it from the config
self.encryption_identity: identity.IDSIdentity = identity.IDSIdentity()
self.ec_key = self.rsa_key = None self.ec_key = self.rsa_key = None
@ -63,10 +68,6 @@ class IDSUser:
self.ec_key, self.rsa_key will be set to a randomly gnenerated EC and RSA keypair self.ec_key, self.rsa_key will be set to a randomly gnenerated EC and RSA keypair
if they are not already set if they are not already set
""" """
if self.encryption_identity is None:
self.encryption_identity = identity.IDSIdentity()
cert = identity.register( cert = identity.register(
b64encode(self.push_connection.credentials.token), b64encode(self.push_connection.credentials.token),
self.handles, self.handles,
@ -81,6 +82,48 @@ class IDSUser:
def restore_identity(self, id_keypair: _helpers.KeyPair): def restore_identity(self, id_keypair: _helpers.KeyPair):
self._id_keypair = id_keypair self._id_keypair = id_keypair
def auth_and_set_encryption_from_config(self, config: dict[str, dict[str, Any]]):
auth = config.get("auth", {})
if (
((key := auth.get("key")) is not None) and
((cert := auth.get("cert")) is not None) and
((user_id := auth.get("user_id")) is not None) and
((handles := auth.get("handles")) is not None)
):
auth_keypair = _helpers.KeyPair(key, cert)
self.restore_authentication(auth_keypair, user_id, handles)
else:
username = input("Username: ")
password = getpass("Password: ")
self.authenticate(username, password)
encryption: dict[str, str] = config.get("encryption", {})
id: dict[str, str] = config.get("id", {})
if (
(rsa_key := encryption.get("rsa_key")) and
(signing_key := encryption.get("ec_key")) and
(cert := id.get("cert")) and
(key := id.get("key"))
):
self.encryption_identity = identity.IDSIdentity(
encryption_key=rsa_key,
signing_key=signing_key,
)
id_keypair = _helpers.KeyPair(key, cert)
self.restore_identity(id_keypair)
else:
logging.info("Registering new identity...")
import emulated.nac
vd = emulated.nac.generate_validation_data()
vd = b64encode(vd).decode()
self.register(vd)
async def lookup(self, uris: list[str], topic: str = "com.apple.madrid") -> Any: async def lookup(self, uris: list[str], topic: str = "com.apple.madrid") -> Any:
return await query.lookup(self.push_connection, self.current_handle, self._id_keypair, uris, topic) return await query.lookup(self.push_connection, self.current_handle, self._id_keypair, uris, topic)

View file

@ -3,18 +3,26 @@ from base64 import b64decode
import requests import requests
from ._helpers import PROTOCOL_VERSION, USER_AGENT, KeyPair, parse_key, serialize_key from ._helpers import PROTOCOL_VERSION, KeyPair, parse_key, serialize_key
from .signing import add_auth_signature, armour_cert from .signing import add_auth_signature, armour_cert
from io import BytesIO from io import BytesIO
from cryptography.hazmat.primitives.asymmetric import ec, rsa from cryptography.hazmat.primitives.asymmetric import ec, rsa
from typing import Self
import logging import logging
logger = logging.getLogger("ids") logger = logging.getLogger("ids")
class IDSIdentity: class IDSIdentity:
def __init__(self, signing_key: str | None = None, encryption_key: str | None = None, signing_public_key: str | None = None, encryption_public_key: str | None = None): def __init__(
self,
signing_key: str | None = None,
encryption_key: str | None = None,
signing_public_key: str | None = None,
encryption_public_key: str | None = None
):
if signing_key is not None: if signing_key is not None:
self.signing_key = signing_key self.signing_key = signing_key
self.signing_public_key = serialize_key(parse_key(signing_key).public_key())# type: ignore self.signing_public_key = serialize_key(parse_key(signing_key).public_key())# type: ignore
@ -36,8 +44,8 @@ class IDSIdentity:
self.encryption_key = serialize_key(rsa.generate_private_key(65537, 1280)) self.encryption_key = serialize_key(rsa.generate_private_key(65537, 1280))
self.encryption_public_key = serialize_key(parse_key(self.encryption_key).public_key())# type: ignore self.encryption_public_key = serialize_key(parse_key(self.encryption_key).public_key())# type: ignore
@staticmethod @classmethod
def decode(inp: bytes) -> 'IDSIdentity': def decode(cls, inp: bytes) -> Self:
input = BytesIO(inp) input = BytesIO(inp)
assert input.read(5) == b'\x30\x81\xF6\x81\x43' # DER header assert input.read(5) == b'\x30\x81\xF6\x81\x43' # DER header

View file

@ -7,13 +7,13 @@ import requests
from cryptography import x509 from cryptography import x509
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID from cryptography.x509.oid import NameOID
import bags import bags
from . import signing from . import signing
from ._helpers import PROTOCOL_VERSION, USER_AGENT, KeyPair from ._helpers import PROTOCOL_VERSION, KeyPair
import logging import logging
logger = logging.getLogger("ids") logger = logging.getLogger("ids")
@ -50,8 +50,6 @@ def _auth_token_request(username: str, password: str) -> Any:
def get_auth_token( def get_auth_token(
username: str, password: str, factor_gen: Callable | None = None username: str, password: str, factor_gen: Callable | None = None
) -> tuple[str, str]: ) -> tuple[str, str]:
from sys import platform
result = _auth_token_request(username, password) result = _auth_token_request(username, password)
if result["status"] != 0: if result["status"] != 0:
if result["status"] == 5000: if result["status"] == 5000:

View file

@ -5,7 +5,6 @@ from base64 import b64encode
import apns import apns
import bags import bags
import logging
from ._helpers import KeyPair, PROTOCOL_VERSION from ._helpers import KeyPair, PROTOCOL_VERSION
from . import signing from . import signing

View file

@ -5,8 +5,7 @@ from datetime import datetime
from cryptography import x509 from cryptography import x509
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.x509.oid import NameOID
from ._helpers import KeyPair, dearmour from ._helpers import KeyPair, dearmour
@ -24,8 +23,6 @@ Generates a nonce in this format:
000001876d008cc5 # unix time 000001876d008cc5 # unix time
r1r2r3r4r5r6r7r8 # random bytes r1r2r3r4r5r6r7r8 # random bytes
""" """
def generate_nonce() -> bytes: def generate_nonce() -> bytes:
return ( return (
b"\x01" b"\x01"
@ -33,17 +30,13 @@ def generate_nonce() -> bytes:
+ random.randbytes(8) + random.randbytes(8)
) )
import typing
# Creates a payload from individual parts for signing # Creates a payload from individual parts for signing
def _create_payload( def _create_payload(
bag_key: str, bag_key: str,
query_string: str, query_string: str,
push_token: typing.Union[str, bytes], push_token: str | bytes,
payload: bytes, payload: bytes,
nonce: typing.Union[bytes, None] = None, nonce: bytes | None = None,
) -> tuple[bytes, bytes]: ) -> tuple[bytes, bytes]:
# Generate the nonce # Generate the nonce
if nonce is None: if nonce is None:

View file

@ -4,9 +4,10 @@ import logging
import plistlib import plistlib
import random import random
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass
from hashlib import sha1, sha256 from hashlib import sha256
from io import BytesIO from io import BytesIO
from typing import Any
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec, padding from cryptography.hazmat.primitives.asymmetric import ec, padding
@ -504,15 +505,15 @@ class iMessageUser:
""" """
Will return the next iMessage in the queue, or None if there are no messages Will return the next iMessage in the queue, or None if there are no messages
""" """
body = await self._receive_raw([t for t, _ in MESSAGE_TYPES.items()], [t[0] for _, t in MESSAGE_TYPES.items()]) body: dict[str, Any] = await self._receive_raw(list(MESSAGE_TYPES.keys()), [t[0] for t in MESSAGE_TYPES.values()])
t = MESSAGE_TYPES[body["c"]][1] t: type[Message] = MESSAGE_TYPES[body["c"]][1]
if not await self._verify_payload(body["P"], body["sP"], body["t"]): if not await self._verify_payload(body["P"], body["sP"], body["t"]):
raise Exception("Failed to verify payload") raise Exception("Failed to verify payload")
logger.debug(f"Encrypted body : {body}") logger.debug(f"Encrypted body : {body}")
decrypted = self._decrypt_payload(body["P"]) decrypted: bytes = self._decrypt_payload(body["P"])
try: try:
return t.from_raw(decrypted, body["sP"]) return t.from_raw(decrypted, body["sP"])
@ -627,7 +628,7 @@ class iMessageUser:
await self.connection.send_notification(topic, body, message_id) await self.connection.send_notification(topic, body, message_id)
async def _receive_raw(self, c: int | list[int], topics: str | list[str]) -> dict: async def _receive_raw(self, c: int | list[int], topics: str | list[str]) -> dict[str, Any]:
def check(payload: apns.APNSPayload): def check(payload: apns.APNSPayload):
# Check if the "c" key matches # Check if the "c" key matches
body = payload.fields_with_id(3)[0].value body = payload.fields_with_id(3)[0].value
@ -644,8 +645,8 @@ class iMessageUser:
payload = await self.connection.expect_notification(topics, check) payload = await self.connection.expect_notification(topics, check)
body = payload.fields_with_id(3)[0].value body_bytes: bytes = payload.fields_with_id(3)[0].value
body = plistlib.loads(body) body: dict[str, Any] = plistlib.loads(body_bytes)
return body return body
async def activate_sms(self): async def activate_sms(self):
@ -655,14 +656,12 @@ class iMessageUser:
Call repeatedly until it returns True Call repeatedly until it returns True
""" """
act_message = await self._receive_raw(145, "com.apple.private.alloy.sms") act_message: dict[str, Any] = await self._receive_raw(145, "com.apple.private.alloy.sms")
if act_message is None:
return False
logger.info(f"Received SMS activation message : {act_message}") logger.info(f"Received SMS activation message : {act_message}")
# Decrypt the payload # Decrypt the payload
act_message = self._decrypt_payload(act_message["P"]) act_message_bytes: bytes = self._decrypt_payload(act_message["P"])
act_message = plistlib.loads(maybe_decompress(act_message)) act_message = plistlib.loads(maybe_decompress(act_message_bytes))
if act_message == {'wc': False, 'ar': True}: if act_message == {'wc': False, 'ar': True}:
logger.info("SMS forwarding activated, sending response") logger.info("SMS forwarding activated, sending response")
@ -715,7 +714,7 @@ class iMessageUser:
total += 1 total += 1
while count < total and time.time() - start < 2: while count < total and time.time() - start < 2:
resp = await self._receive_raw(255, topic) resp: dict[str, Any] = await self._receive_raw(255, topic)
#if resp is None: #if resp is None:
# continue # continue
count += 1 count += 1

File diff suppressed because one or more lines are too long

View file

@ -1,5 +1,6 @@
import os import os
import sys import sys
import traceback
# setting path so we can import the needed packages # setting path so we can import the needed packages
sys.path.append(os.path.join(sys.path[0], "../")) sys.path.append(os.path.join(sys.path[0], "../"))
@ -39,8 +40,10 @@ async def handle_proxy(stream: trio.SocketStream):
try: try:
p = APNSProxy(stream) p = APNSProxy(stream)
await p.start() await p.start()
except Exception as e: except Exception:
logging.error("APNSProxy instance encountered exception: " + str(e)) logging.error(f"APNSProxy instance encountered exception:")
traceback.print_exc()
#raise e #raise e
class APNSProxy: class APNSProxy:
@ -54,7 +57,7 @@ class APNSProxy:
try: try:
apns_server = apns.APNSConnection(nursery) apns_server = apns.APNSConnection(nursery)
await apns_server._connect_socket() await apns_server._connect_socket()
self.server = apns_server.sock self.connection = apns_server
nursery.start_soon(self.proxy, True) nursery.start_soon(self.proxy, True)
nursery.start_soon(self.proxy, False) nursery.start_soon(self.proxy, False)
@ -69,10 +72,11 @@ class APNSProxy:
async def proxy(self, to_server: bool): async def proxy(self, to_server: bool):
if to_server: if to_server:
from_stream = self.client from_stream = self.client
to_stream = self.server to_stream = self.connection.sock
else: else:
from_stream = self.server from_stream = self.connection.sock
to_stream = self.client to_stream = self.client
while True: while True:
payload = await apns.APNSPayload.read_from_stream(from_stream) payload = await apns.APNSPayload.read_from_stream(from_stream)
payload = self.tamper(payload, to_server) payload = self.tamper(payload, to_server)
@ -95,15 +99,15 @@ class APNSProxy:
def tamper_lookup_keys(self, payload: apns.APNSPayload) -> apns.APNSPayload: def tamper_lookup_keys(self, payload: apns.APNSPayload) -> apns.APNSPayload:
if payload.id == 0xA: # Notification if payload.id == 0xA: # Notification
if payload.fields_with_id(2)[0].value == sha1(b"com.apple.madrid").digest(): # Topic if payload.fields_with_id(2)[0].value == sha1(b"com.apple.madrid").digest(): # Topic
if body := payload.fields_with_id(3)[0].value is not None: if (body := payload.fields_with_id(3)[0].value) is not None:
body = plistlib.loads(body) body = plistlib.loads(body)
if body['c'] == 97: # Lookup response if body['c'] == 97: # Lookup response
resp = gzip.decompress(body["b"]) # HTTP body resp = gzip.decompress(body["b"]) # HTTP body
resp = plistlib.loads(resp) resp = plistlib.loads(resp)
# Replace public keys # Replace public keys
for r in resp["results"].keys(): for result in resp["results"].values():
for identity in resp["results"][r]["identities"]: for identity in result["identities"]:
if "client-data" in identity: if "client-data" in identity:
identity["client-data"]["public-message-identity-key"] = b"REDACTED" identity["client-data"]["public-message-identity-key"] = b"REDACTED"
@ -117,4 +121,4 @@ class APNSProxy:
return payload return payload
if __name__ == "__main__": if __name__ == "__main__":
trio.run(main) trio.run(main)