-
Notifications
You must be signed in to change notification settings - Fork 590
/
Copy pathtest_openai.py
125 lines (97 loc) · 4.73 KB
/
test_openai.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import importlib
import json
from contextlib import contextmanager
from unittest import mock
from unittest.mock import MagicMock
import pytest
from openai import AsyncOpenAI
from outlines import generate
from outlines.models.openai import OpenAI, OpenAIConfig
def module_patch(path):
"""Patch functions that have the same name as the module in which they're implemented."""
target = path
components = target.split(".")
for i in range(len(components), 0, -1):
try:
# attempt to import the module
imported = importlib.import_module(".".join(components[:i]))
# module was imported, let's use it in the patch
patch = mock.patch(path)
patch.getter = lambda: imported
patch.attribute = ".".join(components[i:])
return patch
except Exception:
continue
# did not find a module, just return the default mock
return mock.patch(path)
def test_openai_call():
with module_patch("outlines.models.openai.generate_chat") as mocked_generate_chat:
mocked_generate_chat.return_value = ["foo"], 1, 2
async_client = MagicMock(spec=AsyncOpenAI, api_key="key")
model = OpenAI(
async_client,
OpenAIConfig(max_tokens=10, temperature=0.5, n=2, stop=["."]),
)
assert model("bar")[0] == "foo"
assert model.prompt_tokens == 1
assert model.completion_tokens == 2
mocked_generate_chat_args = mocked_generate_chat.call_args
mocked_generate_chat_arg_config = mocked_generate_chat_args[0][3]
assert isinstance(mocked_generate_chat_arg_config, OpenAIConfig)
assert mocked_generate_chat_arg_config.max_tokens == 10
assert mocked_generate_chat_arg_config.temperature == 0.5
assert mocked_generate_chat_arg_config.n == 2
assert mocked_generate_chat_arg_config.stop == ["."]
model("bar", samples=3)
mocked_generate_chat_args = mocked_generate_chat.call_args
mocked_generate_chat_arg_config = mocked_generate_chat_args[0][3]
assert mocked_generate_chat_arg_config.n == 3
@contextmanager
def patched_openai(completion, **oai_config):
"""Create a patched openai whose chat completions always returns `completion`"""
with module_patch("outlines.models.openai.generate_chat") as mocked_generate_chat:
mocked_generate_chat.return_value = completion, 1, 2
async_client = MagicMock(spec=AsyncOpenAI, api_key="key")
model = OpenAI(
async_client,
OpenAIConfig(max_tokens=10, temperature=0.5, n=2, stop=["."]),
)
yield model
def test_openai_choice_call():
with patched_openai(completion='{"result": "foo"}') as model:
generator = generate.choice(model, ["foo", "bar"])
assert generator("hi") == "foo"
def test_openai_choice_call_invalid_server_response():
with patched_openai(completion="not actual json") as model:
generator = generate.choice(model, ["foo", "bar"])
with pytest.raises(json.decoder.JSONDecodeError):
generator("hi")
def test_openai_json_call_pydantic():
from pydantic import BaseModel, ConfigDict, ValidationError
class Person(BaseModel):
model_config = ConfigDict(extra="forbid") # required for openai
first_name: str
last_name: str
age: int
completion = '{"first_name": "Usain", "last_name": "Bolt", "age": 38}'
# assert success for valid response
with patched_openai(completion=completion) as model:
generator = generate.json(model, Person)
assert generator("fastest person") == Person.model_validate_json(completion)
# assert fail for non-json response
with patched_openai(completion="usain bolt") as model:
generator = generate.json(model, Person)
with pytest.raises(ValidationError):
assert generator("fastest person")
def test_openai_json_call_str():
person_schema = '{"additionalProperties": false, "properties": {"first_name": {"title": "First Name", "type": "string"}, "last_name": {"title": "Last Name", "type": "string"}, "age": {"title": "Age", "type": "integer"}}, "required": ["first_name", "last_name", "age"], "title": "Person", "type": "object"}'
output = {"first_name": "Usain", "last_name": "Bolt", "age": 38}
# assert success for valid response
with patched_openai(completion=json.dumps(output)) as model:
generator = generate.json(model, person_schema)
assert generator("fastest person") == output
# assert fail for non-json response
with patched_openai(completion="usain bolt") as model:
generator = generate.json(model, person_schema)
with pytest.raises(json.decoder.JSONDecodeError):
assert generator("fastest person")