process: migrate arg validation and checking to Lua

This commit is contained in:
takase1121 2024-07-09 18:04:58 +08:00 committed by Guldoman
parent 1f0533482b
commit a25ab3c535
2 changed files with 111 additions and 312 deletions

View File

@ -1,4 +1,5 @@
local config = require "core.config" local config = require "core.config"
local common = require "core.common"
---An abstraction over the standard input and outputs of a process ---An abstraction over the standard input and outputs of a process
@ -149,10 +150,53 @@ function process:__index(k)
return self.process[k] return self.process[k]
end end
local function env_key(str)
if PLATFORM == "Windows" then return str:upper() else return str end
end
---Sorts the environment variable by its key, converted to uppercase.
---This is only needed on Windows.
local function compare_env(a, b)
return env_key(a:match("([^=]*)=")) < env_key(b:match("([^=]*)="))
end
local old_start = process.start local old_start = process.start
function process.start(...) function process.start(command, options)
local self = setmetatable({ process = old_start(...) }, process) assert(type(command) == "table" or type(command) == "string", "invalid argument #1 to process.start(), expected string or table, got "..type(command))
assert(type(options) == "table" or type(options) == "nil", "invalid argument #2 to process.start(), expected table or nil, got "..type(options))
if PLATFORM == "Windows" then
if type(command) == "table" then
-- escape the arguments into a command line string
-- https://github.com/python/cpython/blob/48f9d3e3faec5faaa4f7c9849fecd27eae4da213/Lib/subprocess.py#L531
local arglist = {}
for _, v in ipairs(command) do
local backslash, arg = 0, {}
for c in v:gmatch(".") do
if c == "\\" then backslash = backslash + 1
elseif c == '"' then arg[#arg+1] = string.rep("\\", backslash * 2 + 1)..'"'; backslash = 0
else arg[#arg+1] = string.rep("\\", backslash) .. c; backslash = 0 end
end
arg[#arg+1] = string.rep("\\", backslash) -- add remaining backslashes
if #v == 0 or v:find("[\t\v\r\n ]") then arglist[#arglist+1] = '"'..table.concat(arg, "")..'"'
else arglist[#arglist+1] = table.concat(arg, "") end
end
command = table.concat(arglist, " ")
end
else
command = type(command) == "table" and command or { command }
end
if type(options) == "table" and options.env then
local user_env = options.env --[[@as table]]
options.env = function(system_env)
local final_env, envlist = {}, {}
for k, v in pairs(system_env) do final_env[env_key(k)] = k.."="..v end
for k, v in pairs(user_env) do final_env[env_key(k)] = k.."="..v end
for _, v in pairs(final_env) do envlist[#envlist+1] = v end
if PLATFORM == "Windows" then table.sort(envlist, compare_env) end
return table.concat(envlist, "\0").."\0\0"
end
end
local self = setmetatable({ process = old_start(command, options) }, process)
self.stdout = process.stream.new(self, process.STREAM_STDOUT) self.stdout = process.stream.new(self, process.STREAM_STDOUT)
self.stderr = process.stream.new(self, process.STREAM_STDERR) self.stderr = process.stream.new(self, process.STREAM_STDERR)
self.stdin = process.stream.new(self, process.STREAM_STDIN) self.stdin = process.stream.new(self, process.STREAM_STDIN)

View File

@ -12,6 +12,7 @@
// https://stackoverflow.com/questions/60645/overlapped-i-o-on-anonymous-pipe // https://stackoverflow.com/questions/60645/overlapped-i-o-on-anonymous-pipe
// https://docs.microsoft.com/en-us/windows/win32/procthread/creating-a-child-process-with-redirected-input-and-output // https://docs.microsoft.com/en-us/windows/win32/procthread/creating-a-child-process-with-redirected-input-and-output
#include <windows.h> #include <windows.h>
#include "../utfconv.h"
#else #else
#include <errno.h> #include <errno.h>
#include <unistd.h> #include <unistd.h>
@ -21,6 +22,8 @@
#include <sys/wait.h> #include <sys/wait.h>
#endif #endif
#include "../arena_allocator.h"
#define READ_BUF_SIZE 2048 #define READ_BUF_SIZE 2048
#define PROCESS_TERM_TRIES 3 #define PROCESS_TERM_TRIES 3
#define PROCESS_TERM_DELAY 50 #define PROCESS_TERM_DELAY 50
@ -351,268 +354,35 @@ static bool signal_process(process_t* proc, signal_e sig) {
return true; return true;
} }
static UNUSED char *xstrdup(const char *str) {
char *result = str ? malloc(strlen(str) + 1) : NULL;
if (result) strcpy(result, str);
return result;
}
static int process_arglist_init(process_arglist_t *list, size_t *list_len, size_t nargs) {
*list_len = 0;
#ifdef _WIN32
memset(*list, 0, sizeof(process_arglist_t));
#else
*list = calloc(sizeof(char *), nargs + 1);
if (!*list) return ENOMEM;
#endif
return 0;
}
static int process_arglist_add(process_arglist_t *list, size_t *list_len, const char *arg, bool escape) {
size_t len = *list_len;
#ifdef _WIN32
int arg_len;
wchar_t *cmdline = *list;
wchar_t arg_w[32767];
// this length includes the null terminator!
if (!(arg_len = MultiByteToWideChar(CP_UTF8, 0, arg, -1, arg_w, 32767)))
return GetLastError();
if (arg_len + len > 32767)
return ERROR_NOT_ENOUGH_MEMORY;
if (!escape) {
// replace the current null terminator with a space
if (len > 0) cmdline[len-1] = ' ';
memcpy(cmdline + len, arg_w, arg_len * sizeof(wchar_t));
len += arg_len;
} else {
// if the string contains spaces, then we must quote it
bool quote = wcspbrk(arg_w, L" \t\v\r\n");
int backslash = 0, escaped_len = quote ? 2 : 0;
for (int i = 0; i < arg_len; i++) {
if (arg_w[i] == L'\\') {
backslash++;
} else if (arg_w[i] == L'"') {
escaped_len += backslash + 1;
backslash = 0;
} else {
backslash = 0;
}
escaped_len++;
}
// escape_len contains NUL terminator
if (escaped_len + len > 32767)
return ERROR_NOT_ENOUGH_MEMORY;
// replace our previous NUL terminator with space
if (len > 0) cmdline[len-1] = L' ';
if (quote) cmdline[len++] = L'"';
// we are not going to iterate over NUL terminator
for (int i = 0;arg_w[i]; i++) {
if (arg_w[i] == L'\\') {
backslash++;
} else if (arg_w[i] == L'"') {
// add backslash + 1 backslashes
for (int j = 0; j < backslash; j++)
cmdline[len++] = L'\\';
cmdline[len++] = L'\\';
backslash = 0;
} else {
backslash = 0;
}
cmdline[len++] = arg_w[i];
}
if (quote) cmdline[len++] = L'"';
cmdline[len++] = L'\0';
}
#else
char **cmd = *list;
cmd[len] = xstrdup(arg);
if (!cmd[len]) return ENOMEM;
len++;
#endif
*list_len = len;
return 0;
}
static void process_arglist_free(process_arglist_t *list) {
if (!*list) return;
#ifndef _WIN32
char **cmd = *list;
for (int i = 0; cmd[i]; i++)
free(cmd[i]);
free(cmd);
*list = NULL;
#endif
}
static int process_env_init(process_env_t *env_list, size_t *env_len, size_t nenv) {
*env_len = 0;
#ifdef _WIN32
*env_list = NULL;
#else
*env_list = calloc(sizeof(char *), nenv * 2);
if (!*env_list) return ENOMEM;
#endif
return 0;
}
#ifdef _WIN32
static int cmp_name(wchar_t *a, wchar_t *b) {
wchar_t _A[32767], _B[32767], *A = _A, *B = _B, *a_eq, *b_eq;
int na, nb, r;
a_eq = wcschr(a, L'=');
b_eq = wcschr(b, L'=');
assert(a_eq);
assert(b_eq);
na = a_eq - a;
nb = b_eq - b;
r = LCMapStringW(LOCALE_INVARIANT, LCMAP_UPPERCASE, a, na, A, na);
assert(r == na);
A[na] = L'\0';
r = LCMapStringW(LOCALE_INVARIANT, LCMAP_UPPERCASE, b, nb, B, nb);
assert(r == nb);
B[nb] = L'\0';
for (;;) {
wchar_t AA = *A++, BB = *B++;
if (AA > BB)
return 1;
else if (AA < BB)
return -1;
else if (!AA && !BB)
return 0;
}
}
static int process_env_add_variable(process_env_t *env_list, size_t *env_list_len, wchar_t *var, size_t var_len) {
wchar_t *list, *list_p;
size_t block_var_len, list_len;
list = list_p = *env_list;
list_len = *env_list_len;
if (list_len) {
// check if it is already in the block
while ((block_var_len = wcslen(list_p))) {
if (cmp_name(list_p, var) == 0)
return -1; // already installed
list_p += block_var_len + 1;
}
}
// allocate list + 1 characters for the block terminator
list = realloc(list, (list_len + var_len + 1) * sizeof(wchar_t));
if (!list) return ERROR_NOT_ENOUGH_MEMORY;
// copy the env variable to the block
memcpy(list + list_len, var, var_len * sizeof(wchar_t));
// terminate the block again
list[list_len + var_len] = L'\0';
*env_list = list;
*env_list_len = (list_len + var_len);
return 0;
}
static int process_env_add_system(process_env_t *env_list, size_t *env_list_len) {
int retval = 0;
wchar_t *proc_env_block, *proc_env_block_p;
int proc_env_len;
proc_env_block = proc_env_block_p = GetEnvironmentStringsW();
while ((proc_env_len = wcslen(proc_env_block_p))) {
// try to add it to the list
if ((retval = process_env_add_variable(env_list, env_list_len, proc_env_block_p, proc_env_len + 1)) > 0)
goto cleanup;
proc_env_block_p += proc_env_len + 1;
}
retval = 0;
cleanup:
if (proc_env_block) FreeEnvironmentStringsW(proc_env_block);
return retval;
}
#endif
static int process_env_add(process_env_t *env_list, size_t *env_len, const char *key, const char *value) {
#ifdef _WIN32
wchar_t env_var[32767];
int r, var_len = 0;
if (!(r = MultiByteToWideChar(CP_UTF8, 0, key, -1, env_var, 32767)))
return GetLastError();
var_len += r;
env_var[var_len-1] = L'=';
if (!(r = MultiByteToWideChar(CP_UTF8, 0, value, -1, env_var + var_len, 32767 - var_len)))
return GetLastError();
var_len += r;
return process_env_add_variable(env_list, env_len, env_var, var_len);
#else
(*env_list)[*env_len] = xstrdup(key);
if (!(*env_list)[*env_len])
return ENOMEM;
(*env_list)[*env_len + 1] = xstrdup(value);
if (!(*env_list)[*env_len + 1])
return ENOMEM;
*env_len += 2;
#endif
return 0;
}
static void process_env_free(process_env_t *list, size_t list_len) {
if (!*list) return;
#ifndef _WIN32
for (size_t i = 0; i < list_len; i++)
free((*list)[i]);
#endif
free(*list);
*list = NULL;
}
static int process_start(lua_State* L) { static int process_start(lua_State* L) {
int r, retval = 1; int r, retval = 1;
size_t env_len = 0, cmd_len = 0, arglist_len = 0, env_vars_len = 0;
process_t *self = NULL; process_t *self = NULL;
process_arglist_t arglist = PROCESS_ARGLIST_INITIALIZER; int deadline = 10, detach = false, new_fds[3] = { STDIN_FD, STDOUT_FD, STDERR_FD };
process_env_t env_vars = NULL;
const char *cwd = NULL;
bool detach = false, escape = true;
int deadline = 10, new_fds[3] = { STDIN_FD, STDOUT_FD, STDERR_FD };
if (lua_isstring(L, 1)) { #ifdef _WIN32
escape = false; wchar_t *commandline = NULL, *env = NULL, *cwd = NULL;
// create a table that contains the string as the value #else
lua_createtable(L, 1, 0); const char **cmd = NULL, *env = NULL, *cwd = NULL;
lua_pushvalue(L, 1); #endif
lua_rawseti(L, -2, 1);
lua_replace(L, 1);
}
lua_settop(L, 3);
lxl_arena A; lxl_arena_init(L, &A);
// copy command line arguments
#ifdef _WIN32
if ( !(commandline = utfconv_fromutf8(&A, luaL_checkstring(L, 1))) )
return luaL_error(L, "%s", UTFCONV_ERROR_INVALID_CONVERSION);
#else
luaL_checktype(L, 1, LUA_TTABLE); luaL_checktype(L, 1, LUA_TTABLE);
#if LUA_VERSION_NUM > 501 int len = luaL_len(L, 1);
lua_len(L, 1); cmd = lxl_arena_zero(&A, (len + 1) * sizeof(char *));
#else for (int i = 0; i < len; i++) {
lua_pushinteger(L, (int)lua_objlen(L, 1)); cmd[i] = lxl_arena_strdup(&A, (lua_rawgeti(L, 1, i+1), luaL_checkstring(L, -1)));
#endif
cmd_len = luaL_checknumber(L, -1); lua_pop(L, 1);
if (!cmd_len)
return luaL_argerror(L, 1, "table cannot be empty");
// check if each arguments is a string
for (size_t i = 1; i <= cmd_len; ++i) {
lua_rawgeti(L, 1, i);
luaL_checkstring(L, -1);
lua_pop(L, 1);
} }
#endif
if (lua_istable(L, 2)) { if (lua_istable(L, 2)) {
lua_getfield(L, 2, "detach"); detach = lua_toboolean(L, -1); lua_getfield(L, 2, "detach"); detach = lua_toboolean(L, -1);
lua_getfield(L, 2, "timeout"); deadline = luaL_optnumber(L, -1, deadline); lua_getfield(L, 2, "timeout"); deadline = luaL_optnumber(L, -1, deadline);
lua_getfield(L, 2, "cwd"); cwd = luaL_optstring(L, -1, NULL);
lua_getfield(L, 2, "stdin"); new_fds[STDIN_FD] = luaL_optnumber(L, -1, STDIN_FD); lua_getfield(L, 2, "stdin"); new_fds[STDIN_FD] = luaL_optnumber(L, -1, STDIN_FD);
lua_getfield(L, 2, "stdout"); new_fds[STDOUT_FD] = luaL_optnumber(L, -1, STDOUT_FD); lua_getfield(L, 2, "stdout"); new_fds[STDOUT_FD] = luaL_optnumber(L, -1, STDOUT_FD);
lua_getfield(L, 2, "stderr"); new_fds[STDERR_FD] = luaL_optnumber(L, -1, STDERR_FD); lua_getfield(L, 2, "stderr"); new_fds[STDERR_FD] = luaL_optnumber(L, -1, STDERR_FD);
@ -620,52 +390,42 @@ static int process_start(lua_State* L) {
if (new_fds[stream] > STDERR_FD || new_fds[stream] < REDIRECT_PARENT) if (new_fds[stream] > STDERR_FD || new_fds[stream] < REDIRECT_PARENT)
return luaL_error(L, "error: redirect to handles, FILE* and paths are not supported"); return luaL_error(L, "error: redirect to handles, FILE* and paths are not supported");
} }
lua_pop(L, 6); // pop all the values above lua_pop(L, 5); // pop all the values above
luaL_getsubtable(L, 2, "env"); #ifdef _WIN32
// count environment variobles if (lua_getfield(L, 2, "env") == LUA_TFUNCTION) {
lua_pushnil(L); lua_newtable(L);
while (lua_next(L, -2) != 0) { LPWCH system_env = GetEnvironmentStringsW(), envp = system_env;
luaL_checkstring(L, -2); while (wcslen(envp) > 0) {
luaL_checkstring(L, -1); const char *env = utfconv_fromwstr(&A, envp), *eq = env ? strchr(env, '=') : NULL;
lua_pop(L, 1); if (!env) return (FreeEnvironmentStringsW(system_env), luaL_error(L, "%s", UTFCONV_ERROR_INVALID_CONVERSION));
env_len++; if (!eq) return (FreeEnvironmentStringsW(system_env), luaL_error(L, "invalid environment variable"));
lua_pushlstring(L, env, eq - env); lua_pushstring(L, eq+1);
lxl_arena_free(&A, (void *) env);
lua_rawset(L, -3);
envp += wcslen(envp) + 1;
} }
FreeEnvironmentStringsW(system_env);
if (env_len) { lua_call(L, 1, 1);
if ((r = process_env_init(&env_vars, &env_vars_len, env_len)) != 0) { size_t len = 0; const char *env_mb = luaL_checklstring(L, -1, &len);
retval = -1; if (!(env = utfconv_fromlutf8(&A, env_mb, len)))
push_error(L, "cannot allocate environment list", r); return luaL_error(L, "%s", UTFCONV_ERROR_INVALID_CONVERSION);
goto cleanup;
} }
if (lua_getfield(L, 2, "cwd"), luaL_optstring(L, -1, NULL)) {
lua_pushnil(L); if ( !(cwd = utfconv_fromutf8(&A, lua_tostring(L, -1))) )
while (lua_next(L, -2) != 0) { return luaL_error(L, UTFCONV_ERROR_INVALID_CONVERSION);
if ((r = process_env_add(&env_vars, &env_vars_len, lua_tostring(L, -2), lua_tostring(L, -1))) != 0) {
retval = -1;
push_error(L, "cannot copy environment variable", r);
goto cleanup;
} }
lua_pop(L, 1); lua_pop(L, 2);
env_len++; #else
if (lua_getfield(L, 2, "env") == LUA_TFUNCTION) {
lua_newtable(L);
lua_call(L, 1, 1);
size_t len = 0; env = lua_tolstring(L, -1, &len);
env = lxl_arena_copy(&A, (void *) env, len+1);
} }
} cwd = lxl_arena_strdup(&A, (lua_getfield(L, 2, "cwd"), luaL_optstring(L, -1, NULL)));
} lua_pop(L, 2);
#endif
// allocate and copy commands
if ((r = process_arglist_init(&arglist, &arglist_len, cmd_len)) != 0) {
retval = -1;
push_error(L, "cannot create argument list", r);
goto cleanup;
}
for (size_t i = 1; i <= cmd_len; i++) {
lua_rawgeti(L, 1, i);
if ((r = process_arglist_add(&arglist, &arglist_len, lua_tostring(L, -1), escape)) != 0) {
retval = -1;
push_error(L, "cannot add argument", r);
goto cleanup;
}
lua_pop(L, 1);
} }
self = lua_newuserdata(L, sizeof(process_t)); self = lua_newuserdata(L, sizeof(process_t));
@ -674,13 +434,6 @@ static int process_start(lua_State* L) {
self->deadline = deadline; self->deadline = deadline;
self->detached = detach; self->detached = detach;
#if _WIN32 #if _WIN32
if (env_vars) {
if ((r = process_env_add_system(&env_vars, &env_vars_len)) != 0) {
retval = -1;
push_error(L, "cannot add environment variable", r);
goto cleanup;
}
}
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
switch (new_fds[i]) { switch (new_fds[i]) {
case REDIRECT_PARENT: case REDIRECT_PARENT:
@ -739,10 +492,7 @@ static int process_start(lua_State* L) {
siStartInfo.hStdInput = self->child_pipes[STDIN_FD][0]; siStartInfo.hStdInput = self->child_pipes[STDIN_FD][0];
siStartInfo.hStdOutput = self->child_pipes[STDOUT_FD][1]; siStartInfo.hStdOutput = self->child_pipes[STDOUT_FD][1];
siStartInfo.hStdError = self->child_pipes[STDERR_FD][1]; siStartInfo.hStdError = self->child_pipes[STDERR_FD][1];
wchar_t cwd_w[MAX_PATH]; if (!CreateProcessW(NULL, commandline, NULL, NULL, true, (detach ? DETACHED_PROCESS : CREATE_NO_WINDOW) | CREATE_UNICODE_ENVIRONMENT, env, cwd, &siStartInfo, &self->process_information)) {
if (cwd) // TODO: error handling
MultiByteToWideChar(CP_UTF8, 0, cwd, -1, cwd_w, MAX_PATH);
if (!CreateProcessW(NULL, arglist, NULL, NULL, true, (detach ? DETACHED_PROCESS : CREATE_NO_WINDOW) | CREATE_UNICODE_ENVIRONMENT, env_vars, cwd ? cwd_w : NULL, &siStartInfo, &self->process_information)) {
push_error(L, NULL, GetLastError()); push_error(L, NULL, GetLastError());
retval = -1; retval = -1;
goto cleanup; goto cleanup;
@ -789,10 +539,17 @@ static int process_start(lua_State* L) {
dup2(self->child_pipes[new_fds[stream]][new_fds[stream] == STDIN_FD ? 0 : 1], stream); dup2(self->child_pipes[new_fds[stream]][new_fds[stream] == STDIN_FD ? 0 : 1], stream);
close(self->child_pipes[stream][stream == STDIN_FD ? 1 : 0]); close(self->child_pipes[stream][stream == STDIN_FD ? 1 : 0]);
} }
size_t set; if (env) {
for (set = 0; set < env_vars_len && setenv(env_vars[set], env_vars[set+1], 1) == 0; set += 2); size_t len = 0;
if (set == env_vars_len && (!detach || setsid() != -1) && (!cwd || chdir(cwd) != -1)) while ((len = strlen(env)) != 0) {
execvp(arglist[0], (char** const)arglist); char *value = strchr(env, '=');
*value = '\0'; value++; // change the '=' into '\0', forming 2 strings side by side
setenv(env, value, 1);
env += len+1;
}
}
if ((!detach || setsid() != -1) && (!cwd || chdir(cwd) != -1))
execvp(cmd[0], (char** const) cmd);
write(control_pipe[1], &errno, sizeof(errno)); write(control_pipe[1], &errno, sizeof(errno));
_exit(-1); _exit(-1);
} }
@ -834,8 +591,6 @@ static int process_start(lua_State* L) {
} }
} }
} }
process_arglist_free(&arglist);
process_env_free(&env_vars, env_vars_len);
if (retval == -1) if (retval == -1)
return lua_error(L); return lua_error(L);