diff options
Diffstat (limited to 'nix/stateMonad')
-rw-r--r-- | nix/stateMonad/default.nix | 76 | ||||
-rw-r--r-- | nix/stateMonad/tests/default.nix | 110 |
2 files changed, 186 insertions, 0 deletions
diff --git a/nix/stateMonad/default.nix b/nix/stateMonad/default.nix new file mode 100644 index 000000000000..8f92241753df --- /dev/null +++ b/nix/stateMonad/default.nix @@ -0,0 +1,76 @@ +# Simple state monad represented as +# +# stateMonad s a = s -> { state : s; value : a } +# +{ ... }: + +rec { + # + # Monad + # + + # Type: stateMonad s a -> (a -> stateMonad s b) -> stateMonad s b + bind = action: f: state: + let + afterAction = action state; + in + (f afterAction.value) afterAction.state; + + # Type: stateMonad s a -> stateMonad s b -> stateMonad s b + after = action1: action2: bind action1 (_: action2); + + # Type: stateMonad s (stateMonad s a) -> stateMonad s a + join = action: bind action (action': action'); + + # Type: [a] -> (a -> stateMonad s b) -> stateMonad s null + for_ = xs: f: + builtins.foldl' + (laterAction: x: + after (f x) laterAction + ) + (pure null) + xs; + + # + # Applicative + # + + # Type: a -> stateMonad s a + pure = value: state: { inherit state value; }; + + # TODO(sterni): <*>, lift2, … + + # + # Functor + # + + # Type: (a -> b) -> stateMonad s a -> stateMonad s b + fmap = f: action: bind action (result: pure (f result)); + + # + # State Monad + # + + # Type: (s -> s) -> stateMonad s null + modify = f: state: { value = null; state = f state; }; + + # Type: stateMonad s s + get = state: { value = state; inherit state; }; + + # Type: s -> stateMonad s null + set = new: modify (_: new); + + # Type: str -> stateMonad set set.${str} + getAttr = attr: fmap (state: state.${attr}) get; + + # Type: str -> (any -> any) -> stateMonad s null + modifyAttr = attr: f: modify (state: state // { + ${attr} = f state.${attr}; + }); + + # Type: str -> any -> stateMonad s null + setAttr = attr: value: modifyAttr attr (_: value); + + # Type: s -> stateMonad s a -> a + run = state: action: (action state).value; +} diff --git a/nix/stateMonad/tests/default.nix b/nix/stateMonad/tests/default.nix new file mode 100644 index 000000000000..c3cb5c99b550 --- /dev/null +++ b/nix/stateMonad/tests/default.nix @@ -0,0 +1,110 @@ +{ depot, ... }: + +let + inherit (depot.nix.runTestsuite) + runTestsuite + it + assertEq + ; + + inherit (depot.nix.stateMonad) + pure + run + join + fmap + bind + get + set + modify + after + for_ + getAttr + setAttr + modifyAttr + ; + + runStateIndependent = run (throw "This should never be evaluated!"); +in + +runTestsuite "stateMonad" [ + (it "behaves correctly independent of state" [ + (assertEq "pure" (runStateIndependent (pure 21)) 21) + (assertEq "join pure" (runStateIndependent (join (pure (pure 42)))) 42) + (assertEq "fmap pure" (runStateIndependent (fmap (builtins.mul 2) (pure 21))) 42) + (assertEq "bind pure" (runStateIndependent (bind (pure 12) (x: pure x))) 12) + ]) + (it "behaves correctly with an integer state" [ + (assertEq "get" (run 42 get) 42) + (assertEq "after set get" (run 21 (after (set 42) get)) 42) + (assertEq "after modify get" (run 21 (after (modify (builtins.mul 2)) get)) 42) + (assertEq "fmap get" (run 40 (fmap (builtins.add 2) get)) 42) + (assertEq "stateful sum list" + (run 0 (after + (for_ + [ + 15 + 12 + 10 + 5 + ] + (x: modify (builtins.add x))) + get)) + 42) + ]) + (it "behaves correctly with an attr set state" [ + (assertEq "getAttr" (run { foo = 42; } (getAttr "foo")) 42) + (assertEq "after setAttr getAttr" + (run { foo = 21; } (after (setAttr "foo" 42) (getAttr "foo"))) + 42) + (assertEq "after modifyAttr getAttr" + (run { foo = 10.5; } + (after + (modifyAttr "foo" (builtins.mul 4)) + (getAttr "foo"))) + 42) + (assertEq "fmap getAttr" + (run { foo = 21; } (fmap (builtins.mul 2) (getAttr "foo"))) + 42) + (assertEq "after setAttr to insert getAttr" + (run { } (after (setAttr "foo" 42) (getAttr "foo"))) + 42) + (assertEq "insert permutations" + (run + { + a = 2; + b = 3; + c = 5; + } + (after + (bind get + (state: + let + names = builtins.attrNames state; + in + for_ names (name1: + for_ names (name2: + # this is of course a bit silly, but making it more cumbersome + # makes sure the test exercises more of the code. + (bind (getAttr name1) + (value1: + (bind (getAttr name2) + (value2: + setAttr "${name1}_${name2}" (value1 * value2))))))))) + get)) + { + a = 2; + b = 3; + c = 5; + a_a = 4; + a_b = 6; + a_c = 10; + b_a = 6; + b_b = 9; + b_c = 15; + c_c = 25; + c_a = 10; + c_b = 15; + } + ) + ]) +] |