main.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. """Minimal example of Python websocket server
  2. handling OTA updates for ESP32 amd ESP8266
  3. Check and upload of firmware works.
  4. Register and state function are jus for example.
  5. """
  6. # pylint: disable=W0703,E1101
  7. import asyncio
  8. import copy
  9. import json
  10. import logging
  11. import subprocess
  12. import threading
  13. import time
  14. from os import listdir
  15. from os.path import join as join_pth
  16. from pathlib import Path
  17. import websockets
  18. from packaging import version
  19. # Logger settings
  20. logging.basicConfig(filename="ws_server.log")
  21. Logger = logging.getLogger('WS-OTA')
  22. Logger.addHandler(logging.StreamHandler())
  23. Logger.setLevel(logging.INFO)
  24. # Path to directory with FW
  25. fw_path = join_pth(Path().absolute(), "firmware")
  26. def create_path(path: str) -> None:
  27. """Check if path exist or create it"""
  28. Path(path).mkdir(parents=True, exist_ok=True)
  29. def shell(command):
  30. """Handle execution of shell commands"""
  31. with subprocess.Popen(command, shell=True,
  32. stdout=subprocess.PIPE,
  33. universal_newlines=True
  34. ) as process:
  35. for stdout_line in iter(process.stdout.readline, ""):
  36. Logger.debug(stdout_line)
  37. process.stdout.close()
  38. return_code = process.wait()
  39. Logger.debug("Shell returned: %s", return_code)
  40. return process.returncode
  41. return None
  42. async def binary_send(websocket, fw_file):
  43. """Read firmware file, divide it to chunks and send them"""
  44. with open(fw_file, "rb") as binaryfile:
  45. while True:
  46. chunk = binaryfile.read(2048)
  47. if not chunk:
  48. break
  49. try:
  50. await websocket.send(chunk)
  51. except Exception as exception:
  52. Logger.exception(exception)
  53. return False
  54. time.sleep(0.2)
  55. def version_checker(name, vdev, vapp):
  56. """Parse and compare FW version"""
  57. if version.parse(vdev) < version.parse(vapp):
  58. Logger.info("Client(%s) version %s is smaller than %s: Go for update", name, vdev, vapp)
  59. return True
  60. Logger.info("Client(%s) version %s is greater or equal to %s: Not updating", name, vdev, vapp)
  61. return False
  62. class WsOtaHandler (threading.Thread):
  63. """Thread handling ota update
  64. Runing ota directly from message would kill WS
  65. as message bus would timeout.
  66. """
  67. def __init__(self, name, message, websocket):
  68. threading.Thread.__init__(self, daemon=True)
  69. self.name = name
  70. self.msg = message
  71. self.websocket = websocket
  72. def run(self, ):
  73. try:
  74. asyncio.run(self.start_)
  75. except Exception as exception:
  76. Logger.exception(exception)
  77. finally:
  78. pass
  79. async def start_(self):
  80. """Start _ota se asyncio future"""
  81. msg_task = asyncio.ensure_future(
  82. self._ota())
  83. done, pending = await asyncio.wait(
  84. [msg_task],
  85. return_when=asyncio.FIRST_COMPLETED,
  86. )
  87. Logger.info("WS Ota Handler done: %s", done)
  88. for task in pending:
  89. task.cancel()
  90. async def _ota(self):
  91. """Check for new fw and update or pass"""
  92. device_name = self.msg['name']
  93. device_chip = self.msg['chip']
  94. device_version = self.msg['version']
  95. fw_version = ''
  96. fw_name = ''
  97. fw_device = ''
  98. for filename in listdir(fw_path):
  99. fw_info = filename.split("-")
  100. fw_device = fw_info[0]
  101. if fw_device == device_name:
  102. fw_version = fw_info[1]
  103. fw_name = filename
  104. break
  105. if not fw_version:
  106. Logger.info("Client(%s): No fw found!", device_name)
  107. msg = '{"type": "ota", "value":"ok"}'
  108. await self.websocket.send(msg)
  109. return
  110. if not version_checker(device_name, device_version, fw_version):
  111. return
  112. fw_file = join_pth(fw_path, fw_name)
  113. if device_chip == 'esp8266' and not fw_file.endswith('.gz'):
  114. # We can compress fw to make it smaller for upload
  115. fw_cpress = fw_file
  116. fw_file = fw_cpress + ".gz"
  117. cpress = f"gzip -9 {fw_cpress}"
  118. cstate = shell(cpress)
  119. if cstate:
  120. Logger.error("Cannot compress firmware: %s", fw_name)
  121. return
  122. # Get size of fw
  123. size = Path(fw_file).stat().st_size
  124. # Request ota mode
  125. msg = '{"type": "ota", "value":"go", "size":' + str(size) + '}'
  126. await self.websocket.send(msg)
  127. # send file by chunks trough websocket
  128. await binary_send(self.websocket, fw_file)
  129. async def _register(websocket, message):
  130. mac = message.get('mac')
  131. name = message.get('name')
  132. Logger.info("Client(%s) mac: %s", name, mac)
  133. # Some code
  134. response = {'response_type': 'registry', 'state': 'ok'}
  135. await websocket.send(json.dumps(response))
  136. async def _state(websocket, message):
  137. mac = message.get('mac')
  138. name = message.get('name')
  139. Logger.info("Client(%s) mac: %s", name, mac)
  140. # Some code
  141. response = {'response_type': 'state', 'state': 'ok'}
  142. await websocket.send(json.dumps(response))
  143. async def _unhandleld(websocket, msg):
  144. Logger.info("Unhandled message from device: %s", str(msg))
  145. response = {'response_type': 'response', 'state': 'nok'}
  146. await websocket.send(json.dumps(response))
  147. async def _greetings(websocket, message):
  148. WsOtaHandler('thread_ota', copy.deepcopy(message), websocket).start()
  149. async def message_received(websocket, message) -> None:
  150. """Handle incoming messages
  151. Check if message contain json and run waned function
  152. """
  153. switcher = {"greetings": _greetings,
  154. "register": _register,
  155. "state": _state
  156. }
  157. if message[0:1] == "{":
  158. try:
  159. msg_json = json.loads(message)
  160. except Exception as exception:
  161. Logger.error(exception)
  162. return
  163. type_ = msg_json.get('type')
  164. name = msg_json.get('name')
  165. func = switcher.get(type_, _unhandleld)
  166. Logger.debug("Client(%s)said: %s", name, type_)
  167. try:
  168. await func(websocket, msg_json)
  169. except Exception as exception:
  170. Logger.error(exception)
  171. # pylint: disable=W0613
  172. async def ws_server(websocket, path) -> None:
  173. """Run in cycle and wait for new messages"""
  174. async for message in websocket:
  175. await message_received(websocket, message)
  176. async def main():
  177. """Server starter
  178. Normal user can bind only port nubers greater than 1024
  179. """
  180. async with websockets.serve(ws_server, "10.0.1.5", 8081):
  181. await asyncio.Future() # run forever
  182. create_path(fw_path)
  183. asyncio.run(main())