From 36f6322d16ad0c2341bc37235bf327db516ef97d Mon Sep 17 00:00:00 2001 From: sterni Date: Mon, 14 Feb 2022 12:01:50 +0100 Subject: feat(sterni/nix/fun): implement tail call “optimization” for Nix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've had the notion that builtins.genericClosure can be used to express any recursive algorithm, but a proof is much better than a notion of course! In this case we can easily show this by implementing a function that converts a tail recursive function into an application of builtins.genericClosure. This is possible if the function resolves its self reference using a fixed point which allows us to pass a function that encodes the call to self in a returned attribute set, leaving the actual call to genericClosure's operator. Additionally, some tools for collecting meta data about functions (argCount) and calling arbitrary functions (apply, unapply) are necessary. Change-Id: I7d455db66d0a55e8639856ccc207639d371a5eb8 Reviewed-on: https://cl.tvl.fyi/c/depot/+/5292 Tested-by: BuildkiteCI Reviewed-by: sterni Autosubmit: sterni --- users/sterni/nix/fun/default.nix | 198 +++++++++++++++++++++++++++++++++ users/sterni/nix/fun/tests/default.nix | 53 +++++++++ 2 files changed, 251 insertions(+) diff --git a/users/sterni/nix/fun/default.nix b/users/sterni/nix/fun/default.nix index 6b3541ed4c..bb10f9e6c1 100644 --- a/users/sterni/nix/fun/default.nix +++ b/users/sterni/nix/fun/default.nix @@ -39,6 +39,200 @@ let builtins.match ".*.*" (builtins.toXML f) != null; + /* Return the number of arguments the given function accepts or 0 if the value + is not a function. + + Example: + + argCount argCount + => 1 + + argCount builtins.add + => 2 + + argCount pkgs.stdenv.mkDerivation + => 1 + */ + argCount = f: + let + # N.B. since we are only interested if the result of calling is a function + # as opposed to a normal value or evaluation failure, we never need to + # check success, as value will be false (i.e. not a function) in the + # failure case. + called = builtins.tryEval ( + f (builtins.throw "You should never see this error message") + ); + in + if !(builtins.isFunction f || builtins.isFunction (f.__functor or null)) + then 0 + else 1 + argCount called.value; + + /* Call a given function with a given list of arguments. + + Example: + + apply builtins.sub [ 20 10 ] + => 10 + */ + apply = f: args: + builtins.foldl' (f: x: f x) f args; + + # TODO(sterni): think of a better name for unapply + /* Collect n arguments into a list and pass them to the given function. + Allows calling a function that expects a list by feeding it the list + elements individually as function arguments - the limitation is + that the list must be of constant length. + + This is mainly useful for functions that wrap other, arbitrary functions + in conjunction with argCount and apply, since lists of arguments are + easier to deal with usually. + + Example: + + (unapply 3 lib.id) 1 2 3 + => [ 1 2 3 ] + + (unapply 5 lib.reverse) 1 2 null 4 5 + => [ 5 4 null 2 1 ] + + # unapply and apply compose the identity relation together + + unapply (argCount f) (apply f) + # is equivalent to f (if the function has a constant number of arguments) + + (unapply 2 (apply builtins.sub)) 20 10 + => 10 + */ + unapply = + let + unapply' = acc: n: f: x: + if n == 1 + then f (acc ++ [ x ]) + else unapply' (acc ++ [ x ]) (n - 1) f; + in + unapply' [ ]; + + /* Optimize a tail recursive Nix function by intercepting the recursive + function application and expressing it in terms of builtins.genericClosure + instead. The main benefit of this optimization is that even a naively + written recursive algorithm won't overflow the stack. + + For this to work the following things prerequisites are necessary: + + - The passed function needs to be a fix point for its self reference, + i. e. the argument to tailCallOpt needs to be of the form + `self: # function body that uses self to call itself`. + This is because tailCallOpt needs to manipulate the call to self + which otherwise wouldn't be possible due to Nix's lexical scoping. + + - The passed function may only call itself as a tail call, all other + forms of recursions will fail evaluation. + + This function was mainly written to prove that builtins.genericClosure + can be used to express any (tail) recursive algorithm. It can be used + to avoid stack overflows for deeply recursive, but naively written + functions (in the context of Nix this mainly means using recursion + instead of (ab)using more performant and less limited builtins). + A better alternative to using this function is probably translating + the algorithm to builtins.genericClosure manually. Also note that + using tailCallOpt doesn't mean that the stack won't ever overflow: + Data structures, especially lazy ones, can still cause all the + available stack space to be consumed. + + The optimization also only concerns avoiding stack overflows, + tailCallOpt will make functions slower if anything. + + Type: (F -> F) -> F where F is any tail recursive function. + + Example: + + let + label' = self: acc: n: + if n == 0 + then "This is " + acc + "cursed." + else self (acc + "very ") (n - 1); + + # Equivalent to a naive recursive implementation in Nix + label = (lib.fix label') ""; + + labelOpt = (tailCallOpt label') ""; + in + + label 5 + => "This is very very very very very cursed." + + labelOpt 5 + => "This is very very very very very cursed." + + label 10000 + => error: stack overflow (possible infinite recursion) + + labelOpt 10000 + => "This is very very very very very very very very very… + */ + tailCallOpt = f: + let + argc = argCount (lib.fix f); + + # This function simulates being f for f's self reference. Instead of + # recursing, it will just return the arguments received as a specially + # tagged set, so the recursion step can be performed later. + fakef = unapply argc (args: { + __tailCall = true; + inherit args; + }); + # Pass fakef to f so that it'll be called instead of recursing, ensuring + # only one recursion step is performed at a time. + encodedf = f fakef; + + opt = args: + let + steps = builtins.genericClosure { + # This is how we encode a (tail) call: A set with final == false + # and the list of arguments to pass to be found in args. + startSet = [ + { + key = "0"; + id = 0; + final = false; + inherit args; + } + ]; + + operator = + { id, final, ... }@state: + let + # Plumbing to make genericClosure happy + newIds = { + key = toString (id + 1); + id = id + 1; + }; + + # Perform recursion step + call = apply encodedf state.args; + + # If call encodes a new call, return the new encoded call, + # otherwise signal that we're done. + newState = + if builtins.isAttrs call && call.__tailCall or false + then newIds // { + final = false; + inherit (call) args; + } else newIds // { + final = true; + value = call; + }; + in + + if final + then [ ] # end condition for genericClosure + else [ newState ]; + }; + in + # The returned list contains intermediate steps we ignore. + (builtins.head (builtins.filter (x: x.final) steps)).value; + in + unapply argc opt; in { @@ -55,5 +249,9 @@ in lr lrs hasEllipsis + argCount + tailCallOpt + apply + unapply ; } diff --git a/users/sterni/nix/fun/tests/default.nix b/users/sterni/nix/fun/tests/default.nix index f02f199433..6b1e6fcc7b 100644 --- a/users/sterni/nix/fun/tests/default.nix +++ b/users/sterni/nix/fun/tests/default.nix @@ -7,6 +7,8 @@ let assertEq ; + inherit (depot.nix) escapeExecline; + inherit (depot.users.sterni.nix) fun ; @@ -23,7 +25,58 @@ let (assertEq "Ellipsis" true (fun.hasEllipsis ({ depot, pkgs, ... }: 42))) ]; + + argCountTests = it "checks fun.argCount" [ + (assertEq "builtins.sub has two arguments" 2 + (fun.argCount builtins.sub)) + (assertEq "fun.argCount has one argument" 1 + (fun.argCount fun.argCount)) + (assertEq "runTestsuite has two arguments" 2 + (fun.argCount runTestsuite)) + ]; + + applyTests = it "checks that fun.apply is equivalent to calling" [ + (assertEq "fun.apply builtins.sub" (builtins.sub 23 42) + (fun.apply builtins.sub [ 23 42 ])) + (assertEq "fun.apply escapeExecline" (escapeExecline [ "foo" [ "bar" ] ]) + (fun.apply escapeExecline [ [ "foo" [ "bar" ] ] ])) + ]; + + unapplyTests = it "checks fun.unapply" [ + (assertEq "fun.unapply 3 accepts 3 args" 3 + (fun.argCount (fun.unapply 3 fun.id))) + (assertEq "fun.unapply 73 accepts 73 args" 73 + (fun.argCount (fun.unapply 73 fun.id))) + (assertEq "fun.unapply 1 accepts 73 args" 1 + (fun.argCount (fun.unapply 1 fun.id))) + (assertEq "fun.unapply collects arguments correctly" + (fun.unapply 5 fun.id 1 2 3 4 5) + [ 1 2 3 4 5 ]) + (assertEq "fun.unapply calls the given function correctly" 1 + (fun.unapply 1 builtins.head 1)) + ]; + + fac' = self: acc: n: if n == 0 then acc else self (n * acc) (n - 1); + + facPlain = fun.fix fac' 1; + facOpt = fun.tailCallOpt fac' 1; + + tailCallOptTests = it "checks fun.tailCallOpt" [ + (assertEq "optimized and unoptimized factorial have the same base case" + (facPlain 0) + (facOpt 0)) + (assertEq "optimized and unoptimized factorial have same value for 1" + (facPlain 1) + (facOpt 1)) + (assertEq "optimized and unoptimized factorial have same value for 100" + (facPlain 100) + (facOpt 100)) + ]; in runTestsuite "nix.fun" [ hasEllipsisTests + argCountTests + applyTests + unapplyTests + tailCallOptTests ] -- cgit 1.4.1