Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
382 changes: 62 additions & 320 deletions Cargo.lock

Large diffs are not rendered by default.

21 changes: 8 additions & 13 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,21 @@ default = []
api-mocks = ["dep:trillium-testing"]
integration-testing = []
# Enables a non-production axum middleware that reads an `X-Integration-Testing-User`
# header and injects the decoded [`User`] into request extensions. Used by
# `test-support` to impersonate specific users in tests; never enable in deployed
# builds. TODO(Part 9): fold into `integration-testing` when test-support is rewritten.
# header and injects the decoded [`User`] into request extensions. Strictly for
# use by the test harness (`test-support`); never enable in deployed builds.
# TODO: fold into `integration-testing` in Part 9 (test-support rewrite).
test-header-injection = []
otlp-trace = ["opentelemetry/trace", "opentelemetry-otlp", "opentelemetry_sdk/trace", "trillium-opentelemetry/trace"]
otlp-trace = ["opentelemetry/trace", "opentelemetry-otlp", "opentelemetry_sdk/trace"]

[dependencies]
aes-gcm = "0.10.3"
async-trait = "0.1"
axum = "0.8"
async-lock = "3.4.1"
async-session = "3.0.0"
base64 = "0.22.1"
console-subscriber = "0.5.0"
# Enables Key::derive_from in tower-sessions's re-exported cookie crate (via Cargo feature unification)
cookie = { version = "0.18", features = ["key-expansion"] }
educe = "0.6.0"
email_address = "0.2.9"
fastrand = "2.3.0"
Expand All @@ -66,6 +68,7 @@ subtle = "2.6.1"
thiserror = "2.0.12"
time = { version = "0.3.41", features = ["serde", "serde-well-known"] }
tokio = { version = "1.47.1", features = ["full"] }
tokio-util = { version = "0.7", features = ["rt"] }
tracing = "0.1.41"
trillium = "0.2.20"
tracing-chrome = "0.7.2"
Expand All @@ -80,26 +83,18 @@ tracing-subscriber = { version = "0.3.19", features = [
trillium-api = { version = "0.2.0-rc.12", default-features = false }
trillium-caching-headers = "0.2.3"
trillium-client = { version = "0.6.2", features = ["json"] }
trillium-compression = "0.1.3"
trillium-conn-id = "0.2.3"
trillium-cookies = "0.4.2"
trillium-forwarding = "0.2.4"
trillium-http = { version = "0.3.14", features = ["http-compat-1", "serde"] }
trillium-logger = "0.4.5"
trillium-macros = "0.0.6"
trillium-prometheus = "0.2.0"
trillium-redirect = "0.1.2"
trillium-router = "0.4.1"
trillium-rustls = "0.9.0"
trillium-sessions = "0.4.3"
trillium-static-compiled = "0.5.2"
trillium-testing = { version = "0.7.0", optional = true }
trillium-tokio = "0.4.0"
typenum = "1.18.0"
url = "2.5.2"
uuid = { version = "1.16.0", features = ["v4", "fast-rng", "serde"] }
validator = { version = "0.20.0", features = ["derive"] }
trillium-opentelemetry = { version = "0.10.0", default-features = false, features = ["metrics"] }
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] }
tower = { version = "0.5", features = ["util"] }
tower-http = { version = "0.6", features = ["trace", "cors", "compression-full", "set-header"] }
Expand Down
135 changes: 99 additions & 36 deletions src/bin.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,42 @@
use divviup_api::{
trace::{install_trace_subscriber, traceconfig_handler},
Config, DivviupApi, Queue,
handler::build_app, telemetry, trace, trace::install_trace_subscriber, Config, Queue,
};
use trillium::HttpConfig;
use trillium_http::Stopper;
use trillium_tokio::CloneCounterObserver;
use prometheus::Registry;
use std::sync::Arc;
use tokio::{
net::TcpListener,
signal::{
self,
unix::{signal, SignalKind},
},
};
use tokio_util::sync::CancellationToken;

/// Maximum request body size: 1 MiB. This JSON API never needs bodies
/// larger than this under normal operation.
const MAX_REQUEST_BODY_SIZE: u64 = 1024 * 1024;
#[derive(Clone, Debug)]
struct MonitoringState {
registry: Registry,
trace_reload_handle: Arc<trace::TraceReloadHandle>,
}

impl axum::extract::FromRef<MonitoringState> for Registry {
fn from_ref(state: &MonitoringState) -> Self {
state.registry.clone()
}
}

impl axum::extract::FromRef<MonitoringState> for Arc<trace::TraceReloadHandle> {
fn from_ref(state: &MonitoringState) -> Self {
state.trace_reload_handle.clone()
}
}

#[tokio::main]
async fn main() {
// Choose aws-lc-rs as the default rustls crypto provider. This is what's currently enabled by
// the default Cargo feature. Specifying a default provider here prevents runtime errors if
// another dependency also enables the ring feature.
// Choose aws-lc-rs as the default rustls crypto provider. This is what's
// currently enabled by the default Cargo feature. Specifying a default
// provider here prevents runtime errors if another dependency also enables
// the ring feature.
// TODO: switch to a direct `rustls` dep when trillium-rustls is removed
let _ = trillium_rustls::rustls::crypto::aws_lc_rs::default_provider().install_default();

let config = match Config::from_env() {
Expand All @@ -26,31 +48,72 @@ async fn main() {
};

let (_guards, trace_reload_handle) = install_trace_subscriber(&config.trace_config()).unwrap();
let cancel = CancellationToken::new();

// Monitoring server (metrics + traceconfig)
let registry = telemetry::install_metrics().expect("failed to install metrics provider");
let monitoring_state = MonitoringState {
registry,
trace_reload_handle: Arc::new(trace_reload_handle),
};
let monitoring_router = axum::Router::new()
.route("/metrics", axum::routing::get(telemetry::metrics_handler))
.route(
"/traceconfig",
axum::routing::get(trace::get_traceconfig).put(trace::put_traceconfig),
)
.with_state(monitoring_state);
let monitoring_listener = TcpListener::bind(config.monitoring_listen_address)
.await
.expect("failed to bind monitoring listener");
let monitoring_cancel = cancel.clone();
let monitoring_handle = tokio::spawn(async move {
if let Err(e) = axum::serve(monitoring_listener, monitoring_router)
.with_graceful_shutdown(monitoring_cancel.cancelled_owned())
.await
{
tracing::error!("monitoring server error: {e}");
}
});

let stopper = Stopper::new();
let observer = CloneCounterObserver::default();

trillium_tokio::config()
.without_signals()
.with_socketaddr(config.monitoring_listen_address)
.with_observer(observer.clone())
.with_stopper(stopper.clone())
.spawn((
divviup_api::telemetry::metrics_exporter().unwrap(),
traceconfig_handler(trace_reload_handle),
));

let app = DivviupApi::new(config).await;

Queue::new(app.db(), app.config())
.with_observer(observer.clone())
.with_stopper(stopper.clone())
.spawn_workers();

trillium_tokio::config()
.with_http_config(HttpConfig::default().with_received_body_max_len(MAX_REQUEST_BODY_SIZE))
.with_stopper(stopper)
.with_observer(observer)
.spawn(app)
// Main application
let listen_address = config.listen_address;
let app = build_app(config).await;

tracing::info!(
"divviup-api {} listening on {listen_address}",
env!("CARGO_PKG_VERSION")
);

let queue_handle = Queue::new(&app.db, &app.config, cancel.clone()).spawn_workers();

let listener = TcpListener::bind(listen_address)
.await
.expect("failed to bind main listener");

let serve_result = axum::serve(listener, app.router)
.with_graceful_shutdown(shutdown_signal(cancel.clone()))
.await;
cancel.cancel();

if let Err(e) = serve_result {
tracing::error!("server error: {e}");
}

if let Err(e) = queue_handle.await {
tracing::error!("queue worker panic: {e}");
}

let _ = monitoring_handle.await;
}

async fn shutdown_signal(cancel: CancellationToken) {
let ctrl_c = signal::ctrl_c();
let mut sigterm = signal(SignalKind::terminate()).expect("failed to register SIGTERM handler");
tokio::select! {
_ = ctrl_c => {},
_ = sigterm.recv() => {},
}
tracing::info!("shutdown signal received, draining connections");
cancel.cancel();
}
14 changes: 11 additions & 3 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::{
collections::VecDeque,
env::{self, VarError},
error::Error,
net::SocketAddr,
net::{IpAddr, Ipv6Addr, SocketAddr},
str::FromStr,
};
use thiserror::Error;
Expand Down Expand Up @@ -48,6 +48,8 @@ pub struct Config {
pub postmark_token: String,
/// The URL to postmark.
pub postmark_url: Url,
/// The address to listen on for the main HTTP server.
pub listen_address: SocketAddr,
/// The address to listen on for prometheus metrics and tracing configuration.
pub monitoring_listen_address: SocketAddr,
/// Comma-joined unpadded base64url encoded cryptographically random secrets, 32 bytes long
Expand Down Expand Up @@ -158,9 +160,14 @@ impl Config {
email_address: var("EMAIL_ADDRESS")?,
postmark_token: var("POSTMARK_TOKEN")?,
postmark_url: Url::parse(POSTMARK_URL).unwrap(),
listen_address: {
let host: IpAddr = var_optional("HOST", IpAddr::from(Ipv6Addr::UNSPECIFIED))?;
let port: u16 = var_optional("PORT", 8080)?;
SocketAddr::new(host, port)
},
monitoring_listen_address: var_optional(
"MONITORING_LISTEN_ADDRESS",
"127.0.0.1:9464".parse().unwrap(),
SocketAddr::from((Ipv6Addr::LOCALHOST, 9464)),
)?,
session_secrets: var("SESSION_SECRETS")?,
trace_use_test_writer: false,
Expand Down Expand Up @@ -282,7 +289,8 @@ mod tests {
email_address: "test@example.test".parse().unwrap(),
postmark_token: "pmak-secret-token".into(),
postmark_url: "https://postmark.example".parse().unwrap(),
monitoring_listen_address: "127.0.0.1:9464".parse().unwrap(),
listen_address: SocketAddr::from((Ipv6Addr::UNSPECIFIED, 8080)),
monitoring_listen_address: SocketAddr::from((Ipv6Addr::LOCALHOST, 9464)),
session_secrets: vec![0u8; 32].into(),
trace_use_test_writer: false,
trace_force_json_writer: false,
Expand Down
24 changes: 1 addition & 23 deletions src/db.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use log::LevelFilter;
use sea_orm::{ConnectOptions, ConnectionTrait, Database, DbConn};
use std::ops::{Deref, DerefMut};
use trillium::{async_trait, Conn, Handler};
use trillium_api::FromConn;

#[derive(Clone, Debug)]
pub struct Db(DbConn);
Expand All @@ -15,26 +13,6 @@ impl Db {
}
}

impl From<DbConn> for Db {
fn from(value: DbConn) -> Self {
Self(value)
}
}

#[async_trait]
impl FromConn for Db {
async fn from_conn(conn: &mut Conn) -> Option<Self> {
conn.state().cloned()
}
}

#[async_trait]
impl Handler for Db {
async fn run(&self, conn: Conn) -> Conn {
conn.with_state(self.clone())
}
}

impl Deref for Db {
type Target = DbConn;

Expand All @@ -49,7 +27,7 @@ impl DerefMut for Db {
}
}

#[async_trait]
#[async_trait::async_trait]
impl ConnectionTrait for Db {
fn get_database_backend(&self) -> sea_orm::DbBackend {
self.0.get_database_backend()
Expand Down
Loading
Loading