about summary refs log tree commit diff
path: root/broadlink/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'broadlink/__init__.py')
-rw-r--r--broadlink/__init__.py75
1 files changed, 42 insertions, 33 deletions
diff --git a/broadlink/__init__.py b/broadlink/__init__.py
index c3b2cecd1a7d..7a385fc1556e 100644
--- a/broadlink/__init__.py
+++ b/broadlink/__init__.py
@@ -1,18 +1,18 @@
 #!/usr/bin/python
 
+import codecs
+import random
+import socket
+import threading
+import time
 from datetime import datetime
 
 try:
-    from Crypto.Cipher import AES
-except ImportError as e:
+    from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
+    from cryptography.hazmat.backends import default_backend
+except ImportError:
     import pyaes
 
-import time
-import random
-import socket
-import threading
-import codecs
-
 
 def gendevice(devtype, host, mac):
     devices = {
@@ -55,10 +55,10 @@ def gendevice(devtype, host, mac):
     }
 
     # Look for the class associated to devtype in devices
-    [deviceClass] = [dev for dev in devices if devtype in devices[dev]] or [None]
-    if deviceClass is None:
+    [device_class] = [dev for dev in devices if devtype in devices[dev]] or [None]
+    if device_class is None:
         return device(host=host, mac=mac, devtype=devtype)
-    return deviceClass(host=host, mac=mac, devtype=devtype)
+    return device_class(host=host, mac=mac, devtype=devtype)
 
 
 def discover(timeout=None, local_ip_address=None):
@@ -145,8 +145,6 @@ class device:
         self.devtype = devtype
         self.timeout = timeout
         self.count = random.randrange(0xffff)
-        self.key = bytearray(
-            [0x09, 0x76, 0x28, 0x34, 0x3f, 0xe9, 0x9e, 0x23, 0x76, 0x5c, 0x15, 0x13, 0xac, 0xcf, 0x8b, 0x02])
         self.iv = bytearray(
             [0x56, 0x2e, 0x17, 0x99, 0x6d, 0x09, 0x3d, 0x28, 0xdd, 0xb3, 0xba, 0x69, 0x5a, 0x2e, 0x6f, 0x58])
         self.id = bytearray([0, 0, 0, 0])
@@ -160,25 +158,38 @@ class device:
         if 'pyaes' in globals():
             self.encrypt = self.encrypt_pyaes
             self.decrypt = self.decrypt_pyaes
+            self.update_aes = self.update_aes_pyaes
+
         else:
-            self.encrypt = self.encrypt_pycrypto
-            self.decrypt = self.decrypt_pycrypto
+            self.encrypt = self.encrypt_crypto
+            self.decrypt = self.decrypt_crypto
+            self.update_aes = self.update_aes_crypto
+
+        self.aes = None
+        key = bytearray(
+            [0x09, 0x76, 0x28, 0x34, 0x3f, 0xe9, 0x9e, 0x23, 0x76, 0x5c, 0x15, 0x13, 0xac, 0xcf, 0x8b, 0x02])
+        self.update_aes(key)
+
+    def update_aes_pyaes(self, key):
+        self.aes = pyaes.AESModeOfOperationCBC(key, iv=bytes(self.iv))
 
     def encrypt_pyaes(self, payload):
-        aes = pyaes.AESModeOfOperationCBC(self.key, iv=bytes(self.iv))
-        return b"".join([aes.encrypt(bytes(payload[i:i + 16])) for i in range(0, len(payload), 16)])
+        return b"".join([self.aes.encrypt(bytes(payload[i:i + 16])) for i in range(0, len(payload), 16)])
 
     def decrypt_pyaes(self, payload):
-        aes = pyaes.AESModeOfOperationCBC(self.key, iv=bytes(self.iv))
-        return b"".join([aes.decrypt(bytes(payload[i:i + 16])) for i in range(0, len(payload), 16)])
+        return b"".join([self.aes.decrypt(bytes(payload[i:i + 16])) for i in range(0, len(payload), 16)])
+
+    def update_aes_crypto(self, key):
+        self.aes = Cipher(algorithms.AES(key), modes.CBC(self.iv),
+                          backend=default_backend())
 
-    def encrypt_pycrypto(self, payload):
-        aes = AES.new(bytes(self.key), AES.MODE_CBC, bytes(self.iv))
-        return aes.encrypt(bytes(payload))
+    def encrypt_crypto(self, payload):
+        encryptor = self.aes.encryptor()
+        return encryptor.update(payload) + encryptor.finalize()
 
-    def decrypt_pycrypto(self, payload):
-        aes = AES.new(bytes(self.key), AES.MODE_CBC, bytes(self.iv))
-        return aes.decrypt(bytes(payload))
+    def decrypt_crypto(self, payload):
+        decryptor = self.aes.decryptor()
+        return decryptor.update(payload) + decryptor.finalize()
 
     def auth(self):
         payload = bytearray(0x50)
@@ -219,7 +230,7 @@ class device:
             return False
 
         self.id = payload[0x00:0x04]
-        self.key = key
+        self.update_aes(key)
 
         return True
 
@@ -278,7 +289,7 @@ class device:
         packet[0x20] = checksum & 0xff
         packet[0x21] = checksum >> 8
 
-        starttime = time.time()
+        start_time = time.time()
         with self.lock:
             while True:
                 try:
@@ -287,7 +298,7 @@ class device:
                     response = self.cs.recvfrom(2048)
                     break
                 except socket.timeout:
-                    if (time.time() - starttime) > self.timeout:
+                    if (time.time() - start_time) > self.timeout:
                         raise
         return bytearray(response[0])
 
@@ -702,7 +713,6 @@ class hysen(device):
     # The sensor command is currently experimental
     def set_mode(self, auto_mode, loop_mode, sensor=0):
         mode_byte = ((loop_mode + 1) << 4) + auto_mode
-        # print 'Mode byte: 0x'+ format(mode_byte, '02x')
         self.send_request(bytearray([0x01, 0x06, 0x00, 0x02, mode_byte, sensor]))
 
     # Advanced settings
@@ -787,8 +797,8 @@ class S1C(device):
     Its VERY VERY VERY DIRTY IMPLEMENTATION of S1C
     """
 
-    def __init__(self, *a, **kw):
-        device.__init__(self, *a, **kw)
+    def __init__(self, host, mac, devtype):
+        device.__init__(self, host, mac, devtype)
         self.type = 'S1C'
 
     def get_sensors_status(self):
@@ -798,9 +808,8 @@ class S1C(device):
         err = response[0x22] | (response[0x23] << 8)
         if err != 0:
             return None
-        aes = AES.new(bytes(self.key), AES.MODE_CBC, bytes(self.iv))
 
-        payload = aes.decrypt(bytes(response[0x38:]))
+        payload = self.decrypt(bytes(response[0x38:]))
         if not payload:
             return None
         count = payload[0x4]