Skip to content

Commit 56b9666

Browse files
author
Gilad Barnea
committed
Added a -r, --raw option to use the LLM as-is without a specific role or system message.
Raw mode increases flexibility by enabling responses similar to those obtained directly through the LLM's API.
1 parent 7678afe commit 56b9666

File tree

4 files changed

+185
-4
lines changed

4 files changed

+185
-4
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ Possible options for `CODE_THEME`: https://pygments.org/styles/
436436
│ --interaction --no-interaction Interactive mode for --shell option. [default: interaction] │
437437
│ --describe-shell -d Describe a shell command. │
438438
│ --code -c Generate only code. │
439+
│ --raw -r Use the LLM as-is without a specific role or system message. │
439440
│ --functions --no-functions Allow function calls. [default: functions] │
440441
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯
441442
╭─ Chat Options ───────────────────────────────────────────────────────────────────────────────────────────╮

sgpt/app.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ def main(
7575
help="Generate only code.",
7676
rich_help_panel="Assistance Options",
7777
),
78+
raw: bool = typer.Option(
79+
False,
80+
"-r",
81+
"--raw",
82+
help="Use the LLM as-is without a specific role or system message.",
83+
rich_help_panel="Assistance Options",
84+
),
7885
functions: bool = typer.Option(
7986
cfg.get("OPENAI_USE_FUNCTIONS") == "true",
8087
help="Allow function calls.",
@@ -183,9 +190,9 @@ def main(
183190
# Non-interactive shell.
184191
pass
185192

186-
if sum((shell, describe_shell, code)) > 1:
193+
if sum((shell, describe_shell, code, raw)) > 1:
187194
raise BadArgumentUsage(
188-
"Only one of --shell, --describe-shell, and --code options can be used at a time."
195+
"Only one of --shell, --describe-shell, --code and --raw options can be used at a time."
189196
)
190197

191198
if chat and repl:
@@ -198,7 +205,7 @@ def main(
198205
prompt = get_edited_prompt()
199206

200207
role_class = (
201-
DefaultRoles.check_get(shell, describe_shell, code)
208+
DefaultRoles.check_get(shell, describe_shell, code, raw)
202209
if not role
203210
else SystemRole.get(role)
204211
)

sgpt/role.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
Provide short responses in about 100 words, unless you are specifically asked for more details.
4040
If you need to store any data, assume it will be stored in the conversation.
4141
APPLY MARKDOWN formatting when possible."""
42+
43+
RAW_ROLE = """APPLY MARKDOWN formatting when possible."""
4244
# Note that output for all roles containing "APPLY MARKDOWN" will be formatted as Markdown.
4345

4446
ROLE_TEMPLATE = "You are {name}\n{role}"
@@ -68,6 +70,7 @@ def create_defaults(cls) -> None:
6870
SystemRole("Shell Command Generator", SHELL_ROLE, variables),
6971
SystemRole("Shell Command Descriptor", DESCRIBE_SHELL_ROLE, variables),
7072
SystemRole("Code Generator", CODE_ROLE),
73+
SystemRole("GPT", RAW_ROLE),
7174
):
7275
if not default_role._exists:
7376
default_role._save()
@@ -167,15 +170,20 @@ class DefaultRoles(Enum):
167170
SHELL = "Shell Command Generator"
168171
DESCRIBE_SHELL = "Shell Command Descriptor"
169172
CODE = "Code Generator"
173+
RAW = "GPT"
170174

171175
@classmethod
172-
def check_get(cls, shell: bool, describe_shell: bool, code: bool) -> SystemRole:
176+
def check_get(
177+
cls, shell: bool, describe_shell: bool, code: bool, raw: bool
178+
) -> SystemRole:
173179
if shell:
174180
return SystemRole.get(DefaultRoles.SHELL.value)
175181
if describe_shell:
176182
return SystemRole.get(DefaultRoles.DESCRIBE_SHELL.value)
177183
if code:
178184
return SystemRole.get(DefaultRoles.CODE.value)
185+
if raw:
186+
return SystemRole.get(DefaultRoles.RAW.value)
179187
return SystemRole.get(DefaultRoles.DEFAULT.value)
180188

181189
def get_role(self) -> SystemRole:

tests/test_raw.py

+165
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
from pathlib import Path
2+
from unittest.mock import patch
3+
4+
import typer
5+
from typer.testing import CliRunner
6+
7+
from sgpt import config, main
8+
from sgpt.role import DefaultRoles, SystemRole
9+
10+
from .utils import app, cmd_args, comp_args, mock_comp, runner
11+
12+
role = SystemRole.get(DefaultRoles.RAW.value)
13+
cfg = config.cfg
14+
15+
16+
@patch("sgpt.handlers.handler.completion")
17+
def test_raw_long_option(completion):
18+
completion.return_value = mock_comp("Prague")
19+
20+
args = {"prompt": "capital of the Czech Republic?", "--raw": True}
21+
result = runner.invoke(app, cmd_args(**args))
22+
23+
completion.assert_called_once_with(**comp_args(role, **args))
24+
assert result.exit_code == 0
25+
assert "Prague" in result.stdout
26+
27+
28+
@patch("sgpt.handlers.handler.completion")
29+
def test_raw_short_option(completion):
30+
completion.return_value = mock_comp("Prague")
31+
32+
args = {"prompt": "capital of the Czech Republic?", "-r": True}
33+
result = runner.invoke(app, cmd_args(**args))
34+
35+
completion.assert_called_once_with(**comp_args(role, **args))
36+
assert result.exit_code == 0
37+
assert "Prague" in result.stdout
38+
39+
40+
@patch("sgpt.handlers.handler.completion")
41+
def test_raw_stdin(completion):
42+
completion.return_value = mock_comp("Prague")
43+
44+
args = {"--raw": True}
45+
stdin = "capital of the Czech Republic?"
46+
result = runner.invoke(app, cmd_args(**args), input=stdin)
47+
48+
completion.assert_called_once_with(**comp_args(role, stdin))
49+
assert result.exit_code == 0
50+
assert "Prague" in result.stdout
51+
52+
53+
@patch("sgpt.handlers.handler.completion")
54+
def test_raw_chat(completion):
55+
completion.side_effect = [mock_comp("ok"), mock_comp("4")]
56+
chat_name = "_test"
57+
chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name
58+
chat_path.unlink(missing_ok=True)
59+
60+
args = {"prompt": "my number is 2", "--raw": True, "--chat": chat_name}
61+
result = runner.invoke(app, cmd_args(**args))
62+
assert result.exit_code == 0
63+
assert "ok" in result.stdout
64+
assert chat_path.exists()
65+
66+
args["prompt"] = "my number + 2?"
67+
result = runner.invoke(app, cmd_args(**args))
68+
assert result.exit_code == 0
69+
assert "4" in result.stdout
70+
71+
expected_messages = [
72+
{"role": "system", "content": role.role},
73+
{"role": "user", "content": "my number is 2"},
74+
{"role": "assistant", "content": "ok"},
75+
{"role": "user", "content": "my number + 2?"},
76+
{"role": "assistant", "content": "4"},
77+
]
78+
expected_args = comp_args(role, "", messages=expected_messages)
79+
completion.assert_called_with(**expected_args)
80+
assert completion.call_count == 2
81+
82+
result = runner.invoke(app, ["--list-chats"])
83+
assert result.exit_code == 0
84+
assert "_test" in result.stdout
85+
86+
result = runner.invoke(app, ["--show-chat", chat_name])
87+
assert result.exit_code == 0
88+
assert "my number is 2" in result.stdout
89+
assert "ok" in result.stdout
90+
assert "my number + 2?" in result.stdout
91+
assert "4" in result.stdout
92+
93+
args["--shell"] = True
94+
result = runner.invoke(app, cmd_args(**args))
95+
assert result.exit_code == 2
96+
assert "Error" in result.stdout
97+
98+
args["--code"] = True
99+
result = runner.invoke(app, cmd_args(**args))
100+
assert result.exit_code == 2
101+
assert "Error" in result.stdout
102+
chat_path.unlink()
103+
104+
105+
@patch("sgpt.handlers.handler.completion")
106+
def test_raw_repl(completion):
107+
completion.side_effect = [mock_comp("ok"), mock_comp("8")]
108+
chat_name = "_test"
109+
chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name
110+
chat_path.unlink(missing_ok=True)
111+
112+
args = {"--raw": True, "--repl": chat_name}
113+
inputs = ["__sgpt__eof__", "my number is 6", "my number + 2?", "exit()"]
114+
result = runner.invoke(app, cmd_args(**args), input="\n".join(inputs))
115+
116+
expected_messages = [
117+
{"role": "system", "content": role.role},
118+
{"role": "user", "content": "my number is 6"},
119+
{"role": "assistant", "content": "ok"},
120+
{"role": "user", "content": "my number + 2?"},
121+
{"role": "assistant", "content": "8"},
122+
]
123+
expected_args = comp_args(role, "", messages=expected_messages)
124+
completion.assert_called_with(**expected_args)
125+
assert completion.call_count == 2
126+
127+
assert result.exit_code == 0
128+
assert ">>> my number is 6" in result.stdout
129+
assert "ok" in result.stdout
130+
assert ">>> my number + 2?" in result.stdout
131+
assert "8" in result.stdout
132+
133+
134+
@patch("sgpt.handlers.handler.completion")
135+
def test_raw_repl_stdin(completion):
136+
completion.side_effect = [mock_comp("ok init"), mock_comp("ok another")]
137+
chat_name = "_test"
138+
chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name
139+
chat_path.unlink(missing_ok=True)
140+
141+
my_runner = CliRunner()
142+
my_app = typer.Typer()
143+
my_app.command()(main)
144+
145+
args = {"--raw": True, "--repl": chat_name}
146+
inputs = ["this is stdin", "__sgpt__eof__", "prompt", "another", "exit()"]
147+
result = my_runner.invoke(my_app, cmd_args(**args), input="\n".join(inputs))
148+
149+
expected_messages = [
150+
{"role": "system", "content": role.role},
151+
{"role": "user", "content": "this is stdin\n\n\n\nprompt"},
152+
{"role": "assistant", "content": "ok init"},
153+
{"role": "user", "content": "another"},
154+
{"role": "assistant", "content": "ok another"},
155+
]
156+
expected_args = comp_args(role, "", messages=expected_messages)
157+
completion.assert_called_with(**expected_args)
158+
assert completion.call_count == 2
159+
160+
assert result.exit_code == 0
161+
assert "this is stdin" in result.stdout
162+
assert ">>> prompt" in result.stdout
163+
assert "ok init" in result.stdout
164+
assert ">>> another" in result.stdout
165+
assert "ok another" in result.stdout

0 commit comments

Comments
 (0)