Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ add_executable(picowota
main.c
tcp_comm.c
dhcpserver/dhcpserver.c
${PICO_SDK_PATH}/lib/mbedtls/library/sha256.c
${PICO_SDK_PATH}/lib/mbedtls/library/platform_util.c
${PICO_SDK_PATH}/lib/mbedtls/library/constant_time.c
)

function(target_cl_options option)
Expand All @@ -54,7 +57,9 @@ pico_add_extra_outputs(picowota)

target_include_directories(picowota PRIVATE
${CMAKE_CURRENT_LIST_DIR} # Needed so that lwip can find lwipopts.h
${CMAKE_CURRENT_LIST_DIR}/dhcpserver)
${CMAKE_CURRENT_LIST_DIR}/dhcpserver
${PICO_SDK_PATH}/lib/mbedtls/include
)

pico_enable_stdio_usb(picowota 1)

Expand Down Expand Up @@ -89,6 +94,7 @@ endfunction()
picowota_retrieve_variable(PICOWOTA_WIFI_SSID false)
picowota_retrieve_variable(PICOWOTA_WIFI_PASS true)
picowota_retrieve_variable(PICOWOTA_WIFI_AP false)
picowota_retrieve_variable(PICOWOTA_FLASH_KEY true)

if ((NOT PICOWOTA_WIFI_SSID) OR (NOT PICOWOTA_WIFI_PASS))
message(FATAL_ERROR
Expand All @@ -114,6 +120,8 @@ function(picowota_build_standalone NAME)
pico_add_bin_output(${NAME})
endfunction()

pico_set_linker_script(picowota ${CMAKE_CURRENT_LIST_DIR}/bootloader.ld)

# Provide a helper to build a combined target
# The build process is roughly:
# 1. Build the bootloader, using a special linker script which leaves
Expand Down
76 changes: 38 additions & 38 deletions bootloader.ld
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,38 @@ SECTIONS
launches only, to perform proper flash setup.
*/

.flashtext : {
.text : {
__logical_binary_start = .;
KEEP (*(.vectors))
KEEP (*(.binary_info_header))
__binary_info_header_end = .;
KEEP (*(.reset))
}
/* TODO revisit this now memset/memcpy/float in ROM */
/* bit of a hack right now to exclude all floating point and time critical (e.g. memset, memcpy) code from
* FLASH ... we will include any thing excluded here in .data below by default */
*(.init)
*(EXCLUDE_FILE(*libgcc.a: *libc.a:*lib_a-mem*.o *libm.a:) .text*)
*(.fini)
/* Pull all c'tors into .text */
*crtbegin.o(.ctors)
*crtbegin?.o(.ctors)
*(EXCLUDE_FILE(*crtend?.o *crtend.o) .ctors)
*(SORT(.ctors.*))
*(.ctors)
/* Followed by destructors */
*crtbegin.o(.dtors)
*crtbegin?.o(.dtors)
*(EXCLUDE_FILE(*crtend?.o *crtend.o) .dtors)
*(SORT(.dtors.*))
*(.dtors)

*(.eh_frame*)
. = ALIGN(4);
} > FLASH

.rodata : {
/* segments not marked as .flashdata are instead pulled into .data (in RAM) to avoid accidental flash accesses */
*(EXCLUDE_FILE(*libgcc.a: *libc.a:*lib_a-mem*.o *libm.a:) .rodata*)
. = ALIGN(4);
*(SORT_BY_ALIGNMENT(SORT_BY_NAME(.flashdata*)))
. = ALIGN(4);
} > FLASH
Expand Down Expand Up @@ -100,42 +122,18 @@ SECTIONS
__binary_info_end = .;
. = ALIGN(4);

/* Vector table goes first in RAM, to avoid large alignment hole */
.ram_vector_table (COPY): {
.ram_vector_table (NOLOAD): {
*(.ram_vector_table)
} > RAM

.text : {
__ram_text_start__ = .;
*(.init)
*(.text*)
*(.fini)
/* Pull all c'tors into .text */
*crtbegin.o(.ctors)
*crtbegin?.o(.ctors)
*(EXCLUDE_FILE(*crtend?.o *crtend.o) .ctors)
*(SORT(.ctors.*))
*(.ctors)
/* Followed by destructors */
*crtbegin.o(.dtors)
*crtbegin?.o(.dtors)
*(EXCLUDE_FILE(*crtend?.o *crtend.o) .dtors)
*(SORT(.dtors.*))
*(.dtors)

*(.eh_frame*)
. = ALIGN(4);
__ram_text_end__ = .;
} > RAM AT> FLASH
__ram_text_source__ = LOADADDR(.text);


.data : {
__data_start__ = .;
*(vtable)

*(.time_critical*)

/* remaining .text and .rodata; i.e. stuff we exclude above because we want it in RAM */
*(.text*)
. = ALIGN(4);
*(.rodata*)
. = ALIGN(4);
Expand Down Expand Up @@ -177,10 +175,10 @@ SECTIONS
/* All data end */
__data_end__ = .;
} > RAM AT> FLASH
/* __etext is the name of the .data init source pointer (...) */
/* __etext is (for backwards compatibility) the name of the .data init source pointer (...) */
__etext = LOADADDR(.data);

.uninitialized_data (COPY): {
.uninitialized_data (NOLOAD): {
. = ALIGN(4);
*(.uninitialized_data*)
} > RAM
Expand Down Expand Up @@ -211,11 +209,11 @@ SECTIONS
__bss_end__ = .;
} > RAM

.heap (COPY):
.heap (NOLOAD):
{
__end__ = .;
end = __end__;
*(.heap*)
KEEP(*(.heap*))
__HeapLimit = .;
} > RAM

Expand All @@ -228,18 +226,20 @@ SECTIONS
/* by default we put core 0 stack at the end of scratch Y, so that if core 1
* stack is not used then all of SCRATCH_X is free.
*/
.stack1_dummy (COPY):
.stack1_dummy (NOLOAD):
{
*(.stack1*)
} > SCRATCH_X
.stack_dummy (COPY):
.stack_dummy (NOLOAD):
{
*(.stack*)
KEEP(*(.stack*))
} > SCRATCH_Y

.flash_end : {
__flash_binary_end = .;
PROVIDE(__flash_binary_end = .);
} > FLASH
__wota_image_header_offset = ORIGIN(FLASH)+LENGTH(FLASH);
ASSERT( __flash_binary_end < __wota_image_header_offset, "Flash overflowed into image")

/* stack limit is poorly named, but historically is maximum heap ptr */
__StackLimit = ORIGIN(RAM) + LENGTH(RAM);
Expand Down
45 changes: 34 additions & 11 deletions main.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
#include "pico/cyw43_arch.h"

#include "tcp_comm.h"

#include "mbedtls/sha256.h"
#include "mbedtls/constant_time.h"
#include "picowota/reboot.h"

#ifdef DEBUG
Expand Down Expand Up @@ -60,6 +61,8 @@ const char *wifi_ssid = STR(PICOWOTA_WIFI_SSID);
const char *wifi_pass = STR(PICOWOTA_WIFI_PASS);
#endif

const char *flash_key = STR(PICOWOTA_FLASH_KEY);

critical_section_t critical_section;

#define EVENT_QUEUE_LENGTH 8
Expand Down Expand Up @@ -87,8 +90,18 @@ struct event {

#define TCP_PORT 4242

#define IMAGE_HEADER_OFFSET (360 * 1024)
struct image_header {
uint32_t vtor;
uint32_t size;
uint32_t crc;
uint8_t mac[32];
uint8_t pad[FLASH_PAGE_SIZE - (3 * 4) - 32];
};
static_assert(sizeof(struct image_header) == FLASH_PAGE_SIZE, "image_header must be FLASH_PAGE_SIZE bytes");

extern struct image_header __wota_image_header_offset;

#define IMAGE_HEADER_OFFSET (((uint32_t)&__wota_image_header_offset) - XIP_BASE)
#define WRITE_ADDR_MIN (XIP_BASE + IMAGE_HEADER_OFFSET + FLASH_SECTOR_SIZE)
#define ERASE_ADDR_MIN (XIP_BASE + IMAGE_HEADER_OFFSET)
#define FLASH_ADDR_MAX (XIP_BASE + PICO_FLASH_SIZE_BYTES)
Expand Down Expand Up @@ -365,13 +378,6 @@ struct comm_command write_cmd = {
.handle = &handle_write,
};

struct image_header {
uint32_t vtor;
uint32_t size;
uint32_t crc;
uint8_t pad[FLASH_PAGE_SIZE - (3 * 4)];
};
static_assert(sizeof(struct image_header) == FLASH_PAGE_SIZE, "image_header must be FLASH_PAGE_SIZE bytes");

static bool image_header_ok(struct image_header *hdr)
{
Expand All @@ -384,6 +390,18 @@ static bool image_header_ok(struct image_header *hdr)
return false;
}

uint8_t hash[32];
mbedtls_sha256_context ctx = {};
mbedtls_sha256_init(&ctx);
mbedtls_sha256_starts_ret(&ctx, 0);
mbedtls_sha256_update_ret(&ctx, (uint8_t*) flash_key, strlen(flash_key));
mbedtls_sha256_update_ret(&ctx, (void*) hdr->vtor, hdr->size);
mbedtls_sha256_finish(&ctx, hash);
mbedtls_sha256_free(&ctx);
if(mbedtls_ct_memcmp(hdr->mac, hash, 32) != 0) {
return false;
}

// Stack pointer needs to be in RAM
if (vtor[0] < SRAM_BASE) {
return false;
Expand All @@ -406,6 +424,7 @@ static uint32_t handle_seal(uint32_t *args_in, uint8_t *data_in, uint32_t *resp_
.size = args_in[1],
.crc = args_in[2],
};
memcpy(hdr.mac, data_in, 32);

if ((hdr.vtor & 0xff) || (hdr.size & 0x3)) {
// Must be aligned
Expand All @@ -428,14 +447,18 @@ static uint32_t handle_seal(uint32_t *args_in, uint8_t *data_in, uint32_t *resp_

return TCP_COMM_RSP_OK;
}

static uint32_t size_seal(uint32_t *args_in, uint32_t *data_len_out, uint32_t *resp_data_len_out)
{
*data_len_out = 32;
return TCP_COMM_RSP_OK;
}
struct comm_command seal_cmd = {
// SEAL vtor len crc
// OKOK
.opcode = CMD_SEAL,
.nargs = 3,
.resp_nargs = 0,
.size = NULL,
.size = &size_seal,
.handle = &handle_seal,
};

Expand Down
4 changes: 4 additions & 0 deletions standalone.ld
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ SECTIONS
__flash_binary_end = .;
} > FLASH

__wota_image_header_offset = ORIGIN(FLASH)+LENGTH(FLASH - 4k);
ASSERT( __flash_binary_end < __wota_image_header_offset, "Flash overflowed into image")


/* stack limit is poorly named, but historically is maximum heap ptr */
__StackLimit = ORIGIN(RAM) + LENGTH(RAM);
__StackOneTop = ORIGIN(SCRATCH_X) + LENGTH(SCRATCH_X);
Expand Down
40 changes: 37 additions & 3 deletions tcp_comm.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

#include "tcp_comm.h"

#include "mbedtls/sha256.h"
#include "mbedtls/constant_time.h"

#ifdef DEBUG
#include <stdio.h>
#define DEBUG_printf(...) printf(__VA_ARGS__)
Expand All @@ -27,6 +30,8 @@
#define COMM_MAX_NARG 5

enum conn_state {
CONN_STATE_WRITE_LOGIN,
CONN_STATE_HANDLE_LOGIN,
CONN_STATE_WAIT_FOR_SYNC,
CONN_STATE_READ_OPCODE,
CONN_STATE_READ_ARGS,
Expand Down Expand Up @@ -56,6 +61,8 @@ struct tcp_comm_ctx {

uint32_t resp_data_len;

uint64_t nonce;

const struct comm_command *cmd;
const struct comm_command *const *cmds;
unsigned int n_cmds;
Expand Down Expand Up @@ -235,9 +242,27 @@ static int tcp_comm_response_complete(struct tcp_comm_ctx *ctx)
return tcp_comm_opcode_begin(ctx);
}

extern const char *flash_key;
static int tcp_comm_rx_complete(struct tcp_comm_ctx *ctx)
{
switch (ctx->conn_state) {
case CONN_STATE_HANDLE_LOGIN:
{
uint8_t hash[32];
mbedtls_sha256_context hash_ctx = {};
mbedtls_sha256_init(&hash_ctx);
mbedtls_sha256_starts_ret(&hash_ctx, 0);
mbedtls_sha256_update_ret(&hash_ctx, (uint8_t *) flash_key, strlen(flash_key));
mbedtls_sha256_update_ret(&hash_ctx, (uint8_t *) &ctx->nonce, sizeof(ctx->nonce));
mbedtls_sha256_finish(&hash_ctx, hash);
mbedtls_sha256_free(&hash_ctx);
if(mbedtls_ct_memcmp(ctx->buf, hash, 32) == 0) {
tcp_comm_sync_begin(ctx);
return 0;
} else {
return -1;
}
}
case CONN_STATE_WAIT_FOR_SYNC:
return tcp_comm_sync_complete(ctx);
case CONN_STATE_READ_OPCODE:
Expand All @@ -247,13 +272,17 @@ static int tcp_comm_rx_complete(struct tcp_comm_ctx *ctx)
case CONN_STATE_READ_DATA:
return tcp_comm_data_complete(ctx);
default:
return -1;
return -2;
}
}

static int tcp_comm_tx_complete(struct tcp_comm_ctx *ctx)
{
switch (ctx->conn_state) {
case CONN_STATE_WRITE_LOGIN:
ctx->conn_state = CONN_STATE_HANDLE_LOGIN;
ctx->rx_bytes_needed = 32;
return tcp_comm_handle_input(ctx);
case CONN_STATE_WRITE_RESP:
return tcp_comm_response_complete(ctx);
case CONN_STATE_WRITE_ERROR:
Expand Down Expand Up @@ -460,12 +489,17 @@ static void tcp_comm_client_init(struct tcp_comm_ctx *ctx, struct tcp_pcb *pcb)

cyw43_arch_gpio_put (0, true);

tcp_comm_sync_begin(ctx);

ctx->nonce = get_rand_64();
ctx->conn_state = CONN_STATE_WRITE_LOGIN;
ctx->tx_bytes_sent = 0;
ctx->tx_bytes_remaining = sizeof(ctx->nonce);

tcp_sent(pcb, tcp_comm_client_sent);
tcp_recv(pcb, tcp_comm_client_recv);
tcp_poll(pcb, tcp_comm_client_poll, POLL_TIME_S * 2);
tcp_err(pcb, tcp_comm_client_err);

tcp_write(ctx->client_pcb, &ctx->nonce, ctx->tx_bytes_remaining, 0);
}

static err_t tcp_comm_server_accept(void *arg, struct tcp_pcb *client_pcb, err_t err)
Expand Down