Skip to content

Some hyper 1.0 / axum 0.7 preparations #1513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 26 additions & 25 deletions server/svix-server/src/core/idempotency.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
use std::{collections::HashMap, convert::Infallible, future::Future, pin::Pin, time::Duration};

use axum::{
body::{Body, BoxBody, HttpBody},
http::{Request, Response, StatusCode},
response::IntoResponse,
body::{Body, HttpBody},
http::{Request, StatusCode},
response::{IntoResponse, Response},
};
use blake2::{Blake2b512, Digest};
use http::request::Parts;
Expand Down Expand Up @@ -88,7 +88,7 @@ fn finished_serialized_response_to_response(
code: u16,
headers: Option<HashMap<String, Vec<u8>>>,
body: Option<Vec<u8>>,
) -> Result<Response<BoxBody>, ConversionToResponseError> {
) -> Result<Response, ConversionToResponseError> {
let mut out = body.unwrap_or_default().into_response();

let status = out.status_mut();
Expand All @@ -105,16 +105,16 @@ fn finished_serialized_response_to_response(
Ok(out)
}

async fn resolve_service<S>(
mut service: S,
req: Request<Body>,
) -> Result<Response<BoxBody>, Infallible>
async fn resolve_service<S>(mut service: S, req: Request<Body>) -> Response
where
S: Service<Request<Body>, Error = Infallible> + Clone + Send + 'static,
S::Response: IntoResponse,
S::Future: Send + 'static,
{
service.call(req).await.map(IntoResponse::into_response)
match service.call(req).await {
Ok(res) => res.into_response(),
Err(e) => match e {},
}
}

/// The idempotency middleware itself -- used via the [`Router::layer`] method
Expand All @@ -130,7 +130,7 @@ where
S::Response: IntoResponse,
S::Future: Send + 'static,
{
type Response = Response<BoxBody>;
type Response = Response;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

Expand All @@ -151,15 +151,15 @@ where

// If not a POST request, simply resolve the service as usual
if parts.method != http::Method::POST {
return resolve_service(service, Request::from_parts(parts, body)).await;
return Ok(resolve_service(service, Request::from_parts(parts, body)).await);
}

// Retrieve `IdempotencyKey` from header and URL parts, but returning the service
// normally in the event a key could not be created.
let key = if let Some(key) = get_key(&parts) {
key
} else {
return resolve_service(service, Request::from_parts(parts, body)).await;
return Ok(resolve_service(service, Request::from_parts(parts, body)).await);
};

// Set the [`SerializedResponse::Start`] lock if the key does not exist in the cache
Expand Down Expand Up @@ -231,8 +231,13 @@ where
// If it's set or the lock or the `lock_loop` returns Ok(None), then the key has no
// value, so continue resolving the service while caching the response for 2xx
// responses
resolve_and_cache_response(&cache, &key, service, Request::from_parts(parts, body))
.await
Ok(resolve_and_cache_response(
&cache,
&key,
service,
Request::from_parts(parts, body),
)
.await)
})
} else {
Box::pin(async move { Ok(service.call(req).await.into_response()) })
Expand Down Expand Up @@ -299,22 +304,18 @@ async fn resolve_and_cache_response<S>(
key: &IdempotencyKey,
service: S,
request: Request<Body>,
) -> Result<Response<BoxBody>, Infallible>
) -> Response
where
S: Service<Request<Body>, Error = Infallible> + Clone + Send + 'static,
S::Response: IntoResponse,
S::Future: Send + 'static,
{
let (parts, mut body) = resolve_service(service, request)
.await
// Infallible
.unwrap()
.into_parts();
let (parts, body) = resolve_service(service, request).await.into_parts();

// If a 2xx response, cache the actual response
if parts.status.is_success() {
// TODO: Don't skip over Err value
let bytes = body.data().await.and_then(Result::ok);
let bytes = body.collect().await.ok().map(|c| c.to_bytes());

let resp = SerializedResponse::Finished {
code: parts.status.into(),
Expand All @@ -329,20 +330,20 @@ where
};

if cache.set(key, &resp, expiry_default()).await.is_err() {
return Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response());
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}

// Assumes None to be an empty byte array
let bytes = bytes.unwrap_or_default();
Ok(Response::from_parts(parts, Body::from(bytes)).into_response())
Response::from_parts(parts, Body::from(bytes)).into_response()
}
// If any other status, unset the start lock and return the response
else {
if cache.delete(key).await.is_err() {
return Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response());
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}

Ok(Response::from_parts(parts, body).into_response())
Response::from_parts(parts, body).into_response()
}
}

Expand Down
34 changes: 16 additions & 18 deletions server/svix-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

use std::{
borrow::Cow,
net::TcpListener,
sync::atomic::{AtomicBool, Ordering},
time::Duration,
};
Expand All @@ -21,6 +20,7 @@ use redis::RedisManager;
use sea_orm::DatabaseConnection;
use sentry::integrations::tracing::EventFilter;
use svix_ksuid::{KsuidLike, KsuidMs};
use tokio::net::TcpListener;
use tower::layer::layer_fn;
use tower_http::{
cors::{AllowHeaders, Any, CorsLayer},
Expand Down Expand Up @@ -88,10 +88,9 @@ async fn graceful_shutdown_handler() {
}

#[tracing::instrument(name = "app_start", level = "trace", skip_all)]
pub async fn run(cfg: Configuration, listener: Option<TcpListener>) {
pub async fn run(cfg: Configuration) {
let _metrics = setup_metrics(&cfg);

run_with_prefix(None, cfg, listener).await
run_with_prefix(None, cfg, None).await
}

#[derive(Clone)]
Expand Down Expand Up @@ -184,20 +183,19 @@ pub async fn run_with_prefix(
let (server, worker_loop, expired_message_cleaner_loop) = tokio::join!(
async {
if with_api {
if let Some(l) = listener {
tracing::debug!("API: Listening on {}", l.local_addr().unwrap());
axum::Server::from_tcp(l)
.expect("Error starting http server")
.serve(svc)
.with_graceful_shutdown(graceful_shutdown_handler())
.await
} else {
tracing::debug!("API: Listening on {}", listen_address);
axum::Server::bind(&listen_address)
.serve(svc)
.with_graceful_shutdown(graceful_shutdown_handler())
let listener = match listener {
Some(l) => l,
None => TcpListener::bind(listen_address)
.await
}
.expect("Error binding to listen_address"),
};
tracing::debug!("API: Listening on {}", listener.local_addr().unwrap());

let incoming = hyper::server::conn::AddrIncoming::from_listener(listener)?;
axum::Server::builder(incoming)
.serve(svc)
.with_graceful_shutdown(graceful_shutdown_handler())
.await
} else {
tracing::debug!("API: off");
graceful_shutdown_handler().await;
Expand Down Expand Up @@ -273,7 +271,7 @@ pub fn setup_tracing(
.tracing()
.with_exporter(exporter)
.with_trace_config(
opentelemetry_sdk::trace::config()
opentelemetry_sdk::trace::Config::default()
.with_sampler(
cfg.opentelemetry_sample_ratio
.map(opentelemetry_sdk::trace::Sampler::TraceIdRatioBased)
Expand Down
2 changes: 1 addition & 1 deletion server/svix-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ async fn main() {
);
}
None => {
run(cfg, None).await;
run(cfg).await;
}
};

Expand Down
2 changes: 1 addition & 1 deletion server/svix-server/src/metrics/redis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub enum RedisQueueType<'a> {
SortedSet(&'a str),
}

impl<'a> RedisQueueType<'a> {
impl RedisQueueType<'_> {
pub async fn queue_depth(&self, redis: &RedisManager) -> Result<u64> {
let mut conn = redis.get().await?;
match self {
Expand Down
13 changes: 9 additions & 4 deletions server/svix-server/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::{
time::Duration,
};

use axum::body::HttpBody as _;
use chrono::Utc;
use futures::future;
use http::{HeaderValue, StatusCode, Version};
Expand Down Expand Up @@ -370,11 +371,15 @@ async fn make_http_call(
None
};

let body = match hyper::body::to_bytes(res.into_body()).await {
Ok(bytes) if bytes.len() > RESPONSE_MAX_SIZE => {
bytes_to_string(bytes.slice(..RESPONSE_MAX_SIZE))
let body = match res.into_body().collect().await {
Ok(collected) => {
let bytes = collected.to_bytes();
if bytes.len() > RESPONSE_MAX_SIZE {
bytes_to_string(bytes.slice(..RESPONSE_MAX_SIZE))
} else {
bytes_to_string(bytes)
}
}
Ok(bytes) => bytes_to_string(bytes),
Err(err) => format!("Error reading response body: {err}"),
};

Expand Down
5 changes: 5 additions & 0 deletions server/svix-server/tests/it/e2e_operational_webhooks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ fn start_svix_server_with_operational_webhooks(
let regular_jwt = generate_org_token(&cfg.jwt_signing_config, org_id.clone()).unwrap();

let listener = TcpListener::bind("127.0.0.1:0").unwrap();
// Could update this fn to take a tokio TcpListener instead, but that's a pretty large diff
// for very little benefit (since this is just test code anyways).
listener.set_nonblocking(true).unwrap();
let listener = tokio::net::TcpListener::from_std(listener).unwrap();

let base_url = format!("http://{}", listener.local_addr().unwrap());

cfg.operational_webhook_address = Some(base_url.clone());
Expand Down
15 changes: 6 additions & 9 deletions server/svix-server/tests/it/integ_webhook_http_client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{net::TcpListener, sync::Arc};

use axum::extract::State;
use axum::{body::HttpBody as _, extract::State};
use http::{header::USER_AGENT, HeaderValue, Request, StatusCode, Version};
use hyper::Body;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -122,14 +122,11 @@ async fn test_client_basic_operation() {
let reqwest_http_req = receiver.req_recv.recv().await.unwrap();

assert_eq!(our_http_req.headers(), reqwest_http_req.headers());
assert_eq!(
hyper::body::to_bytes(our_http_req.into_body())
.await
.unwrap(),
hyper::body::to_bytes(reqwest_http_req.into_body())
.await
.unwrap()
);

let our_body = our_http_req.into_body().collect().await.unwrap().to_bytes();
#[rustfmt::skip]
let reqwest_body = reqwest_http_req.into_body().collect().await.unwrap().to_bytes();
assert_eq!(our_body, reqwest_body);
}

#[tokio::test]
Expand Down
5 changes: 5 additions & 0 deletions server/svix-server/tests/it/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,11 @@ pub async fn start_svix_server_with_cfg_and_org_id_and_prefix(
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let base_uri = format!("http://{}", listener.local_addr().unwrap());

// Could update this fn to take a tokio TcpListener instead, but that's a pretty large diff
// for very little benefit (since this is just test code anyways).
listener.set_nonblocking(true).unwrap();
let listener = tokio::net::TcpListener::from_std(listener).unwrap();

let jh = tokio::spawn(
svix_server::run_with_prefix(Some(prefix), cfg, Some(listener))
.with_subscriber(tracing_subscriber),
Expand Down