From 947f27d3c2392bb6b8e978dac1c3fc51f5cdd305 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20D=C3=B6rre?= Date: Sat, 14 Oct 2023 11:33:08 +0000 Subject: [PATCH] make tcp server more timing resistant --- tcp_comm.c | 57 +++++++++++++++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 20 deletions(-) diff --git a/tcp_comm.c b/tcp_comm.c index ea754d1..dd24262 100644 --- a/tcp_comm.c +++ b/tcp_comm.c @@ -46,6 +46,7 @@ struct tcp_comm_ctx { // Note: sizeof(buf) is used elsewhere, so if this is changed to not // be an array, those will need updating uint8_t buf[(sizeof(uint32_t) * (1 + COMM_MAX_NARG)) + TCP_COMM_MAX_DATA_LEN]; + uint32_t args[COMM_MAX_NARG]; uint16_t rx_start_offs; uint16_t rx_bytes_received; @@ -95,6 +96,7 @@ static int tcp_comm_data_complete(struct tcp_comm_ctx *ctx); static int tcp_comm_response_begin(struct tcp_comm_ctx *ctx); static int tcp_comm_response_complete(struct tcp_comm_ctx *ctx); static int tcp_comm_error_begin(struct tcp_comm_ctx *ctx); +static int tcp_comm_handle_input(struct tcp_comm_ctx *ctx); static int tcp_comm_sync_begin(struct tcp_comm_ctx *ctx) { @@ -119,7 +121,7 @@ static int tcp_comm_opcode_begin(struct tcp_comm_ctx *ctx) ctx->conn_state = CONN_STATE_READ_OPCODE; ctx->rx_bytes_needed = sizeof(uint32_t); - return 0; + return tcp_comm_handle_input(ctx); } static int tcp_comm_opcode_complete(struct tcp_comm_ctx *ctx) @@ -152,14 +154,17 @@ static int tcp_comm_args_complete(struct tcp_comm_ctx *ctx) const struct comm_command *cmd = ctx->cmd; uint32_t data_len = 0; + memcpy(ctx->args, ctx->buf, cmd->nargs * sizeof(uint32_t)); if (cmd->size) { - uint32_t status = cmd->size(COMM_BUF_ARGS(ctx->buf), + uint32_t status = cmd->size(ctx->args, &data_len, &ctx->resp_data_len); if (is_error(status)) { return tcp_comm_error_begin(ctx); } + } else { + ctx->resp_data_len = 0; } return tcp_comm_data_begin(ctx, data_len); @@ -182,8 +187,8 @@ static int tcp_comm_data_complete(struct tcp_comm_ctx *ctx) const struct comm_command *cmd = ctx->cmd; if (cmd->handle) { - uint32_t status = cmd->handle(COMM_BUF_ARGS(ctx->buf), - COMM_BUF_BODY(ctx->buf, cmd->nargs), + uint32_t status = cmd->handle(ctx->args, + ctx->buf, COMM_BUF_ARGS(ctx->buf), COMM_BUF_BODY(ctx->buf, cmd->resp_nargs)); if (is_error(status)) { @@ -365,7 +370,32 @@ static err_t tcp_comm_client_sent(void *arg, struct tcp_pcb *tpcb, u16_t len) return ERR_OK; } +static int tcp_comm_handle_input(struct tcp_comm_ctx *ctx) +{ + while (ctx->rx_bytes_received >= ctx->rx_bytes_needed && ctx->rx_bytes_needed > 0) { + uint16_t consumed = ctx->rx_bytes_needed; + + int res = tcp_comm_rx_complete(ctx); + if (res == -2 ) { + return ERR_OK; + } + if (res) { + return tcp_comm_client_complete(ctx, ERR_ARG); + } + ctx->rx_start_offs += consumed; + ctx->rx_bytes_received -= consumed; + + if (ctx->rx_bytes_received == 0) { + ctx->rx_start_offs = 0; + break; + }else{ + memmove(ctx->buf, ctx->buf + ctx->rx_start_offs,ctx->rx_bytes_received); + ctx->rx_start_offs = 0; + } + } + return ERR_OK; +} static err_t tcp_comm_client_recv(void *arg, struct tcp_pcb *tpcb, struct pbuf *p, err_t err) { struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg; @@ -412,22 +442,9 @@ static err_t tcp_comm_client_recv(void *arg, struct tcp_pcb *tpcb, struct pbuf * ctx->rx_bytes_received += p->tot_len; tcp_recved(tpcb, p->tot_len); - - while (ctx->rx_bytes_received >= ctx->rx_bytes_needed) { - uint16_t consumed = ctx->rx_bytes_needed; - - int res = tcp_comm_rx_complete(ctx); - if (res) { - return tcp_comm_client_complete(ctx, ERR_ARG); - } - - ctx->rx_start_offs += consumed; - ctx->rx_bytes_received -= consumed; - - if (ctx->rx_bytes_received == 0) { - ctx->rx_start_offs = 0; - break; - } + err_t err = tcp_comm_handle_input(ctx); + if(err != ERR_OK) { + return err; } } pbuf_free(p);