[package]
name = "sqlx-sqlite-bank"
version = "0.1.0"
edition = "2024"
[dependencies]
sqlx = { version = "0.8.6", features = ["sqlite", "runtime-tokio-rustls", "migrate"] }
tokio = { version = "1.48.0", features = ["macros", "rt-multi-thread"] }
CREATE TABLE IF NOT EXISTS accounts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
amount INTEGER NOT NULL CHECK (amount >= 0)
);
use sqlx::{
Executor, Row, SqlitePool,
migrate::Migrator,
sqlite::{SqliteConnectOptions, SqlitePoolOptions},
};
use std::str::FromStr;
static MIGRATOR: Migrator = sqlx::migrate!();
const DATABASE_URL: &str = "sqlite://bank.db";
const USAGE: &str = "Usage: cargo run -- list | cargo run -- add NAME AMOUNT | cargo run -- transfer AMOUNT FROM_NAME TO_NAME";
#[tokio::main]
async fn main() {
let args: Vec<String> = std::env::args().skip(1).collect();
let database_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| DATABASE_URL.to_string());
if let Err(err) = run(args, &database_url).await {
eprintln!("{err}");
std::process::exit(1);
}
}
async fn run(args: Vec<String>, database_url: &str) -> Result<(), Box<dyn std::error::Error>> {
let options = SqliteConnectOptions::from_str(database_url)?.create_if_missing(true);
let pool = SqlitePoolOptions::new().connect_with(options).await?;
initialize_database(&pool).await?;
match args.as_slice() {
[cmd] if cmd == "list" => list_accounts(&pool).await?,
[cmd, name, amount] if cmd == "add" => {
add_account(&pool, name, parse_non_negative_amount(amount)?).await?;
}
[cmd, amount, from_name, to_name] if cmd == "transfer" => {
transfer(
&pool,
parse_positive_amount(amount)?,
from_name,
to_name,
should_panic(),
)
.await?;
}
_ => {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, USAGE).into());
}
}
Ok(())
}
async fn initialize_database(pool: &SqlitePool) -> Result<(), sqlx::Error> {
pool.execute("PRAGMA foreign_keys = ON").await?;
MIGRATOR.run(pool).await?;
Ok(())
}
async fn add_account(pool: &SqlitePool, name: &str, amount: i64) -> Result<(), sqlx::Error> {
sqlx::query(
"INSERT INTO accounts (name, amount)
VALUES (?, ?)",
)
.bind(name)
.bind(amount)
.execute(pool)
.await?;
Ok(())
}
async fn list_accounts(pool: &SqlitePool) -> Result<(), sqlx::Error> {
let rows = sqlx::query(
"SELECT name, amount
FROM accounts
ORDER BY name",
)
.fetch_all(pool)
.await?;
println!("name | amount");
println!("---- | ------");
for row in rows {
let name: String = row.get("name");
let amount: i64 = row.get("amount");
println!("{name} | {amount}");
}
Ok(())
}
async fn transfer(
pool: &SqlitePool,
amount: i64,
from_name: &str,
to_name: &str,
panic_in_middle: bool,
) -> Result<(), Box<dyn std::error::Error>> {
if from_name == to_name {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Source and destination accounts must be different",
)
.into());
}
let mut tx = pool.begin().await?;
let from_balance = find_amount(&mut *tx, from_name).await?.ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Unknown account: {from_name}"),
)
})?;
if find_amount(&mut *tx, to_name).await?.is_none() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Unknown account: {to_name}"),
)
.into());
}
if from_balance < amount {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Insufficient funds in {from_name}: has {from_balance}, needs {amount}"),
)
.into());
}
sqlx::query(
"UPDATE accounts
SET amount = amount - ?
WHERE name = ?",
)
.bind(amount)
.bind(from_name)
.execute(&mut *tx)
.await?;
if panic_in_middle {
panic!("simulated crash in the middle of transfer");
}
sqlx::query(
"UPDATE accounts
SET amount = amount + ?
WHERE name = ?",
)
.bind(amount)
.bind(to_name)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(())
}
async fn find_amount<'a, E>(executor: E, name: &str) -> Result<Option<i64>, sqlx::Error>
where
E: Executor<'a, Database = sqlx::Sqlite>,
{
let row = sqlx::query("SELECT amount FROM accounts WHERE name = ?")
.bind(name)
.fetch_optional(executor)
.await?;
Ok(row.map(|record| record.get("amount")))
}
fn parse_non_negative_amount(value: &str) -> Result<i64, Box<dyn std::error::Error>> {
let amount = value.parse::<i64>().map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Amount must be a non-negative integer: {value}"),
)
})?;
if amount < 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Amount must be a non-negative integer: {value}"),
)
.into());
}
Ok(amount)
}
fn parse_positive_amount(value: &str) -> Result<i64, Box<dyn std::error::Error>> {
let amount = value.parse::<i64>().map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Amount must be a positive integer: {value}"),
)
})?;
if amount <= 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Amount must be a positive integer: {value}"),
)
.into());
}
Ok(amount)
}
fn should_panic() -> bool {
matches!(std::env::var("PANIC").as_deref(), Ok("true"))
}
#![allow(unused)]
fn main() {
use std::path::PathBuf;
use std::process::Command;
use std::time::{SystemTime, UNIX_EPOCH};
fn binary_path() -> PathBuf {
PathBuf::from(env!("CARGO_BIN_EXE_sqlx-sqlite-bank"))
}
fn unique_database_url() -> String {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock before unix epoch")
.as_nanos();
let db_path = std::env::temp_dir().join(format!(
"sqlx-sqlite-bank-cli-test-{}-{}.db",
std::process::id(),
nanos
));
format!("sqlite://{}", db_path.display())
}
#[test]
fn add_list_and_transfer_money() {
let database_url = unique_database_url();
let add_alice = Command::new(binary_path())
.args(["add", "Alice", "100"])
.env("DATABASE_URL", &database_url)
.output()
.expect("failed to add Alice");
assert!(add_alice.status.success());
assert_eq!(String::from_utf8_lossy(&add_alice.stdout), "");
assert_eq!(String::from_utf8_lossy(&add_alice.stderr), "");
let add_bob = Command::new(binary_path())
.args(["add", "Bob", "40"])
.env("DATABASE_URL", &database_url)
.output()
.expect("failed to add Bob");
assert!(add_bob.status.success());
assert_eq!(String::from_utf8_lossy(&add_bob.stdout), "");
assert_eq!(String::from_utf8_lossy(&add_bob.stderr), "");
let transfer = Command::new(binary_path())
.args(["transfer", "25", "Alice", "Bob"])
.env("DATABASE_URL", &database_url)
.output()
.expect("failed to transfer money");
assert!(transfer.status.success());
assert_eq!(String::from_utf8_lossy(&transfer.stdout), "");
assert_eq!(String::from_utf8_lossy(&transfer.stderr), "");
let list = Command::new(binary_path())
.arg("list")
.env("DATABASE_URL", &database_url)
.output()
.expect("failed to list accounts");
assert!(list.status.success());
assert_eq!(
String::from_utf8_lossy(&list.stdout),
"name | amount\n---- | ------\nAlice | 75\nBob | 65\n"
);
assert_eq!(String::from_utf8_lossy(&list.stderr), "");
}
#[test]
fn panic_during_transfer_rolls_back_transaction() {
let database_url = unique_database_url();
for args in [["add", "Alice", "100"], ["add", "Bob", "40"]] {
let output = Command::new(binary_path())
.args(args)
.env("DATABASE_URL", &database_url)
.output()
.expect("failed to seed accounts");
assert!(output.status.success());
}
let transfer = Command::new(binary_path())
.args(["transfer", "25", "Alice", "Bob"])
.env("DATABASE_URL", &database_url)
.env("PANIC", "true")
.output()
.expect("failed to run crashing transfer");
assert!(!transfer.status.success());
assert_eq!(String::from_utf8_lossy(&transfer.stdout), "");
assert!(
String::from_utf8_lossy(&transfer.stderr)
.contains("simulated crash in the middle of transfer")
);
let list = Command::new(binary_path())
.arg("list")
.env("DATABASE_URL", &database_url)
.output()
.expect("failed to list accounts after crash");
assert!(list.status.success());
assert_eq!(
String::from_utf8_lossy(&list.stdout),
"name | amount\n---- | ------\nAlice | 100\nBob | 40\n"
);
}
#[test]
fn invalid_arguments_print_usage() {
let database_url = unique_database_url();
let invalid = Command::new(binary_path())
.args(["hello", "world"])
.env("DATABASE_URL", &database_url)
.output()
.expect("failed to run invalid command");
assert!(!invalid.status.success());
assert_eq!(String::from_utf8_lossy(&invalid.stdout), "");
assert_eq!(
String::from_utf8_lossy(&invalid.stderr),
"Usage: cargo run -- list | cargo run -- add NAME AMOUNT | cargo run -- transfer AMOUNT FROM_NAME TO_NAME\n"
);
}
#[test]
fn transfer_requires_enough_money() {
let database_url = unique_database_url();
for args in [["add", "Alice", "10"], ["add", "Bob", "40"]] {
let output = Command::new(binary_path())
.args(args)
.env("DATABASE_URL", &database_url)
.output()
.expect("failed to seed accounts");
assert!(output.status.success());
}
let transfer = Command::new(binary_path())
.args(["transfer", "25", "Alice", "Bob"])
.env("DATABASE_URL", &database_url)
.output()
.expect("failed to run transfer");
assert!(!transfer.status.success());
assert_eq!(String::from_utf8_lossy(&transfer.stdout), "");
assert_eq!(
String::from_utf8_lossy(&transfer.stderr),
"Insufficient funds in Alice: has 10, needs 25\n"
);
}
}