about summary refs log tree commit diff
path: root/nix/stateMonad
diff options
context:
space:
mode:
Diffstat (limited to 'nix/stateMonad')
-rw-r--r--nix/stateMonad/default.nix76
-rw-r--r--nix/stateMonad/tests/default.nix110
2 files changed, 186 insertions, 0 deletions
diff --git a/nix/stateMonad/default.nix b/nix/stateMonad/default.nix
new file mode 100644
index 000000000000..8f92241753df
--- /dev/null
+++ b/nix/stateMonad/default.nix
@@ -0,0 +1,76 @@
+# Simple state monad represented as
+#
+#     stateMonad s a = s -> { state : s; value : a }
+#
+{ ... }:
+
+rec {
+  #
+  # Monad
+  #
+
+  # Type: stateMonad s a -> (a -> stateMonad s b) -> stateMonad s b
+  bind = action: f: state:
+    let
+      afterAction = action state;
+    in
+    (f afterAction.value) afterAction.state;
+
+  # Type: stateMonad s a -> stateMonad s b -> stateMonad s b
+  after = action1: action2: bind action1 (_: action2);
+
+  # Type: stateMonad s (stateMonad s a) -> stateMonad s a
+  join = action: bind action (action': action');
+
+  # Type: [a] -> (a -> stateMonad s b) -> stateMonad s null
+  for_ = xs: f:
+    builtins.foldl'
+      (laterAction: x:
+        after (f x) laterAction
+      )
+      (pure null)
+      xs;
+
+  #
+  # Applicative
+  #
+
+  # Type: a -> stateMonad s a
+  pure = value: state: { inherit state value; };
+
+  # TODO(sterni): <*>, lift2, …
+
+  #
+  # Functor
+  #
+
+  # Type: (a -> b) -> stateMonad s a -> stateMonad s b
+  fmap = f: action: bind action (result: pure (f result));
+
+  #
+  # State Monad
+  #
+
+  # Type: (s -> s) -> stateMonad s null
+  modify = f: state: { value = null; state = f state; };
+
+  # Type: stateMonad s s
+  get = state: { value = state; inherit state; };
+
+  # Type: s -> stateMonad s null
+  set = new: modify (_: new);
+
+  # Type: str -> stateMonad set set.${str}
+  getAttr = attr: fmap (state: state.${attr}) get;
+
+  # Type: str -> (any -> any) -> stateMonad s null
+  modifyAttr = attr: f: modify (state: state // {
+    ${attr} = f state.${attr};
+  });
+
+  # Type: str -> any -> stateMonad s null
+  setAttr = attr: value: modifyAttr attr (_: value);
+
+  # Type: s -> stateMonad s a -> a
+  run = state: action: (action state).value;
+}
diff --git a/nix/stateMonad/tests/default.nix b/nix/stateMonad/tests/default.nix
new file mode 100644
index 000000000000..c3cb5c99b550
--- /dev/null
+++ b/nix/stateMonad/tests/default.nix
@@ -0,0 +1,110 @@
+{ depot, ... }:
+
+let
+  inherit (depot.nix.runTestsuite)
+    runTestsuite
+    it
+    assertEq
+    ;
+
+  inherit (depot.nix.stateMonad)
+    pure
+    run
+    join
+    fmap
+    bind
+    get
+    set
+    modify
+    after
+    for_
+    getAttr
+    setAttr
+    modifyAttr
+    ;
+
+  runStateIndependent = run (throw "This should never be evaluated!");
+in
+
+runTestsuite "stateMonad" [
+  (it "behaves correctly independent of state" [
+    (assertEq "pure" (runStateIndependent (pure 21)) 21)
+    (assertEq "join pure" (runStateIndependent (join (pure (pure 42)))) 42)
+    (assertEq "fmap pure" (runStateIndependent (fmap (builtins.mul 2) (pure 21))) 42)
+    (assertEq "bind pure" (runStateIndependent (bind (pure 12) (x: pure x))) 12)
+  ])
+  (it "behaves correctly with an integer state" [
+    (assertEq "get" (run 42 get) 42)
+    (assertEq "after set get" (run 21 (after (set 42) get)) 42)
+    (assertEq "after modify get" (run 21 (after (modify (builtins.mul 2)) get)) 42)
+    (assertEq "fmap get" (run 40 (fmap (builtins.add 2) get)) 42)
+    (assertEq "stateful sum list"
+      (run 0 (after
+        (for_
+          [
+            15
+            12
+            10
+            5
+          ]
+          (x: modify (builtins.add x)))
+        get))
+      42)
+  ])
+  (it "behaves correctly with an attr set state" [
+    (assertEq "getAttr" (run { foo = 42; } (getAttr "foo")) 42)
+    (assertEq "after setAttr getAttr"
+      (run { foo = 21; } (after (setAttr "foo" 42) (getAttr "foo")))
+      42)
+    (assertEq "after modifyAttr getAttr"
+      (run { foo = 10.5; }
+        (after
+          (modifyAttr "foo" (builtins.mul 4))
+          (getAttr "foo")))
+      42)
+    (assertEq "fmap getAttr"
+      (run { foo = 21; } (fmap (builtins.mul 2) (getAttr "foo")))
+      42)
+    (assertEq "after setAttr to insert getAttr"
+      (run { } (after (setAttr "foo" 42) (getAttr "foo")))
+      42)
+    (assertEq "insert permutations"
+      (run
+        {
+          a = 2;
+          b = 3;
+          c = 5;
+        }
+        (after
+          (bind get
+            (state:
+              let
+                names = builtins.attrNames state;
+              in
+              for_ names (name1:
+                for_ names (name2:
+                  # this is of course a bit silly, but making it more cumbersome
+                  # makes sure the test exercises more of the code.
+                  (bind (getAttr name1)
+                    (value1:
+                      (bind (getAttr name2)
+                        (value2:
+                          setAttr "${name1}_${name2}" (value1 * value2)))))))))
+          get))
+      {
+        a = 2;
+        b = 3;
+        c = 5;
+        a_a = 4;
+        a_b = 6;
+        a_c = 10;
+        b_a = 6;
+        b_b = 9;
+        b_c = 15;
+        c_c = 25;
+        c_a = 10;
+        c_b = 15;
+      }
+    )
+  ])
+]