Skip to content

Commit 5d3142d

Browse files
committed
fix(test): correctly handle beam_search in generate text
1 parent 69ec787 commit 5d3142d

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

tests/generate/test_generate.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,19 @@ def get_inputs(fixture_name, batch_size=None):
279279
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
280280
def test_generate_text(request, model_fixture, sampler_name):
281281
model = request.getfixturevalue(model_fixture)
282-
generator = generate.text(model, getattr(samplers, sampler_name)())
283282
with enforce_not_implemented(model_fixture, sampler_name):
284-
res = generator(**get_inputs(model_fixture), max_tokens=10)
285-
assert isinstance(res, str)
283+
if sampler_name == "beam_search":
284+
num_head = 2
285+
generator = generate.text(model, getattr(samplers, sampler_name)(num_head))
286+
res = generator(**get_inputs(model_fixture), max_tokens=10)
287+
assert isinstance(res, list)
288+
assert len(res) == num_head
289+
for elt in res:
290+
assert isinstance(elt, str)
291+
else:
292+
generator = generate.text(model, getattr(samplers, sampler_name)())
293+
res = generator(**get_inputs(model_fixture), max_tokens=10)
294+
assert isinstance(res, str)
286295

287296

288297
@pytest.mark.parametrize("pattern", REGEX_PATTERNS)

0 commit comments

Comments
 (0)