Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions node/rest/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,9 @@ workspace = true
[dev-dependencies.base64]
workspace = true

[dev-dependencies.tower]
version = "0.4"
features = ["util"]

[build-dependencies.built]
workspace = true
40 changes: 36 additions & 4 deletions node/rest/src/helpers/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,49 @@ use axum::{
response::{IntoResponse, Response},
};

/// An enum of error handlers for the REST API server.
pub struct RestError(pub String);
/// A generic error for the REST API server.
pub struct RestError {
/// The HTTP status code to return.
pub status: StatusCode,
/// The error message.
pub message: String,
}

impl RestError {
/// Creates a new internal server error.
pub fn new(message: String) -> Self {
Self { status: StatusCode::INTERNAL_SERVER_ERROR, message }
}

/// Creates a new error with a specific status code.
pub fn with_status(message: String, status: StatusCode) -> Self {
Self { status, message }
}

/// Creates a 400 Bad Request error.
pub fn bad_request(message: String) -> Self {
Self::with_status(message, StatusCode::BAD_REQUEST)
}

/// Creates a 404 Not Found error.
pub fn not_found(message: String) -> Self {
Self::with_status(message, StatusCode::NOT_FOUND)
}

/// Creates a 503 Service Unavailable error.
pub fn service_unavailable(message: String) -> Self {
Self::with_status(message, StatusCode::SERVICE_UNAVAILABLE)
}
}

impl IntoResponse for RestError {
fn into_response(self) -> Response {
(StatusCode::INTERNAL_SERVER_ERROR, format!("Something went wrong: {}", self.0)).into_response()
(self.status, format!("Something went wrong: {}", self.message)).into_response()
}
}

impl From<anyhow::Error> for RestError {
fn from(err: anyhow::Error) -> Self {
Self(err.to_string())
Self::new(err.to_string())
}
}
132 changes: 107 additions & 25 deletions node/rest/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,33 +132,9 @@ impl<N: Network, C: ConsensusStorage<N>, R: Routing<N>> Rest<N, C, R> {

impl<N: Network, C: ConsensusStorage<N>, R: Routing<N>> Rest<N, C, R> {
async fn spawn_server(&mut self, rest_ip: SocketAddr, rest_rps: u32) -> Result<()> {
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods([Method::GET, Method::POST, Method::OPTIONS])
.allow_headers([CONTENT_TYPE]);

// Log the REST rate limit per IP.
debug!("REST rate limit per IP - {rest_rps} RPS");

// Prepare the rate limiting setup.
let governor_config = Box::new(
GovernorConfigBuilder::default()
.per_nanosecond((1_000_000_000 / rest_rps) as u64)
.burst_size(rest_rps)
.error_handler(|error| {
// Properly return a 429 Too Many Requests error
let error_message = error.to_string();
let mut response = Response::new(error_message.clone().into());
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
if error_message.contains("Too Many Requests") {
*response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
}
response
})
.finish()
.expect("Couldn't set up rate limiting for the REST server!"),
);

// Get the network being used.
let network = match N::ID {
snarkvm::console::network::MainnetV0::ID => "mainnet",
Expand All @@ -169,7 +145,32 @@ impl<N: Network, C: ConsensusStorage<N>, R: Routing<N>> Rest<N, C, R> {
}
};

let router = {
// Closure to build the API routes for a given version.
let build_routes = || {
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods([Method::GET, Method::POST, Method::OPTIONS])
.allow_headers([CONTENT_TYPE]);

// Prepare the rate limiting setup.
let governor_config = Box::new(
GovernorConfigBuilder::default()
.per_nanosecond((1_000_000_000 / rest_rps) as u64)
.burst_size(rest_rps)
.error_handler(|error| {
// Properly return a 429 Too Many Requests error
let error_message = error.to_string();
let mut response = Response::new(error_message.clone().into());
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
if error_message.contains("Too Many Requests") {
*response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
}
response
})
.finish()
.expect("Couldn't set up rate limiting for the REST server!"),
);

let routes = axum::Router::new()

// All the endpoints before the call to `route_layer` are protected with JWT auth.
Expand Down Expand Up @@ -274,6 +275,11 @@ impl<N: Network, C: ConsensusStorage<N>, R: Routing<N>> Rest<N, C, R> {
})
};

// Build routers for v1 and v2.
let router_v1 = build_routes().route_layer(middleware::from_fn(legacy_error_middleware));
let router_v2 = axum::Router::new().nest("/v2", build_routes());
let router = router_v1.merge(router_v2);

let rest_listener =
TcpListener::bind(rest_ip).await.with_context(|| "Failed to bind TCP port for REST endpoints")?;

Expand Down Expand Up @@ -307,3 +313,79 @@ pub fn fmt_id(id: impl ToString) -> String {
}
formatted_id
}

/// Middleware to ensure legacy routes always return HTTP 500 on error.
async fn legacy_error_middleware(req: Request<Body>, next: Next) -> Response {
let mut res = next.run(req).await;
if !res.status().is_success() {
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
}
res
}

#[cfg(test)]
mod tests {
use super::*;
use axum::{
Router,
body::Body,
http::{Request, StatusCode},
middleware,
routing::get,
};
use tower::ServiceExt; // for `oneshot`

fn test_app() -> Router {
let build_routes = || {
Router::new()
.route(
"/not_found",
get(|| async { Err::<(), RestError>(RestError::not_found("missing".to_string())) }),
)
.route(
"/bad_request",
get(|| async { Err::<(), RestError>(RestError::bad_request("bad".to_string())) }),
)
.route(
"/service_unavailable",
get(|| async { Err::<(), RestError>(RestError::service_unavailable("gone".to_string())) }),
)
};
let router_v1 = build_routes().route_layer(middleware::from_fn(legacy_error_middleware));
let router_v2 = Router::new().nest("/v2", build_routes());
router_v1.merge(router_v2)
}

#[tokio::test]
async fn v1_routes_force_internal_server_error() {
let app = test_app();

let res = app.clone().oneshot(Request::builder().uri("/not_found").body(Body::empty()).unwrap()).await.unwrap();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);

let res =
app.clone().oneshot(Request::builder().uri("/bad_request").body(Body::empty()).unwrap()).await.unwrap();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);

let res =
app.oneshot(Request::builder().uri("/service_unavailable").body(Body::empty()).unwrap()).await.unwrap();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
}

#[tokio::test]
async fn v2_routes_return_specific_errors() {
let app = test_app();

let res =
app.clone().oneshot(Request::builder().uri("/v2/not_found").body(Body::empty()).unwrap()).await.unwrap();
assert_eq!(res.status(), StatusCode::NOT_FOUND);

let res =
app.clone().oneshot(Request::builder().uri("/v2/bad_request").body(Body::empty()).unwrap()).await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);

let res =
app.oneshot(Request::builder().uri("/v2/service_unavailable").body(Body::empty()).unwrap()).await.unwrap();
assert_eq!(res.status(), StatusCode::SERVICE_UNAVAILABLE);
}
}
Loading