about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/fix.cc62
-rw-r--r--src/test.cc10
2 files changed, 47 insertions, 25 deletions
diff --git a/src/fix.cc b/src/fix.cc
index cf6d5617aaa7..6e3809f828e5 100644
--- a/src/fix.cc
+++ b/src/fix.cc
@@ -8,15 +8,24 @@
 
 typedef ATerm Expr;
 
+typedef map<ATerm, ATerm> NormalForms;
 
-static Strings searchDirs;
+struct EvalState 
+{
+    Strings searchDirs;
+    NormalForms normalForms;
+};
+
+
+static Expr evalFile(EvalState & state, string fileName);
+static Expr evalExpr(EvalState & state, Expr e);
 
 
-static string searchPath(string relPath)
+static string searchPath(const Strings & searchDirs, string relPath)
 {
     if (string(relPath, 0, 1) == "/") return relPath;
 
-    for (Strings::iterator i = searchDirs.begin();
+    for (Strings::const_iterator i = searchDirs.begin();
          i != searchDirs.end(); i++)
     {
         string path = *i + "/" + relPath;
@@ -29,9 +38,6 @@ static string searchPath(string relPath)
 }
 
 
-static Expr evalFile(string fileName);
-
-
 static Expr substExpr(string x, Expr rep, Expr e)
 {
     char * s;
@@ -98,7 +104,7 @@ static Expr substExprMany(ATermList formals, ATermList args, Expr body)
 }
 
 
-static Expr evalExpr(Expr e)
+static Expr evalExpr2(EvalState & state, Expr e)
 {
     char * s1;
     Expr e1, e2, e3, e4;
@@ -119,21 +125,22 @@ static Expr evalExpr(Expr e)
 
     /* Application. */
     if (ATmatch(e, "App(<term>, [<list>])", &e1, &e2)) {
-        e1 = evalExpr(e1);
+        e1 = evalExpr(state, e1);
         if (!ATmatch(e1, "Function([<list>], <term>)", &e3, &e4))
             throw badTerm("expecting a function", e1);
-        return evalExpr(substExprMany((ATermList) e3, (ATermList) e2, e4));
+        return evalExpr(state,
+            substExprMany((ATermList) e3, (ATermList) e2, e4));
     }
 
     /* Fix inclusion. */
     if (ATmatch(e, "IncludeFix(<str>)", &s1)) {
         string fileName(s1);
-        return evalFile(s1);
+        return evalFile(state, s1);
     }
 
     /* Relative files. */
     if (ATmatch(e, "Relative(<str>)", &s1)) {
-        string srcPath = searchPath(s1);
+        string srcPath = searchPath(state.searchDirs, s1);
         string dstPath;
         FSId id;
         addToStore(srcPath, dstPath, id, true);
@@ -160,7 +167,7 @@ static Expr evalExpr(Expr e)
             ATerm bnd = ATgetFirst(bnds);
             if (!ATmatch(bnd, "(<str>, <term>)", &s1, &e1))
                 throw badTerm("binding expected", bnd);
-            bndMap[s1] = evalExpr(e1);
+            bndMap[s1] = evalExpr(state, e1);
             bnds = ATgetNext(bnds);
         }
 
@@ -218,7 +225,7 @@ static Expr evalExpr(Expr e)
 
     /* BaseName primitive function. */
     if (ATmatch(e, "BaseName(<term>)", &e1)) {
-        e1 = evalExpr(e1);
+        e1 = evalExpr(state, e1);
         if (!ATmatch(e1, "<str>", &s1)) 
             throw badTerm("string expected", e1);
         return ATmake("<str>", baseNameOf(s1).c_str());
@@ -229,22 +236,37 @@ static Expr evalExpr(Expr e)
 }
 
 
-static Expr evalFile(string relPath)
+static Expr evalExpr(EvalState & state, Expr e)
+{
+    /* Consult the memo table to quickly get the normal form of
+       previously evaluated expressions. */
+    NormalForms::iterator i = state.normalForms.find(e);
+    if (i != state.normalForms.end()) return i->second;
+
+    /* Otherwise, evaluate and memoize. */
+    Expr nf = evalExpr2(state, e);
+    state.normalForms[e] = nf;
+    return nf;
+}
+
+
+static Expr evalFile(EvalState & state, string relPath)
 {
-    string path = searchPath(relPath);
+    string path = searchPath(state.searchDirs, relPath);
     Expr e = ATreadFromNamedFile(path.c_str());
     if (!e) 
         throw Error(format("unable to read a term from `%1%'") % path);
-    return evalExpr(e);
+    return evalExpr(state, e);
 }
 
 
 void run(Strings args)
 {
+    EvalState state;
     Strings files;
 
-    searchDirs.push_back(".");
-    searchDirs.push_back(nixDataDir + "/fix");
+    state.searchDirs.push_back(".");
+    state.searchDirs.push_back(nixDataDir + "/fix");
     
     for (Strings::iterator it = args.begin();
          it != args.end(); )
@@ -254,7 +276,7 @@ void run(Strings args)
         if (arg == "--includedir" || arg == "-I") {
             if (it == args.end())
                 throw UsageError(format("argument required in `%1%'") % arg);
-            searchDirs.push_back(*it++);
+            state.searchDirs.push_back(*it++);
         }
         else if (arg[0] == '-')
             throw UsageError(format("unknown flag `%1%`") % arg);
@@ -267,7 +289,7 @@ void run(Strings args)
     for (Strings::iterator it = files.begin();
          it != files.end(); it++)
     {
-        Expr e = evalFile(*it);
+        Expr e = evalFile(state, *it);
         char * s;
         if (ATmatch(e, "FSId(<str>)", &s)) {
             cout << format("%1%\n") % s;
diff --git a/src/test.cc b/src/test.cc
index 5c7559af6afd..6b567abe030f 100644
--- a/src/test.cc
+++ b/src/test.cc
@@ -117,7 +117,7 @@ void runTests()
         ((string) builder1id).c_str(),
         builder1fn.c_str(),
         ((string) builder1id).c_str());
-    FSId fs1id = writeTerm(fs1, "", 0);
+    FSId fs1id = writeTerm(fs1, "");
 
     realise(fs1id);
     realise(fs1id);
@@ -127,7 +127,7 @@ void runTests()
         ((string) builder1id).c_str(),
         (builder1fn + "_bla").c_str(),
         ((string) builder1id).c_str());
-    FSId fs2id = writeTerm(fs2, "", 0);
+    FSId fs2id = writeTerm(fs2, "");
 
     realise(fs2id);
     realise(fs2id);
@@ -143,7 +143,7 @@ void runTests()
         thisSystem.c_str(),
         out1fn.c_str());
     debug(printTerm(fs3));
-    FSId fs3id = writeTerm(fs3, "", 0);
+    FSId fs3id = writeTerm(fs3, "");
 
     realise(fs3id);
     realise(fs3id);
@@ -158,7 +158,7 @@ void runTests()
         ((string) builder4id).c_str(),
         builder4fn.c_str(),
         ((string) builder4id).c_str());
-    FSId fs4id = writeTerm(fs4, "", 0);
+    FSId fs4id = writeTerm(fs4, "");
 
     realise(fs4id);
 
@@ -174,7 +174,7 @@ void runTests()
         out5fn.c_str(),
         ((string) builder4fn).c_str());
     debug(printTerm(fs5));
-    FSId fs5id = writeTerm(fs5, "", 0);
+    FSId fs5id = writeTerm(fs5, "");
 
     realise(fs5id);
     realise(fs5id);