diff options
-rw-r--r-- | src/libexpr/eval-test.cc | 130 |
1 files changed, 117 insertions, 13 deletions
diff --git a/src/libexpr/eval-test.cc b/src/libexpr/eval-test.cc index 68a7380e5da9..27cf685f9de4 100644 --- a/src/libexpr/eval-test.cc +++ b/src/libexpr/eval-test.cc @@ -26,7 +26,8 @@ typedef enum { tInt = 1, tAttrs, tThunk, - tLambda + tLambda, + tCopy } ValueType; @@ -46,6 +47,7 @@ struct Value Pattern pat; Expr body; } lambda; + Value * val; }; }; @@ -65,6 +67,9 @@ std::ostream & operator << (std::ostream & str, Value & v) case tThunk: str << "<CODE>"; break; + case tLambda: + str << "<LAMBDA>"; + break; default: abort(); } @@ -72,13 +77,16 @@ std::ostream & operator << (std::ostream & str, Value & v) } -void eval(Env * env, Expr e, Value & v); +static void eval(Env * env, Expr e, Value & v); void forceValue(Value & v) { - if (v.type != tThunk) return; - eval(v.thunk.env, v.thunk.expr, v); + if (v.type == tThunk) eval(v.thunk.env, v.thunk.expr, v); + else if (v.type == tCopy) { + forceValue(*v.val); + v = *v.val; + } } @@ -103,7 +111,92 @@ Env * allocEnv() } -void eval(Env * env, Expr e, Value & v) +static bool patternIsStrict(Pattern pat) +{ + ATerm name, ellipsis, pat1, pat2; + ATermList formals; + if (matchVarPat(pat, name)) return false; + else if (matchAttrsPat(pat, formals, ellipsis)) return true; + else if (matchAtPat(pat, pat1, pat2)) + return patternIsStrict(pat1) || patternIsStrict(pat2); + else abort(); +} + + +static void bindVarPats(Pattern pat, Env & newEnv, + Env * argEnv, Expr argExpr, Value * & vArg) +{ + Pattern pat1, pat2; + if (matchAtPat(pat, pat1, pat2)) { + bindVarPats(pat1, newEnv, argEnv, argExpr, vArg); + bindVarPats(pat2, newEnv, argEnv, argExpr, vArg); + return; + } + + ATerm name; + if (!matchVarPat(pat, name)) abort(); + + if (vArg) { + Value & v = newEnv.bindings[aterm2String(name)]; + v.type = tCopy; + v.val = vArg; + } else { + vArg = &newEnv.bindings[aterm2String(name)]; + vArg->type = tThunk; + vArg->thunk.env = argEnv; + vArg->thunk.expr = argExpr; + } +} + + +static void bindAttrPats(Pattern pat, Env & newEnv, + Value & vArg, Value * & vArgInEnv) +{ + Pattern pat1, pat2; + if (matchAtPat(pat, pat1, pat2)) { + bindAttrPats(pat1, newEnv, vArg, vArgInEnv); + bindAttrPats(pat2, newEnv, vArg, vArgInEnv); + return; + } + + ATerm name; + if (matchVarPat(pat, name)) { + if (vArgInEnv) { + Value & v = newEnv.bindings[aterm2String(name)]; + v.type = tCopy; + v.val = vArgInEnv; + } else { + vArgInEnv = &newEnv.bindings[aterm2String(name)]; + *vArgInEnv = vArg; + } + return; + } + + ATerm ellipsis; + ATermList formals; + if (matchAttrsPat(pat, formals, ellipsis)) { + for (ATermIterator i(formals); i; ++i) { + Expr name, def; + DefaultValue def2; + if (!matchFormal(*i, name, def2)) abort(); /* can't happen */ + + Bindings::iterator j = vArg.attrs->find(aterm2String(name)); + if (j == vArg.attrs->end()) + throw TypeError(format("the argument named `%1%' required by the function is missing") + % aterm2String(name)); + + Value & v = newEnv.bindings[aterm2String(name)]; + v.type = tCopy; + v.val = &j->second; + } + return; + } + + abort(); +} + + +static void eval(Env * env, Expr e, Value & v) { printMsg(lvlError, format("eval: %1%") % e); @@ -182,17 +275,21 @@ void eval(Env * env, Expr e, Value & v) if (matchCall(e, fun, arg)) { eval(env, fun, v); if (v.type != tLambda) throw TypeError("expected function"); - if (!matchVarPat(v.lambda.pat, name)) throw Error("not implemented"); Env * env2 = allocEnv(); env2->up = env; - - Value & arg_ = env2->bindings[aterm2String(name)]; - nrValues++; - arg_.type = tThunk; - arg_.thunk.env = env; - arg_.thunk.expr = arg; + if (patternIsStrict(v.lambda.pat)) { + Value vArg; + eval(env, arg, vArg); + if (vArg.type != tAttrs) throw TypeError("expected attribute set"); + Value * vArg2 = 0; + bindAttrPats(v.lambda.pat, *env2, vArg, vArg2); + } else { + Value * vArg = 0; + bindVarPats(v.lambda.pat, *env2, env, arg, vArg); + } + eval(env2, v.lambda.body, v); return; } @@ -205,7 +302,7 @@ void doTest(string s) { EvalState state; Expr e = parseExprFromString(state, s, "/"); - printMsg(lvlError, format("%1%") % e); + printMsg(lvlError, format(">>>>> %1%") % e); Value v; eval(0, e, v); printMsg(lvlError, format("result: %1%") % v); @@ -222,6 +319,13 @@ void run(Strings args) doTest("rec { x = 1; y = x; }.y"); doTest("(x: x) 1"); doTest("(x: y: y) 1 2"); + doTest("(x@y: x) 1"); + doTest("(x@y: y) 2"); + doTest("(x@y@z: y) 3"); + doTest("x: x"); + doTest("({x, y}: x) { x = 1; y = 2; }"); + doTest("({x, y}@args: args.x) { x = 1; y = 2; }"); + doTest("({x, y}@args@args2: args2.x) { x = 1; y = 2; }"); //Expr e = parseExprFromString(state, "let x = \"a\"; in x + \"b\"", "/"); //Expr e = parseExprFromString(state, "(x: x + \"b\") \"a\"", "/"); |