diff --git a/src/shrpx.cc b/src/shrpx.cc index f9d9c3bf..7c00d9aa 100644 --- a/src/shrpx.cc +++ b/src/shrpx.cc @@ -48,6 +48,7 @@ #include #include #include +#include #include #include @@ -90,6 +91,13 @@ const int GRACEFUL_SHUTDOWN_SIGNAL = SIGQUIT; // binary is listening to. #define ENV_PORT "NGHTTPX_PORT" +// Environment variable to tell new binary the listening socket's file +// descriptor if frontend listens UNIX domain socket. +#define ENV_UNIX_FD "NGHTTP2_UNIX_FD" +// Environment variable to tell new binary the UNIX domain socket +// path. +#define ENV_UNIX_PATH "NGHTTP2_UNIX_PATH" + namespace { int resolve_hostname(sockaddr_union *addr, size_t *addrlen, const char *hostname, uint16_t port, int family) { @@ -137,6 +145,85 @@ int resolve_hostname(sockaddr_union *addr, size_t *addrlen, } } // namespace +namespace { +void close_env_fd(std::initializer_list envnames) { + for (auto envname : envnames) { + auto envfd = getenv(envname); + if (!envfd) { + continue; + } + auto fd = strtol(envfd, nullptr, 10); + close(fd); + } +} +} // namespace + +namespace { +std::unique_ptr +create_unix_domain_acceptor(ConnectionHandler *handler) { + auto path = get_config()->host.get(); + auto pathlen = strlen(path); + { + auto envfd = getenv(ENV_UNIX_FD); + auto envpath = getenv(ENV_UNIX_PATH); + if (envfd && envpath) { + auto fd = strtoul(envfd, nullptr, 10); + + if (util::streq(envpath, path)) { + LOG(NOTICE) << "Listening on UNIX domain socket " << path; + + return make_unique(fd, handler); + } + + LOG(WARN) << "UNIX domain socket path was changed between old binary (" + << envpath << ") and new binary (" << path << ")"; + close(fd); + } + } + +#ifdef SOCK_NONBLOCK + auto fd = socket(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0); + if (fd == -1) { + return nullptr; + } +#else // !SOCK_NONBLOCK + auto fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (fd == -1) { + return nullptr; + } + util::make_socket_nonblocking(fd); +#endif // !SOCK_NONBLOCK + int val = 1; + if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &val, + static_cast(sizeof(val))) == -1) { + close(fd); + return nullptr; + } + + sockaddr_union addr; + addr.un.sun_family = AF_UNIX; + if (pathlen + 1 > sizeof(addr.un.sun_path)) { + LOG(FATAL) << "UNIX domain socket path " << path << " is too long > " + << sizeof(addr.un.sun_path); + return nullptr; + } + // copy path including terminal NULL + std::copy_n(path, pathlen + 1, addr.un.sun_path); + + // unlink (remove) already existing UNIX domain socket path + unlink(path); + + if (bind(fd, &addr.sa, sizeof(addr.un)) != 0 || + listen(fd, get_config()->backlog) != 0) { + return nullptr; + } + + LOG(NOTICE) << "Listening on UNIX domain socket " << path; + + return make_unique(fd, handler); +} +} // namespace + namespace { std::unique_ptr create_acceptor(ConnectionHandler *handler, int family) { @@ -367,32 +454,45 @@ void exec_binary_signal_cb(struct ev_loop *loop, ev_signal *w, int revents) { size_t envlen = 0; for (char **p = environ; *p; ++p, ++envlen) ; - // 3 for missing fd4, fd6 and port. + // 3 for missing (fd4, fd6 and port) or (unix fd and unix path) auto envp = make_unique(envlen + 3 + 1); size_t envidx = 0; - auto acceptor4 = conn_handler->get_acceptor4(); - if (acceptor4) { - std::string fd4 = ENV_LISTENER4_FD "="; - fd4 += util::utos(acceptor4->get_fd()); - envp[envidx++] = strdup(fd4.c_str()); - } + if (get_config()->host_unix) { + auto acceptor = conn_handler->get_acceptor4(); + std::string fd = ENV_UNIX_FD "="; + fd += util::utos(acceptor->get_fd()); + envp[envidx++] = strdup(fd.c_str()); - auto acceptor6 = conn_handler->get_acceptor6(); - if (acceptor6) { - std::string fd6 = ENV_LISTENER6_FD "="; - fd6 += util::utos(acceptor6->get_fd()); - envp[envidx++] = strdup(fd6.c_str()); - } + std::string path = ENV_UNIX_PATH "="; + path += get_config()->host.get(); + envp[envidx++] = strdup(path.c_str()); + } else { + auto acceptor4 = conn_handler->get_acceptor4(); + if (acceptor4) { + std::string fd4 = ENV_LISTENER4_FD "="; + fd4 += util::utos(acceptor4->get_fd()); + envp[envidx++] = strdup(fd4.c_str()); + } - std::string port = ENV_PORT "="; - port += util::utos(get_config()->port); - envp[envidx++] = strdup(port.c_str()); + auto acceptor6 = conn_handler->get_acceptor6(); + if (acceptor6) { + std::string fd6 = ENV_LISTENER6_FD "="; + fd6 += util::utos(acceptor6->get_fd()); + envp[envidx++] = strdup(fd6.c_str()); + } + + std::string port = ENV_PORT "="; + port += util::utos(get_config()->port); + envp[envidx++] = strdup(port.c_str()); + } for (size_t i = 0; i < envlen; ++i) { - if (strcmp(ENV_LISTENER4_FD, environ[i]) == 0 || - strcmp(ENV_LISTENER6_FD, environ[i]) == 0 || - strcmp(ENV_PORT, environ[i]) == 0) { + if (util::startsWith(environ[i], ENV_LISTENER4_FD) || + util::startsWith(environ[i], ENV_LISTENER6_FD) || + util::startsWith(environ[i], ENV_PORT) || + util::startsWith(environ[i], ENV_UNIX_FD) || + util::startsWith(environ[i], ENV_UNIX_PATH)) { continue; } @@ -528,16 +628,29 @@ int event_loop() { save_pid(); } - auto acceptor6 = create_acceptor(conn_handler.get(), AF_INET6); - auto acceptor4 = create_acceptor(conn_handler.get(), AF_INET); - if (!acceptor6 && !acceptor4) { - LOG(FATAL) << "Failed to listen on address " << get_config()->host.get() - << ", port " << get_config()->port; - exit(EXIT_FAILURE); - } + if (get_config()->host_unix) { + close_env_fd({ENV_LISTENER4_FD, ENV_LISTENER6_FD}); + auto acceptor = create_unix_domain_acceptor(conn_handler.get()); + if (!acceptor) { + LOG(FATAL) << "Failed to listen on UNIX domain socket " + << get_config()->host.get(); + exit(EXIT_FAILURE); + } - conn_handler->set_acceptor4(std::move(acceptor4)); - conn_handler->set_acceptor6(std::move(acceptor6)); + conn_handler->set_acceptor4(std::move(acceptor)); + } else { + close_env_fd({ENV_UNIX_FD}); + auto acceptor6 = create_acceptor(conn_handler.get(), AF_INET6); + auto acceptor4 = create_acceptor(conn_handler.get(), AF_INET); + if (!acceptor6 && !acceptor4) { + LOG(FATAL) << "Failed to listen on address " << get_config()->host.get() + << ", port " << get_config()->port; + exit(EXIT_FAILURE); + } + + conn_handler->set_acceptor4(std::move(acceptor4)); + conn_handler->set_acceptor6(std::move(acceptor6)); + } ev_timer renew_ticket_key_timer; if (!get_config()->upstream_no_tls) { @@ -787,6 +900,7 @@ void fill_default_config() { mod_config()->downstream_request_buffer_size = 16 * 1024; mod_config()->downstream_response_buffer_size = 16 * 1024; mod_config()->no_server_push = false; + mod_config()->host_unix = false; } } // namespace @@ -829,7 +943,9 @@ Connections: << DEFAULT_DOWNSTREAM_PORT << R"( -f, --frontend= Set frontend host and port. If is '*', it - assumes all addresses including both IPv4 and IPv6. + assumes all addresses including both IPv4 and IPv6. + UNIX domain socket can be specified by prefixing path + name with "unix:" (e.g., -funix:/var/run/nghttpx.sock) Default: )" << get_config()->host.get() << "," << get_config()->port << R"( --backlog= @@ -1885,7 +2001,7 @@ int main(int argc, char **argv) { auto pathlen = strlen(path); if (pathlen + 1 > sizeof(addr.addr.un.sun_path)) { - LOG(FATAL) << "path unix domain socket is bound to is too long > " + LOG(FATAL) << "UNIX domain socket path " << path << " is too long > " << sizeof(addr.addr.un.sun_path); exit(EXIT_FAILURE); } diff --git a/src/shrpx_config.cc b/src/shrpx_config.cc index a5e09ef0..2407e9b9 100644 --- a/src/shrpx_config.cc +++ b/src/shrpx_config.cc @@ -545,12 +545,22 @@ int parse_config(const char *opt, const char *optarg) { } if (util::strieq(opt, SHRPX_OPT_FRONTEND)) { + if (util::istartsWith(optarg, SHRPX_UNIX_PATH_PREFIX)) { + auto path = optarg + str_size(SHRPX_UNIX_PATH_PREFIX); + mod_config()->host = strcopy(path); + mod_config()->port = 0; + mod_config()->host_unix = true; + + return 0; + } + if (split_host_port(host, sizeof(host), &port, optarg) == -1) { return -1; } mod_config()->host = strcopy(host); mod_config()->port = port; + mod_config()->host_unix = false; return 0; } diff --git a/src/shrpx_config.h b/src/shrpx_config.h index 5bb05861..0e704e4f 100644 --- a/src/shrpx_config.h +++ b/src/shrpx_config.h @@ -203,6 +203,8 @@ struct Config { ev_tstamp stream_write_timeout; ev_tstamp downstream_idle_read_timeout; ev_tstamp listener_disable_timeout; + // address of frontend connection. This could be a path to UNIX + // domain socket. In this case, |host_unix| must be true. std::unique_ptr host; std::unique_ptr private_key_file; std::unique_ptr private_key_passwd; @@ -279,6 +281,8 @@ struct Config { uid_t uid; gid_t gid; pid_t pid; + // frontend listening port. 0 if frontend listens on UNIX domain + // socket, in this case |host_unix| must be true. uint16_t port; // port in http proxy URI uint16_t downstream_http_proxy_port; @@ -309,6 +313,8 @@ struct Config { bool no_host_rewrite; bool tls_ctx_per_worker; bool no_server_push; + // true if host contains UNIX domain socket path + bool host_unix; }; const Config *get_config(); diff --git a/src/util.h b/src/util.h index e2dbef68..5bddbd55 100644 --- a/src/util.h +++ b/src/util.h @@ -237,6 +237,10 @@ inline bool startsWith(const std::string &a, const std::string &b) { return startsWith(std::begin(a), std::end(a), std::begin(b), std::end(b)); } +inline bool startsWith(const char *a, const char *b) { + return startsWith(a, a + strlen(a), b, b + strlen(b)); +} + struct CaseCmp { bool operator()(char lhs, char rhs) const { return lowcase(lhs) == lowcase(rhs); @@ -342,6 +346,13 @@ bool streq(InputIt1 a, size_t alen, InputIt2 b, size_t blen) { return std::equal(a, a + alen, b); } +inline bool streq(const char *a, const char *b) { + if (!a || !b) { + return false; + } + return streq(a, strlen(a), b, strlen(b)); +} + template bool streq_l(const char (&a)[N], InputIt b, size_t blen) { return streq(a, N - 1, b, blen);