Skip to content

Commit

Permalink
gh-111201: Speed up paste mode in the REPL
Browse files Browse the repository at this point in the history
  • Loading branch information
pablogsal committed May 21, 2024
1 parent b7f45a9 commit cef4312
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 13 deletions.
5 changes: 2 additions & 3 deletions Lib/_pyrepl/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,18 +458,17 @@ def do(self) -> None:
class paste_mode(Command):

def do(self) -> None:
if not self.reader.paste_mode:
self.reader.was_paste_mode_activated = True
self.reader.paste_mode = not self.reader.paste_mode
self.reader.dirty = True


class enable_bracketed_paste(Command):
def do(self) -> None:
self.reader.paste_mode = True
self.reader.was_paste_mode_activated = True
self.reader.in_bracketed_paste = True

class disable_bracketed_paste(Command):
def do(self) -> None:
self.reader.paste_mode = False
self.reader.in_bracketed_paste = False
self.reader.dirty = True
8 changes: 4 additions & 4 deletions Lib/_pyrepl/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def disp_str(buffer: str) -> tuple[str, list[int]]:
b: list[int] = []
s: list[str] = []
for c in buffer:
if unicodedata.category(c).startswith("C"):
if ord(c) > 128 and unicodedata.category(c).startswith("C"):
c = r"\u%04x" % ord(c)
s.append(c)
b.append(wlen(c))
Expand Down Expand Up @@ -223,7 +223,7 @@ class Reader:
dirty: bool = False
finished: bool = False
paste_mode: bool = False
was_paste_mode_activated: bool = False
in_bracketed_paste: bool = False
commands: dict[str, type[Command]] = field(default_factory=make_default_commands)
last_command: type[Command] | None = None
syntax_table: dict[str, int] = field(default_factory=make_default_syntax_table)
Expand Down Expand Up @@ -422,7 +422,7 @@ def get_prompt(self, lineno: int, cursor_on_line: bool) -> str:
elif "\n" in self.buffer:
if lineno == 0:
prompt = self.ps2
elif lineno == self.buffer.count("\n"):
elif self.ps4 and lineno == self.buffer.count("\n"):
prompt = self.ps4
else:
prompt = self.ps3
Expand Down Expand Up @@ -585,7 +585,7 @@ def do_cmd(self, cmd: tuple[str, list[str]]) -> None:

self.after_command(command)

if self.dirty:
if self.dirty and not self.in_bracketed_paste:
self.refresh()
else:
self.update_cursor()
Expand Down
6 changes: 3 additions & 3 deletions Lib/_pyrepl/readline.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,12 @@ def multiline_input(self, more_lines: MoreLinesCallable, ps1: str, ps2: str) ->
try:
reader.more_lines = more_lines
reader.ps1 = reader.ps2 = ps1
reader.ps3 = reader.ps4 = ps2
return reader.readline(), reader.was_paste_mode_activated
reader.ps3 = ps2
reader.ps4 = ""
return reader.readline()
finally:
reader.more_lines = saved
reader.paste_mode = False
reader.was_paste_mode_activated = False

def parse_and_bind(self, string: str) -> None:
pass # XXX we don't support parsing GNU-readline-style init files
Expand Down
2 changes: 1 addition & 1 deletion Lib/_pyrepl/simple_interact.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def more_lines(unicodetext: str) -> bool:
ps1 = getattr(sys, "ps1", ">>> ")
ps2 = getattr(sys, "ps2", "... ")
try:
statement, contains_pasted_code = multiline_input(more_lines, ps1, ps2)
statement = multiline_input(more_lines, ps1, ps2)
except EOFError:
break

Expand Down
8 changes: 6 additions & 2 deletions Lib/_pyrepl/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import re
import unicodedata
import functools

ANSI_ESCAPE_SEQUENCE = re.compile(r"\x1b\[[ -@]*[A-~]")


@functools.cache
def str_width(c: str) -> int:
if ord(c) < 128:
return 1
w = unicodedata.east_asian_width(c)
if w in ('N', 'Na', 'H', 'A'):
return 1
Expand All @@ -13,6 +17,6 @@ def str_width(c: str) -> int:

def wlen(s: str) -> int:
length = sum(str_width(i) for i in s)

# remove lengths of any escape sequences
return length - sum(len(i) for i in ANSI_ESCAPE_SEQUENCE.findall(s))
sequence = ANSI_ESCAPE_SEQUENCE.findall(s)
return length - sum(len(i) for i in sequence)

0 comments on commit cef4312

Please sign in to comment.