forked from TheR1D/shell_gpt
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improved tests with mocked API responses (TheR1D#442)
* Changed name of --top-probability parameter to --top-p. * Fixed bug in --repl --shell when describing shell command.
- Loading branch information
Showing
12 changed files
with
600 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import datetime | ||
|
||
import pytest | ||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk | ||
from openai.types.chat.chat_completion_chunk import Choice as StreamChoice | ||
from openai.types.chat.chat_completion_chunk import ChoiceDelta | ||
|
||
from sgpt.config import cfg | ||
|
||
|
||
@pytest.fixture | ||
def completion(request): | ||
tokens_string = request.param | ||
return [ | ||
ChatCompletionChunk( | ||
id="foo", | ||
model=cfg.get("DEFAULT_MODEL"), | ||
object="chat.completion.chunk", | ||
choices=[ | ||
StreamChoice( | ||
index=0, | ||
finish_reason=None, | ||
delta=ChoiceDelta(content=token, role="assistant"), | ||
), | ||
], | ||
created=int(datetime.datetime.now().timestamp()), | ||
) | ||
for token in tokens_string | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
from pathlib import Path | ||
from unittest.mock import patch | ||
|
||
from sgpt.config import cfg | ||
from sgpt.role import DefaultRoles, SystemRole | ||
|
||
from .utils import app, comp_args, comp_chunks, make_args, parametrize, runner | ||
|
||
role = SystemRole.get(DefaultRoles.CODE.value) | ||
|
||
|
||
@parametrize("completion", ["print('Hello World')"], indirect=True) | ||
@patch("openai.resources.chat.Completions.create") | ||
def test_code_generation(mock, completion): | ||
mock.return_value = completion | ||
|
||
args = {"prompt": "hello world python", "--code": True} | ||
result = runner.invoke(app, make_args(**args)) | ||
|
||
mock.assert_called_once_with(**comp_args(role, args["prompt"])) | ||
assert result.exit_code == 0 | ||
assert "print('Hello World')" in result.stdout | ||
|
||
|
||
@parametrize("completion", ["# Hello\nprint('Hello')"], indirect=True) | ||
@patch("openai.resources.chat.Completions.create") | ||
def test_code_generation_stdin(mock, completion): | ||
mock.return_value = completion | ||
|
||
args = {"prompt": "make comments for code", "--code": True} | ||
stdin = "print('Hello')" | ||
result = runner.invoke(app, make_args(**args), input=stdin) | ||
|
||
expected_prompt = f"{stdin}\n\n{args['prompt']}" | ||
mock.assert_called_once_with(**comp_args(role, expected_prompt)) | ||
assert result.exit_code == 0 | ||
assert "# Hello" in result.stdout | ||
assert "print('Hello')" in result.stdout | ||
|
||
|
||
@patch("openai.resources.chat.Completions.create") | ||
def test_code_chat(mock): | ||
mock.side_effect = [ | ||
comp_chunks("print('hello')"), | ||
comp_chunks("print('hello')\nprint('world')"), | ||
] | ||
chat_name = "_test" | ||
chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name | ||
chat_path.unlink(missing_ok=True) | ||
|
||
args = {"prompt": "print hello", "--code": True, "--chat": chat_name} | ||
result = runner.invoke(app, make_args(**args)) | ||
assert result.exit_code == 0 | ||
assert "print('hello')" in result.stdout | ||
assert chat_path.exists() | ||
|
||
args["prompt"] = "also print world" | ||
result = runner.invoke(app, make_args(**args)) | ||
assert result.exit_code == 0 | ||
assert "print('hello')" in result.stdout | ||
assert "print('world')" in result.stdout | ||
|
||
expected_messages = [ | ||
{"role": "system", "content": role.role}, | ||
{"role": "user", "content": "print hello"}, | ||
{"role": "assistant", "content": "print('hello')"}, | ||
{"role": "user", "content": "also print world"}, | ||
{"role": "assistant", "content": "print('hello')\nprint('world')"}, | ||
] | ||
expected_args = comp_args(role, "", messages=expected_messages) | ||
mock.assert_called_with(**expected_args) | ||
assert mock.call_count == 2 | ||
|
||
args["--shell"] = True | ||
result = runner.invoke(app, make_args(**args)) | ||
assert result.exit_code == 2 | ||
assert "Error" in result.stdout | ||
chat_path.unlink() | ||
# TODO: Code chat can be recalled without --code option. | ||
|
||
|
||
@patch("openai.resources.chat.Completions.create") | ||
def test_code_repl(mock_completion): | ||
mock_completion.side_effect = [ | ||
comp_chunks("print('hello')"), | ||
comp_chunks("print('hello')\nprint('world')"), | ||
] | ||
chat_name = "_test" | ||
chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name | ||
chat_path.unlink(missing_ok=True) | ||
|
||
args = {"--repl": chat_name, "--code": True} | ||
inputs = ["print hello", "also print world", "exit()"] | ||
result = runner.invoke(app, make_args(**args), input="\n".join(inputs)) | ||
|
||
expected_messages = [ | ||
{"role": "system", "content": role.role}, | ||
{"role": "user", "content": "print hello"}, | ||
{"role": "assistant", "content": "print('hello')"}, | ||
{"role": "user", "content": "also print world"}, | ||
{"role": "assistant", "content": "print('hello')\nprint('world')"}, | ||
] | ||
expected_args = comp_args(role, "", messages=expected_messages) | ||
mock_completion.assert_called_with(**expected_args) | ||
assert mock_completion.call_count == 2 | ||
|
||
assert result.exit_code == 0 | ||
assert ">>> print hello" in result.stdout | ||
assert "print('hello')" in result.stdout | ||
assert ">>> also print world" in result.stdout | ||
assert "print('world')" in result.stdout | ||
|
||
|
||
@patch("openai.resources.chat.Completions.create") | ||
def test_code_and_shell(mock): | ||
args = {"--code": True, "--shell": True} | ||
result = runner.invoke(app, make_args(**args)) | ||
|
||
mock.assert_not_called() | ||
assert result.exit_code == 2 | ||
assert "Error" in result.stdout | ||
|
||
|
||
@patch("openai.resources.chat.Completions.create") | ||
def test_code_and_describe_shell(mock): | ||
args = {"--code": True, "--describe-shell": True} | ||
result = runner.invoke(app, make_args(**args)) | ||
|
||
mock.assert_not_called() | ||
assert result.exit_code == 2 | ||
assert "Error" in result.stdout |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
from pathlib import Path | ||
from unittest.mock import patch | ||
|
||
from sgpt import config | ||
from sgpt.__version__ import __version__ | ||
from sgpt.role import DefaultRoles, SystemRole | ||
|
||
from .utils import app, comp_args, comp_chunks, make_args, parametrize, runner | ||
|
||
role = SystemRole.get(DefaultRoles.DEFAULT.value) | ||
cfg = config.cfg | ||
|
||
|
||
@parametrize("completion", ["Prague"], indirect=True) | ||
@patch("openai.resources.chat.Completions.create") | ||
def test_default(mock, completion): | ||
mock.return_value = completion | ||
|
||
args = {"prompt": "capital of the Czech Republic?"} | ||
result = runner.invoke(app, make_args(**args)) | ||
|
||
mock.assert_called_once_with(**comp_args(role, **args)) | ||
assert result.exit_code == 0 | ||
assert "Prague" in result.stdout | ||
|
||
|
||
@parametrize("completion", ["Prague"], indirect=True) | ||
@patch("openai.resources.chat.Completions.create") | ||
def test_default_stdin(mock, completion): | ||
mock.return_value = completion | ||
|
||
stdin = "capital of the Czech Republic?" | ||
result = runner.invoke(app, make_args(), input=stdin) | ||
|
||
mock.assert_called_once_with(**comp_args(role, stdin)) | ||
assert result.exit_code == 0 | ||
assert "Prague" in result.stdout | ||
|
||
|
||
@patch("openai.resources.chat.Completions.create") | ||
def test_default_chat(mock): | ||
mock.side_effect = [comp_chunks("ok"), comp_chunks("4")] | ||
chat_name = "_test" | ||
chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name | ||
chat_path.unlink(missing_ok=True) | ||
|
||
args = {"prompt": "my number is 2", "--chat": chat_name} | ||
result = runner.invoke(app, make_args(**args)) | ||
assert result.exit_code == 0 | ||
assert "ok" in result.stdout | ||
assert chat_path.exists() | ||
|
||
args["prompt"] = "my number + 2?" | ||
result = runner.invoke(app, make_args(**args)) | ||
assert result.exit_code == 0 | ||
assert "4" in result.stdout | ||
|
||
expected_messages = [ | ||
{"role": "system", "content": role.role}, | ||
{"role": "user", "content": "my number is 2"}, | ||
{"role": "assistant", "content": "ok"}, | ||
{"role": "user", "content": "my number + 2?"}, | ||
{"role": "assistant", "content": "4"}, | ||
] | ||
expected_args = comp_args(role, "", messages=expected_messages) | ||
mock.assert_called_with(**expected_args) | ||
assert mock.call_count == 2 | ||
|
||
result = runner.invoke(app, ["--list-chats"]) | ||
assert result.exit_code == 0 | ||
assert "_test" in result.stdout | ||
|
||
result = runner.invoke(app, ["--show-chat", chat_name]) | ||
assert result.exit_code == 0 | ||
assert "my number is 2" in result.stdout | ||
assert "ok" in result.stdout | ||
assert "my number + 2?" in result.stdout | ||
assert "4" in result.stdout | ||
|
||
args["--shell"] = True | ||
result = runner.invoke(app, make_args(**args)) | ||
assert result.exit_code == 2 | ||
assert "Error" in result.stdout | ||
|
||
args["--code"] = True | ||
result = runner.invoke(app, make_args(**args)) | ||
assert result.exit_code == 2 | ||
assert "Error" in result.stdout | ||
chat_path.unlink() | ||
|
||
|
||
@patch("openai.resources.chat.Completions.create") | ||
def test_default_repl(mock): | ||
mock.side_effect = [comp_chunks("ok"), comp_chunks("8")] | ||
chat_name = "_test" | ||
chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name | ||
chat_path.unlink(missing_ok=True) | ||
|
||
args = {"--repl": chat_name} | ||
inputs = ["my number is 6", "my number + 2?", "exit()"] | ||
result = runner.invoke(app, make_args(**args), input="\n".join(inputs)) | ||
|
||
expected_messages = [ | ||
{"role": "system", "content": role.role}, | ||
{"role": "user", "content": "my number is 6"}, | ||
{"role": "assistant", "content": "ok"}, | ||
{"role": "user", "content": "my number + 2?"}, | ||
{"role": "assistant", "content": "8"}, | ||
] | ||
expected_args = comp_args(role, "", messages=expected_messages) | ||
mock.assert_called_with(**expected_args) | ||
assert mock.call_count == 2 | ||
|
||
assert result.exit_code == 0 | ||
assert ">>> my number is 6" in result.stdout | ||
assert "ok" in result.stdout | ||
assert ">>> my number + 2?" in result.stdout | ||
assert "8" in result.stdout | ||
|
||
|
||
@parametrize("completion", ["Berlin"], indirect=True) | ||
@patch("openai.resources.chat.Completions.create") | ||
def test_llm_options(mock, completion): | ||
mock.return_value = completion | ||
|
||
args = { | ||
"prompt": "capital of the Germany?", | ||
"--model": "gpt-4-test", | ||
"--temperature": 0.5, | ||
"--top-p": 0.5, | ||
"--no-functions": True, | ||
} | ||
result = runner.invoke(app, make_args(**args)) | ||
|
||
expected_args = comp_args( | ||
role=role, | ||
prompt=args["prompt"], | ||
model=args["--model"], | ||
temperature=args["--temperature"], | ||
top_p=args["--top-p"], | ||
functions=None, | ||
) | ||
mock.assert_called_once_with(**expected_args) | ||
assert result.exit_code == 0 | ||
assert "Berlin" in result.stdout | ||
|
||
|
||
@patch("openai.resources.chat.Completions.create") | ||
def test_version(mock): | ||
args = {"--version": True} | ||
result = runner.invoke(app, make_args(**args)) | ||
|
||
mock.assert_not_called() | ||
assert __version__ in result.stdout |
Oops, something went wrong.