From 4adf75fb52b3c80785151065222dd27901574415 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20D=C3=B6rre?= Date: Sat, 14 Oct 2023 11:39:13 +0000 Subject: [PATCH] authenticate and seal firmware images --- CMakeLists.txt | 10 ++++++- bootloader.ld | 76 +++++++++++++++++++++++++------------------------- main.c | 45 ++++++++++++++++++++++-------- standalone.ld | 4 +++ tcp_comm.c | 40 ++++++++++++++++++++++++-- 5 files changed, 122 insertions(+), 53 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index eb3fd0a..0616dd5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,6 +37,9 @@ add_executable(picowota main.c tcp_comm.c dhcpserver/dhcpserver.c + ${PICO_SDK_PATH}/lib/mbedtls/library/sha256.c + ${PICO_SDK_PATH}/lib/mbedtls/library/platform_util.c + ${PICO_SDK_PATH}/lib/mbedtls/library/constant_time.c ) function(target_cl_options option) @@ -54,7 +57,9 @@ pico_add_extra_outputs(picowota) target_include_directories(picowota PRIVATE ${CMAKE_CURRENT_LIST_DIR} # Needed so that lwip can find lwipopts.h - ${CMAKE_CURRENT_LIST_DIR}/dhcpserver) + ${CMAKE_CURRENT_LIST_DIR}/dhcpserver + ${PICO_SDK_PATH}/lib/mbedtls/include +) pico_enable_stdio_usb(picowota 1) @@ -89,6 +94,7 @@ endfunction() picowota_retrieve_variable(PICOWOTA_WIFI_SSID false) picowota_retrieve_variable(PICOWOTA_WIFI_PASS true) picowota_retrieve_variable(PICOWOTA_WIFI_AP false) +picowota_retrieve_variable(PICOWOTA_FLASH_KEY true) if ((NOT PICOWOTA_WIFI_SSID) OR (NOT PICOWOTA_WIFI_PASS)) message(FATAL_ERROR @@ -114,6 +120,8 @@ function(picowota_build_standalone NAME) pico_add_bin_output(${NAME}) endfunction() +pico_set_linker_script(picowota ${CMAKE_CURRENT_LIST_DIR}/bootloader.ld) + # Provide a helper to build a combined target # The build process is roughly: # 1. Build the bootloader, using a special linker script which leaves diff --git a/bootloader.ld b/bootloader.ld index 34a9e34..a2e2db3 100644 --- a/bootloader.ld +++ b/bootloader.ld @@ -63,16 +63,38 @@ SECTIONS launches only, to perform proper flash setup. */ - .flashtext : { + .text : { __logical_binary_start = .; KEEP (*(.vectors)) KEEP (*(.binary_info_header)) __binary_info_header_end = .; KEEP (*(.reset)) - } + /* TODO revisit this now memset/memcpy/float in ROM */ + /* bit of a hack right now to exclude all floating point and time critical (e.g. memset, memcpy) code from + * FLASH ... we will include any thing excluded here in .data below by default */ + *(.init) + *(EXCLUDE_FILE(*libgcc.a: *libc.a:*lib_a-mem*.o *libm.a:) .text*) + *(.fini) + /* Pull all c'tors into .text */ + *crtbegin.o(.ctors) + *crtbegin?.o(.ctors) + *(EXCLUDE_FILE(*crtend?.o *crtend.o) .ctors) + *(SORT(.ctors.*)) + *(.ctors) + /* Followed by destructors */ + *crtbegin.o(.dtors) + *crtbegin?.o(.dtors) + *(EXCLUDE_FILE(*crtend?.o *crtend.o) .dtors) + *(SORT(.dtors.*)) + *(.dtors) + + *(.eh_frame*) + . = ALIGN(4); + } > FLASH .rodata : { - /* segments not marked as .flashdata are instead pulled into .data (in RAM) to avoid accidental flash accesses */ + *(EXCLUDE_FILE(*libgcc.a: *libc.a:*lib_a-mem*.o *libm.a:) .rodata*) + . = ALIGN(4); *(SORT_BY_ALIGNMENT(SORT_BY_NAME(.flashdata*))) . = ALIGN(4); } > FLASH @@ -100,42 +122,18 @@ SECTIONS __binary_info_end = .; . = ALIGN(4); - /* Vector table goes first in RAM, to avoid large alignment hole */ - .ram_vector_table (COPY): { + .ram_vector_table (NOLOAD): { *(.ram_vector_table) } > RAM - .text : { - __ram_text_start__ = .; - *(.init) - *(.text*) - *(.fini) - /* Pull all c'tors into .text */ - *crtbegin.o(.ctors) - *crtbegin?.o(.ctors) - *(EXCLUDE_FILE(*crtend?.o *crtend.o) .ctors) - *(SORT(.ctors.*)) - *(.ctors) - /* Followed by destructors */ - *crtbegin.o(.dtors) - *crtbegin?.o(.dtors) - *(EXCLUDE_FILE(*crtend?.o *crtend.o) .dtors) - *(SORT(.dtors.*)) - *(.dtors) - - *(.eh_frame*) - . = ALIGN(4); - __ram_text_end__ = .; - } > RAM AT> FLASH - __ram_text_source__ = LOADADDR(.text); - - .data : { __data_start__ = .; *(vtable) *(.time_critical*) + /* remaining .text and .rodata; i.e. stuff we exclude above because we want it in RAM */ + *(.text*) . = ALIGN(4); *(.rodata*) . = ALIGN(4); @@ -177,10 +175,10 @@ SECTIONS /* All data end */ __data_end__ = .; } > RAM AT> FLASH - /* __etext is the name of the .data init source pointer (...) */ + /* __etext is (for backwards compatibility) the name of the .data init source pointer (...) */ __etext = LOADADDR(.data); - .uninitialized_data (COPY): { + .uninitialized_data (NOLOAD): { . = ALIGN(4); *(.uninitialized_data*) } > RAM @@ -211,11 +209,11 @@ SECTIONS __bss_end__ = .; } > RAM - .heap (COPY): + .heap (NOLOAD): { __end__ = .; end = __end__; - *(.heap*) + KEEP(*(.heap*)) __HeapLimit = .; } > RAM @@ -228,18 +226,20 @@ SECTIONS /* by default we put core 0 stack at the end of scratch Y, so that if core 1 * stack is not used then all of SCRATCH_X is free. */ - .stack1_dummy (COPY): + .stack1_dummy (NOLOAD): { *(.stack1*) } > SCRATCH_X - .stack_dummy (COPY): + .stack_dummy (NOLOAD): { - *(.stack*) + KEEP(*(.stack*)) } > SCRATCH_Y .flash_end : { - __flash_binary_end = .; + PROVIDE(__flash_binary_end = .); } > FLASH + __wota_image_header_offset = ORIGIN(FLASH)+LENGTH(FLASH); + ASSERT( __flash_binary_end < __wota_image_header_offset, "Flash overflowed into image") /* stack limit is poorly named, but historically is maximum heap ptr */ __StackLimit = ORIGIN(RAM) + LENGTH(RAM); diff --git a/main.c b/main.c index 80b892e..cfc6c3c 100644 --- a/main.c +++ b/main.c @@ -27,7 +27,8 @@ #include "pico/cyw43_arch.h" #include "tcp_comm.h" - +#include "mbedtls/sha256.h" +#include "mbedtls/constant_time.h" #include "picowota/reboot.h" #ifdef DEBUG @@ -60,6 +61,8 @@ const char *wifi_ssid = STR(PICOWOTA_WIFI_SSID); const char *wifi_pass = STR(PICOWOTA_WIFI_PASS); #endif +const char *flash_key = STR(PICOWOTA_FLASH_KEY); + critical_section_t critical_section; #define EVENT_QUEUE_LENGTH 8 @@ -87,8 +90,18 @@ struct event { #define TCP_PORT 4242 -#define IMAGE_HEADER_OFFSET (360 * 1024) +struct image_header { + uint32_t vtor; + uint32_t size; + uint32_t crc; + uint8_t mac[32]; + uint8_t pad[FLASH_PAGE_SIZE - (3 * 4) - 32]; +}; +static_assert(sizeof(struct image_header) == FLASH_PAGE_SIZE, "image_header must be FLASH_PAGE_SIZE bytes"); + +extern struct image_header __wota_image_header_offset; +#define IMAGE_HEADER_OFFSET (((uint32_t)&__wota_image_header_offset) - XIP_BASE) #define WRITE_ADDR_MIN (XIP_BASE + IMAGE_HEADER_OFFSET + FLASH_SECTOR_SIZE) #define ERASE_ADDR_MIN (XIP_BASE + IMAGE_HEADER_OFFSET) #define FLASH_ADDR_MAX (XIP_BASE + PICO_FLASH_SIZE_BYTES) @@ -365,13 +378,6 @@ struct comm_command write_cmd = { .handle = &handle_write, }; -struct image_header { - uint32_t vtor; - uint32_t size; - uint32_t crc; - uint8_t pad[FLASH_PAGE_SIZE - (3 * 4)]; -}; -static_assert(sizeof(struct image_header) == FLASH_PAGE_SIZE, "image_header must be FLASH_PAGE_SIZE bytes"); static bool image_header_ok(struct image_header *hdr) { @@ -384,6 +390,18 @@ static bool image_header_ok(struct image_header *hdr) return false; } + uint8_t hash[32]; + mbedtls_sha256_context ctx = {}; + mbedtls_sha256_init(&ctx); + mbedtls_sha256_starts_ret(&ctx, 0); + mbedtls_sha256_update_ret(&ctx, (uint8_t*) flash_key, strlen(flash_key)); + mbedtls_sha256_update_ret(&ctx, (void*) hdr->vtor, hdr->size); + mbedtls_sha256_finish(&ctx, hash); + mbedtls_sha256_free(&ctx); + if(mbedtls_ct_memcmp(hdr->mac, hash, 32) != 0) { + return false; + } + // Stack pointer needs to be in RAM if (vtor[0] < SRAM_BASE) { return false; @@ -406,6 +424,7 @@ static uint32_t handle_seal(uint32_t *args_in, uint8_t *data_in, uint32_t *resp_ .size = args_in[1], .crc = args_in[2], }; + memcpy(hdr.mac, data_in, 32); if ((hdr.vtor & 0xff) || (hdr.size & 0x3)) { // Must be aligned @@ -428,14 +447,18 @@ static uint32_t handle_seal(uint32_t *args_in, uint8_t *data_in, uint32_t *resp_ return TCP_COMM_RSP_OK; } - +static uint32_t size_seal(uint32_t *args_in, uint32_t *data_len_out, uint32_t *resp_data_len_out) +{ + *data_len_out = 32; + return TCP_COMM_RSP_OK; +} struct comm_command seal_cmd = { // SEAL vtor len crc // OKOK .opcode = CMD_SEAL, .nargs = 3, .resp_nargs = 0, - .size = NULL, + .size = &size_seal, .handle = &handle_seal, }; diff --git a/standalone.ld b/standalone.ld index 040f6ad..74b22f7 100644 --- a/standalone.ld +++ b/standalone.ld @@ -217,6 +217,10 @@ SECTIONS __flash_binary_end = .; } > FLASH + __wota_image_header_offset = ORIGIN(FLASH)+LENGTH(FLASH - 4k); + ASSERT( __flash_binary_end < __wota_image_header_offset, "Flash overflowed into image") + + /* stack limit is poorly named, but historically is maximum heap ptr */ __StackLimit = ORIGIN(RAM) + LENGTH(RAM); __StackOneTop = ORIGIN(SCRATCH_X) + LENGTH(SCRATCH_X); diff --git a/tcp_comm.c b/tcp_comm.c index ea754d1..9ce95c8 100644 --- a/tcp_comm.c +++ b/tcp_comm.c @@ -15,6 +15,9 @@ #include "tcp_comm.h" +#include "mbedtls/sha256.h" +#include "mbedtls/constant_time.h" + #ifdef DEBUG #include #define DEBUG_printf(...) printf(__VA_ARGS__) @@ -27,6 +30,8 @@ #define COMM_MAX_NARG 5 enum conn_state { + CONN_STATE_WRITE_LOGIN, + CONN_STATE_HANDLE_LOGIN, CONN_STATE_WAIT_FOR_SYNC, CONN_STATE_READ_OPCODE, CONN_STATE_READ_ARGS, @@ -56,6 +61,8 @@ struct tcp_comm_ctx { uint32_t resp_data_len; + uint64_t nonce; + const struct comm_command *cmd; const struct comm_command *const *cmds; unsigned int n_cmds; @@ -235,9 +242,27 @@ static int tcp_comm_response_complete(struct tcp_comm_ctx *ctx) return tcp_comm_opcode_begin(ctx); } +extern const char *flash_key; static int tcp_comm_rx_complete(struct tcp_comm_ctx *ctx) { switch (ctx->conn_state) { + case CONN_STATE_HANDLE_LOGIN: + { + uint8_t hash[32]; + mbedtls_sha256_context hash_ctx = {}; + mbedtls_sha256_init(&hash_ctx); + mbedtls_sha256_starts_ret(&hash_ctx, 0); + mbedtls_sha256_update_ret(&hash_ctx, (uint8_t *) flash_key, strlen(flash_key)); + mbedtls_sha256_update_ret(&hash_ctx, (uint8_t *) &ctx->nonce, sizeof(ctx->nonce)); + mbedtls_sha256_finish(&hash_ctx, hash); + mbedtls_sha256_free(&hash_ctx); + if(mbedtls_ct_memcmp(ctx->buf, hash, 32) == 0) { + tcp_comm_sync_begin(ctx); + return 0; + } else { + return -1; + } + } case CONN_STATE_WAIT_FOR_SYNC: return tcp_comm_sync_complete(ctx); case CONN_STATE_READ_OPCODE: @@ -247,13 +272,17 @@ static int tcp_comm_rx_complete(struct tcp_comm_ctx *ctx) case CONN_STATE_READ_DATA: return tcp_comm_data_complete(ctx); default: - return -1; + return -2; } } static int tcp_comm_tx_complete(struct tcp_comm_ctx *ctx) { switch (ctx->conn_state) { + case CONN_STATE_WRITE_LOGIN: + ctx->conn_state = CONN_STATE_HANDLE_LOGIN; + ctx->rx_bytes_needed = 32; + return tcp_comm_handle_input(ctx); case CONN_STATE_WRITE_RESP: return tcp_comm_response_complete(ctx); case CONN_STATE_WRITE_ERROR: @@ -460,12 +489,17 @@ static void tcp_comm_client_init(struct tcp_comm_ctx *ctx, struct tcp_pcb *pcb) cyw43_arch_gpio_put (0, true); - tcp_comm_sync_begin(ctx); - + ctx->nonce = get_rand_64(); + ctx->conn_state = CONN_STATE_WRITE_LOGIN; + ctx->tx_bytes_sent = 0; + ctx->tx_bytes_remaining = sizeof(ctx->nonce); + tcp_sent(pcb, tcp_comm_client_sent); tcp_recv(pcb, tcp_comm_client_recv); tcp_poll(pcb, tcp_comm_client_poll, POLL_TIME_S * 2); tcp_err(pcb, tcp_comm_client_err); + + tcp_write(ctx->client_pcb, &ctx->nonce, ctx->tx_bytes_remaining, 0); } static err_t tcp_comm_server_accept(void *arg, struct tcp_pcb *client_pcb, err_t err)