diff --git a/homekit/accessoryserver.py b/homekit/accessoryserver.py index be545fcd..1c265080 100644 --- a/homekit/accessoryserver.py +++ b/homekit/accessoryserver.py @@ -20,6 +20,7 @@ import json from json.decoder import JSONDecodeError import select +import threading from http.server import HTTPServer, BaseHTTPRequestHandler from socketserver import ThreadingMixIn @@ -38,7 +39,7 @@ from homekit.crypto.srp import SrpServer from homekit.exceptions import ConfigurationError, ConfigLoadingError, ConfigSavingError, FormatError, \ - CharacteristicPermissionError + CharacteristicPermissionError, DisconnectedControllerError from homekit.http_impl import HttpStatusCodes from homekit.model import Accessories, Categories from homekit.model.characteristics import CharacteristicsTypes @@ -79,6 +80,18 @@ def __init__(self, config_file, logger=sys.stderr): HTTPServer.__init__(self, (self.data.ip, self.data.port), AccessoryRequestHandler) + def write_event(self, characteristics, source=None): + dead_sessions = [] + for session_id, session in self.sessions.items(): + if source and session_id == source: + continue + try: + session['handler'].write_event(characteristics) + except DisconnectedControllerError: + dead_sessions.append(session_id) + for session_id in dead_sessions: + del self.sessions[session_id] + def add_accessory(self, accessory): self.accessories.add_accessory(accessory) @@ -278,8 +291,7 @@ class AccessoryRequestHandler(BaseHTTPRequestHandler): def __init__(self, request, client_address, server): # keep pycharm from complaining about those not being define in __init__ - # self.session_id = '{ip}:{port}'.format(ip=client_address[0], port= client_address[1]) - self.session_id = '{ip}'.format(ip=client_address[0]) + self.session_id = '{ip}:{port}'.format(ip=client_address[0], port=client_address[1]) if self.session_id not in server.sessions: server.sessions[self.session_id] = {'handler': self} self.rfile = None @@ -314,9 +326,76 @@ def __init__(self, request, client_address, server): # get the identify callback function from calling server self.identify_callback = server.identify_callback + self.write_lock = threading.Lock() + self.subscriptions = set() + # init super class BaseHTTPRequestHandler.__init__(self, request, client_address, server) + def setup(self): + super().setup() + + self.orig_wfile = self.wfile + self.orig_rfile = self.rfile + + def write_event(self, characteristics): + tmp = [] + for (aid, iid) in characteristics: + if (aid, iid) not in self.subscriptions: + continue + + char = self._get_characteristic_instance(aid, iid) + + tmp.append({ + 'aid': aid, + 'iid': iid, + 'value': char.get_value(), + }) + + # Bail out if this connection isnt subscribing to any of these characteristics + if not tmp: + return + + body = json.dumps({'characteristics': tmp}) + + event = [ + 'EVENT/1.0 200 OK', + 'Content-Type: application/hap+json', + 'Content-Length: {}'.format(len(body)), + '', + body + ] + + self.write_encrypted_bytes('\r\n'.join(event).encode('utf-8')) + + def write_encrypted_bytes(self, data): + with self.write_lock: + if AccessoryRequestHandler.DEBUG_CRYPT: + self.log_message('response >%s<', data) + self.log_message('len(response) %s', len(data)) + + block_size = 1024 + out_data = bytearray() + while len(data) > 0: + block = data[:block_size] + if AccessoryRequestHandler.DEBUG_CRYPT: + self.log_message('==> BLOCK: len %s', len(block)) + data = data[block_size:] + + len_bytes = len(block).to_bytes(2, byteorder='little') + a2c_key = self.server.sessions[self.session_id]['accessory_to_controller_key'] + cnt_bytes = self.server.sessions[self.session_id]['accessory_to_controller_count'].\ + to_bytes(8, byteorder='little') + ciper_and_mac = chacha20_aead_encrypt(len_bytes, a2c_key, cnt_bytes, bytes([0, 0, 0, 0]), block) + self.server.sessions[self.session_id]['accessory_to_controller_count'] += 1 + out_data += len_bytes + ciper_and_mac[0] + ciper_and_mac[1] + + try: + self.orig_wfile.write(out_data) + self.orig_wfile.flush() + except ValueError: + raise DisconnectedControllerError() + def handle_one_request(self): """ This is used to determine whether the request is encrypted or not. This is done by looking at the first bytes of @@ -412,34 +491,12 @@ def handle_one_request(self): self.wfile.seek(0) in_data = self.wfile.read(65537) - if AccessoryRequestHandler.DEBUG_CRYPT: - self.log_message('response >%s<', in_data) - self.log_message('len(response) %s', len(in_data)) - - block_size = 1024 - out_data = bytearray() - while len(in_data) > 0: - block = in_data[:block_size] - if AccessoryRequestHandler.DEBUG_CRYPT: - self.log_message('==> BLOCK: len %s', len(block)) - in_data = in_data[block_size:] - - len_bytes = len(block).to_bytes(2, byteorder='little') - a2c_key = self.server.sessions[self.session_id]['accessory_to_controller_key'] - cnt_bytes = self.server.sessions[self.session_id]['accessory_to_controller_count'].\ - to_bytes(8, byteorder='little') - ciper_and_mac = chacha20_aead_encrypt(len_bytes, a2c_key, cnt_bytes, bytes([0, 0, 0, 0]), block) - self.server.sessions[self.session_id]['accessory_to_controller_count'] += 1 - out_data += len_bytes + ciper_and_mac[0] + ciper_and_mac[1] + self.write_encrypted_bytes(in_data) # change back to originals to handle multiple calls self.rfile = old_rfile self.wfile = old_wfile - # send data to original requester - self.wfile.write(out_data) - self.wfile.flush() - def _get_characteristics(self): """ As described on page 84 @@ -519,7 +576,7 @@ def _get_characteristics(self): errors += 1 if ev: # TODO handling of events is missing - result['characteristics'][-1]['ev'] = False + result['characteristics'][-1]['ev'] = (aid, cid) in self.subscriptions if include_type: result['characteristics'][-1]['type'] = \ CharacteristicsTypes.get_short_uuid(characteristic.type) @@ -552,6 +609,16 @@ def _get_characteristics(self): self.end_headers() self.wfile.write(result_bytes) + def _get_characteristic_instance(self, aid, iid): + for accessory in self.server.accessories.accessories: + if accessory.aid != aid: + continue + for service in accessory.services: + for characteristic in service.characteristics: + if characteristic.iid != iid: + continue + return characteristic + def _put_characteristics(self): """ Defined page 80 ff @@ -566,6 +633,7 @@ def _put_characteristics(self): result = { 'characteristics': [] } + changed = [] errors = 0 for characteristic_to_set in characteristics_to_set: aid = characteristic_to_set['aid'] @@ -583,7 +651,10 @@ def _put_characteristics(self): if AccessoryRequestHandler.DEBUG_PUT_CHARACTERISTICS: self.log_message('set ev >%s< >%s< >%s<', aid, cid, characteristic_to_set['ev']) if 'ev' in characteristic.perms: - characteristic.set_events(characteristic_to_set['ev']) + if characteristic_to_set['ev']: + self.subscriptions.add((aid, cid)) + else: + self.subscriptions.discard((aid, cid)) result['characteristics'].append({'aid': aid, 'iid': cid, 'status': 0}) else: result['characteristics'].append( @@ -595,6 +666,7 @@ def _put_characteristics(self): try: characteristic.set_value(characteristic_to_set['value']) result['characteristics'].append({'aid': aid, 'iid': cid, 'status': 0}) + changed.append((aid, cid)) except FormatError: result['characteristics'].append( {'aid': aid, 'iid': cid, 'status': HapStatusCodes.INVALID_VALUE}) @@ -610,6 +682,9 @@ def _put_characteristics(self): {'aid': aid, 'iid': cid, 'status': HapStatusCodes.RESOURCE_NOT_EXIST}) errors += 1 + if changed: + self.server.write_event(changed, self.session_id) + if len(result['characteristics']) == errors: self.send_response(HttpStatusCodes.BAD_REQUEST) elif len(result['characteristics']) > errors > 0: diff --git a/homekit/exceptions.py b/homekit/exceptions.py index ca4daf95..9d974029 100644 --- a/homekit/exceptions.py +++ b/homekit/exceptions.py @@ -254,3 +254,8 @@ class TransportNotSupportedError(HomeKitException): def __init__(self, transport): Exception.__init__(self, 'Transport {t} not supported. See setup.py for required dependencies.'.format(t=transport)) + + +class DisconnectedControllerError(HomeKitException): + def __init__(self): + Exception.__init__(self, 'Controller has passed away')