diff --git a/__main__.py b/__main__.py index 1211ae0..210e8b5 100644 --- a/__main__.py +++ b/__main__.py @@ -146,46 +146,31 @@ def main(): reboot_after = not args.dont_reboot_after uboot_only = args.uboot_only boot_only = args.boot_only - should_create_keys = False - + if jtag_hardware == "auto": jtag_hardware = detect_jtag_hardware() log.info("Detected JTAG hardware '{}'".format(jtag_hardware)) import sshkeys if args.ssh_public_key: - with open(args.ssh_public_key, 'r') as f: + with open(args.ssh_public_key, 'rb') as f: ssh_pubkey_data = f.read() if not sshkeys.check_public_key(ssh_pubkey_data): raise Exception("RSA key is not valid") log.info("Using RSA key in {}".format(args.ssh_public_key)) else: - if (os.path.exists("{}".format(args.output_ssh_key)) and os.path.exists("{}.pub".format(args.output_ssh_key))): - # Still load the public key - with open("{}".format(args.output_ssh_key), 'r') as f: - ssh_pubkey_data = f.read() - - # If the if statment above loads the public key, this check should pass, if the file is empty, assumed is that this check should not pass - if (should_create_keys = not sshkeys.check_public_key(ssh_pubkey_data)): - log.info("RSA key is not valid") - - if should_create_keys - # Do not write the public and private keys, they get overwritten everytime this script runs. - # Found that out the hard way.. - (pub, priv) = sshkeys.generate_key_pair(args.private_key_password) - ssh_pubkey_data = pub - with open("{}".format(args.output_ssh_key), 'w') as f: - f.write(priv) - with open("{}.pub".format(args.output_ssh_key), 'w') as f: - f.write(pub) - log.info("Written private and public key pair to {0} and {0}.pub, respectively".format(args.output_ssh_key)) - else: - log.info("Not creating a key pair, because one exists.") - + (pub, priv) = sshkeys.generate_key_pair(args.private_key_password) + ssh_pubkey_data = pub + with open("{}".format(args.output_ssh_key), 'wb') as f: + f.write(priv) + with open("{}.pub".format(args.output_ssh_key), 'wb') as f: + f.write(pub) + log.info("Written private and public key pair to {0} and {0}.pub, respectively".format(args.output_ssh_key)) + import json params = { "port" : serial_path, - "ssh_pubkey_data" : ssh_pubkey_data, + "ssh_pubkey_data" : ssh_pubkey_data.decode('ascii'), "has_jtag" : jtag_available, "check_uboot" : check_current_bootloader, "cleanup_payload" : cleanup_payload, diff --git a/rooter.py b/rooter.py index 78e8d6e..0d1f2fd 100644 --- a/rooter.py +++ b/rooter.py @@ -10,7 +10,7 @@ import random from time import sleep from serial.serialutil import Timeout -import StringIO +import io import tempfile logging.basicConfig(level=logging.DEBUG) @@ -24,7 +24,7 @@ def __init__(self, **params): baudrate=115200 ) self._port = params['port'] - self._ssh_pubkey_data = params['ssh_pubkey_data'] + self._ssh_pubkey_data = params['ssh_pubkey_data'].encode('ascii') self._has_jtag = params['has_jtag'] self._check_uboot = params['check_uboot'] self._cleanup_payload = params['cleanup_payload'] @@ -85,16 +85,16 @@ def run(self): def read_uboot_version(self): version_line_match = re.compile(r'^U-Boot ([^ ]+)') while True: - line = self._port.readline().strip() + line = self._port.readline().strip().decode('ascii') match = version_line_match.match(line) if match: return match.group(1) def access_uboot(self, password): log.info("Logging in to U-Boot") - self._port.write(password) + self._port.write(password.encode('ascii')) self._port.flush() - log.debug(self._port.read_until("U-Boot>")) + log.debug(self._port.read_until(b"U-Boot>")) log.debug("Logged in to U-Boot") def patch_uboot(self): @@ -103,16 +103,16 @@ def patch_uboot(self): log.info("Patching U-Boot") port.reset_input_buffer() sleep(0.1) - port.write("printenv\n") + port.write(b"printenv\n") port.flush() - add_misc_match = re.compile(r'^addmisc=(.+)$') + add_misc_match = re.compile(br'^addmisc=(.+)$') add_misc_val = None sleep(0.5) - lines = port.read_until("U-Boot>") + lines = port.read_until(b"U-Boot>") log.debug(lines) - for line in lines.split('\n'): + for line in lines.split(b'\n'): line = line.strip() log.debug(line) match = add_misc_match.match(line) @@ -123,11 +123,11 @@ def patch_uboot(self): log.error("Could not find value for addmisc environment variable") return - cmd = "setenv addmisc " + re.sub(r'([\$;])',r'\\\1', add_misc_val + " init=/bin/sh") - port.write(cmd + "\n") + cmd = b"setenv addmisc " + re.sub(br'([\$;])',br'\\\1', add_misc_val + b" init=/bin/sh") + port.write(cmd + b"\n") port.flush() - log.debug(port.read_until("U-Boot>")) - port.write("run boot_nand\n") + log.debug(port.read_until(b"U-Boot>")) + port.write(b"run boot_nand\n") port.flush() def create_payload_tar(self): @@ -136,24 +136,24 @@ def create_payload_tar(self): with tarfile.open(tar_path, "w:gz") as tar: tar.add('payload/', arcname='payload') - ssh_key_str = StringIO.StringIO(ssh_key) + ssh_key_str = io.BytesIO(ssh_key) info = tarfile.TarInfo(name="payload/id_rsa.pub") info.size=len(ssh_key) - tar.addfile(tarinfo=info, fileobj=StringIO.StringIO(ssh_key)) + tar.addfile(tarinfo=info, fileobj=io.BytesIO(ssh_key)) return tar_path def write_payload(self): port = self._port tar_path = self.create_payload_tar() - log.debug(port.read_until("/ # ")) - port.write("timeout -t 60 base64 -d | tar zxf -\n") + log.debug(port.read_until(b"/ # ")) + port.write(b"timeout -t 60 base64 -d | tar zxf -\n") port.flush() log.info("Transferring payload") - with open(tar_path, 'r') as f: + with open(tar_path, 'rb') as f: base64.encode(f, port) log.info("Transferring payload done. Waiting for toon to finish.") @@ -161,10 +161,10 @@ def write_payload(self): port.flush() port.reset_input_buffer() - port.write("\x04") + port.write(b"\x04") port.flush() - port.write("\n") - log.debug(port.read_until("/ # ")) + port.write(b"\n") + log.debug(port.read_until(b"/ # ")) log.info("Transferring payload finished") def patch_toon(self): @@ -172,13 +172,13 @@ def patch_toon(self): self._port, self._cleanup_payload, self._reboot_after) log.info("Patching Toon") password = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(8)) - port.write("if [ -f payload/patch_toon.sh ] ; then sh payload/patch_toon.sh \"{}\" ; fi\n".format(password)) + port.write("if [ -f payload/patch_toon.sh ] ; then sh payload/patch_toon.sh \"{}\" ; fi\n".format(password).encode('ascii')) try: while True: - line = read_until(port, ["/ # ", "\n"]) - if line == "/ # ": + line = read_until(port, [b"/ # ", b"\n"]) + if line == b"/ # ": break - if line.startswith(">>>"): + if line.startswith(b">>>"): log.info(line.strip()) else: log.debug(line.strip()) @@ -187,11 +187,11 @@ def patch_toon(self): sleep(5) if clean_up: log.info("Cleaning up") - port.write("rm -r payload\n") - log.debug(port.read_until("/ # ")) + port.write(b"rm -r payload\n") + log.debug(port.read_until(b"/ # ")) if reboot: log.info("Rebooting") - port.write("/etc/init.d/reboot\n") + port.write(b"/etc/init.d/reboot\n") def start_bootloader(self, bin_path): @@ -209,20 +209,20 @@ def start_bootloader(self, bin_path): log.info("Waiting for {} seconds".format(wait)) sleep(wait) client = telnetlib.Telnet('localhost', 4444) - log.debug(client.read_until("> ")) + log.debug(client.read_until(b"> ")) log.info("Halting CPU") - client.write("soft_reset_halt\n") - log.debug(client.read_until("> ")) + client.write(b"soft_reset_halt\n") + log.debug(client.read_until(b"> ")) sleep(0.1) - client.write("reset halt\n") - log.debug(client.read_until("> ")) + client.write(b"reset halt\n") + log.debug(client.read_until(b"> ")) sleep(0.1) log.info("Loading new image to RAM") - client.write("load_image {} 0xa1f00000\n".format(bin_path)) - log.debug(client.read_until("> ")) + client.write("load_image {} 0xa1f00000\n".format(bin_path).encode("ascii")) + log.debug(client.read_until(b"> ")) sleep(0.1) log.info("Starting up new image") - client.write("resume 0xa1f00000\n") + client.write(b"resume 0xa1f00000\n") except: try: log.exception(proc.communicate()[0]) @@ -240,7 +240,7 @@ def read_until(port, terminators=None, size=None): """ if not terminators: terminators = ['\n'] - terms = map(lambda t: (t, len(t)), terminators) + terms = [(t, len(t)) for t in terminators] line = bytearray() timeout = Timeout(port._timeout) while True: