diff --git a/README.md b/README.md index a0fb6cb..9463b44 100644 --- a/README.md +++ b/README.md @@ -4,11 +4,12 @@ Authentication middleware for Axum applications using the SnazzyFellas authentic ## Features -- **Middleware**: Automatically redirect unauthenticated users to the SF auth endpoint +- **Struct Middleware**: Tower-based middleware with dynamic redirect URI callback - **Extractor**: Type-safe access to authenticated user information via the `SfUser` extractor - **Callback Handler**: Ready-to-use route handler for authentication callbacks - **Session Integration**: Seamless integration with tower-sessions - **Fail-Closed Security**: Validation failures result in denied access, not automatic approval +- **Flexible Redirects**: Use a callback function to dynamically determine redirect URIs ## Installation @@ -25,27 +26,25 @@ tokio = { version = "1", features = ["full"] } ## Quick Start ```rust -use axum::{routing::get, Router, middleware}; -use sf_auth_middleware_axum::{SfAuthConfig, sf_auth_middleware, auth_callback, SfUser}; +use axum::{routing::get, Router}; +use sf_auth_middleware_axum::{SfAuthLayer, create_auth_callback, SfUser}; use tower_sessions::{MemoryStore, SessionManagerLayer}; #[tokio::main] async fn main() { - // Configure the authentication middleware - let config = SfAuthConfig::new("https://myapp.com/dashboard"); - // Set up session store let session_store = MemoryStore::default(); let session_layer = SessionManagerLayer::new(session_store); // Build your application let app = Router::new() - // Public callback route (no auth required) - .route("/auth/callback", get(auth_callback)) + // Public callback route - redirects to /dashboard after successful auth + .route("/auth/callback", get(create_auth_callback("/dashboard"))) // Protected routes .route("/dashboard", get(dashboard)) - .layer(middleware::from_fn(move |session, req, next| { - sf_auth_middleware(config.clone(), session, req, next) + // Apply authentication middleware - points to the callback route + .layer(SfAuthLayer::new(|_req| { + "http://localhost:3000/auth/callback".to_string() })) // Add session layer .layer(session_layer); @@ -62,58 +61,117 @@ async fn dashboard(user: SfUser) -> String { ## How It Works -1. **Protection**: Apply the middleware to routes that require authentication +1. **Protection**: Apply the middleware layer to routes that require authentication 2. **Session Check**: The middleware checks for `sf_username` and `sf_user_id` in the session -3. **Redirect**: If not authenticated, redirects to the SF authentication endpoint: +3. **Redirect to SF Auth**: If not authenticated, calls your callback to get the redirect URI (should point to your `/auth/callback` route), then redirects to: ``` - https://snazzyfellas.com/api/redirect/authenticate?redirect_uri={your_configured_uri} + https://snazzyfellas.com/api/redirect/authenticate?redirect_uri={your_callback_route} ``` -4. **Callback**: The SF server redirects back to `/auth/callback` with credentials (`user_id`, `username`, `key`) -5. **Validation**: The callback handler validates credentials with the SF server: +4. **SF Auth Validates**: The SF server authenticates the user and redirects back to your callback route with credentials (`user_id`, `username`, `key`) +5. **Callback Validates**: Your callback handler validates credentials with the SF server: ``` POST https://snazzyfellas.com/api/redirect/validate Body: { "user_id": "...", "key": "..." } Response: { "valid": true, "user_id": "..." } ``` 6. **Session Setup**: On successful validation, sets `sf_username` and `sf_user_id` in the session -7. **Access Granted**: Use the `SfUser` extractor in handlers to access authenticated user data +7. **Final Redirect**: Redirects to the URI specified when you created the callback (e.g., `/dashboard`) +8. **Access Granted**: Subsequent requests use the `SfUser` extractor to access authenticated user data + +### Example Flow + +``` +User -> /dashboard (protected) + ↓ +Middleware checks session (not authenticated) + ↓ +Redirect to: https://snazzyfellas.com/api/redirect/authenticate?redirect_uri=http://myapp.com/auth/callback + ↓ +SF Auth validates user + ↓ +Redirect to: http://myapp.com/auth/callback?user_id=123&username=john&key=abc + ↓ +Callback validates with SF server + ↓ +Session set (sf_username, sf_user_id) + ↓ +Redirect to: /dashboard + ↓ +User accesses /dashboard (authenticated) +``` ## Architecture -### Configuration (`SfAuthConfig`) +### Middleware Layer (`SfAuthLayer`) -Configure the redirect URI where users should land after authentication: +The `SfAuthLayer` is a Tower middleware that checks authentication before allowing requests through. It takes a callback function that determines where users should be redirected after authentication. +**Simple Static Redirect to Callback:** ```rust -let config = SfAuthConfig::new("https://myapp.com/dashboard"); +// Point to your callback route +let auth_layer = SfAuthLayer::new(|_req| { + "http://localhost:3000/auth/callback".to_string() +}); ``` -### Middleware (`sf_auth_middleware`) +**Dynamic Redirect Based on Environment:** +```rust +let auth_layer = SfAuthLayer::new(|req| { + // Use different callback URLs for different environments + let host = req.headers() + .get("host") + .and_then(|h| h.to_str().ok()) + .unwrap_or("localhost:3000"); + + format!("http://{}/auth/callback", host) +}); +``` -The middleware function checks authentication and redirects unauthenticated users: +**Multiple Callback Routes:** +```rust +// Different sections can use different callback routes with different post-auth destinations +let admin_layer = SfAuthLayer::new(|_req| { + "http://myapp.com/admin/callback".to_string() +}); + +let user_layer = SfAuthLayer::new(|_req| { + "http://myapp.com/user/callback".to_string() +}); + +// Then define different callback handlers: +.route("/admin/callback", get(create_auth_callback("/admin/dashboard"))) +.route("/user/callback", get(create_auth_callback("/user/profile"))) +``` + +### Callback Route (`create_auth_callback`) + +Create a callback handler that specifies where to redirect users after successful authentication: ```rust -use axum::middleware; +// Redirect to dashboard after auth +.route("/auth/callback", get(create_auth_callback("/dashboard"))) -.layer(middleware::from_fn(move |session, req, next| { - sf_auth_middleware(config.clone(), session, req, next) +// Or use a full URL +.route("/auth/callback", get(create_auth_callback("https://myapp.com/dashboard"))) +``` + +**Important**: The callback route must be publicly accessible (not behind the auth middleware). + +The callback handler: +- Receives `user_id`, `username`, and `key` as query parameters from SF auth server +- Validates credentials with the SF server +- Sets session values (`sf_username`, `sf_user_id`) on successful validation +- Redirects to the specified URI on success +- Returns error on validation failure (fail-closed) + +**Middleware Configuration**: The middleware's redirect URI callback should point to this callback route: + +```rust +.layer(SfAuthLayer::new(|_req| { + "http://localhost:3000/auth/callback".to_string() })) ``` -### Callback Route (`auth_callback`) - -Mount this handler at `/auth/callback` to receive authentication callbacks: - -```rust -.route("/auth/callback", get(auth_callback)) -``` - -This route: -- Receives `user_id`, `username`, and `key` as query parameters -- Validates credentials with the SF server -- Sets session values on successful validation -- Returns error on validation failure (fail-closed) - ### Extractor (`SfUser`) Use the `SfUser` extractor in your handlers to access authenticated user data: @@ -171,6 +229,54 @@ let pool = sqlx::PgPool::connect("...").await?; let session_store = PostgresStore::new(pool); ``` +## Advanced Usage + +### Protecting Specific Routes + +You can apply the middleware to specific route groups: + +```rust +let app = Router::new() + // Public routes + .route("/", get(home)) + .route("/about", get(about)) + // Public callback that redirects to admin dashboard + .route("/auth/callback", get(create_auth_callback("/admin/dashboard"))) + // Protected admin routes + .nest("/admin", admin_routes().layer(SfAuthLayer::new(|_| { + "http://myapp.com/auth/callback".to_string() + }))) + // Apply session layer to everything + .layer(session_layer); + +fn admin_routes() -> Router { + Router::new() + .route("/dashboard", get(admin_dashboard)) + .route("/users", get(admin_users)) +} +``` + +### Multiple Callback Routes for Different Sections + +Different parts of your app can use different callback routes that redirect to different destinations: + +```rust +let app = Router::new() + // Admin callback redirects to admin dashboard + .route("/admin/callback", get(create_auth_callback("/admin/dashboard"))) + // User callback redirects to user profile + .route("/user/callback", get(create_auth_callback("/user/profile"))) + // Admin section uses admin callback + .nest("/admin", admin_routes().layer(SfAuthLayer::new(|_| { + "http://myapp.com/admin/callback".to_string() + }))) + // User section uses user callback + .nest("/user", user_routes().layer(SfAuthLayer::new(|_| { + "http://myapp.com/user/callback".to_string() + }))) + .layer(session_layer); +``` + ## Examples Run the included example: @@ -182,17 +288,19 @@ cargo run --example basic Then visit: - `http://localhost:3000/` - Public home page - `http://localhost:3000/dashboard` - Protected page (will redirect to SF auth) +- `http://localhost:3000/profile` - Another protected page ## API Reference -### `SfAuthConfig` +### `SfAuthLayer` ```rust -pub struct SfAuthConfig { /* ... */ } +pub struct SfAuthLayer { /* ... */ } -impl SfAuthConfig { - pub fn new(redirect_uri: impl Into) -> Self - pub fn redirect_uri(&self) -> &str +impl SfAuthLayer { + pub fn new(redirect_uri_fn: F) -> Self + where + F: Fn(&Request) -> String + Send + Sync + 'static } ``` @@ -207,25 +315,15 @@ impl SfUser { } ``` -### `sf_auth_middleware` +### `create_auth_callback` ```rust -pub async fn sf_auth_middleware( - config: SfAuthConfig, - session: Session, - req: Request, - next: Next, -) -> Response +pub fn create_auth_callback( + redirect_uri: impl Into, +) -> impl Fn(Session, Query) -> Future> ``` -### `auth_callback` - -```rust -pub async fn auth_callback( - session: Session, - Query(params): Query, -) -> Result -``` +Creates a handler that validates authentication and redirects to the specified URI on success. ## Security Considerations @@ -233,6 +331,7 @@ pub async fn auth_callback( 2. **Secure Sessions**: Configure session cookies with `secure` and `httponly` flags 3. **Session Expiry**: Set appropriate session expiration times 4. **Fail-Closed**: The middleware denies access on any validation errors +5. **Callback Security**: The `/auth/callback` route validates all credentials before setting session Example secure session configuration: @@ -243,9 +342,31 @@ use time::Duration; let session_layer = SessionManagerLayer::new(session_store) .with_secure(true) .with_http_only(true) + .with_same_site(cookie::SameSite::Lax) .with_expiry(Expiry::OnInactivity(Duration::hours(2))); ``` +## Troubleshooting + +### "No session found" errors + +Make sure the `SessionManagerLayer` is applied AFTER your routes but BEFORE the `SfAuthLayer`: + +```rust +let app = Router::new() + .route("/protected", get(handler)) + .layer(SfAuthLayer::new(|_| "...".to_string())) // Auth layer first + .layer(session_layer); // Session layer last +``` + +### Callback route requires authentication + +The `/auth/callback` route must be defined BEFORE applying the auth layer, or it should be in a separate router that doesn't have the auth middleware. + +### Session not persisting + +Ensure your session store is properly configured and that cookies are being set correctly (check HTTPS requirements for secure cookies). + ## License This project is licensed under the MIT License. diff --git a/examples/basic.rs b/examples/basic.rs index a146996..704ffb0 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -1,5 +1,5 @@ -use axum::{middleware, response::Html, routing::get, Router}; -use sf_auth_middleware_axum::{auth_callback, sf_auth_middleware, SfAuthConfig, SfUser}; +use axum::{response::Html, routing::get, Router}; +use sf_auth_middleware_axum::{create_auth_callback, SfAuthLayer, SfUser}; use tower_sessions::{MemoryStore, SessionManagerLayer}; #[tokio::main] @@ -7,10 +7,6 @@ async fn main() { // Set up tracing for debugging tracing_subscriber::fmt::init(); - // Configure the SF authentication middleware - // The redirect_uri should point to where users should land after authentication - let config = SfAuthConfig::new("http://localhost:3000/dashboard"); - // Set up session store using in-memory storage // In production, you'd want to use a persistent store like Redis or PostgreSQL let session_store = MemoryStore::default(); @@ -22,13 +18,17 @@ async fn main() { .route("/", get(home)) // Authentication callback route - must be publicly accessible // This is where the SF auth server redirects users after authentication - .route("/auth/callback", get(auth_callback)) + // After validation, users will be redirected to /dashboard + .route("/auth/callback", get(create_auth_callback("/dashboard"))) // Protected routes - require authentication .route("/dashboard", get(dashboard)) .route("/profile", get(profile)) - // Apply authentication middleware to protected routes - .layer(middleware::from_fn(move |session, req, next| { - sf_auth_middleware(config.clone(), session, req, next) + // Apply authentication middleware + // The redirect URI should point to the callback route in your app + // This is where the SF auth server will send users after they authenticate + .layer(SfAuthLayer::new(|_req| { + // Point to the auth callback route defined above + "http://localhost:3000/auth/callback".to_string() })) // Apply session layer (must be after the routes) .layer(session_layer); @@ -43,6 +43,13 @@ async fn main() { println!(" - http://localhost:3000/ (public)"); println!(" - http://localhost:3000/dashboard (protected, will redirect to SF auth)"); println!(" - http://localhost:3000/profile (protected, will redirect to SF auth)"); + println!(); + println!("Authentication flow:"); + println!(" 1. Access /dashboard (protected)"); + println!(" 2. Redirect to SF auth with redirect_uri=http://localhost:3000/auth/callback"); + println!(" 3. SF auth validates and redirects to /auth/callback with credentials"); + println!(" 4. Callback validates credentials and redirects to /dashboard"); + println!(" 5. Access granted to /dashboard"); axum::serve(listener, app).await.unwrap(); } diff --git a/src/callback.rs b/src/callback.rs index b3dcbba..7aa371c 100644 --- a/src/callback.rs +++ b/src/callback.rs @@ -1,8 +1,11 @@ use axum::{ extract::Query, - response::{IntoResponse, Response}, + response::{IntoResponse, Redirect, Response}, }; use serde::Deserialize; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; use tower_sessions::Session; use crate::{client::validate_user, error::SfAuthError}; @@ -15,52 +18,65 @@ pub struct CallbackQuery { key: String, } -/// Handler for the authentication callback route. +/// Creates an authentication callback handler that redirects to the specified URI on success. /// -/// This route should be mounted at `/auth/callback` in your application. +/// This function returns a handler that you can mount at any route in your application. /// It receives `user_id`, `username`, and `key` as query parameters, /// validates the credentials with the SF authentication server, and -/// sets the session if validation succeeds. +/// redirects to the provided URI if validation succeeds. +/// +/// # Arguments +/// +/// * `redirect_uri` - Where to redirect users after successful authentication /// /// # Example /// /// ```ignore /// use axum::{routing::get, Router}; -/// use sf_auth_middleware_axum::auth_callback; +/// use sf_auth_middleware_axum::create_auth_callback; /// /// let app = Router::new() -/// .route("/auth/callback", get(auth_callback)); +/// .route("/auth/callback", get(create_auth_callback("/dashboard"))); /// ``` /// /// # Query Parameters /// +/// The handler expects these query parameters: /// - `user_id`: The user's ID /// - `username`: The user's username /// - `key`: The authentication key to validate /// /// # Returns /// -/// Returns a 200 OK response with a success message if validation succeeds, -/// or an error response if validation fails. -pub async fn auth_callback( - session: Session, - Query(params): Query, -) -> Result { - // Validate the credentials with the SF server - let validated_user_id = validate_user(params.user_id.clone(), params.key).await?; +/// Returns a redirect to the specified URI on success, or an error response if validation fails. +pub fn create_auth_callback( + redirect_uri: impl Into, +) -> impl Fn(Session, Query) -> Pin> + Send>> + + Clone + + Send + + 'static { + let redirect_uri = Arc::new(redirect_uri.into()); - // Set session values only if validation succeeded - session - .insert("sf_username", params.username.clone()) - .await - .map_err(|e| SfAuthError::Session(e.to_string()))?; + move |session: Session, Query(params): Query| { + let redirect_uri = Arc::clone(&redirect_uri); + Box::pin(async move { + // Validate the credentials with the SF server + let validated_user_id = validate_user(params.user_id.clone(), params.key).await?; - session - .insert("sf_user_id", validated_user_id) - .await - .map_err(|e| SfAuthError::Session(e.to_string()))?; + // Set session values only if validation succeeded + session + .insert("sf_username", params.username.clone()) + .await + .map_err(|e| SfAuthError::Session(e.to_string()))?; - // Return success response - // Note: The SF auth server handles the redirect, so we just confirm success - Ok("Authentication successful".into_response()) + session + .insert("sf_user_id", validated_user_id) + .await + .map_err(|e| SfAuthError::Session(e.to_string()))?; + + // Redirect to the specified URI + Ok(Redirect::to(&redirect_uri).into_response()) + }) + as Pin> + Send>> + } } diff --git a/src/config.rs b/src/config.rs deleted file mode 100644 index f064479..0000000 --- a/src/config.rs +++ /dev/null @@ -1,41 +0,0 @@ -/// Configuration for SF authentication middleware -#[derive(Debug, Clone)] -pub struct SfAuthConfig { - /// The redirect URI to pass to the authentication endpoint. - /// This is where users will be redirected after successful authentication. - redirect_uri: String, -} - -impl SfAuthConfig { - /// Creates a new `SfAuthConfig` with the specified redirect URI. - /// - /// # Arguments - /// - /// * `redirect_uri` - The URI where users should be redirected after authentication - /// - /// # Example - /// - /// ```ignore - /// use sf_auth_middleware_axum::SfAuthConfig; - /// - /// let config = SfAuthConfig::new("https://myapp.com/dashboard"); - /// ``` - pub fn new(redirect_uri: impl Into) -> Self { - Self { - redirect_uri: redirect_uri.into(), - } - } - - /// Returns the configured redirect URI - pub fn redirect_uri(&self) -> &str { - &self.redirect_uri - } - - /// Builds the full authentication URL with the redirect_uri query parameter - pub(crate) fn auth_url(&self) -> String { - format!( - "https://snazzyfellas.com/api/redirect/authenticate?redirect_uri={}", - urlencoding::encode(&self.redirect_uri) - ) - } -} diff --git a/src/lib.rs b/src/lib.rs index 155bc12..c4a479a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,39 +1,37 @@ //! # SF Auth Middleware for Axum //! //! This library provides authentication middleware for Axum applications using -//! the SnazzyFellas authentication service with tower_session for session management. +//! the SnazzyFellas authentication service with tower_sessions for session management. //! //! ## Features //! -//! - **Middleware**: Automatically redirect unauthenticated users to the SF auth endpoint +//! - **Struct Middleware**: Tower-based middleware with dynamic redirect URI callback //! - **Extractor**: Type-safe access to authenticated user information //! - **Callback Handler**: Ready-to-use route for handling authentication callbacks -//! - **Session Integration**: Seamless integration with tower_session +//! - **Session Integration**: Seamless integration with tower_sessions //! //! ## Quick Start //! -//! ```no_run -//! use axum::{routing::get, Router, middleware}; -//! use sf_auth_middleware_axum::{SfAuthConfig, sf_auth_middleware, auth_callback, SfUser}; -//! use tower_session::{SessionManagerLayer, MemoryStore}; +//! ```ignore +//! use axum::{routing::get, Router}; +//! use sf_auth_middleware_axum::{SfAuthLayer, create_auth_callback, SfUser}; +//! use tower_sessions::{SessionManagerLayer, MemoryStore}; //! //! #[tokio::main] //! async fn main() { -//! // Configure the authentication middleware -//! let config = SfAuthConfig::new("https://myapp.com/dashboard"); -//! //! // Set up session store //! let session_store = MemoryStore::default(); //! let session_layer = SessionManagerLayer::new(session_store); //! //! // Build your application //! let app = Router::new() -//! // Public callback route (no auth required) -//! .route("/auth/callback", get(auth_callback)) +//! // Public callback route (no auth required) - redirects to /dashboard after auth +//! .route("/auth/callback", get(create_auth_callback("/dashboard"))) //! // Protected routes -//! .route("/protected", get(protected_handler)) -//! .layer(middleware::from_fn(move |session, req, next| { -//! sf_auth_middleware(config.clone(), session, req, next) +//! .route("/dashboard", get(dashboard)) +//! // Apply authentication middleware - redirect to callback route +//! .layer(SfAuthLayer::new(|_req| { +//! "http://localhost:3000/auth/callback".to_string() //! })) //! // Add session layer //! .layer(session_layer); @@ -43,16 +41,16 @@ //! axum::serve(listener, app).await.unwrap(); //! } //! -//! async fn protected_handler(user: SfUser) -> String { +//! async fn dashboard(user: SfUser) -> String { //! format!("Hello, {}! Your ID: {}", user.username(), user.user_id()) //! } //! ``` //! //! ## How It Works //! -//! 1. **Protection**: Apply the middleware to routes that require authentication +//! 1. **Protection**: Apply the middleware layer to routes that require authentication //! 2. **Check**: The middleware checks for `sf_username` and `sf_user_id` in the session -//! 3. **Redirect**: If not authenticated, redirects to `https://snazzyfellas.com/api/redirect/authenticate?redirect_uri={your_uri}` +//! 3. **Redirect**: If not authenticated, calls your callback to get the redirect URI, then redirects to SF auth //! 4. **Callback**: The SF server redirects back to `/auth/callback` with credentials //! 5. **Validation**: The callback handler validates credentials with the SF server //! 6. **Session**: On success, sets `sf_username` and `sf_user_id` in the session @@ -60,14 +58,12 @@ mod callback; mod client; -mod config; mod error; mod extractor; mod middleware; // Public exports -pub use callback::auth_callback; -pub use config::SfAuthConfig; +pub use callback::create_auth_callback; pub use error::SfAuthError; pub use extractor::SfUser; -pub use middleware::sf_auth_middleware; +pub use middleware::{SfAuthLayer, SfAuthMiddleware}; diff --git a/src/middleware.rs b/src/middleware.rs index cc25e75..374024d 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -1,47 +1,146 @@ use axum::{ extract::Request, - middleware::Next, response::{IntoResponse, Redirect, Response}, }; +use std::task::{Context, Poll}; +use tower::{Layer, Service}; use tower_sessions::Session; -use crate::config::SfAuthConfig; - -/// Middleware function that enforces SF authentication. +/// Type alias for the redirect URI callback function. /// -/// This middleware checks if the user has valid session credentials (`sf_username` and `sf_user_id`). -/// If not authenticated, it redirects to the SF authentication endpoint. +/// This function receives the current request and returns the redirect URI +/// where users should be sent after successful authentication. +pub type RedirectUriCallback = Box String + Send + Sync>; + +/// Middleware layer that enforces SF authentication. +/// +/// This layer checks if the user has valid session credentials (`sf_username` and `sf_user_id`). +/// If not authenticated, it redirects to the SF authentication endpoint with a configurable redirect URI. /// /// # Example /// /// ```ignore -/// use axum::{routing::get, Router, middleware}; -/// use sf_auth_middleware_axum::{SfAuthConfig, sf_auth_middleware}; -/// -/// let config = SfAuthConfig::new("https://myapp.com/dashboard"); +/// use axum::{routing::get, Router}; +/// use sf_auth_middleware_axum::SfAuthLayer; /// /// let app = Router::new() /// .route("/protected", get(|| async { "Protected!" })) -/// .layer(middleware::from_fn(move |session, req, next| { -/// sf_auth_middleware(config.clone(), session, req, next) +/// .layer(SfAuthLayer::new(|_req| { +/// "https://myapp.com/dashboard".to_string() /// })); /// ``` -pub async fn sf_auth_middleware( - config: SfAuthConfig, - session: Session, - req: Request, - next: Next, -) -> Response { - // Try to get username and user_id from session - let username: Option = session.get("sf_username").await.unwrap_or(None); - let user_id: Option = session.get("sf_user_id").await.unwrap_or(None); +#[derive(Clone)] +pub struct SfAuthLayer { + redirect_uri_fn: std::sync::Arc, +} - // Check if both are present - if username.is_some() && user_id.is_some() { - // User is authenticated, proceed with the request - next.run(req).await - } else { - // User is not authenticated, redirect to auth endpoint - Redirect::to(&config.auth_url()).into_response() +impl SfAuthLayer { + /// Creates a new `SfAuthLayer` with a callback function to determine the redirect URI. + /// + /// # Arguments + /// + /// * `redirect_uri_fn` - A function that takes a request reference and returns the redirect URI + /// + /// # Example + /// + /// ```ignore + /// use sf_auth_middleware_axum::SfAuthLayer; + /// + /// // Simple static redirect + /// let layer = SfAuthLayer::new(|_req| "https://myapp.com/dashboard".to_string()); + /// + /// // Dynamic redirect based on request + /// let layer = SfAuthLayer::new(|req| { + /// format!("https://myapp.com{}", req.uri().path()) + /// }); + /// ``` + pub fn new(redirect_uri_fn: F) -> Self + where + F: Fn(&Request) -> String + Send + Sync + 'static, + { + Self { + redirect_uri_fn: std::sync::Arc::new(Box::new(redirect_uri_fn)), + } } } + +impl Layer for SfAuthLayer { + type Service = SfAuthMiddleware; + + fn layer(&self, inner: S) -> Self::Service { + SfAuthMiddleware { + inner, + redirect_uri_fn: self.redirect_uri_fn.clone(), + } + } +} + +/// The actual middleware service implementation. +#[derive(Clone)] +pub struct SfAuthMiddleware { + inner: S, + redirect_uri_fn: std::sync::Arc, +} + +impl Service for SfAuthMiddleware +where + S: Service + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = std::pin::Pin< + Box> + Send>, + >; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let redirect_uri_fn = self.redirect_uri_fn.clone(); + let mut inner = self.inner.clone(); + + Box::pin(async move { + // Extract session from request extensions + let session = match req.extensions().get::() { + Some(session) => session.clone(), + None => { + // No session found, redirect to auth + let redirect_uri = (redirect_uri_fn)(&req); + let auth_url = build_auth_url(&redirect_uri); + return Ok(Redirect::to(&auth_url).into_response()); + } + }; + + // Try to get username and user_id from session + let username: Option = session.get("sf_username").await.unwrap_or(None); + let user_id: Option = session.get("sf_user_id").await.unwrap_or(None); + + tracing::info!( + "Username: {}, User ID: {}", + username.as_deref().unwrap_or("None"), + user_id.as_deref().unwrap_or("None") + ); + + // Check if both are present + if username.is_some() && user_id.is_some() { + // User is authenticated, proceed with the request + inner.call(req).await + } else { + // User is not authenticated, redirect to auth endpoint + let redirect_uri = (redirect_uri_fn)(&req); + let auth_url = build_auth_url(&redirect_uri); + Ok(Redirect::to(&auth_url).into_response()) + } + }) + } +} + +/// Builds the authentication URL with the redirect_uri query parameter. +fn build_auth_url(redirect_uri: &str) -> String { + format!( + "https://snazzyfellas.com/api/redirect/authenticate?redirect_uri={}", + urlencoding::encode(redirect_uri) + ) +}