Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 11 additions & 26 deletions __main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,46 +146,31 @@ def main():
reboot_after = not args.dont_reboot_after
uboot_only = args.uboot_only
boot_only = args.boot_only
should_create_keys = False


if jtag_hardware == "auto":
jtag_hardware = detect_jtag_hardware()
log.info("Detected JTAG hardware '{}'".format(jtag_hardware))

import sshkeys
if args.ssh_public_key:
with open(args.ssh_public_key, 'r') as f:
with open(args.ssh_public_key, 'rb') as f:
ssh_pubkey_data = f.read()
if not sshkeys.check_public_key(ssh_pubkey_data):
raise Exception("RSA key is not valid")
log.info("Using RSA key in {}".format(args.ssh_public_key))
else:
if (os.path.exists("{}".format(args.output_ssh_key)) and os.path.exists("{}.pub".format(args.output_ssh_key))):
# Still load the public key
with open("{}".format(args.output_ssh_key), 'r') as f:
ssh_pubkey_data = f.read()

# If the if statment above loads the public key, this check should pass, if the file is empty, assumed is that this check should not pass
if (should_create_keys = not sshkeys.check_public_key(ssh_pubkey_data)):
log.info("RSA key is not valid")

if should_create_keys
# Do not write the public and private keys, they get overwritten everytime this script runs.
# Found that out the hard way..
(pub, priv) = sshkeys.generate_key_pair(args.private_key_password)
ssh_pubkey_data = pub
with open("{}".format(args.output_ssh_key), 'w') as f:
f.write(priv)
with open("{}.pub".format(args.output_ssh_key), 'w') as f:
f.write(pub)
log.info("Written private and public key pair to {0} and {0}.pub, respectively".format(args.output_ssh_key))
else:
log.info("Not creating a key pair, because one exists.")

(pub, priv) = sshkeys.generate_key_pair(args.private_key_password)
ssh_pubkey_data = pub
with open("{}".format(args.output_ssh_key), 'wb') as f:
f.write(priv)
with open("{}.pub".format(args.output_ssh_key), 'wb') as f:
f.write(pub)
log.info("Written private and public key pair to {0} and {0}.pub, respectively".format(args.output_ssh_key))

import json
params = {
"port" : serial_path,
"ssh_pubkey_data" : ssh_pubkey_data,
"ssh_pubkey_data" : ssh_pubkey_data.decode('ascii'),
"has_jtag" : jtag_available,
"check_uboot" : check_current_bootloader,
"cleanup_payload" : cleanup_payload,
Expand Down
74 changes: 37 additions & 37 deletions rooter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import random
from time import sleep
from serial.serialutil import Timeout
import StringIO
import io
import tempfile

logging.basicConfig(level=logging.DEBUG)
Expand All @@ -24,7 +24,7 @@ def __init__(self, **params):
baudrate=115200
)
self._port = params['port']
self._ssh_pubkey_data = params['ssh_pubkey_data']
self._ssh_pubkey_data = params['ssh_pubkey_data'].encode('ascii')
self._has_jtag = params['has_jtag']
self._check_uboot = params['check_uboot']
self._cleanup_payload = params['cleanup_payload']
Expand Down Expand Up @@ -85,16 +85,16 @@ def run(self):
def read_uboot_version(self):
version_line_match = re.compile(r'^U-Boot ([^ ]+)')
while True:
line = self._port.readline().strip()
line = self._port.readline().strip().decode('ascii')
match = version_line_match.match(line)
if match:
return match.group(1)

def access_uboot(self, password):
log.info("Logging in to U-Boot")
self._port.write(password)
self._port.write(password.encode('ascii'))
self._port.flush()
log.debug(self._port.read_until("U-Boot>"))
log.debug(self._port.read_until(b"U-Boot>"))
log.debug("Logged in to U-Boot")

def patch_uboot(self):
Expand All @@ -103,16 +103,16 @@ def patch_uboot(self):
log.info("Patching U-Boot")
port.reset_input_buffer()
sleep(0.1)
port.write("printenv\n")
port.write(b"printenv\n")
port.flush()
add_misc_match = re.compile(r'^addmisc=(.+)$')
add_misc_match = re.compile(br'^addmisc=(.+)$')
add_misc_val = None

sleep(0.5)

lines = port.read_until("U-Boot>")
lines = port.read_until(b"U-Boot>")
log.debug(lines)
for line in lines.split('\n'):
for line in lines.split(b'\n'):
line = line.strip()
log.debug(line)
match = add_misc_match.match(line)
Expand All @@ -123,11 +123,11 @@ def patch_uboot(self):
log.error("Could not find value for addmisc environment variable")
return

cmd = "setenv addmisc " + re.sub(r'([\$;])',r'\\\1', add_misc_val + " init=/bin/sh")
port.write(cmd + "\n")
cmd = b"setenv addmisc " + re.sub(br'([\$;])',br'\\\1', add_misc_val + b" init=/bin/sh")
port.write(cmd + b"\n")
port.flush()
log.debug(port.read_until("U-Boot>"))
port.write("run boot_nand\n")
log.debug(port.read_until(b"U-Boot>"))
port.write(b"run boot_nand\n")
port.flush()

def create_payload_tar(self):
Expand All @@ -136,49 +136,49 @@ def create_payload_tar(self):
with tarfile.open(tar_path, "w:gz") as tar:
tar.add('payload/', arcname='payload')

ssh_key_str = StringIO.StringIO(ssh_key)
ssh_key_str = io.BytesIO(ssh_key)

info = tarfile.TarInfo(name="payload/id_rsa.pub")
info.size=len(ssh_key)

tar.addfile(tarinfo=info, fileobj=StringIO.StringIO(ssh_key))
tar.addfile(tarinfo=info, fileobj=io.BytesIO(ssh_key))
return tar_path

def write_payload(self):
port = self._port
tar_path = self.create_payload_tar()

log.debug(port.read_until("/ # "))
port.write("timeout -t 60 base64 -d | tar zxf -\n")
log.debug(port.read_until(b"/ # "))
port.write(b"timeout -t 60 base64 -d | tar zxf -\n")
port.flush()

log.info("Transferring payload")
with open(tar_path, 'r') as f:
with open(tar_path, 'rb') as f:
base64.encode(f, port)

log.info("Transferring payload done. Waiting for toon to finish.")
os.remove(tar_path)

port.flush()
port.reset_input_buffer()
port.write("\x04")
port.write(b"\x04")
port.flush()
port.write("\n")
log.debug(port.read_until("/ # "))
port.write(b"\n")
log.debug(port.read_until(b"/ # "))
log.info("Transferring payload finished")

def patch_toon(self):
(port, clean_up, reboot) = (
self._port, self._cleanup_payload, self._reboot_after)
log.info("Patching Toon")
password = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(8))
port.write("if [ -f payload/patch_toon.sh ] ; then sh payload/patch_toon.sh \"{}\" ; fi\n".format(password))
port.write("if [ -f payload/patch_toon.sh ] ; then sh payload/patch_toon.sh \"{}\" ; fi\n".format(password).encode('ascii'))
try:
while True:
line = read_until(port, ["/ # ", "\n"])
if line == "/ # ":
line = read_until(port, [b"/ # ", b"\n"])
if line == b"/ # ":
break
if line.startswith(">>>"):
if line.startswith(b">>>"):
log.info(line.strip())
else:
log.debug(line.strip())
Expand All @@ -187,11 +187,11 @@ def patch_toon(self):
sleep(5)
if clean_up:
log.info("Cleaning up")
port.write("rm -r payload\n")
log.debug(port.read_until("/ # "))
port.write(b"rm -r payload\n")
log.debug(port.read_until(b"/ # "))
if reboot:
log.info("Rebooting")
port.write("/etc/init.d/reboot\n")
port.write(b"/etc/init.d/reboot\n")

def start_bootloader(self, bin_path):

Expand All @@ -209,20 +209,20 @@ def start_bootloader(self, bin_path):
log.info("Waiting for {} seconds".format(wait))
sleep(wait)
client = telnetlib.Telnet('localhost', 4444)
log.debug(client.read_until("> "))
log.debug(client.read_until(b"> "))
log.info("Halting CPU")
client.write("soft_reset_halt\n")
log.debug(client.read_until("> "))
client.write(b"soft_reset_halt\n")
log.debug(client.read_until(b"> "))
sleep(0.1)
client.write("reset halt\n")
log.debug(client.read_until("> "))
client.write(b"reset halt\n")
log.debug(client.read_until(b"> "))
sleep(0.1)
log.info("Loading new image to RAM")
client.write("load_image {} 0xa1f00000\n".format(bin_path))
log.debug(client.read_until("> "))
client.write("load_image {} 0xa1f00000\n".format(bin_path).encode("ascii"))
log.debug(client.read_until(b"> "))
sleep(0.1)
log.info("Starting up new image")
client.write("resume 0xa1f00000\n")
client.write(b"resume 0xa1f00000\n")
except:
try:
log.exception(proc.communicate()[0])
Expand All @@ -240,7 +240,7 @@ def read_until(port, terminators=None, size=None):
"""
if not terminators:
terminators = ['\n']
terms = map(lambda t: (t, len(t)), terminators)
terms = [(t, len(t)) for t in terminators]
line = bytearray()
timeout = Timeout(port._timeout)
while True:
Expand Down