about summary refs log tree commit diff
path: root/finito-postgres/src
diff options
context:
space:
mode:
Diffstat (limited to 'finito-postgres/src')
-rw-r--r--finito-postgres/src/error.rs15
-rw-r--r--finito-postgres/src/lib.rs36
2 files changed, 42 insertions, 9 deletions
diff --git a/finito-postgres/src/error.rs b/finito-postgres/src/error.rs
index 0bf7f4018591..e130d18361f1 100644
--- a/finito-postgres/src/error.rs
+++ b/finito-postgres/src/error.rs
@@ -7,8 +7,9 @@ use uuid::Uuid;
 use std::error::Error as StdError;
 
 // errors to chain:
-use serde_json::Error as JsonError;
 use postgres::Error as PgError;
+use r2d2_postgres::r2d2::Error as PoolError;
+use serde_json::Error as JsonError;
 
 pub type Result<T> = result::Result<T, Error>;
 
@@ -26,6 +27,9 @@ pub enum ErrorKind {
     /// Errors occuring during communication with the database.
     Database(String),
 
+    /// Errors with the database connection pool.
+    DBPool(String),
+
     /// State machine could not be found.
     FSMNotFound(Uuid),
 
@@ -43,6 +47,9 @@ impl fmt::Display for Error {
             Database(err) =>
                 format!("PostgreSQL error: {}", err),
 
+            DBPool(err) =>
+                format!("Database connection pool error: {}", err),
+
             FSMNotFound(id) =>
                 format!("FSM with ID {} not found", id),
 
@@ -80,6 +87,12 @@ impl From<PgError> for ErrorKind {
     }
 }
 
+impl From<PoolError> for ErrorKind {
+    fn from(err: PoolError) -> ErrorKind {
+        ErrorKind::DBPool(err.to_string())
+    }
+}
+
 /// Helper trait that makes it possible to supply contextual
 /// information with an error.
 pub trait ResultExt<T> {
diff --git a/finito-postgres/src/lib.rs b/finito-postgres/src/lib.rs
index 844e8f79fee3..eea6405c6f45 100644
--- a/finito-postgres/src/lib.rs
+++ b/finito-postgres/src/lib.rs
@@ -9,6 +9,7 @@
 
 extern crate chrono;
 extern crate finito;
+extern crate r2d2_postgres;
 extern crate serde;
 extern crate serde_json;
 extern crate uuid;
@@ -19,16 +20,20 @@ extern crate uuid;
 mod error;
 pub use error::{Result, Error, ErrorKind};
 
-use error::ResultExt;
 use chrono::prelude::{DateTime, Utc};
+use error::ResultExt;
 use finito::{FSM, FSMBackend};
-use postgres::{Connection, GenericConnection};
 use postgres::transaction::Transaction;
+use postgres::GenericConnection;
 use serde::Serialize;
 use serde::de::DeserializeOwned;
 use serde_json::Value;
 use std::marker::PhantomData;
 use uuid::Uuid;
+use r2d2_postgres::{r2d2, PostgresConnectionManager};
+
+type DBPool = r2d2::Pool<PostgresConnectionManager>;
+type DBConn = r2d2::PooledConnection<PostgresConnectionManager>;
 
 /// This struct represents rows in the database table in which events
 /// are persisted.
@@ -103,8 +108,16 @@ struct ActionT {
 /// now.
 pub struct FinitoPostgres<S> {
     state: S,
-    // TODO: Use connection pool?
-    conn: Connection,
+
+    db_pool: DBPool,
+}
+
+impl <S> FinitoPostgres<S> {
+    pub fn new(state: S, db_pool: DBPool, pool_size: usize) -> Self {
+        FinitoPostgres {
+            state, db_pool,
+        }
+    }
 }
 
 impl <State: 'static> FSMBackend<State> for FinitoPostgres<State> {
@@ -121,14 +134,14 @@ impl <State: 'static> FSMBackend<State> for FinitoPostgres<State> {
         let fsm = S::FSM_NAME.to_string();
         let state = serde_json::to_value(initial).context("failed to serialise FSM")?;
 
-        self.conn.execute(query, &[&id, &fsm, &state]).context("failed to insert FSM")?;
+        self.conn()?.execute(query, &[&id, &fsm, &state]).context("failed to insert FSM")?;
 
         return Ok(id);
 
     }
 
     fn get_machine<S: FSM + DeserializeOwned>(&self, key: Uuid) -> Result<S> {
-        get_machine_internal(&self.conn, key, false)
+        get_machine_internal(&*self.conn()?, key, false)
     }
 
     /// Advance a persisted state machine by applying an event, and
@@ -147,7 +160,8 @@ impl <State: 'static> FSMBackend<State> for FinitoPostgres<State> {
           S::State: From<&'a State>,
           S::Event: Serialize + DeserializeOwned,
           S::Action: Serialize + DeserializeOwned {
-        let tx = self.conn.transaction().context("could not begin transaction")?;
+        let conn = self.conn()?;
+        let tx = conn.transaction().context("could not begin transaction")?;
         let state = get_machine_internal(&tx, key, true)?;
 
         // Advancing the FSM consumes the event, so it is persisted first:
@@ -184,9 +198,10 @@ impl <State: 'static> FinitoPostgres<State> {
         S::Action: Serialize + DeserializeOwned,
         S::State: From<&'a State> {
         let state: S::State = (&self.state).into();
+        let conn = self.conn().expect("TODO");
 
         for action_id in action_ids {
-            let tx = self.conn.transaction().expect("TODO");
+            let tx = conn.transaction().expect("TODO");
 
             // TODO: Determine which concurrency setup we actually want.
             if let Ok(events) = run_action(tx, action_id, &state, PhantomData::<S>) {
@@ -196,6 +211,11 @@ impl <State: 'static> FinitoPostgres<State> {
             }
         }
     }
+
+    /// Retrieve a single connection from the database connection pool.
+    fn conn(&self) -> Result<DBConn> {
+        self.db_pool.get().context("failed to retrieve connection from pool")
+    }
 }