Skip to content

Commit cbbd9a7

Browse files
committed
feat(audio): integrate audio transfromers
1 parent 36f1bf2 commit cbbd9a7

11 files changed

+489
-6
lines changed

outlines/generate/api.py

+91
Original file line numberDiff line numberDiff line change
@@ -621,3 +621,94 @@ def valid_types(prompts, media):
621621
)
622622

623623
return prompts, media
624+
625+
626+
class AudioSequenceGeneratorAdapter(SequenceGeneratorAdapter):
627+
def __call__( # type: ignore
628+
self,
629+
prompts: Union[str, List[str]],
630+
media: Union[str, Any],
631+
max_tokens: Optional[int] = None,
632+
stop_at: Optional[Union[str, List[str]]] = None,
633+
seed: Optional[int] = None,
634+
**model_specific_params,
635+
):
636+
"""
637+
Generate text from a prompt or list of prompts.
638+
639+
Media: A URI to construct media or media object itself. Used as AutoProcessor argument.
640+
"""
641+
prompts, media = self._validate_prompt_media_types(prompts, media)
642+
643+
generation_params = self.prepare_generation_parameters(
644+
max_tokens, stop_at, seed
645+
)
646+
647+
completions = self.model.generate(
648+
prompts,
649+
media,
650+
generation_params,
651+
copy(self.logits_processor),
652+
self.sampling_params,
653+
**model_specific_params,
654+
)
655+
656+
return self._format(completions)
657+
658+
def stream( # type: ignore
659+
self,
660+
prompts: Union[str, List[str]],
661+
media: List[Union[str, Any, List[Union[str, Any]]]],
662+
max_tokens: Optional[int] = None,
663+
stop_at: Optional[Union[str, List[str]]] = None,
664+
seed: Optional[int] = None,
665+
**model_specific_params,
666+
):
667+
"""Return a text generator from a prompt or a list of prompts."""
668+
prompts, media = self._validate_prompt_media_types(prompts, media)
669+
generation_params = self.prepare_generation_parameters(
670+
max_tokens, stop_at, seed
671+
)
672+
return self.model.stream(
673+
prompts,
674+
media,
675+
generation_params,
676+
copy(self.logits_processor),
677+
self.sampling_params,
678+
**model_specific_params,
679+
)
680+
681+
@classmethod
682+
def _validate_prompt_media_types(
683+
cls,
684+
prompts: Union[str, List[str]],
685+
media: Union[str, Any, List[Union[str, Any]]],
686+
) -> Union[Any, List[Any]]:
687+
"""
688+
Prepare media as np.ndarray and ensure for every prompt str there is one List[PIL.Image]
689+
"""
690+
691+
def valid_types(prompts, media):
692+
import numpy as np # type: ignore
693+
694+
if isinstance(prompts, list):
695+
if not isinstance(media, list) or len(prompts) != len(media):
696+
return False
697+
for subprompt, submedia in zip(prompts, media):
698+
if not isinstance(subprompt, str) or not all(
699+
isinstance(m, np.ndarray) for m in submedia
700+
):
701+
return False
702+
elif isinstance(prompts, str):
703+
if not all(isinstance(m, np.ndarray) for m in media):
704+
return False
705+
return True
706+
707+
if not valid_types(prompts, media):
708+
raise TypeError(
709+
"Expected (prompts, media) to be of type "
710+
"(str, List[np.ndarray])), or (List[str], List[List[np.ndarray]]) "
711+
f"instead got prompts={prompts}, media={media}"
712+
)
713+
714+
return prompts, media

outlines/generate/cfg.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from functools import singledispatch
22

33
from outlines.generate.api import (
4+
AudioSequenceGeneratorAdapter,
45
SequenceGeneratorAdapter,
56
VisionSequenceGeneratorAdapter,
67
)
7-
from outlines.models import LlamaCpp, OpenAI, TransformersVision
8+
from outlines.models import LlamaCpp, OpenAI, TransformersAudio, TransformersVision
89
from outlines.samplers import Sampler, multinomial
910

1011

@@ -33,6 +34,14 @@ def cfg(
3334
return SequenceGeneratorAdapter(model, logits_processor, sampler)
3435

3536

37+
@cfg.register(TransformersAudio)
38+
def cfg_audio(model, cfg_str: str, sampler: Sampler = multinomial()):
39+
from outlines.processors import CFGLogitsProcessor
40+
41+
logits_processor = CFGLogitsProcessor(cfg_str, tokenizer=model.tokenizer)
42+
return AudioSequenceGeneratorAdapter(model, logits_processor, sampler)
43+
44+
3645
@cfg.register(TransformersVision)
3746
def cfg_vision(model, cfg_str: str, sampler: Sampler = multinomial()):
3847
from outlines.processors import CFGLogitsProcessor

outlines/generate/fsm.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44

55
from outlines.fsm.guide import RegexGuide
66
from outlines.generate.api import (
7+
AudioSequenceGeneratorAdapter,
78
SequenceGeneratorAdapter,
89
VisionSequenceGeneratorAdapter,
910
)
10-
from outlines.models import TransformersVision
11+
from outlines.models import TransformersAudio, TransformersVision
1112
from outlines.samplers import Sampler, multinomial
1213

1314

@@ -22,6 +23,15 @@ def fsm(
2223
return SequenceGeneratorAdapter(model, logits_processor, sampler)
2324

2425

26+
@fsm.register(TransformersAudio)
27+
def fsm_audio(model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()):
28+
from outlines.processors import GuideLogitsProcessor
29+
30+
guide = RegexGuide.from_interegular_fsm(fsm, model.tokenizer)
31+
logits_processor = GuideLogitsProcessor(tokenizer=model.tokenizer, guide=guide)
32+
return AudioSequenceGeneratorAdapter(model, logits_processor, sampler)
33+
34+
2535
@fsm.register(TransformersVision)
2636
def fsm_vision(model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()):
2737
from outlines.processors import GuideLogitsProcessor

outlines/generate/regex.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from functools import singledispatch
22

33
from outlines.generate.api import (
4+
AudioSequenceGeneratorAdapter,
45
SequenceGeneratorAdapter,
56
VisionSequenceGeneratorAdapter,
67
)
7-
from outlines.models import OpenAI, TransformersVision
8+
from outlines.models import OpenAI, TransformersAudio, TransformersVision
89
from outlines.samplers import Sampler, multinomial
910

1011

@@ -35,6 +36,18 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()):
3536
return SequenceGeneratorAdapter(model, logits_processor, sampler)
3637

3738

39+
@regex.register(TransformersAudio)
40+
def regex_audio(
41+
model,
42+
regex_str: str,
43+
sampler: Sampler = multinomial(),
44+
):
45+
from outlines.processors import RegexLogitsProcessor
46+
47+
logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
48+
return AudioSequenceGeneratorAdapter(model, logits_processor, sampler)
49+
50+
3851
@regex.register(TransformersVision)
3952
def regex_vision(
4053
model,

outlines/generate/text.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from functools import singledispatch
22

33
from outlines.generate.api import (
4+
AudioSequenceGeneratorAdapter,
45
SequenceGeneratorAdapter,
56
VisionSequenceGeneratorAdapter,
67
)
7-
from outlines.models import OpenAI, TransformersVision
8+
from outlines.models import OpenAI, TransformersAudio, TransformersVision
89
from outlines.samplers import Sampler, multinomial
910

1011

@@ -34,6 +35,11 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGeneratorAdapter:
3435
return SequenceGeneratorAdapter(model, None, sampler)
3536

3637

38+
@text.register(TransformersAudio)
39+
def text_audio(model, sampler: Sampler = multinomial()):
40+
return AudioSequenceGeneratorAdapter(model, None, sampler)
41+
42+
3743
@text.register(TransformersVision)
3844
def text_vision(model, sampler: Sampler = multinomial()):
3945
return VisionSequenceGeneratorAdapter(model, None, sampler)

outlines/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .mlxlm import MLXLM, mlxlm
1414
from .openai import OpenAI, azure_openai, openai
1515
from .transformers import Transformers, TransformerTokenizer, mamba, transformers
16+
from .transformers_audio import TransformersAudio, transformers_audio
1617
from .transformers_vision import TransformersVision, transformers_vision
1718
from .vllm import VLLM, vllm
1819

outlines/models/transformers_audio.py

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
2+
3+
from outlines.generate.api import GenerationParameters, SamplingParameters
4+
from outlines.models import Transformers
5+
6+
if TYPE_CHECKING:
7+
from outlines.processors import OutlinesLogitsProcessor
8+
9+
10+
class TransformersAudio(Transformers):
11+
def __init__(self, model, tokenizer, processor):
12+
super().__init__(model, tokenizer)
13+
self.processor = processor
14+
15+
def generate( # type: ignore
16+
self,
17+
prompts: Union[str, List[str]],
18+
media: Union[List[Any], List[List[Any]]],
19+
generation_parameters: GenerationParameters,
20+
logits_processor: Optional["OutlinesLogitsProcessor"],
21+
sampling_parameters: SamplingParameters,
22+
) -> Union[str, List[str], List[List[str]]]:
23+
"""Generate text using `transformers`.
24+
25+
Arguments
26+
---------
27+
prompts
28+
A prompt or list of prompts.
29+
media
30+
A List[numpy.ndarray] or List[List[numpy.ndarray]]
31+
generation_parameters
32+
An instance of `GenerationParameters` that contains the prompt,
33+
the maximum number of tokens, stop sequences and seed. All the
34+
arguments to `SequenceGeneratorAdapter`'s `__cal__` method.
35+
logits_processor
36+
The logits processor to use when generating text.
37+
sampling_parameters
38+
An instance of `SamplingParameters`, a dataclass that contains
39+
the name of the sampler to use and related parameters as available
40+
in Outlines.
41+
42+
Returns
43+
-------
44+
The generated text
45+
"""
46+
inputs = self.processor(
47+
text=prompts, audios=media, padding=True, return_tensors="pt"
48+
).to(self.model.device)
49+
50+
generation_kwargs = self._get_generation_kwargs(
51+
prompts,
52+
generation_parameters,
53+
logits_processor,
54+
sampling_parameters,
55+
)
56+
generated_ids = self._generate_output_seq(prompts, inputs, **generation_kwargs)
57+
58+
# if single str input and single sample per input, convert to a 1D output
59+
if isinstance(prompts, str):
60+
# Should always be true until NotImplementedError above is fixed
61+
generated_ids = generated_ids.squeeze(0)
62+
63+
return self._decode_generation(generated_ids)
64+
65+
def stream( # type: ignore
66+
self,
67+
prompts: Union[str, List[str]],
68+
media: Union[Any, List[Any]], # TODO: docstring
69+
generation_parameters: GenerationParameters,
70+
logits_processor: Optional["OutlinesLogitsProcessor"],
71+
sampling_parameters: SamplingParameters,
72+
) -> Iterator[Union[str, List[str]]]:
73+
raise NotImplementedError
74+
75+
76+
def transformers_audio(
77+
model_name: str,
78+
model_class,
79+
device: Optional[str] = None,
80+
model_kwargs: dict = {},
81+
processor_kwargs: dict = {},
82+
tokenizer_class=None,
83+
processor_class=None,
84+
):
85+
"""Instantiate a model from the `transformers` library and its tokenizer.
86+
87+
Parameters
88+
----------
89+
model_name
90+
The name of the model as listed on Hugging Face's model page.
91+
model_class
92+
The `PreTrainedModel` class from transformers to use in initializing the vision model from `model_name`.
93+
https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel
94+
device
95+
The device(s) on which the model should be loaded. This overrides
96+
the `device_map` entry in `model_kwargs` when provided.
97+
model_kwargs
98+
A dictionary that contains the keyword arguments to pass to the
99+
`from_pretrained` method when loading the model.
100+
processor_kwargs
101+
A dictionary that contains the keyword arguments to pass to the
102+
`from_pretrained` method when loading the processor.
103+
104+
Returns
105+
-------
106+
A `TransformersModel` model instance.
107+
108+
"""
109+
if processor_class is None or tokenizer_class is None:
110+
try:
111+
from transformers import AutoProcessor, AutoTokenizer
112+
except ImportError:
113+
raise ImportError(
114+
"The `transformers` library needs to be installed in order to use `transformers` models."
115+
)
116+
if processor_class is None:
117+
processor_class = AutoProcessor
118+
119+
if device is not None:
120+
model_kwargs["device_map"] = device
121+
122+
model = model_class.from_pretrained(model_name, **model_kwargs)
123+
124+
processor_kwargs.setdefault("padding_side", "left")
125+
processor_kwargs.setdefault("pad_token", "[PAD]")
126+
processor = processor_class.from_pretrained(model_name, **processor_kwargs)
127+
128+
if tokenizer_class is None:
129+
if getattr(processor, "tokenizer", None):
130+
tokenizer = processor.tokenizer
131+
else:
132+
tokenizer = AutoTokenizer.from_pretrained(model_name, **processor_kwargs)
133+
else:
134+
tokenizer = tokenizer_class.from_pretrained(model_name, **processor_kwargs)
135+
136+
return TransformersAudio(model, tokenizer, processor)

pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ test = [
7272
"pillow",
7373
"exllamav2",
7474
"jax"
75+
"librosa",
7576
]
7677
serve = [
7778
"vllm>=0.3.0",
@@ -147,6 +148,7 @@ module = [
147148
"pycountry.*",
148149
"airportsdata.*",
149150
"outlines_core.*",
151+
"librosa",
150152
]
151153
ignore_missing_imports = true
152154

0 commit comments

Comments
 (0)