about summary refs log tree commit diff
diff options
context:
space:
mode:
authorEelco Dolstra <e.dolstra@tudelft.nl>2006-12-04T17·17+0000
committerEelco Dolstra <e.dolstra@tudelft.nl>2006-12-04T17·17+0000
commit0130ef88ea280e67037fa76bcedc59db17d9a8ca (patch)
tree120a616c9a9ee296d2c1832f1b238c281df016ae
parent4740baf3a61c48c07f12f23927c84ca9892088a8 (diff)
* Daemon mode (`nix-worker --daemon'). Clients connect to the server
  via the Unix domain socket in /nix/var/nix/daemon.socket.  The
  server forks a worker process per connection.
* readString(): use the heap, not the stack.
* Some protocol fixes.

-rw-r--r--src/libmain/shared.hh6
-rw-r--r--src/libstore/remote-store.cc3
-rw-r--r--src/libutil/serialise.cc7
-rw-r--r--src/libutil/types.hh2
-rw-r--r--src/libutil/util.cc17
-rw-r--r--src/libutil/util.hh17
-rw-r--r--src/nix-worker/main.cc183
7 files changed, 182 insertions, 53 deletions
diff --git a/src/libmain/shared.hh b/src/libmain/shared.hh
index 2c574d148f6b..fa45645fef40 100644
--- a/src/libmain/shared.hh
+++ b/src/libmain/shared.hh
@@ -3,6 +3,8 @@
 
 #include "types.hh"
 
+#include <signal.h>
+
 
 /* These are not implemented here, but must be implemented by a
    program linking against libmain. */
@@ -27,6 +29,10 @@ void printGCWarning();
 /* Whether we're running setuid. */
 extern bool setuidMode;
 
+extern volatile ::sig_atomic_t blockInt;
+
+MakeError(UsageError, nix::Error)
+
 }
 
 
diff --git a/src/libstore/remote-store.cc b/src/libstore/remote-store.cc
index 4d4189be0c15..b9ed1fdbc09d 100644
--- a/src/libstore/remote-store.cc
+++ b/src/libstore/remote-store.cc
@@ -39,10 +39,11 @@ RemoteStore::RemoteStore()
     
     /* Send the magic greeting, check for the reply. */
     try {
-        processStderr();
         writeInt(WORKER_MAGIC_1, to);
+        writeInt(verbosity, to);
         unsigned int magic = readInt(from);
         if (magic != WORKER_MAGIC_2) throw Error("protocol mismatch");
+        processStderr();
     } catch (Error & e) {
         throw Error(format("cannot start worker (%1%)")
             % e.msg());
diff --git a/src/libutil/serialise.cc b/src/libutil/serialise.cc
index 969f638ef408..c0e1c17af066 100644
--- a/src/libutil/serialise.cc
+++ b/src/libutil/serialise.cc
@@ -85,10 +85,11 @@ unsigned int readInt(Source & source)
 string readString(Source & source)
 {
     unsigned int len = readInt(source);
-    char buf[len];
-    source((unsigned char *) buf, len);
+    unsigned char * buf = new unsigned char[len];
+    AutoDeleteArray<unsigned char> d(buf);
+    source(buf, len);
     readPadding(len, source);
-    return string(buf, len);
+    return string((char *) buf, len);
 }
 
  
diff --git a/src/libutil/types.hh b/src/libutil/types.hh
index 1de378961e46..257871a82b35 100644
--- a/src/libutil/types.hh
+++ b/src/libutil/types.hh
@@ -44,8 +44,6 @@ public:
         newClass(const format & f) : superClass(f) { }; \
     };
 
-MakeError(UsageError, Error)
-
 
 typedef list<string> Strings;
 typedef set<string> StringSet;
diff --git a/src/libutil/util.cc b/src/libutil/util.cc
index 7c1138720cd3..08385e5d96e0 100644
--- a/src/libutil/util.cc
+++ b/src/libutil/util.cc
@@ -191,18 +191,6 @@ Strings readDirectory(const Path & path)
 }
 
 
-template <class T>
-struct AutoDeleteArray
-{
-    T * p;
-    AutoDeleteArray(T * p) : p(p) { }
-    ~AutoDeleteArray() 
-    {
-        delete [] p;
-    }
-};
-
-
 string readFile(int fd)
 {
     struct stat st;
@@ -468,7 +456,7 @@ void readFull(int fd, unsigned char * buf, size_t count)
             if (errno == EINTR) continue;
             throw SysError("reading from file");
         }
-        if (res == 0) throw Error("unexpected end-of-file");
+        if (res == 0) throw EndOfFile("unexpected end-of-file");
         count -= res;
         buf += res;
     }
@@ -707,6 +695,7 @@ int Pid::wait(bool block)
         if (res == 0 && !block) return -1;
         if (errno != EINTR)
             throw SysError("cannot get child exit status");
+        checkInterrupt();
     }
 }
 
@@ -793,7 +782,7 @@ void _interrupted()
        kills the program! */
     if (!std::uncaught_exception()) {
         _isInterrupted = 0;
-        throw Error("interrupted by the user");
+        throw Interrupted("interrupted by the user");
     }
 }
 
diff --git a/src/libutil/util.hh b/src/libutil/util.hh
index 0d39ffee9eae..b88508dec30d 100644
--- a/src/libutil/util.hh
+++ b/src/libutil/util.hh
@@ -139,6 +139,8 @@ extern void (*writeToStderr) (const unsigned char * buf, size_t count);
 void readFull(int fd, unsigned char * buf, size_t count);
 void writeFull(int fd, const unsigned char * buf, size_t count);
 
+MakeError(EndOfFile, Error)
+
 
 /* Read a file descriptor until EOF occurs. */
 string drainFD(int fd);
@@ -147,6 +149,19 @@ string drainFD(int fd);
 
 /* Automatic cleanup of resources. */
 
+
+template <class T>
+struct AutoDeleteArray
+{
+    T * p;
+    AutoDeleteArray(T * p) : p(p) { }
+    ~AutoDeleteArray() 
+    {
+        delete [] p;
+    }
+};
+
+
 class AutoDelete
 {
     string path;
@@ -229,6 +244,8 @@ void inline checkInterrupt()
     if (_isInterrupted) _interrupted();
 }
 
+MakeError(Interrupted, Error)
+
 
 /* String packing / unpacking. */
 string packStrings(const Strings & strings);
diff --git a/src/nix-worker/main.cc b/src/nix-worker/main.cc
index c8576ddb67f5..d104ea8406b3 100644
--- a/src/nix-worker/main.cc
+++ b/src/nix-worker/main.cc
@@ -9,6 +9,10 @@
 #include <iostream>
 #include <unistd.h>
 #include <signal.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <sys/socket.h>
+#include <sys/un.h>
 #include <fcntl.h>
 
 using namespace nix;
@@ -43,9 +47,6 @@ bool canSendStderr;
    socket. */
 static void tunnelStderr(const unsigned char * buf, size_t count)
 {
-    if (canSendStderr)
-        writeFull(STDERR_FILENO, (unsigned char *) "L: ", 3);
-    writeFull(STDERR_FILENO, buf, count);
     if (canSendStderr) {
         try {
             writeInt(STDERR_NEXT, to);
@@ -65,12 +66,20 @@ static void tunnelStderr(const unsigned char * buf, size_t count)
    socket.  This handler is enabled at precisely those moments in the
    protocol when we're doing work and the client is supposed to be
    quiet.  Thus, if we get a SIGIO signal, it means that the client
-   has quit.  So we should quit as well. */
+   has quit.  So we should quit as well.
+
+   Too bad most operating systems don't support the POLL_HUP value for
+   si_code in siginfo_t.  That would make most of the SIGIO complexity
+   unnecessary, i.e., we could just enable SIGIO all the time and
+   wouldn't have to worry about races. */
 static void sigioHandler(int sigNo)
 {
-    _isInterrupted = 1;
-    canSendStderr = false;
-    write(STDERR_FILENO, "SIGIO\n", 6);
+    if (!blockInt) {
+        _isInterrupted = 1;
+        blockInt = 1;
+        canSendStderr = false;
+        write(STDERR_FILENO, "SIGIO\n", 6);
+    }
 }
 
 
@@ -97,14 +106,14 @@ static void startWork()
 
     fd_set fds;
     FD_ZERO(&fds);
-    FD_SET(STDIN_FILENO, &fds);
+    FD_SET(from.fd, &fds);
         
-    if (select(STDIN_FILENO + 1, &fds, 0, 0, &timeout) == -1)
+    if (select(from.fd + 1, &fds, 0, 0, &timeout) == -1)
         throw SysError("select()");
 
-    if (FD_ISSET(STDIN_FILENO, &fds)) {
+    if (FD_ISSET(from.fd, &fds)) {
         char c;
-        if (read(STDIN_FILENO, &c, 1) != 0)
+        if (read(from.fd, &c, 1) != 0)
             throw Error("EOF expected (protocol error?)");
         _isInterrupted = 1;
         checkInterrupt();
@@ -114,7 +123,7 @@ static void startWork()
 
 /* stopWork() means that we're done; stop sending stderr to the
    client. */
-static void stopWork()
+static void stopWork(bool success = true, const string & msg = "")
 {
     /* Stop handling async client death; we're going to a state where
        we're either sending or receiving from the client, so we'll be
@@ -123,7 +132,13 @@ static void stopWork()
         throw SysError("ignoring SIGIO");
     
     canSendStderr = false;
-    writeInt(STDERR_LAST, to);
+
+    if (success)
+        writeInt(STDERR_LAST, to);
+    else {
+        writeInt(STDERR_ERROR, to);
+        writeString(msg, to);
+    }
 }
 
 
@@ -237,11 +252,17 @@ static void processConnection()
 
     /* Allow us to receive SIGIO for events on the client socket. */
     signal(SIGIO, SIG_IGN);
-    if (fcntl(STDIN_FILENO, F_SETOWN, getpid()) == -1)
+    if (fcntl(from.fd, F_SETOWN, getpid()) == -1)
         throw SysError("F_SETOWN");
-    if (fcntl(STDIN_FILENO, F_SETFL, fcntl(STDIN_FILENO, F_GETFL, 0) | FASYNC) == -1)
+    if (fcntl(from.fd, F_SETFL, fcntl(from.fd, F_GETFL, 0) | FASYNC) == -1)
         throw SysError("F_SETFL");
 
+    /* Exchange the greeting. */
+    unsigned int magic = readInt(from);
+    if (magic != WORKER_MAGIC_1) throw Error("protocol mismatch");
+    verbosity = (Verbosity) readInt(from);
+    writeInt(WORKER_MAGIC_2, to);
+
     /* Send startup error messages to the client. */
     startWork();
 
@@ -258,40 +279,137 @@ static void processConnection()
         stopWork();
         
     } catch (Error & e) {
-        writeInt(STDERR_ERROR, to);
-        writeString(e.msg(), to);
+        stopWork(false, e.msg());
         return;
     }
 
-    /* Exchange the greeting. */
-    unsigned int magic = readInt(from);
-    if (magic != WORKER_MAGIC_1) throw Error("protocol mismatch");
-    writeInt(WORKER_MAGIC_2, to);
-    debug("greeting exchanged");
-
     /* Process client requests. */
-    bool quit = false;
-
     unsigned int opCount = 0;
     
-    do {
-        WorkerOp op = (WorkerOp) readInt(from);
+    while (true) {
+        WorkerOp op;
+        try {
+            op = (WorkerOp) readInt(from);
+        } catch (EndOfFile & e) {
+            break;
+        }
 
         opCount++;
 
         try {
             performOp(from, to, op);
         } catch (Error & e) {
-            writeInt(STDERR_ERROR, to);
-            writeString(e.msg(), to);
+            stopWork(false, e.msg());
         }
-
-    } while (!quit);
+    };
 
     printMsg(lvlError, format("%1% worker operations") % opCount);
 }
 
 
+static void setSigChldAction(bool ignore)
+{
+    struct sigaction act, oact;
+    act.sa_handler = ignore ? SIG_IGN : SIG_DFL;
+    sigfillset(&act.sa_mask);
+    act.sa_flags = 0;
+    if (sigaction(SIGCHLD, &act, &oact))
+        throw SysError("setting SIGCHLD handler");
+}
+
+
+static void daemonLoop()
+{
+    /* Get rid of children automatically; don't let them become
+       zombies. */
+    setSigChldAction(true);
+    
+    /* Create and bind to a Unix domain socket. */
+    AutoCloseFD fdSocket = socket(PF_UNIX, SOCK_STREAM, 0);
+    if (fdSocket == -1)
+        throw SysError("cannot create Unix domain socket");
+
+    string socketPath = nixStateDir + DEFAULT_SOCKET_PATH;
+
+    struct sockaddr_un addr;
+    addr.sun_family = AF_UNIX;
+    if (socketPath.size() >= sizeof(addr.sun_path))
+        throw Error(format("socket path `%1%' is too long") % socketPath);
+    strcpy(addr.sun_path, socketPath.c_str());
+
+    unlink(socketPath.c_str());
+
+    /* Make sure that the socket is created with 0666 permission
+       (everybody can connect). */
+    mode_t oldMode = umask(0111);
+    int res = bind(fdSocket, (struct sockaddr *) &addr, sizeof(addr));
+    umask(oldMode);
+    if (res == -1)
+        throw SysError(format("cannot bind to socket `%1%'") % socketPath);
+
+    if (listen(fdSocket, 5) == -1)
+        throw SysError(format("cannot listen on socket `%1%'") % socketPath);
+
+    /* Loop accepting connections. */
+    while (1) {
+
+        try {
+            /* Important: the server process *cannot* open the
+               Berkeley DB environment, because it doesn't like forks
+               very much. */
+            assert(!store);
+            
+            /* Accept a connection. */
+            struct sockaddr_un remoteAddr;
+            socklen_t remoteAddrLen = sizeof(remoteAddr);
+
+            AutoCloseFD remote = accept(fdSocket,
+                (struct sockaddr *) &remoteAddr, &remoteAddrLen);
+            checkInterrupt();
+            if (remote == -1)
+                throw SysError("accepting connection");
+
+            printMsg(lvlInfo, format("accepted connection %1%") % remote);
+
+            /* Fork a child to handle the connection. */
+            pid_t child;
+            child = fork();
+    
+            switch (child) {
+        
+            case -1:
+                throw SysError("unable to fork");
+
+            case 0:
+                try { /* child */
+                    
+                    /* Background the worker. */
+                    if (setsid() == -1)
+                        throw SysError(format("creating a new session"));
+
+                    /* Restore normal handling of SIGCHLD. */
+                    setSigChldAction(false);
+                    
+                    /* Handle the connection. */
+                    from.fd = remote;
+                    to.fd = remote;
+                    processConnection();
+                    
+                } catch (std::exception & e) {
+                    std::cerr << format("child error: %1%\n") % e.what();
+                }
+                exit(0);
+            }
+
+        } catch (Interrupted & e) {
+            throw;
+        } catch (Error & e) {
+            printMsg(lvlError, format("error processing connection: %1%") % e.msg());
+        }
+    }
+}
+
+
 void run(Strings args)
 {
     bool slave = false;
@@ -315,8 +433,7 @@ void run(Strings args)
     else if (daemon) {
         if (setuidMode)
             throw Error("daemon cannot be started in setuid mode");
-        
-        throw Error("daemon mode not implemented");
+        daemonLoop();
     }
 
     else