Skip to content

Addlayer #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)

## 0.7.1 (13. July, 2023)
### Added
- Layer Feature to allow getting CsrfTokens using a service.
- Example for middleware usage.

## 0.7.0 (12. July, 2023)
### Changed
- Replaced Bcrypt with Argon2.
Expand Down
9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
members = [
".",
"example/minimal",
"example/middleware",
]

[package]
name = "axum_csrf"
version = "0.7.0"
version = "0.7.1"
authors = ["Andrew Wheeler <[email protected]>"]
description = "Library to Provide a CSRF (Cross-Site Request Forgery) protection layer."
edition = "2021"
Expand All @@ -16,6 +17,10 @@ documentation = "https://docs.rs/axum_csrf"
keywords = ["Axum", "CSRF", "Cookies"]
repository = "https://github.com/AscendingCreations/AxumCSRF"

[features]
default = []
layer = ["tower-layer", "tower-service"]

[dependencies]
axum-core = "0.3.3"
http = "0.2.9"
Expand All @@ -29,6 +34,8 @@ cookie = { version = "0.17.0", features = [
] }
argon2 = "0.5.0"
thiserror = "1.0.43"
tower-layer = {version = "0.3.2", optional = true}
tower-service = {version = "0.3.2", optional = true}

[dev-dependencies]
anyhow = "1.0.70"
Expand Down
65 changes: 64 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,14 @@ If you need help with this library please join our [Discord Group](https://disco
```toml
# Cargo.toml
[dependencies]
axum_csrf = "0.7.0"
axum_csrf = "0.7.1"
```

#### Cargo Feature Flags
`default`: []

`layer`: Disables the state and enables a service layer. Useful for middleware interations.

# Example

Add it to axum via shared state:
Expand Down Expand Up @@ -100,6 +105,64 @@ The Template File
</html>
```

Or use the "layer" feature if you dont want to use state:
```rust
use askama::Template;
use axum::{Form, response::IntoResponse, routing::get, Router};
use axum_csrf::{CsrfConfig, CsrfLayer, CsrfToken };
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;

#[derive(Template, Deserialize, Serialize)]
#[template(path = "template.html")]
struct Keys {
authenticity_token: String,
// Your attributes...
}

#[tokio::main]
async fn main() {
// initialize tracing
tracing_subscriber::fmt::init();
let config = CsrfConfig::default();

// build our application with a route
let app = Router::new()
// `GET /` goes to `root` and Post Goes to check key
.route("/", get(root).post(check_key))
.layer(CsrfLayer::new(config));

// run our app with hyper
// `axum::Server` is a re-export of `hyper::Server`
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}

// root creates the CSRF Token and sends it into the page for return.
async fn root(token: CsrfToken) -> impl IntoResponse {
let keys = Keys {
//this Token is a hashed Token. it is returned and the original token is hashed for comparison.
authenticity_token: token.authenticity_token().unwrap(),
};

// We must return the token so that into_response will run and add it to our response cookies.
(token, keys).into_response()
}

async fn check_key(token: CsrfToken, Form(payload): Form<Keys>) -> &'static str {
// Verfiy the Hash and return the String message.
if token.verify(&payload.authenticity_token).is_err() {
"Token is invalid"
} else {
"Token is Valid lets do stuff!"
}
}
```

If you already have an encryption key for private cookies, build the CSRF configuration a different way:
```rust
let cookie_key = cookie::Key::generate();
Expand Down
21 changes: 21 additions & 0 deletions example/middleware/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[package]
name = "middleware"
version = "0.1.0"
edition = "2021"

[dependencies]
axum = "0.6.12"
serde = { version = "1.0.159", features = ["derive"] }
tokio = { version = "1.29.1", features = ["full"] }
askama = "0.12.0"
askama_axum = "0.3.0"
tracing = "0.1.37"
tracing-subscriber = "0.3.16"
serde_urlencoded = "0.7.1"
hyper = "0.14.27"
tower = "0.4"
tower-http = { version = "0.4.0", features = ["map-request-body", "util"] }

[dependencies.axum_csrf]
path = "../.."
features = ["layer"]
91 changes: 91 additions & 0 deletions example/middleware/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
use askama::Template;
use axum::{
body::{self, BoxBody, Full},
http::{Method, Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
routing::{get, post},
Form, Router,
};
use axum_csrf::{CsrfConfig, CsrfLayer, CsrfToken, Key};

use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use tower::ServiceBuilder;
use tower_http::ServiceBuilderExt;

#[derive(Template, Deserialize, Serialize)]
#[template(path = "template.html")]
pub struct Keys {
authenticity_token: String,
// Your attributes...
}

#[tokio::main]
async fn main() {
// initialize tracing
tracing_subscriber::fmt::init();
let cookie_key = Key::generate();
let config = CsrfConfig::default().with_key(Some(cookie_key));

// build our application with a route
let app = Router::new()
.route("/", post(check_key))
.layer(
ServiceBuilder::new()
.map_request_body(body::boxed)
.layer(axum::middleware::from_fn(auth_middleware)),
)
// `GET /` goes to `root` and Post Goes to check key
.route("/", get(root))
.layer(CsrfLayer::new(config));

// run our app with hyper
// `axum::Server` is a re-export of `hyper::Server`
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}

// basic handler that responds with a static string
async fn root(token: CsrfToken) -> impl IntoResponse {
let keys = Keys {
authenticity_token: token.authenticity_token().unwrap(),
};

// We must return the token so that into_response will run and add it to our response cookies.
(token, keys).into_response()
}

/// Can only be done with the feature layer enabled
pub async fn auth_middleware(
token: CsrfToken,
method: Method,
mut request: Request<BoxBody>,
next: Next<BoxBody>,
) -> Result<Response, StatusCode> {
if method == Method::POST {
let (parts, body) = request.into_parts();
let bytes = hyper::body::to_bytes(body)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

let value = serde_urlencoded::from_bytes(&bytes)
.map_err(|_| -> StatusCode { StatusCode::INTERNAL_SERVER_ERROR })?;
let payload: Form<Keys> = Form(value);
if token.verify(&payload.authenticity_token).is_err() {
return Err(StatusCode::UNAUTHORIZED);
}

request = Request::from_parts(parts, body::boxed(Full::from(bytes)));
}

Ok(next.run(request).await)
}

async fn check_key() -> &'static str {
"Token is Valid lets do stuff!"
}
15 changes: 15 additions & 0 deletions example/middleware/templates/template.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
<!DOCTYPE html>
<html>

<head>
<meta charset="UTF-8" />
<title>Minimal</title>
</head>

<body>
<form method="post" action="/">
<input type="hidden" name="authenticity_token" value="{{ authenticity_token }}"/>
<input id="button" type="submit" value="Submit" tabindex="4" />
</form>
</body>
</html>
72 changes: 72 additions & 0 deletions src/cookies.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use crate::CsrfConfig;
use cookie::{Cookie, CookieJar, Key};
use http::{
self,
header::{COOKIE, SET_COOKIE},
HeaderMap,
};
use rand::{distributions::Alphanumeric, thread_rng, Rng};

pub(crate) trait CookiesExt {
fn get_cookie(&self, name: &str, key: &Option<Key>) -> Option<Cookie<'static>>;
fn add_cookie(&mut self, cookie: Cookie<'static>, key: &Option<Key>);
}

impl CookiesExt for CookieJar {
fn get_cookie(&self, name: &str, key: &Option<Key>) -> Option<Cookie<'static>> {
if let Some(key) = key {
self.private(key).get(name)
} else {
self.get(name).cloned()
}
}

fn add_cookie(&mut self, cookie: Cookie<'static>, key: &Option<Key>) {
if let Some(key) = key {
self.private_mut(key).add(cookie)
} else {
self.add(cookie)
}
}
}

pub(crate) fn get_cookies(headers: &mut HeaderMap) -> CookieJar {
let mut jar = CookieJar::new();

let cookie_iter = headers
.get_all(COOKIE)
.into_iter()
.filter_map(|value| value.to_str().ok())
.flat_map(|value| value.split(';'))
.filter_map(|cookie| Cookie::parse_encoded(cookie.to_owned()).ok());

for cookie in cookie_iter {
jar.add_original(cookie);
}

jar
}

pub(crate) fn set_cookies(jar: CookieJar, headers: &mut HeaderMap) {
for cookie in jar.delta() {
if let Ok(header_value) = cookie.encoded().to_string().parse() {
headers.append(SET_COOKIE, header_value);
}
}
}

pub(crate) fn get_token(config: &CsrfConfig, headers: &mut HeaderMap) -> String {
let cookie_jar = get_cookies(headers);

//We check if the Cookie Exists as a signed Cookie or not. If so we use the value of the cookie.
//If not we create a new one.
if let Some(cookie) = cookie_jar.get_cookie(&config.cookie_name, &config.key) {
cookie.value().to_owned()
} else {
thread_rng()
.sample_iter(&Alphanumeric)
.take(config.cookie_len)
.map(char::from)
.collect()
}
}
26 changes: 26 additions & 0 deletions src/layer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use crate::{AxumCsrfService, CsrfConfig};
use tower_layer::Layer;

/// CSRF layer struct used to pass key and CsrfConfig around.
#[derive(Clone)]
pub struct CsrfLayer {
pub(crate) config: CsrfConfig,
}

impl CsrfLayer {
/// Creates the CSRF Protection Layer.
pub fn new(config: CsrfConfig) -> Self {
Self { config }
}
}

impl<S> Layer<S> for CsrfLayer {
type Service = AxumCsrfService<S>;

fn layer(&self, inner: S) -> Self::Service {
AxumCsrfService {
config: self.config.clone(),
inner,
}
}
}
12 changes: 12 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@ mod config;
mod error;
mod token;

pub(crate) mod cookies;

#[cfg(feature = "layer")]
mod layer;
#[cfg(feature = "layer")]
mod service;

#[cfg(feature = "layer")]
pub use layer::CsrfLayer;
#[cfg(feature = "layer")]
pub(crate) use service::AxumCsrfService;

pub use config::{CsrfConfig, Key, SameSite};
pub use error::CsrfError;
pub use token::CsrfToken;
Loading