@@ -279,10 +279,19 @@ def get_inputs(fixture_name, batch_size=None):
279
279
@pytest .mark .parametrize ("model_fixture" , ALL_MODEL_FIXTURES )
280
280
def test_generate_text (request , model_fixture , sampler_name ):
281
281
model = request .getfixturevalue (model_fixture )
282
- generator = generate .text (model , getattr (samplers , sampler_name )())
283
282
with enforce_not_implemented (model_fixture , sampler_name ):
284
- res = generator (** get_inputs (model_fixture ), max_tokens = 10 )
285
- assert isinstance (res , str )
283
+ if sampler_name == "beam_search" :
284
+ num_head = 2
285
+ generator = generate .text (model , getattr (samplers , sampler_name )(num_head ))
286
+ res = generator (** get_inputs (model_fixture ), max_tokens = 10 )
287
+ assert isinstance (res , list )
288
+ assert len (res ) == num_head
289
+ for elt in res :
290
+ assert isinstance (elt , str )
291
+ else :
292
+ generator = generate .text (model , getattr (samplers , sampler_name )())
293
+ res = generator (** get_inputs (model_fixture ), max_tokens = 10 )
294
+ assert isinstance (res , str )
286
295
287
296
288
297
@pytest .mark .parametrize ("pattern" , REGEX_PATTERNS )
0 commit comments