about summary refs log tree commit diff
path: root/src/nix-worker/main.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/nix-worker/main.cc')
-rw-r--r--src/nix-worker/main.cc183
1 files changed, 150 insertions, 33 deletions
diff --git a/src/nix-worker/main.cc b/src/nix-worker/main.cc
index c8576ddb67f5..d104ea8406b3 100644
--- a/src/nix-worker/main.cc
+++ b/src/nix-worker/main.cc
@@ -9,6 +9,10 @@
 #include <iostream>
 #include <unistd.h>
 #include <signal.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <sys/socket.h>
+#include <sys/un.h>
 #include <fcntl.h>
 
 using namespace nix;
@@ -43,9 +47,6 @@ bool canSendStderr;
    socket. */
 static void tunnelStderr(const unsigned char * buf, size_t count)
 {
-    if (canSendStderr)
-        writeFull(STDERR_FILENO, (unsigned char *) "L: ", 3);
-    writeFull(STDERR_FILENO, buf, count);
     if (canSendStderr) {
         try {
             writeInt(STDERR_NEXT, to);
@@ -65,12 +66,20 @@ static void tunnelStderr(const unsigned char * buf, size_t count)
    socket.  This handler is enabled at precisely those moments in the
    protocol when we're doing work and the client is supposed to be
    quiet.  Thus, if we get a SIGIO signal, it means that the client
-   has quit.  So we should quit as well. */
+   has quit.  So we should quit as well.
+
+   Too bad most operating systems don't support the POLL_HUP value for
+   si_code in siginfo_t.  That would make most of the SIGIO complexity
+   unnecessary, i.e., we could just enable SIGIO all the time and
+   wouldn't have to worry about races. */
 static void sigioHandler(int sigNo)
 {
-    _isInterrupted = 1;
-    canSendStderr = false;
-    write(STDERR_FILENO, "SIGIO\n", 6);
+    if (!blockInt) {
+        _isInterrupted = 1;
+        blockInt = 1;
+        canSendStderr = false;
+        write(STDERR_FILENO, "SIGIO\n", 6);
+    }
 }
 
 
@@ -97,14 +106,14 @@ static void startWork()
 
     fd_set fds;
     FD_ZERO(&fds);
-    FD_SET(STDIN_FILENO, &fds);
+    FD_SET(from.fd, &fds);
         
-    if (select(STDIN_FILENO + 1, &fds, 0, 0, &timeout) == -1)
+    if (select(from.fd + 1, &fds, 0, 0, &timeout) == -1)
         throw SysError("select()");
 
-    if (FD_ISSET(STDIN_FILENO, &fds)) {
+    if (FD_ISSET(from.fd, &fds)) {
         char c;
-        if (read(STDIN_FILENO, &c, 1) != 0)
+        if (read(from.fd, &c, 1) != 0)
             throw Error("EOF expected (protocol error?)");
         _isInterrupted = 1;
         checkInterrupt();
@@ -114,7 +123,7 @@ static void startWork()
 
 /* stopWork() means that we're done; stop sending stderr to the
    client. */
-static void stopWork()
+static void stopWork(bool success = true, const string & msg = "")
 {
     /* Stop handling async client death; we're going to a state where
        we're either sending or receiving from the client, so we'll be
@@ -123,7 +132,13 @@ static void stopWork()
         throw SysError("ignoring SIGIO");
     
     canSendStderr = false;
-    writeInt(STDERR_LAST, to);
+
+    if (success)
+        writeInt(STDERR_LAST, to);
+    else {
+        writeInt(STDERR_ERROR, to);
+        writeString(msg, to);
+    }
 }
 
 
@@ -237,11 +252,17 @@ static void processConnection()
 
     /* Allow us to receive SIGIO for events on the client socket. */
     signal(SIGIO, SIG_IGN);
-    if (fcntl(STDIN_FILENO, F_SETOWN, getpid()) == -1)
+    if (fcntl(from.fd, F_SETOWN, getpid()) == -1)
         throw SysError("F_SETOWN");
-    if (fcntl(STDIN_FILENO, F_SETFL, fcntl(STDIN_FILENO, F_GETFL, 0) | FASYNC) == -1)
+    if (fcntl(from.fd, F_SETFL, fcntl(from.fd, F_GETFL, 0) | FASYNC) == -1)
         throw SysError("F_SETFL");
 
+    /* Exchange the greeting. */
+    unsigned int magic = readInt(from);
+    if (magic != WORKER_MAGIC_1) throw Error("protocol mismatch");
+    verbosity = (Verbosity) readInt(from);
+    writeInt(WORKER_MAGIC_2, to);
+
     /* Send startup error messages to the client. */
     startWork();
 
@@ -258,40 +279,137 @@ static void processConnection()
         stopWork();
         
     } catch (Error & e) {
-        writeInt(STDERR_ERROR, to);
-        writeString(e.msg(), to);
+        stopWork(false, e.msg());
         return;
     }
 
-    /* Exchange the greeting. */
-    unsigned int magic = readInt(from);
-    if (magic != WORKER_MAGIC_1) throw Error("protocol mismatch");
-    writeInt(WORKER_MAGIC_2, to);
-    debug("greeting exchanged");
-
     /* Process client requests. */
-    bool quit = false;
-
     unsigned int opCount = 0;
     
-    do {
-        WorkerOp op = (WorkerOp) readInt(from);
+    while (true) {
+        WorkerOp op;
+        try {
+            op = (WorkerOp) readInt(from);
+        } catch (EndOfFile & e) {
+            break;
+        }
 
         opCount++;
 
         try {
             performOp(from, to, op);
         } catch (Error & e) {
-            writeInt(STDERR_ERROR, to);
-            writeString(e.msg(), to);
+            stopWork(false, e.msg());
         }
-
-    } while (!quit);
+    };
 
     printMsg(lvlError, format("%1% worker operations") % opCount);
 }
 
 
+static void setSigChldAction(bool ignore)
+{
+    struct sigaction act, oact;
+    act.sa_handler = ignore ? SIG_IGN : SIG_DFL;
+    sigfillset(&act.sa_mask);
+    act.sa_flags = 0;
+    if (sigaction(SIGCHLD, &act, &oact))
+        throw SysError("setting SIGCHLD handler");
+}
+
+
+static void daemonLoop()
+{
+    /* Get rid of children automatically; don't let them become
+       zombies. */
+    setSigChldAction(true);
+    
+    /* Create and bind to a Unix domain socket. */
+    AutoCloseFD fdSocket = socket(PF_UNIX, SOCK_STREAM, 0);
+    if (fdSocket == -1)
+        throw SysError("cannot create Unix domain socket");
+
+    string socketPath = nixStateDir + DEFAULT_SOCKET_PATH;
+
+    struct sockaddr_un addr;
+    addr.sun_family = AF_UNIX;
+    if (socketPath.size() >= sizeof(addr.sun_path))
+        throw Error(format("socket path `%1%' is too long") % socketPath);
+    strcpy(addr.sun_path, socketPath.c_str());
+
+    unlink(socketPath.c_str());
+
+    /* Make sure that the socket is created with 0666 permission
+       (everybody can connect). */
+    mode_t oldMode = umask(0111);
+    int res = bind(fdSocket, (struct sockaddr *) &addr, sizeof(addr));
+    umask(oldMode);
+    if (res == -1)
+        throw SysError(format("cannot bind to socket `%1%'") % socketPath);
+
+    if (listen(fdSocket, 5) == -1)
+        throw SysError(format("cannot listen on socket `%1%'") % socketPath);
+
+    /* Loop accepting connections. */
+    while (1) {
+
+        try {
+            /* Important: the server process *cannot* open the
+               Berkeley DB environment, because it doesn't like forks
+               very much. */
+            assert(!store);
+            
+            /* Accept a connection. */
+            struct sockaddr_un remoteAddr;
+            socklen_t remoteAddrLen = sizeof(remoteAddr);
+
+            AutoCloseFD remote = accept(fdSocket,
+                (struct sockaddr *) &remoteAddr, &remoteAddrLen);
+            checkInterrupt();
+            if (remote == -1)
+                throw SysError("accepting connection");
+
+            printMsg(lvlInfo, format("accepted connection %1%") % remote);
+
+            /* Fork a child to handle the connection. */
+            pid_t child;
+            child = fork();
+    
+            switch (child) {
+        
+            case -1:
+                throw SysError("unable to fork");
+
+            case 0:
+                try { /* child */
+                    
+                    /* Background the worker. */
+                    if (setsid() == -1)
+                        throw SysError(format("creating a new session"));
+
+                    /* Restore normal handling of SIGCHLD. */
+                    setSigChldAction(false);
+                    
+                    /* Handle the connection. */
+                    from.fd = remote;
+                    to.fd = remote;
+                    processConnection();
+                    
+                } catch (std::exception & e) {
+                    std::cerr << format("child error: %1%\n") % e.what();
+                }
+                exit(0);
+            }
+
+        } catch (Interrupted & e) {
+            throw;
+        } catch (Error & e) {
+            printMsg(lvlError, format("error processing connection: %1%") % e.msg());
+        }
+    }
+}
+
+
 void run(Strings args)
 {
     bool slave = false;
@@ -315,8 +433,7 @@ void run(Strings args)
     else if (daemon) {
         if (setuidMode)
             throw Error("daemon cannot be started in setuid mode");
-        
-        throw Error("daemon mode not implemented");
+        daemonLoop();
     }
 
     else