diff --git a/server/svix-server/src/core/idempotency.rs b/server/svix-server/src/core/idempotency.rs index 581153535..656e0a1be 100644 --- a/server/svix-server/src/core/idempotency.rs +++ b/server/svix-server/src/core/idempotency.rs @@ -13,9 +13,9 @@ use std::{collections::HashMap, convert::Infallible, future::Future, pin::Pin, time::Duration}; use axum::{ - body::{Body, BoxBody, HttpBody}, - http::{Request, Response, StatusCode}, - response::IntoResponse, + body::{Body, HttpBody}, + http::{Request, StatusCode}, + response::{IntoResponse, Response}, }; use blake2::{Blake2b512, Digest}; use http::request::Parts; @@ -88,7 +88,7 @@ fn finished_serialized_response_to_response( code: u16, headers: Option>>, body: Option>, -) -> Result, ConversionToResponseError> { +) -> Result { let mut out = body.unwrap_or_default().into_response(); let status = out.status_mut(); @@ -105,16 +105,16 @@ fn finished_serialized_response_to_response( Ok(out) } -async fn resolve_service( - mut service: S, - req: Request, -) -> Result, Infallible> +async fn resolve_service(mut service: S, req: Request) -> Response where S: Service, Error = Infallible> + Clone + Send + 'static, S::Response: IntoResponse, S::Future: Send + 'static, { - service.call(req).await.map(IntoResponse::into_response) + match service.call(req).await { + Ok(res) => res.into_response(), + Err(e) => match e {}, + } } /// The idempotency middleware itself -- used via the [`Router::layer`] method @@ -130,7 +130,7 @@ where S::Response: IntoResponse, S::Future: Send + 'static, { - type Response = Response; + type Response = Response; type Error = Infallible; type Future = Pin> + Send>>; @@ -151,7 +151,7 @@ where // If not a POST request, simply resolve the service as usual if parts.method != http::Method::POST { - return resolve_service(service, Request::from_parts(parts, body)).await; + return Ok(resolve_service(service, Request::from_parts(parts, body)).await); } // Retrieve `IdempotencyKey` from header and URL parts, but returning the service @@ -159,7 +159,7 @@ where let key = if let Some(key) = get_key(&parts) { key } else { - return resolve_service(service, Request::from_parts(parts, body)).await; + return Ok(resolve_service(service, Request::from_parts(parts, body)).await); }; // Set the [`SerializedResponse::Start`] lock if the key does not exist in the cache @@ -231,8 +231,13 @@ where // If it's set or the lock or the `lock_loop` returns Ok(None), then the key has no // value, so continue resolving the service while caching the response for 2xx // responses - resolve_and_cache_response(&cache, &key, service, Request::from_parts(parts, body)) - .await + Ok(resolve_and_cache_response( + &cache, + &key, + service, + Request::from_parts(parts, body), + ) + .await) }) } else { Box::pin(async move { Ok(service.call(req).await.into_response()) }) @@ -299,22 +304,18 @@ async fn resolve_and_cache_response( key: &IdempotencyKey, service: S, request: Request, -) -> Result, Infallible> +) -> Response where S: Service, Error = Infallible> + Clone + Send + 'static, S::Response: IntoResponse, S::Future: Send + 'static, { - let (parts, mut body) = resolve_service(service, request) - .await - // Infallible - .unwrap() - .into_parts(); + let (parts, body) = resolve_service(service, request).await.into_parts(); // If a 2xx response, cache the actual response if parts.status.is_success() { // TODO: Don't skip over Err value - let bytes = body.data().await.and_then(Result::ok); + let bytes = body.collect().await.ok().map(|c| c.to_bytes()); let resp = SerializedResponse::Finished { code: parts.status.into(), @@ -329,20 +330,20 @@ where }; if cache.set(key, &resp, expiry_default()).await.is_err() { - return Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response()); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } // Assumes None to be an empty byte array let bytes = bytes.unwrap_or_default(); - Ok(Response::from_parts(parts, Body::from(bytes)).into_response()) + Response::from_parts(parts, Body::from(bytes)).into_response() } // If any other status, unset the start lock and return the response else { if cache.delete(key).await.is_err() { - return Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response()); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } - Ok(Response::from_parts(parts, body).into_response()) + Response::from_parts(parts, body).into_response() } } diff --git a/server/svix-server/src/lib.rs b/server/svix-server/src/lib.rs index f71b18f3c..6a02f0254 100644 --- a/server/svix-server/src/lib.rs +++ b/server/svix-server/src/lib.rs @@ -6,7 +6,6 @@ use std::{ borrow::Cow, - net::TcpListener, sync::atomic::{AtomicBool, Ordering}, time::Duration, }; @@ -21,6 +20,7 @@ use redis::RedisManager; use sea_orm::DatabaseConnection; use sentry::integrations::tracing::EventFilter; use svix_ksuid::{KsuidLike, KsuidMs}; +use tokio::net::TcpListener; use tower::layer::layer_fn; use tower_http::{ cors::{AllowHeaders, Any, CorsLayer}, @@ -88,10 +88,9 @@ async fn graceful_shutdown_handler() { } #[tracing::instrument(name = "app_start", level = "trace", skip_all)] -pub async fn run(cfg: Configuration, listener: Option) { +pub async fn run(cfg: Configuration) { let _metrics = setup_metrics(&cfg); - - run_with_prefix(None, cfg, listener).await + run_with_prefix(None, cfg, None).await } #[derive(Clone)] @@ -184,20 +183,19 @@ pub async fn run_with_prefix( let (server, worker_loop, expired_message_cleaner_loop) = tokio::join!( async { if with_api { - if let Some(l) = listener { - tracing::debug!("API: Listening on {}", l.local_addr().unwrap()); - axum::Server::from_tcp(l) - .expect("Error starting http server") - .serve(svc) - .with_graceful_shutdown(graceful_shutdown_handler()) - .await - } else { - tracing::debug!("API: Listening on {}", listen_address); - axum::Server::bind(&listen_address) - .serve(svc) - .with_graceful_shutdown(graceful_shutdown_handler()) + let listener = match listener { + Some(l) => l, + None => TcpListener::bind(listen_address) .await - } + .expect("Error binding to listen_address"), + }; + tracing::debug!("API: Listening on {}", listener.local_addr().unwrap()); + + let incoming = hyper::server::conn::AddrIncoming::from_listener(listener)?; + axum::Server::builder(incoming) + .serve(svc) + .with_graceful_shutdown(graceful_shutdown_handler()) + .await } else { tracing::debug!("API: off"); graceful_shutdown_handler().await; @@ -273,7 +271,7 @@ pub fn setup_tracing( .tracing() .with_exporter(exporter) .with_trace_config( - opentelemetry_sdk::trace::config() + opentelemetry_sdk::trace::Config::default() .with_sampler( cfg.opentelemetry_sample_ratio .map(opentelemetry_sdk::trace::Sampler::TraceIdRatioBased) diff --git a/server/svix-server/src/main.rs b/server/svix-server/src/main.rs index 126c08012..3ae00b6ae 100644 --- a/server/svix-server/src/main.rs +++ b/server/svix-server/src/main.rs @@ -196,7 +196,7 @@ async fn main() { ); } None => { - run(cfg, None).await; + run(cfg).await; } }; diff --git a/server/svix-server/src/metrics/redis.rs b/server/svix-server/src/metrics/redis.rs index bb514a144..34cb0fe29 100644 --- a/server/svix-server/src/metrics/redis.rs +++ b/server/svix-server/src/metrics/redis.rs @@ -14,7 +14,7 @@ pub enum RedisQueueType<'a> { SortedSet(&'a str), } -impl<'a> RedisQueueType<'a> { +impl RedisQueueType<'_> { pub async fn queue_depth(&self, redis: &RedisManager) -> Result { let mut conn = redis.get().await?; match self { diff --git a/server/svix-server/src/worker.rs b/server/svix-server/src/worker.rs index 18e9016db..1a4fc80c2 100644 --- a/server/svix-server/src/worker.rs +++ b/server/svix-server/src/worker.rs @@ -10,6 +10,7 @@ use std::{ time::Duration, }; +use axum::body::HttpBody as _; use chrono::Utc; use futures::future; use http::{HeaderValue, StatusCode, Version}; @@ -370,11 +371,15 @@ async fn make_http_call( None }; - let body = match hyper::body::to_bytes(res.into_body()).await { - Ok(bytes) if bytes.len() > RESPONSE_MAX_SIZE => { - bytes_to_string(bytes.slice(..RESPONSE_MAX_SIZE)) + let body = match res.into_body().collect().await { + Ok(collected) => { + let bytes = collected.to_bytes(); + if bytes.len() > RESPONSE_MAX_SIZE { + bytes_to_string(bytes.slice(..RESPONSE_MAX_SIZE)) + } else { + bytes_to_string(bytes) + } } - Ok(bytes) => bytes_to_string(bytes), Err(err) => format!("Error reading response body: {err}"), }; diff --git a/server/svix-server/tests/it/e2e_operational_webhooks.rs b/server/svix-server/tests/it/e2e_operational_webhooks.rs index 69362885f..3487b3dbd 100644 --- a/server/svix-server/tests/it/e2e_operational_webhooks.rs +++ b/server/svix-server/tests/it/e2e_operational_webhooks.rs @@ -114,6 +114,11 @@ fn start_svix_server_with_operational_webhooks( let regular_jwt = generate_org_token(&cfg.jwt_signing_config, org_id.clone()).unwrap(); let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + // Could update this fn to take a tokio TcpListener instead, but that's a pretty large diff + // for very little benefit (since this is just test code anyways). + listener.set_nonblocking(true).unwrap(); + let listener = tokio::net::TcpListener::from_std(listener).unwrap(); + let base_url = format!("http://{}", listener.local_addr().unwrap()); cfg.operational_webhook_address = Some(base_url.clone()); diff --git a/server/svix-server/tests/it/integ_webhook_http_client.rs b/server/svix-server/tests/it/integ_webhook_http_client.rs index db3f01a02..eb071d1cc 100644 --- a/server/svix-server/tests/it/integ_webhook_http_client.rs +++ b/server/svix-server/tests/it/integ_webhook_http_client.rs @@ -1,6 +1,6 @@ use std::{net::TcpListener, sync::Arc}; -use axum::extract::State; +use axum::{body::HttpBody as _, extract::State}; use http::{header::USER_AGENT, HeaderValue, Request, StatusCode, Version}; use hyper::Body; use serde::{Deserialize, Serialize}; @@ -122,14 +122,11 @@ async fn test_client_basic_operation() { let reqwest_http_req = receiver.req_recv.recv().await.unwrap(); assert_eq!(our_http_req.headers(), reqwest_http_req.headers()); - assert_eq!( - hyper::body::to_bytes(our_http_req.into_body()) - .await - .unwrap(), - hyper::body::to_bytes(reqwest_http_req.into_body()) - .await - .unwrap() - ); + + let our_body = our_http_req.into_body().collect().await.unwrap().to_bytes(); + #[rustfmt::skip] + let reqwest_body = reqwest_http_req.into_body().collect().await.unwrap().to_bytes(); + assert_eq!(our_body, reqwest_body); } #[tokio::test] diff --git a/server/svix-server/tests/it/utils/mod.rs b/server/svix-server/tests/it/utils/mod.rs index 93620bbb9..789019eac 100644 --- a/server/svix-server/tests/it/utils/mod.rs +++ b/server/svix-server/tests/it/utils/mod.rs @@ -348,6 +348,11 @@ pub async fn start_svix_server_with_cfg_and_org_id_and_prefix( let listener = TcpListener::bind("127.0.0.1:0").unwrap(); let base_uri = format!("http://{}", listener.local_addr().unwrap()); + // Could update this fn to take a tokio TcpListener instead, but that's a pretty large diff + // for very little benefit (since this is just test code anyways). + listener.set_nonblocking(true).unwrap(); + let listener = tokio::net::TcpListener::from_std(listener).unwrap(); + let jh = tokio::spawn( svix_server::run_with_prefix(Some(prefix), cfg, Some(listener)) .with_subscriber(tracing_subscriber),