From f93e1c8851ec2b13638e8deb0d5b77da0d8bc010 Mon Sep 17 00:00:00 2001 From: Elliot Saba Date: Wed, 6 Jul 2022 19:17:59 +0000 Subject: [PATCH] Remove `jl_getch()` to fix race condition in `getpass()` We accidentally introduced a race condition in `getpass()` by having `jl_getch()` toggle terminal modes for each keystroke. Not only is this slower and wasteful, it allows the kernel to receive keystrokes within a TTY in canonical mode (where it replaces certain characters [0]) and then reads from the kernel buffer in non-canonical mode. This results in us reading a `0x00` when we expected a `0x04` in certain cases on CI, which breaks some of our tests. The fix is to switch the TTY into raw mode once, before we ever print the password prompt, which closes the race condition. To do this, we moved more code from C to Julia, and removed the `jl_getch()` export, instead providing `jl_termios_size()`. [0] https://github.com/torvalds/linux/blob/e35e5b6f695d241ffb1d223207da58a1fbcdff4b/drivers/tty/n_tty.c#L1318 --- base/util.jl | 90 ++++++++++++++++++++++++++++++++------- src/jl_exported_funcs.inc | 2 +- src/julia.h | 2 +- src/sys.c | 23 ++-------- 4 files changed, 79 insertions(+), 38 deletions(-) diff --git a/base/util.jl b/base/util.jl index 46e7f36475b98..798f7c8866d13 100644 --- a/base/util.jl +++ b/base/util.jl @@ -257,26 +257,84 @@ graphical interface. """ function getpass end -_getch() = UInt8(ccall(:jl_getch, Cint, ())) +# Note, this helper only works within `with_raw_tty()` on POSIX platforms! +function _getch() + @static if Sys.iswindows() + return UInt8(ccall(:_getch, Cint, ())) + else + return only(read(stdin, 1)) + end +end + +const termios_size = Int(ccall(:jl_termios_size, Cint, ())) +make_termios() = zeros(UInt8, termios_size) + +# These values seem to hold on all OSes we care about: +# glibc Linux, musl Linux, macOS, FreeBSD +@enum TCSETATTR_FLAGS TCSANOW=0 TCSADRAIN=1 TCSAFLUSH=2 + +function tcgetattr(fd, termios) + ret = ccall(:tcgetattr, Cint, (Cint, Ptr{Cvoid}), Cint(fd), termios) + if ret != 0 + throw(IOError("tcgetattr failed", ret)) + end +end +function tcsetattr(fd, termios, mode::TCSETATTR_FLAGS = TCSADRAIN) + ret = ccall(:tcsetattr, Cint, (Cint, Cint, Ptr{Cvoid}), Cint(fd), Cint(mode), termios) + if ret != 0 + throw(IOError("tcsetattr failed", ret)) + end +end +cfmakeraw(termios) = ccall(:cfmakeraw, Cvoid, (Ptr{Cvoid},), termios) + +function with_raw_tty(f::Function, input::TTY) + input === stdin || throw(ArgumentError("with_raw_tty only works for stdin")) + fd = 0 + + # If we're on windows, we do nothing, as we have access to `_getch()` quite easily + @static if Sys.iswindows() + f() + return + end + + # Get the current terminal mode + old_termios = make_termios() + tcgetattr(fd, old_termios) + try + # Set a new, raw, terminal mode + new_termios = make_termios() + cfmakeraw(new_termios) + tcsetattr(fd, new_termios) + + # Call the user-supplied callback + f() + finally + # Always restore the terminal mode + tcsetattr(fd, old_termios) + end +end + function getpass(input::TTY, output::IO, prompt::AbstractString) input === stdin || throw(ArgumentError("getpass only works for stdin")) - print(output, prompt, ": ") - flush(output) - s = SecretBuffer() - plen = 0 - while true - c = _getch() - if c == 0xff || c == UInt8('\n') || c == UInt8('\r') || c == 0x04 - break # EOF or return - elseif c == 0x00 || c == 0xe0 - _getch() # ignore function/arrow keys - elseif c == UInt8('\b') && plen > 0 - plen -= 1 # delete last character on backspace - elseif !iscntrl(Char(c)) && plen < 128 - write(s, c) + with_raw_tty(stdin) do + print(output, prompt, ": ") + flush(output) + s = SecretBuffer() + plen = 0 + while true + c = _getch() + if c == 0xff || c == UInt8('\n') || c == UInt8('\r') || c == 0x04 + break # EOF or return + elseif c == 0x00 || c == 0xe0 + _getch() # ignore function/arrow keys + elseif c == UInt8('\b') && plen > 0 + plen -= 1 # delete last character on backspace + elseif !iscntrl(Char(c)) && plen < 128 + write(s, c) + end end + return seekstart(s) end - return seekstart(s) end # allow new getpass methods to be defined if stdin has been diff --git a/src/jl_exported_funcs.inc b/src/jl_exported_funcs.inc index 72d385329ce49..f0cc94d22ba68 100644 --- a/src/jl_exported_funcs.inc +++ b/src/jl_exported_funcs.inc @@ -199,7 +199,6 @@ XX(jl_generic_function_def) \ XX(jl_gensym) \ XX(jl_getallocationgranularity) \ - XX(jl_getch) \ XX(jl_getnameinfo) \ XX(jl_getpagesize) \ XX(jl_get_ARCH) \ @@ -462,6 +461,7 @@ XX(jl_take_buffer) \ XX(jl_task_get_next) \ XX(jl_task_stack_buffer) \ + XX(jl_termios_size) \ XX(jl_test_cpu_feature) \ XX(jl_threadid) \ XX(jl_threadpoolid) \ diff --git a/src/julia.h b/src/julia.h index ada09fe61fadd..f8c39c7ab448b 100644 --- a/src/julia.h +++ b/src/julia.h @@ -2058,7 +2058,7 @@ extern JL_DLLEXPORT JL_STREAM *JL_STDERR; JL_DLLEXPORT JL_STREAM *jl_stdout_stream(void); JL_DLLEXPORT JL_STREAM *jl_stdin_stream(void); JL_DLLEXPORT JL_STREAM *jl_stderr_stream(void); -JL_DLLEXPORT int jl_getch(void); +JL_DLLEXPORT int jl_termios_size(void); // showing and std streams JL_DLLEXPORT void jl_flush_cstdio(void) JL_NOTSAFEPOINT; diff --git a/src/sys.c b/src/sys.c index 2f512888c1873..2de4bc61a20b8 100644 --- a/src/sys.c +++ b/src/sys.c @@ -517,28 +517,11 @@ JL_DLLEXPORT JL_STREAM *jl_stdin_stream(void) { return JL_STDIN; } JL_DLLEXPORT JL_STREAM *jl_stdout_stream(void) { return JL_STDOUT; } JL_DLLEXPORT JL_STREAM *jl_stderr_stream(void) { return JL_STDERR; } -// terminal workarounds -JL_DLLEXPORT int jl_getch(void) JL_NOTSAFEPOINT -{ +JL_DLLEXPORT int jl_termios_size(void) { #if defined(_OS_WINDOWS_) - // Windows has an actual `_getch()`, use that: - return _getch(); + return 0; #else - // On all other platforms, we do the POSIX terminal manipulation dance - char c; - int r; - struct termios old_termios = {0}; - struct termios new_termios = {0}; - if (tcgetattr(0, &old_termios) != 0) - return -1; - new_termios = old_termios; - cfmakeraw(&new_termios); - if (tcsetattr(0, TCSADRAIN, &new_termios) != 0) - return -1; - r = read(0, &c, 1); - if (tcsetattr(0, TCSADRAIN, &old_termios) != 0) - return -1; - return r == 1 ? c : -1; + return sizeof(struct termios); #endif }