diff options
Diffstat (limited to 'src/nix-worker')
-rw-r--r-- | src/nix-worker/main.cc | 183 |
1 files changed, 150 insertions, 33 deletions
diff --git a/src/nix-worker/main.cc b/src/nix-worker/main.cc index c8576ddb67f5..d104ea8406b3 100644 --- a/src/nix-worker/main.cc +++ b/src/nix-worker/main.cc @@ -9,6 +9,10 @@ #include <iostream> #include <unistd.h> #include <signal.h> +#include <sys/types.h> +#include <sys/stat.h> +#include <sys/socket.h> +#include <sys/un.h> #include <fcntl.h> using namespace nix; @@ -43,9 +47,6 @@ bool canSendStderr; socket. */ static void tunnelStderr(const unsigned char * buf, size_t count) { - if (canSendStderr) - writeFull(STDERR_FILENO, (unsigned char *) "L: ", 3); - writeFull(STDERR_FILENO, buf, count); if (canSendStderr) { try { writeInt(STDERR_NEXT, to); @@ -65,12 +66,20 @@ static void tunnelStderr(const unsigned char * buf, size_t count) socket. This handler is enabled at precisely those moments in the protocol when we're doing work and the client is supposed to be quiet. Thus, if we get a SIGIO signal, it means that the client - has quit. So we should quit as well. */ + has quit. So we should quit as well. + + Too bad most operating systems don't support the POLL_HUP value for + si_code in siginfo_t. That would make most of the SIGIO complexity + unnecessary, i.e., we could just enable SIGIO all the time and + wouldn't have to worry about races. */ static void sigioHandler(int sigNo) { - _isInterrupted = 1; - canSendStderr = false; - write(STDERR_FILENO, "SIGIO\n", 6); + if (!blockInt) { + _isInterrupted = 1; + blockInt = 1; + canSendStderr = false; + write(STDERR_FILENO, "SIGIO\n", 6); + } } @@ -97,14 +106,14 @@ static void startWork() fd_set fds; FD_ZERO(&fds); - FD_SET(STDIN_FILENO, &fds); + FD_SET(from.fd, &fds); - if (select(STDIN_FILENO + 1, &fds, 0, 0, &timeout) == -1) + if (select(from.fd + 1, &fds, 0, 0, &timeout) == -1) throw SysError("select()"); - if (FD_ISSET(STDIN_FILENO, &fds)) { + if (FD_ISSET(from.fd, &fds)) { char c; - if (read(STDIN_FILENO, &c, 1) != 0) + if (read(from.fd, &c, 1) != 0) throw Error("EOF expected (protocol error?)"); _isInterrupted = 1; checkInterrupt(); @@ -114,7 +123,7 @@ static void startWork() /* stopWork() means that we're done; stop sending stderr to the client. */ -static void stopWork() +static void stopWork(bool success = true, const string & msg = "") { /* Stop handling async client death; we're going to a state where we're either sending or receiving from the client, so we'll be @@ -123,7 +132,13 @@ static void stopWork() throw SysError("ignoring SIGIO"); canSendStderr = false; - writeInt(STDERR_LAST, to); + + if (success) + writeInt(STDERR_LAST, to); + else { + writeInt(STDERR_ERROR, to); + writeString(msg, to); + } } @@ -237,11 +252,17 @@ static void processConnection() /* Allow us to receive SIGIO for events on the client socket. */ signal(SIGIO, SIG_IGN); - if (fcntl(STDIN_FILENO, F_SETOWN, getpid()) == -1) + if (fcntl(from.fd, F_SETOWN, getpid()) == -1) throw SysError("F_SETOWN"); - if (fcntl(STDIN_FILENO, F_SETFL, fcntl(STDIN_FILENO, F_GETFL, 0) | FASYNC) == -1) + if (fcntl(from.fd, F_SETFL, fcntl(from.fd, F_GETFL, 0) | FASYNC) == -1) throw SysError("F_SETFL"); + /* Exchange the greeting. */ + unsigned int magic = readInt(from); + if (magic != WORKER_MAGIC_1) throw Error("protocol mismatch"); + verbosity = (Verbosity) readInt(from); + writeInt(WORKER_MAGIC_2, to); + /* Send startup error messages to the client. */ startWork(); @@ -258,40 +279,137 @@ static void processConnection() stopWork(); } catch (Error & e) { - writeInt(STDERR_ERROR, to); - writeString(e.msg(), to); + stopWork(false, e.msg()); return; } - /* Exchange the greeting. */ - unsigned int magic = readInt(from); - if (magic != WORKER_MAGIC_1) throw Error("protocol mismatch"); - writeInt(WORKER_MAGIC_2, to); - debug("greeting exchanged"); - /* Process client requests. */ - bool quit = false; - unsigned int opCount = 0; - do { - WorkerOp op = (WorkerOp) readInt(from); + while (true) { + WorkerOp op; + try { + op = (WorkerOp) readInt(from); + } catch (EndOfFile & e) { + break; + } opCount++; try { performOp(from, to, op); } catch (Error & e) { - writeInt(STDERR_ERROR, to); - writeString(e.msg(), to); + stopWork(false, e.msg()); } - - } while (!quit); + }; printMsg(lvlError, format("%1% worker operations") % opCount); } +static void setSigChldAction(bool ignore) +{ + struct sigaction act, oact; + act.sa_handler = ignore ? SIG_IGN : SIG_DFL; + sigfillset(&act.sa_mask); + act.sa_flags = 0; + if (sigaction(SIGCHLD, &act, &oact)) + throw SysError("setting SIGCHLD handler"); +} + + +static void daemonLoop() +{ + /* Get rid of children automatically; don't let them become + zombies. */ + setSigChldAction(true); + + /* Create and bind to a Unix domain socket. */ + AutoCloseFD fdSocket = socket(PF_UNIX, SOCK_STREAM, 0); + if (fdSocket == -1) + throw SysError("cannot create Unix domain socket"); + + string socketPath = nixStateDir + DEFAULT_SOCKET_PATH; + + struct sockaddr_un addr; + addr.sun_family = AF_UNIX; + if (socketPath.size() >= sizeof(addr.sun_path)) + throw Error(format("socket path `%1%' is too long") % socketPath); + strcpy(addr.sun_path, socketPath.c_str()); + + unlink(socketPath.c_str()); + + /* Make sure that the socket is created with 0666 permission + (everybody can connect). */ + mode_t oldMode = umask(0111); + int res = bind(fdSocket, (struct sockaddr *) &addr, sizeof(addr)); + umask(oldMode); + if (res == -1) + throw SysError(format("cannot bind to socket `%1%'") % socketPath); + + if (listen(fdSocket, 5) == -1) + throw SysError(format("cannot listen on socket `%1%'") % socketPath); + + /* Loop accepting connections. */ + while (1) { + + try { + /* Important: the server process *cannot* open the + Berkeley DB environment, because it doesn't like forks + very much. */ + assert(!store); + + /* Accept a connection. */ + struct sockaddr_un remoteAddr; + socklen_t remoteAddrLen = sizeof(remoteAddr); + + AutoCloseFD remote = accept(fdSocket, + (struct sockaddr *) &remoteAddr, &remoteAddrLen); + checkInterrupt(); + if (remote == -1) + throw SysError("accepting connection"); + + printMsg(lvlInfo, format("accepted connection %1%") % remote); + + /* Fork a child to handle the connection. */ + pid_t child; + child = fork(); + + switch (child) { + + case -1: + throw SysError("unable to fork"); + + case 0: + try { /* child */ + + /* Background the worker. */ + if (setsid() == -1) + throw SysError(format("creating a new session")); + + /* Restore normal handling of SIGCHLD. */ + setSigChldAction(false); + + /* Handle the connection. */ + from.fd = remote; + to.fd = remote; + processConnection(); + + } catch (std::exception & e) { + std::cerr << format("child error: %1%\n") % e.what(); + } + exit(0); + } + + } catch (Interrupted & e) { + throw; + } catch (Error & e) { + printMsg(lvlError, format("error processing connection: %1%") % e.msg()); + } + } +} + + void run(Strings args) { bool slave = false; @@ -315,8 +433,7 @@ void run(Strings args) else if (daemon) { if (setuidMode) throw Error("daemon cannot be started in setuid mode"); - - throw Error("daemon mode not implemented"); + daemonLoop(); } else |