Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

possible improvement for read until handling #193

Merged
merged 1 commit into from
Nov 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions scrapli/channel/async_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import asyncio
import re
import time
from io import SEEK_END, BytesIO
from io import BytesIO

try:
from contextlib import asynccontextmanager
Expand Down Expand Up @@ -143,8 +143,7 @@ async def _read_until_prompt(self, buf: bytes = b"") -> bytes:
b = await self.read()
read_buf.write(b)

read_buf.seek(-self._base_channel_args.comms_prompt_search_depth, SEEK_END)
search_buf = read_buf.read()
search_buf = self._process_read_buf(read_buf=read_buf)

channel_match = re.search(
pattern=search_pattern,
Expand Down Expand Up @@ -185,8 +184,7 @@ async def _read_until_explicit_prompt(self, prompts: List[str]) -> bytes:
b = await self.read()
read_buf.write(b)

read_buf.seek(-self._base_channel_args.comms_prompt_search_depth, SEEK_END)
search_buf = read_buf.read()
search_buf = self._process_read_buf(read_buf=read_buf)

for search_pattern in search_patterns:
channel_match = re.search(
Expand Down Expand Up @@ -247,8 +245,7 @@ async def _read_until_prompt_or_time(
except ScrapliTimeout:
pass

read_buf.seek(-self._base_channel_args.comms_prompt_search_depth, SEEK_END)
search_buf = read_buf.read()
search_buf = self._process_read_buf(read_buf=read_buf)

if (time.time() - start) > read_duration:
break
Expand Down
31 changes: 30 additions & 1 deletion scrapli/channel/base_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from datetime import datetime
from functools import lru_cache
from io import BytesIO
from io import SEEK_END, BytesIO
from typing import BinaryIO, List, Optional, Pattern, Tuple, Union

from scrapli.exceptions import ScrapliAuthenticationFailed, ScrapliTypeError, ScrapliValueError
Expand Down Expand Up @@ -286,6 +286,35 @@ def close(self) -> None:
if self.channel_log:
self.channel_log.close()

def _process_read_buf(self, read_buf: BytesIO) -> bytes:
"""
Process the read buffer

Seeks backwards up to search depth then partitions on newlines. Partition is to ensure that
the resulting search_buf does not end up with partial lines in the output which can cause
prompt patterns to match places they should not match!

Args:
read_buf: bytesio object read from the transport

Returns:
bytes: cleaned up search buffer

Raises:
N/A

"""
read_buf.seek(-self._base_channel_args.comms_prompt_search_depth, SEEK_END)
search_buf = read_buf.read()

before, _, search_buf = search_buf.partition(b"\n")

if not search_buf:
# didn't split on anything or nothing after partition
search_buf = before

return search_buf

def write(self, channel_input: str, redacted: bool = False) -> None:
"""
Write input to the underlying Transport session
Expand Down
11 changes: 4 additions & 7 deletions scrapli/channel/sync_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
from contextlib import contextmanager
from datetime import datetime
from io import SEEK_END, BytesIO
from io import BytesIO
from threading import Lock
from typing import Iterator, List, Optional, Tuple

Expand Down Expand Up @@ -136,8 +136,7 @@ def _read_until_prompt(self, buf: bytes = b"") -> bytes:
while True:
read_buf.write(self.read())

read_buf.seek(-self._base_channel_args.comms_prompt_search_depth, SEEK_END)
search_buf = read_buf.read()
search_buf = self._process_read_buf(read_buf=read_buf)

channel_match = re.search(
pattern=search_pattern,
Expand Down Expand Up @@ -177,8 +176,7 @@ def _read_until_explicit_prompt(self, prompts: List[str]) -> bytes:
while True:
read_buf.write(self.read())

read_buf.seek(-self._base_channel_args.comms_prompt_search_depth, SEEK_END)
search_buf = read_buf.read()
search_buf = self._process_read_buf(read_buf=read_buf)

for search_pattern in search_patterns:
channel_match = re.search(
Expand Down Expand Up @@ -238,8 +236,7 @@ def _read_until_prompt_or_time(
except ScrapliTimeout:
pass

read_buf.seek(-self._base_channel_args.comms_prompt_search_depth, SEEK_END)
search_buf = read_buf.read()
search_buf = self._process_read_buf(read_buf=read_buf)

if (time.time() - start) > read_duration:
break
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/channel/test_async_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,8 @@ async def test_send_inputs_interact(monkeypatch, async_channel):
_read_counter = 0
_event_counter = 0

interact_events = [("clear logg", "[confirm]\n"), ("", "scrapli>")]
expected_buf = b"clear logg[confirm]\nscrapli>"
interact_events = [("clear logg", "[confirm]"), ("", "scrapli>")]
expected_buf = b"clear logg[confirm]scrapli>"

async def _read(cls):
nonlocal _read_counter
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/channel/test_base_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,35 @@ def test_channel_log_user_bytesio(base_transport_no_abc):
assert chan.channel_log is bytes_log


@pytest.mark.parametrize(
"test_data",
(
(
b"blah basic stuff>",
b"blah basic stuff>",
),
(
b"\noper-status></physical-interface>\nsomeprompt>",
b"oper-status></physical-interface>\nsomeprompt>",
),
),
ids=("simple_buf", "xml_out"),
)
def test_process_read_buf(test_data, base_channel):
"""
This asserts that the process read buf method always returns a search buf that is "roooted"
on newlines -- meaning we never scan backwards through the readbuf and return a line that gets
split up resulting in a section of the line being at the "start" of the search_buf that looks
exactly like a normal prompt we would match on
"""
inbuf, expected_buf = test_data

read_buf = BytesIO(inbuf)
search_buf = base_channel._process_read_buf(read_buf=read_buf)

assert search_buf == expected_buf


def test_channel_write(caplog, monkeypatch, base_channel):
caplog.set_level(logging.DEBUG, logger="scrapli.channel")

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/channel/test_sync_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,8 @@ def test_send_inputs_interact(monkeypatch, sync_channel):
_read_counter = 0
_event_counter = 0

interact_events = [("clear logg", "[confirm]\n"), ("", "scrapli>")]
expected_buf = b"clear logg[confirm]\nscrapli>"
interact_events = [("clear logg", "[confirm]"), ("", "scrapli>")]
expected_buf = b"clear logg[confirm]scrapli>"

def _read(cls):
nonlocal _read_counter
Expand Down