Skip to content

Commit

Permalink
Remove jl_getch() to fix race condition in getpass()
Browse files Browse the repository at this point in the history
We accidentally introduced a race condition in `getpass()` by having
`jl_getch()` toggle terminal modes for each keystroke.  Not only is this
slower and wasteful, it allows the kernel to receive keystrokes within a
TTY in canonical mode (where it replaces certain characters [0]) and
then reads from the kernel buffer in non-canonical mode.  This results
in us reading a `0x00` when we expected a `0x04` in certain cases on CI,
which breaks some of our tests.

The fix is to switch the TTY into raw mode once, before we ever print
the password prompt, which closes the race condition.  To do this, we
moved more code from C to Julia, and removed the `jl_getch()` export,
instead providing `jl_termios_size()`.

[0] https://github.com/torvalds/linux/blob/e35e5b6f695d241ffb1d223207da58a1fbcdff4b/drivers/tty/n_tty.c#L1318
  • Loading branch information
staticfloat committed Jul 6, 2022
1 parent 301b62a commit f93e1c8
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 38 deletions.
90 changes: 74 additions & 16 deletions base/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,26 +257,84 @@ graphical interface.
"""
function getpass end

_getch() = UInt8(ccall(:jl_getch, Cint, ()))
# Note, this helper only works within `with_raw_tty()` on POSIX platforms!
function _getch()
@static if Sys.iswindows()
return UInt8(ccall(:_getch, Cint, ()))
else
return only(read(stdin, 1))
end
end

const termios_size = Int(ccall(:jl_termios_size, Cint, ()))
make_termios() = zeros(UInt8, termios_size)

# These values seem to hold on all OSes we care about:
# glibc Linux, musl Linux, macOS, FreeBSD
@enum TCSETATTR_FLAGS TCSANOW=0 TCSADRAIN=1 TCSAFLUSH=2

function tcgetattr(fd, termios)
ret = ccall(:tcgetattr, Cint, (Cint, Ptr{Cvoid}), Cint(fd), termios)
if ret != 0
throw(IOError("tcgetattr failed", ret))
end
end
function tcsetattr(fd, termios, mode::TCSETATTR_FLAGS = TCSADRAIN)
ret = ccall(:tcsetattr, Cint, (Cint, Cint, Ptr{Cvoid}), Cint(fd), Cint(mode), termios)
if ret != 0
throw(IOError("tcsetattr failed", ret))
end
end
cfmakeraw(termios) = ccall(:cfmakeraw, Cvoid, (Ptr{Cvoid},), termios)

function with_raw_tty(f::Function, input::TTY)
input === stdin || throw(ArgumentError("with_raw_tty only works for stdin"))
fd = 0

# If we're on windows, we do nothing, as we have access to `_getch()` quite easily
@static if Sys.iswindows()
f()
return
end

# Get the current terminal mode
old_termios = make_termios()
tcgetattr(fd, old_termios)
try
# Set a new, raw, terminal mode
new_termios = make_termios()
cfmakeraw(new_termios)
tcsetattr(fd, new_termios)

# Call the user-supplied callback
f()
finally
# Always restore the terminal mode
tcsetattr(fd, old_termios)
end
end

function getpass(input::TTY, output::IO, prompt::AbstractString)
input === stdin || throw(ArgumentError("getpass only works for stdin"))
print(output, prompt, ": ")
flush(output)
s = SecretBuffer()
plen = 0
while true
c = _getch()
if c == 0xff || c == UInt8('\n') || c == UInt8('\r') || c == 0x04
break # EOF or return
elseif c == 0x00 || c == 0xe0
_getch() # ignore function/arrow keys
elseif c == UInt8('\b') && plen > 0
plen -= 1 # delete last character on backspace
elseif !iscntrl(Char(c)) && plen < 128
write(s, c)
with_raw_tty(stdin) do
print(output, prompt, ": ")
flush(output)
s = SecretBuffer()
plen = 0
while true
c = _getch()
if c == 0xff || c == UInt8('\n') || c == UInt8('\r') || c == 0x04
break # EOF or return
elseif c == 0x00 || c == 0xe0
_getch() # ignore function/arrow keys
elseif c == UInt8('\b') && plen > 0
plen -= 1 # delete last character on backspace
elseif !iscntrl(Char(c)) && plen < 128
write(s, c)
end
end
return seekstart(s)
end
return seekstart(s)
end

# allow new getpass methods to be defined if stdin has been
Expand Down
2 changes: 1 addition & 1 deletion src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@
XX(jl_generic_function_def) \
XX(jl_gensym) \
XX(jl_getallocationgranularity) \
XX(jl_getch) \
XX(jl_getnameinfo) \
XX(jl_getpagesize) \
XX(jl_get_ARCH) \
Expand Down Expand Up @@ -462,6 +461,7 @@
XX(jl_take_buffer) \
XX(jl_task_get_next) \
XX(jl_task_stack_buffer) \
XX(jl_termios_size) \
XX(jl_test_cpu_feature) \
XX(jl_threadid) \
XX(jl_threadpoolid) \
Expand Down
2 changes: 1 addition & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -2058,7 +2058,7 @@ extern JL_DLLEXPORT JL_STREAM *JL_STDERR;
JL_DLLEXPORT JL_STREAM *jl_stdout_stream(void);
JL_DLLEXPORT JL_STREAM *jl_stdin_stream(void);
JL_DLLEXPORT JL_STREAM *jl_stderr_stream(void);
JL_DLLEXPORT int jl_getch(void);
JL_DLLEXPORT int jl_termios_size(void);

// showing and std streams
JL_DLLEXPORT void jl_flush_cstdio(void) JL_NOTSAFEPOINT;
Expand Down
23 changes: 3 additions & 20 deletions src/sys.c
Original file line number Diff line number Diff line change
Expand Up @@ -517,28 +517,11 @@ JL_DLLEXPORT JL_STREAM *jl_stdin_stream(void) { return JL_STDIN; }
JL_DLLEXPORT JL_STREAM *jl_stdout_stream(void) { return JL_STDOUT; }
JL_DLLEXPORT JL_STREAM *jl_stderr_stream(void) { return JL_STDERR; }

// terminal workarounds
JL_DLLEXPORT int jl_getch(void) JL_NOTSAFEPOINT
{
JL_DLLEXPORT int jl_termios_size(void) {
#if defined(_OS_WINDOWS_)
// Windows has an actual `_getch()`, use that:
return _getch();
return 0;
#else
// On all other platforms, we do the POSIX terminal manipulation dance
char c;
int r;
struct termios old_termios = {0};
struct termios new_termios = {0};
if (tcgetattr(0, &old_termios) != 0)
return -1;
new_termios = old_termios;
cfmakeraw(&new_termios);
if (tcsetattr(0, TCSADRAIN, &new_termios) != 0)
return -1;
r = read(0, &c, 1);
if (tcsetattr(0, TCSADRAIN, &old_termios) != 0)
return -1;
return r == 1 ? c : -1;
return sizeof(struct termios);
#endif
}

Expand Down

0 comments on commit f93e1c8

Please sign in to comment.