Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

SQLite bank transaction

[package]
name = "sqlx-sqlite-bank"
version = "0.1.0"
edition = "2024"

[dependencies]
sqlx = { version = "0.8.6", features = ["sqlite", "runtime-tokio-rustls", "migrate"] }
tokio = { version = "1.48.0", features = ["macros", "rt-multi-thread"] }
CREATE TABLE IF NOT EXISTS accounts (
    id INTEGER PRIMARY KEY AUTOINCREMENT,
    name TEXT NOT NULL UNIQUE,
    amount INTEGER NOT NULL CHECK (amount >= 0)
);
use sqlx::{
    Executor, Row, SqlitePool,
    migrate::Migrator,
    sqlite::{SqliteConnectOptions, SqlitePoolOptions},
};
use std::str::FromStr;

static MIGRATOR: Migrator = sqlx::migrate!();

const DATABASE_URL: &str = "sqlite://bank.db";
const USAGE: &str = "Usage: cargo run -- list | cargo run -- add NAME AMOUNT | cargo run -- transfer AMOUNT FROM_NAME TO_NAME";

#[tokio::main]
async fn main() {
    let args: Vec<String> = std::env::args().skip(1).collect();
    let database_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| DATABASE_URL.to_string());

    if let Err(err) = run(args, &database_url).await {
        eprintln!("{err}");
        std::process::exit(1);
    }
}

async fn run(args: Vec<String>, database_url: &str) -> Result<(), Box<dyn std::error::Error>> {
    let options = SqliteConnectOptions::from_str(database_url)?.create_if_missing(true);
    let pool = SqlitePoolOptions::new().connect_with(options).await?;

    initialize_database(&pool).await?;

    match args.as_slice() {
        [cmd] if cmd == "list" => list_accounts(&pool).await?,
        [cmd, name, amount] if cmd == "add" => {
            add_account(&pool, name, parse_non_negative_amount(amount)?).await?;
        }
        [cmd, amount, from_name, to_name] if cmd == "transfer" => {
            transfer(
                &pool,
                parse_positive_amount(amount)?,
                from_name,
                to_name,
                should_panic(),
            )
            .await?;
        }
        _ => {
            return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, USAGE).into());
        }
    }

    Ok(())
}

async fn initialize_database(pool: &SqlitePool) -> Result<(), sqlx::Error> {
    pool.execute("PRAGMA foreign_keys = ON").await?;
    MIGRATOR.run(pool).await?;
    Ok(())
}

async fn add_account(pool: &SqlitePool, name: &str, amount: i64) -> Result<(), sqlx::Error> {
    sqlx::query(
        "INSERT INTO accounts (name, amount)
         VALUES (?, ?)",
    )
    .bind(name)
    .bind(amount)
    .execute(pool)
    .await?;

    Ok(())
}

async fn list_accounts(pool: &SqlitePool) -> Result<(), sqlx::Error> {
    let rows = sqlx::query(
        "SELECT name, amount
         FROM accounts
         ORDER BY name",
    )
    .fetch_all(pool)
    .await?;

    println!("name | amount");
    println!("---- | ------");

    for row in rows {
        let name: String = row.get("name");
        let amount: i64 = row.get("amount");
        println!("{name} | {amount}");
    }

    Ok(())
}

async fn transfer(
    pool: &SqlitePool,
    amount: i64,
    from_name: &str,
    to_name: &str,
    panic_in_middle: bool,
) -> Result<(), Box<dyn std::error::Error>> {
    if from_name == to_name {
        return Err(std::io::Error::new(
            std::io::ErrorKind::InvalidInput,
            "Source and destination accounts must be different",
        )
        .into());
    }

    let mut tx = pool.begin().await?;

    let from_balance = find_amount(&mut *tx, from_name).await?.ok_or_else(|| {
        std::io::Error::new(
            std::io::ErrorKind::InvalidInput,
            format!("Unknown account: {from_name}"),
        )
    })?;

    if find_amount(&mut *tx, to_name).await?.is_none() {
        return Err(std::io::Error::new(
            std::io::ErrorKind::InvalidInput,
            format!("Unknown account: {to_name}"),
        )
        .into());
    }

    if from_balance < amount {
        return Err(std::io::Error::new(
            std::io::ErrorKind::InvalidInput,
            format!("Insufficient funds in {from_name}: has {from_balance}, needs {amount}"),
        )
        .into());
    }

    sqlx::query(
        "UPDATE accounts
         SET amount = amount - ?
         WHERE name = ?",
    )
    .bind(amount)
    .bind(from_name)
    .execute(&mut *tx)
    .await?;

    if panic_in_middle {
        panic!("simulated crash in the middle of transfer");
    }

    sqlx::query(
        "UPDATE accounts
         SET amount = amount + ?
         WHERE name = ?",
    )
    .bind(amount)
    .bind(to_name)
    .execute(&mut *tx)
    .await?;

    tx.commit().await?;

    Ok(())
}

async fn find_amount<'a, E>(executor: E, name: &str) -> Result<Option<i64>, sqlx::Error>
where
    E: Executor<'a, Database = sqlx::Sqlite>,
{
    let row = sqlx::query("SELECT amount FROM accounts WHERE name = ?")
        .bind(name)
        .fetch_optional(executor)
        .await?;

    Ok(row.map(|record| record.get("amount")))
}

fn parse_non_negative_amount(value: &str) -> Result<i64, Box<dyn std::error::Error>> {
    let amount = value.parse::<i64>().map_err(|_| {
        std::io::Error::new(
            std::io::ErrorKind::InvalidInput,
            format!("Amount must be a non-negative integer: {value}"),
        )
    })?;

    if amount < 0 {
        return Err(std::io::Error::new(
            std::io::ErrorKind::InvalidInput,
            format!("Amount must be a non-negative integer: {value}"),
        )
        .into());
    }

    Ok(amount)
}

fn parse_positive_amount(value: &str) -> Result<i64, Box<dyn std::error::Error>> {
    let amount = value.parse::<i64>().map_err(|_| {
        std::io::Error::new(
            std::io::ErrorKind::InvalidInput,
            format!("Amount must be a positive integer: {value}"),
        )
    })?;

    if amount <= 0 {
        return Err(std::io::Error::new(
            std::io::ErrorKind::InvalidInput,
            format!("Amount must be a positive integer: {value}"),
        )
        .into());
    }

    Ok(amount)
}

fn should_panic() -> bool {
    matches!(std::env::var("PANIC").as_deref(), Ok("true"))
}
#![allow(unused)]
fn main() {
use std::path::PathBuf;
use std::process::Command;
use std::time::{SystemTime, UNIX_EPOCH};

fn binary_path() -> PathBuf {
    PathBuf::from(env!("CARGO_BIN_EXE_sqlx-sqlite-bank"))
}

fn unique_database_url() -> String {
    let nanos = SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .expect("system clock before unix epoch")
        .as_nanos();
    let db_path = std::env::temp_dir().join(format!(
        "sqlx-sqlite-bank-cli-test-{}-{}.db",
        std::process::id(),
        nanos
    ));

    format!("sqlite://{}", db_path.display())
}

#[test]
fn add_list_and_transfer_money() {
    let database_url = unique_database_url();

    let add_alice = Command::new(binary_path())
        .args(["add", "Alice", "100"])
        .env("DATABASE_URL", &database_url)
        .output()
        .expect("failed to add Alice");
    assert!(add_alice.status.success());
    assert_eq!(String::from_utf8_lossy(&add_alice.stdout), "");
    assert_eq!(String::from_utf8_lossy(&add_alice.stderr), "");

    let add_bob = Command::new(binary_path())
        .args(["add", "Bob", "40"])
        .env("DATABASE_URL", &database_url)
        .output()
        .expect("failed to add Bob");
    assert!(add_bob.status.success());
    assert_eq!(String::from_utf8_lossy(&add_bob.stdout), "");
    assert_eq!(String::from_utf8_lossy(&add_bob.stderr), "");

    let transfer = Command::new(binary_path())
        .args(["transfer", "25", "Alice", "Bob"])
        .env("DATABASE_URL", &database_url)
        .output()
        .expect("failed to transfer money");
    assert!(transfer.status.success());
    assert_eq!(String::from_utf8_lossy(&transfer.stdout), "");
    assert_eq!(String::from_utf8_lossy(&transfer.stderr), "");

    let list = Command::new(binary_path())
        .arg("list")
        .env("DATABASE_URL", &database_url)
        .output()
        .expect("failed to list accounts");
    assert!(list.status.success());
    assert_eq!(
        String::from_utf8_lossy(&list.stdout),
        "name | amount\n---- | ------\nAlice | 75\nBob | 65\n"
    );
    assert_eq!(String::from_utf8_lossy(&list.stderr), "");
}

#[test]
fn panic_during_transfer_rolls_back_transaction() {
    let database_url = unique_database_url();

    for args in [["add", "Alice", "100"], ["add", "Bob", "40"]] {
        let output = Command::new(binary_path())
            .args(args)
            .env("DATABASE_URL", &database_url)
            .output()
            .expect("failed to seed accounts");
        assert!(output.status.success());
    }

    let transfer = Command::new(binary_path())
        .args(["transfer", "25", "Alice", "Bob"])
        .env("DATABASE_URL", &database_url)
        .env("PANIC", "true")
        .output()
        .expect("failed to run crashing transfer");
    assert!(!transfer.status.success());
    assert_eq!(String::from_utf8_lossy(&transfer.stdout), "");
    assert!(
        String::from_utf8_lossy(&transfer.stderr)
            .contains("simulated crash in the middle of transfer")
    );

    let list = Command::new(binary_path())
        .arg("list")
        .env("DATABASE_URL", &database_url)
        .output()
        .expect("failed to list accounts after crash");
    assert!(list.status.success());
    assert_eq!(
        String::from_utf8_lossy(&list.stdout),
        "name | amount\n---- | ------\nAlice | 100\nBob | 40\n"
    );
}

#[test]
fn invalid_arguments_print_usage() {
    let database_url = unique_database_url();

    let invalid = Command::new(binary_path())
        .args(["hello", "world"])
        .env("DATABASE_URL", &database_url)
        .output()
        .expect("failed to run invalid command");

    assert!(!invalid.status.success());
    assert_eq!(String::from_utf8_lossy(&invalid.stdout), "");
    assert_eq!(
        String::from_utf8_lossy(&invalid.stderr),
        "Usage: cargo run -- list | cargo run -- add NAME AMOUNT | cargo run -- transfer AMOUNT FROM_NAME TO_NAME\n"
    );
}

#[test]
fn transfer_requires_enough_money() {
    let database_url = unique_database_url();

    for args in [["add", "Alice", "10"], ["add", "Bob", "40"]] {
        let output = Command::new(binary_path())
            .args(args)
            .env("DATABASE_URL", &database_url)
            .output()
            .expect("failed to seed accounts");
        assert!(output.status.success());
    }

    let transfer = Command::new(binary_path())
        .args(["transfer", "25", "Alice", "Bob"])
        .env("DATABASE_URL", &database_url)
        .output()
        .expect("failed to run transfer");

    assert!(!transfer.status.success());
    assert_eq!(String::from_utf8_lossy(&transfer.stdout), "");
    assert_eq!(
        String::from_utf8_lossy(&transfer.stderr),
        "Insufficient funds in Alice: has 10, needs 25\n"
    );
}
}