diff --git a/ids/query.py b/ids/query.py index 0737481..63e2527 100644 --- a/ids/query.py +++ b/ids/query.py @@ -5,12 +5,13 @@ from base64 import b64encode import apns import bags +import logging from ._helpers import KeyPair, PROTOCOL_VERSION from . import signing -def lookup( +async def lookup( conn: apns.APNSConnection, self_uri: str, id_keypair: KeyPair, @@ -19,12 +20,12 @@ def lookup( ) -> bytes: BAG_KEY = "id-query" - conn.filter([topic]) + await conn.filter([topic]) body = plistlib.dumps({"uris": query}) body = gzip.compress(body, mtime=0) - push_token = b64encode(conn.token).decode() + push_token = b64encode(conn.credentials.token).decode() headers = { "x-id-self-uri": self_uri, @@ -47,25 +48,22 @@ def lookup( "b": body, } - conn.send_message(topic, plistlib.dumps(req, fmt=plistlib.FMT_BINARY)) - - def check_response(x): - if x[0] != 0x0A: + await conn.send_notification(topic, plistlib.dumps(req, fmt=plistlib.FMT_BINARY)) + + def check(payload: apns.APNSPayload): + body = payload.fields_with_id(3)[0].value + if body is None: return False - resp_body = apns._get_field(x[1], 3) - if resp_body is None: - return False - resp_body = plistlib.loads(resp_body) - return resp_body.get('U') == msg_id + body = plistlib.loads(body) + logging.warning(body.get('U')) + return body.get('U') == msg_id - # Lambda to check if the response is the one we want - payload = conn.incoming_queue.wait_pop_find(check_response) - resp = apns._get_field(payload[1], 3) + payload = await conn.expect_notification(topic, check) + + resp = payload.fields_with_id(3)[0].value resp = plistlib.loads(resp) resp = gzip.decompress(resp["b"]) resp = plistlib.loads(resp) - # Acknowledge the message - #conn._send_ack(apns._get_field(payload[1], 4)) if resp['status'] != 0: raise Exception(f'Query failed: {resp}') diff --git a/ids/signing.py b/ids/signing.py index c5f24ce..71550ff 100644 --- a/ids/signing.py +++ b/ids/signing.py @@ -12,8 +12,8 @@ from ._helpers import KeyPair, dearmour # TODO: Move this helper somewhere else -def armour_cert(cert: bytes) -> str: - cert = x509.load_der_x509_certificate(cert) +def armour_cert(c: bytes) -> str: + cert = x509.load_der_x509_certificate(c) return cert.public_bytes(serialization.Encoding.PEM).decode("utf-8").strip()