about summary refs log tree commit diff
path: root/users/sterni
diff options
context:
space:
mode:
authorsterni <sternenseemann@systemli.org>2022-02-14T11·01+0100
committerclbot <clbot@tvl.fyi>2022-02-15T12·52+0000
commit36f6322d16ad0c2341bc37235bf327db516ef97d (patch)
tree3c45e82c02f90028879f0ea12a51aeec859c86d2 /users/sterni
parent79f93f3d857416f58f779680519a8bd6c99a2fac (diff)
feat(sterni/nix/fun): implement tail call “optimization” for Nix r/3834
I've had the notion that builtins.genericClosure can be used to express
any recursive algorithm, but a proof is much better than a notion of
course! In this case we can easily show this by implementing a function
that converts a tail recursive function into an application of
builtins.genericClosure.

This is possible if the function resolves its self reference using a
fixed point which allows us to pass a function that encodes the call to
self in a returned attribute set, leaving the actual call to
genericClosure's operator. Additionally, some tools for collecting meta
data about functions (argCount) and calling arbitrary functions (apply,
unapply) are necessary.

Change-Id: I7d455db66d0a55e8639856ccc207639d371a5eb8
Reviewed-on: https://cl.tvl.fyi/c/depot/+/5292
Tested-by: BuildkiteCI
Reviewed-by: sterni <sternenseemann@systemli.org>
Autosubmit: sterni <sternenseemann@systemli.org>
Diffstat (limited to 'users/sterni')
-rw-r--r--users/sterni/nix/fun/default.nix198
-rw-r--r--users/sterni/nix/fun/tests/default.nix53
2 files changed, 251 insertions, 0 deletions
diff --git a/users/sterni/nix/fun/default.nix b/users/sterni/nix/fun/default.nix
index 6b3541ed4c65..bb10f9e6c1bf 100644
--- a/users/sterni/nix/fun/default.nix
+++ b/users/sterni/nix/fun/default.nix
@@ -39,6 +39,200 @@ let
     builtins.match ".*<attrspat ellipsis=\"1\">.*"
       (builtins.toXML f) != null;
 
+  /* Return the number of arguments the given function accepts or 0 if the value
+     is not a function.
+
+     Example:
+
+       argCount argCount
+       => 1
+
+       argCount builtins.add
+       => 2
+
+       argCount pkgs.stdenv.mkDerivation
+       => 1
+  */
+  argCount = f:
+    let
+      # N.B. since we are only interested if the result of calling is a function
+      # as opposed to a normal value or evaluation failure, we never need to
+      # check success, as value will be false (i.e. not a function) in the
+      # failure case.
+      called = builtins.tryEval (
+        f (builtins.throw "You should never see this error message")
+      );
+    in
+    if !(builtins.isFunction f || builtins.isFunction (f.__functor or null))
+    then 0
+    else 1 + argCount called.value;
+
+  /* Call a given function with a given list of arguments.
+
+     Example:
+
+       apply builtins.sub [ 20 10 ]
+       => 10
+  */
+  apply = f: args:
+    builtins.foldl' (f: x: f x) f args;
+
+  # TODO(sterni): think of a better name for unapply
+  /* Collect n arguments into a list and pass them to the given function.
+     Allows calling a function that expects a list by feeding it the list
+     elements individually as function arguments - the limitation is
+     that the list must be of constant length.
+
+     This is mainly useful for functions that wrap other, arbitrary functions
+     in conjunction with argCount and apply, since lists of arguments are
+     easier to deal with usually.
+
+     Example:
+
+       (unapply 3 lib.id) 1 2 3
+       => [ 1 2 3 ]
+
+       (unapply 5 lib.reverse) 1 2 null 4 5
+       => [ 5 4 null 2 1 ]
+
+       # unapply and apply compose the identity relation together
+
+       unapply (argCount f) (apply f)
+       # is equivalent to f (if the function has a constant number of arguments)
+
+       (unapply 2 (apply builtins.sub)) 20 10
+       => 10
+  */
+  unapply =
+    let
+      unapply' = acc: n: f: x:
+        if n == 1
+        then f (acc ++ [ x ])
+        else unapply' (acc ++ [ x ]) (n - 1) f;
+    in
+    unapply' [ ];
+
+  /* Optimize a tail recursive Nix function by intercepting the recursive
+     function application and expressing it in terms of builtins.genericClosure
+     instead. The main benefit of this optimization is that even a naively
+     written recursive algorithm won't overflow the stack.
+
+     For this to work the following things prerequisites are necessary:
+
+     - The passed function needs to be a fix point for its self reference,
+       i. e. the argument to tailCallOpt needs to be of the form
+       `self: # function body that uses self to call itself`.
+       This is because tailCallOpt needs to manipulate the call to self
+       which otherwise wouldn't be possible due to Nix's lexical scoping.
+
+     - The passed function may only call itself as a tail call, all other
+       forms of recursions will fail evaluation.
+
+     This function was mainly written to prove that builtins.genericClosure
+     can be used to express any (tail) recursive algorithm. It can be used
+     to avoid stack overflows for deeply recursive, but naively written
+     functions (in the context of Nix this mainly means using recursion
+     instead of (ab)using more performant and less limited builtins).
+     A better alternative to using this function is probably translating
+     the algorithm to builtins.genericClosure manually. Also note that
+     using tailCallOpt doesn't mean that the stack won't ever overflow:
+     Data structures, especially lazy ones, can still cause all the
+     available stack space to be consumed.
+
+     The optimization also only concerns avoiding stack overflows,
+     tailCallOpt will make functions slower if anything.
+
+     Type: (F -> F) -> F where F is any tail recursive function.
+
+     Example:
+
+     let
+       label' = self: acc: n:
+         if n == 0
+         then "This is " + acc + "cursed."
+         else self (acc + "very ") (n - 1);
+
+       # Equivalent to a naive recursive implementation in Nix
+       label = (lib.fix label') "";
+
+       labelOpt = (tailCallOpt label') "";
+     in
+
+     label 5
+     => "This is very very very very very cursed."
+
+     labelOpt 5
+     => "This is very very very very very cursed."
+
+     label 10000
+     => error: stack overflow (possible infinite recursion)
+
+     labelOpt 10000
+     => "This is very very very very very very very very very…
+  */
+  tailCallOpt = f:
+    let
+      argc = argCount (lib.fix f);
+
+      # This function simulates being f for f's self reference. Instead of
+      # recursing, it will just return the arguments received as a specially
+      # tagged set, so the recursion step can be performed later.
+      fakef = unapply argc (args: {
+        __tailCall = true;
+        inherit args;
+      });
+      # Pass fakef to f so that it'll be called instead of recursing, ensuring
+      # only one recursion step is performed at a time.
+      encodedf = f fakef;
+
+      opt = args:
+        let
+          steps = builtins.genericClosure {
+            # This is how we encode a (tail) call: A set with final == false
+            # and the list of arguments to pass to be found in args.
+            startSet = [
+              {
+                key = "0";
+                id = 0;
+                final = false;
+                inherit args;
+              }
+            ];
+
+            operator =
+              { id, final, ... }@state:
+              let
+                # Plumbing to make genericClosure happy
+                newIds = {
+                  key = toString (id + 1);
+                  id = id + 1;
+                };
+
+                # Perform recursion step
+                call = apply encodedf state.args;
+
+                # If call encodes a new call, return the new encoded call,
+                # otherwise signal that we're done.
+                newState =
+                  if builtins.isAttrs call && call.__tailCall or false
+                  then newIds // {
+                    final = false;
+                    inherit (call) args;
+                  } else newIds // {
+                    final = true;
+                    value = call;
+                  };
+              in
+
+              if final
+              then [ ] # end condition for genericClosure
+              else [ newState ];
+          };
+        in
+        # The returned list contains intermediate steps we ignore.
+        (builtins.head (builtins.filter (x: x.final) steps)).value;
+    in
+    unapply argc opt;
 in
 
 {
@@ -55,5 +249,9 @@ in
     lr
     lrs
     hasEllipsis
+    argCount
+    tailCallOpt
+    apply
+    unapply
     ;
 }
diff --git a/users/sterni/nix/fun/tests/default.nix b/users/sterni/nix/fun/tests/default.nix
index f02f19943373..6b1e6fcc7b0b 100644
--- a/users/sterni/nix/fun/tests/default.nix
+++ b/users/sterni/nix/fun/tests/default.nix
@@ -7,6 +7,8 @@ let
     assertEq
     ;
 
+  inherit (depot.nix) escapeExecline;
+
   inherit (depot.users.sterni.nix)
     fun
     ;
@@ -23,7 +25,58 @@ let
     (assertEq "Ellipsis" true
       (fun.hasEllipsis ({ depot, pkgs, ... }: 42)))
   ];
+
+  argCountTests = it "checks fun.argCount" [
+    (assertEq "builtins.sub has two arguments" 2
+      (fun.argCount builtins.sub))
+    (assertEq "fun.argCount has one argument" 1
+      (fun.argCount fun.argCount))
+    (assertEq "runTestsuite has two arguments" 2
+      (fun.argCount runTestsuite))
+  ];
+
+  applyTests = it "checks that fun.apply is equivalent to calling" [
+    (assertEq "fun.apply builtins.sub" (builtins.sub 23 42)
+      (fun.apply builtins.sub [ 23 42 ]))
+    (assertEq "fun.apply escapeExecline" (escapeExecline [ "foo" [ "bar" ] ])
+      (fun.apply escapeExecline [ [ "foo" [ "bar" ] ] ]))
+  ];
+
+  unapplyTests = it "checks fun.unapply" [
+    (assertEq "fun.unapply 3 accepts 3 args" 3
+      (fun.argCount (fun.unapply 3 fun.id)))
+    (assertEq "fun.unapply 73 accepts 73 args" 73
+      (fun.argCount (fun.unapply 73 fun.id)))
+    (assertEq "fun.unapply 1 accepts 73 args" 1
+      (fun.argCount (fun.unapply 1 fun.id)))
+    (assertEq "fun.unapply collects arguments correctly"
+      (fun.unapply 5 fun.id 1 2 3 4 5)
+      [ 1 2 3 4 5 ])
+    (assertEq "fun.unapply calls the given function correctly" 1
+      (fun.unapply 1 builtins.head 1))
+  ];
+
+  fac' = self: acc: n: if n == 0 then acc else self (n * acc) (n - 1);
+
+  facPlain = fun.fix fac' 1;
+  facOpt = fun.tailCallOpt fac' 1;
+
+  tailCallOptTests = it "checks fun.tailCallOpt" [
+    (assertEq "optimized and unoptimized factorial have the same base case"
+      (facPlain 0)
+      (facOpt 0))
+    (assertEq "optimized and unoptimized factorial have same value for 1"
+      (facPlain 1)
+      (facOpt 1))
+    (assertEq "optimized and unoptimized factorial have same value for 100"
+      (facPlain 100)
+      (facOpt 100))
+  ];
 in
 runTestsuite "nix.fun" [
   hasEllipsisTests
+  argCountTests
+  applyTests
+  unapplyTests
+  tailCallOptTests
 ]