Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove jl_getch() to fix race condition in getpass() #45954

Merged
merged 4 commits into from
Jul 8, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 74 additions & 16 deletions base/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,26 +257,84 @@ graphical interface.
"""
function getpass end

_getch() = UInt8(ccall(:jl_getch, Cint, ()))
# Note, this helper only works within `with_raw_tty()` on POSIX platforms!
function _getch()
@static if Sys.iswindows()
return UInt8(ccall(:_getch, Cint, ()))
else
return read(stdin, UInt8)
end
end

const termios_size = Int(ccall(:jl_termios_size, Cint, ()))
make_termios() = zeros(UInt8, termios_size)

# These values seem to hold on all OSes we care about:
# glibc Linux, musl Linux, macOS, FreeBSD
@enum TCSETATTR_FLAGS TCSANOW=0 TCSADRAIN=1 TCSAFLUSH=2

function tcgetattr(fd, termios)
ret = ccall(:tcgetattr, Cint, (Cint, Ptr{Cvoid}), Cint(fd), termios)
staticfloat marked this conversation as resolved.
Show resolved Hide resolved
if ret != 0
throw(IOError("tcgetattr failed", ret))
end
end
function tcsetattr(fd, termios, mode::TCSETATTR_FLAGS = TCSADRAIN)
ret = ccall(:tcsetattr, Cint, (Cint, Cint, Ptr{Cvoid}), Cint(fd), Cint(mode), termios)
staticfloat marked this conversation as resolved.
Show resolved Hide resolved
if ret != 0
throw(IOError("tcsetattr failed", ret))
end
end
cfmakeraw(termios) = ccall(:cfmakeraw, Cvoid, (Ptr{Cvoid},), termios)

function with_raw_tty(f::Function, input::TTY)
input === stdin || throw(ArgumentError("with_raw_tty only works for stdin"))
fd = 0
staticfloat marked this conversation as resolved.
Show resolved Hide resolved

# If we're on windows, we do nothing, as we have access to `_getch()` quite easily
@static if Sys.iswindows()
f()
return
staticfloat marked this conversation as resolved.
Show resolved Hide resolved
end

# Get the current terminal mode
old_termios = make_termios()
tcgetattr(fd, old_termios)
try
# Set a new, raw, terminal mode
new_termios = make_termios()
staticfloat marked this conversation as resolved.
Show resolved Hide resolved
cfmakeraw(new_termios)
tcsetattr(fd, new_termios)

# Call the user-supplied callback
f()
finally
# Always restore the terminal mode
tcsetattr(fd, old_termios)
end
end

function getpass(input::TTY, output::IO, prompt::AbstractString)
input === stdin || throw(ArgumentError("getpass only works for stdin"))
print(output, prompt, ": ")
flush(output)
s = SecretBuffer()
plen = 0
while true
c = _getch()
if c == 0xff || c == UInt8('\n') || c == UInt8('\r') || c == 0x04
break # EOF or return
elseif c == 0x00 || c == 0xe0
_getch() # ignore function/arrow keys
elseif c == UInt8('\b') && plen > 0
plen -= 1 # delete last character on backspace
elseif !iscntrl(Char(c)) && plen < 128
write(s, c)
with_raw_tty(stdin) do
print(output, prompt, ": ")
flush(output)
s = SecretBuffer()
plen = 0
while true
c = _getch()
if c == 0xff || c == UInt8('\n') || c == UInt8('\r') || c == 0x04
break # EOF or return
elseif c == 0x00 || c == 0xe0
_getch() # ignore function/arrow keys
elseif c == UInt8('\b') && plen > 0
plen -= 1 # delete last character on backspace
elseif !iscntrl(Char(c)) && plen < 128
write(s, c)
end
end
return seekstart(s)
end
return seekstart(s)
end

# allow new getpass methods to be defined if stdin has been
Expand Down
2 changes: 1 addition & 1 deletion src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@
XX(jl_generic_function_def) \
XX(jl_gensym) \
XX(jl_getallocationgranularity) \
XX(jl_getch) \
XX(jl_getnameinfo) \
XX(jl_getpagesize) \
XX(jl_get_ARCH) \
Expand Down Expand Up @@ -462,6 +461,7 @@
XX(jl_take_buffer) \
XX(jl_task_get_next) \
XX(jl_task_stack_buffer) \
XX(jl_termios_size) \
XX(jl_test_cpu_feature) \
XX(jl_threadid) \
XX(jl_threadpoolid) \
Expand Down
2 changes: 1 addition & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -2058,7 +2058,7 @@ extern JL_DLLEXPORT JL_STREAM *JL_STDERR;
JL_DLLEXPORT JL_STREAM *jl_stdout_stream(void);
JL_DLLEXPORT JL_STREAM *jl_stdin_stream(void);
JL_DLLEXPORT JL_STREAM *jl_stderr_stream(void);
JL_DLLEXPORT int jl_getch(void);
JL_DLLEXPORT int jl_termios_size(void);

// showing and std streams
JL_DLLEXPORT void jl_flush_cstdio(void) JL_NOTSAFEPOINT;
Expand Down
23 changes: 3 additions & 20 deletions src/sys.c
Original file line number Diff line number Diff line change
Expand Up @@ -517,28 +517,11 @@ JL_DLLEXPORT JL_STREAM *jl_stdin_stream(void) { return JL_STDIN; }
JL_DLLEXPORT JL_STREAM *jl_stdout_stream(void) { return JL_STDOUT; }
JL_DLLEXPORT JL_STREAM *jl_stderr_stream(void) { return JL_STDERR; }

// terminal workarounds
JL_DLLEXPORT int jl_getch(void) JL_NOTSAFEPOINT
{
JL_DLLEXPORT int jl_termios_size(void) {
#if defined(_OS_WINDOWS_)
// Windows has an actual `_getch()`, use that:
return _getch();
return 0;
#else
// On all other platforms, we do the POSIX terminal manipulation dance
char c;
int r;
struct termios old_termios = {0};
struct termios new_termios = {0};
if (tcgetattr(0, &old_termios) != 0)
return -1;
new_termios = old_termios;
cfmakeraw(&new_termios);
if (tcsetattr(0, TCSADRAIN, &new_termios) != 0)
return -1;
r = read(0, &c, 1);
if (tcsetattr(0, TCSADRAIN, &old_termios) != 0)
return -1;
return r == 1 ? c : -1;
return sizeof(struct termios);
#endif
}

Expand Down