#include #include #include #include #include #include #include #include #include #include #define NUM_PROCS 2 struct proc { pid_t pid; char *cmd; char **argv; int _stdin; int _stdout; int _stderr; } procs[NUM_PROCS]; int fds[2]; volatile sig_atomic_t terminate = 0; volatile sig_atomic_t sendsignal = 0; void reap_all() { int i; for (i = 0; i < NUM_PROCS; i++) { if (procs[i].pid != -1) kill(procs[i].pid, SIGTERM); } sleep(5); for (i = 0; i < NUM_PROCS; i++) { if (procs[i].pid != -1) kill(procs[i].pid, SIGKILL); } while (wait(NULL) != -1 || errno == EINTR) ; } void fail(int exit, const char *msg, ...) { reap_all(); va_list args; va_start(args, msg); verr(exit, msg, args); va_end(args); // Not reached? } void failx(int exit, const char *msg, ...) { reap_all(); va_list args; va_start(args, msg); verrx(exit, msg, args); va_end(args); // Not reached? } void gen_argv(struct proc *proc) { int i; char *saveptr, *tmp, *cmd; char **argv, **head; // XXX: Check for overflow in argv argv = malloc(sizeof(char*) * 1024); if (argv == NULL) err(1, "malloc"); cmd = strdup(proc->cmd); if (cmd == NULL) err(1, "strdup"); head = argv; for (i = 0; i < 1024; i++) argv[i] = NULL; tmp = strtok_r(cmd, " \t", &saveptr); while (tmp != NULL) { *(head++) = strdup(tmp); tmp = strtok_r(NULL, " \t", &saveptr); } *head = NULL; free(cmd); proc->argv = argv; } void run_cmd(struct proc *proc) { pid_t cpid; cpid = fork(); if (cpid == 0) { if (proc->_stdin != -1) { if (close(0) == -1) fail(1, "close"); if (dup2(proc->_stdin, 0) == -1) fail(1, "dup2"); } if (proc->_stdout != -1) { if (close(1) == -1) fail(1, "close"); if (dup2(proc->_stdout, 1) == -1) fail(1, "dup2"); } if (proc->_stderr != -1) { if (close(2) == -1) fail(1, "close"); if (dup2(proc->_stderr, 2) == -1) fail(1, "dup2"); } close(fds[0]); // These aren't very important if they fail close(fds[1]); execvp(proc->argv[0], proc->argv); fail(1, "execvp"); } else if (cpid == -1) { fail(1, "fork"); } else { proc->pid = cpid; } } int have_child() { int i; for (i = 0; i < NUM_PROCS; i++) { if (procs[i].pid != -1) return 1; } return 0; } struct proc * proc_by_pid(pid_t pid) { int i; for (i = 0; i < NUM_PROCS; i++) { if (procs[i].pid == pid) return &procs[i]; } return NULL; } void signal_procs() { int i; for (i = 0; i < NUM_PROCS; i++) { if (procs[i].pid != -1) kill(procs[i].pid, sendsignal); //XXX: Check return? } sendsignal = 0; } /*int check_pid_running(pid_t pid) { if (pid == -1) return 0; if (kill(pid, 0) == -1) return 0; else return 1; }*/ void handle_nonfatal(int sig, siginfo_t *siginfo, void *ucontext) { sendsignal = sig; } void handle_fatal(int sig, siginfo_t *siginfo, void *ucontext) { terminate = 1; sendsignal = sig; } void setup_signals() { struct sigaction act; memset(&act, 0, sizeof(act)); act.sa_sigaction = &handle_fatal; act.sa_flags = SA_SIGINFO; if (sigaction(SIGINT, &act, NULL) == -1) err(1, "sigaction"); if (sigaction(SIGHUP, &act, NULL) == -1) err(1, "sigaction"); if (sigaction(SIGTERM, &act, NULL) == -1) err(1, "sigaction"); act.sa_sigaction = &handle_nonfatal; if (sigaction(SIGUSR1, &act, NULL) == -1) err(1, "sigaction"); if (sigaction(SIGUSR2, &act, NULL) == -1) err(1, "sigaction"); } void start_missing_procs() { int i; for (i = 0; i < NUM_PROCS; i++) { if (procs[i].pid == -1) run_cmd(&procs[i]); } } int main(int argc, char **argv) { if (argc < 2) { fprintf(stderr, "Usage: piper \n"); return 1; } setup_signals(); procs[0].cmd = strdup(argv[1]); if (procs[0].cmd == NULL) err(1, "strdup"); procs[1].cmd = strdup(argv[2]); if (procs[1].cmd == NULL) err(1, "strdup"); int i; for (i = 0; i < NUM_PROCS; i++) { procs[i].pid = -1; procs[i]._stdin = -1; procs[i]._stdout = -1; procs[i]._stderr = -1; gen_argv(&procs[i]); } if (pipe(fds) == -1) err(1, "pipe"); procs[0]._stdout = fds[1]; procs[0]._stderr = fds[1]; procs[1]._stdin = fds[0]; while (!terminate || have_child()) { pid_t pid; if (!terminate) start_missing_procs(); if (sendsignal != 0) signal_procs(); pid = wait(NULL); if (pid == -1) { if (errno == EINTR) continue; else fail(1, "wait"); } struct proc *p = proc_by_pid(pid); if (p == NULL) continue; // XXX: Log something here? p->pid = -1; } return 0; }