Skip to content

Commit fb73f5c

Browse files
authored
Merge branch 'main' into fix/make-convert-token-to-string-pickleable
2 parents 2f9aeee + e9485cf commit fb73f5c

18 files changed

+504
-19
lines changed

README.md

+27
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,33 @@ print(add(**result))
300300

301301
A great advantage of passing functions directly to specify the structure is that the structure of the LLM will change with the function's definition. No need to change the code at several places!
302302

303+
You can also embed various functions into an enum to generate params:
304+
305+
```python
306+
from enum import Enum
307+
from functools import partial
308+
309+
import outlines
310+
311+
312+
def add(a: int, b: int) -> int:
313+
return a + b
314+
315+
def mul(c: float, d: float) -> float:
316+
return c * d
317+
318+
class Operation(Enum):
319+
add = partial(add)
320+
mul = partial(mul)
321+
322+
model = outlines.models.transformers("WizardLM/WizardMath-7B-V1.1")
323+
generator = outlines.generate.json(model, add)
324+
result = generator("Return json with two float named c and d respectively. c is negative and d greater than 1.0.")
325+
326+
print(result)
327+
# {'c': -3.14, 'd': 1.5}
328+
```
329+
303330
## Prompting
304331

305332
Building prompts can get messy. **Outlines** makes it easier to write and manage

benchmarks/bench_processors.py

+23
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
except ImportError:
1010
pass
1111

12+
try:
13+
import jax
14+
import jax.numpy as jnp
15+
except ImportError:
16+
pass
17+
1218

1319
def is_mlx_lm_allowed():
1420
try:
@@ -18,6 +24,14 @@ def is_mlx_lm_allowed():
1824
return mx.metal.is_available()
1925

2026

27+
def is_jax_allowed():
28+
try:
29+
import jax # noqa: F401
30+
except ImportError:
31+
return False
32+
return True
33+
34+
2135
def get_mock_processor_inputs(array_library, num_tokens=30000):
2236
"""
2337
logits: (4, 30,000 ) dtype=float
@@ -43,6 +57,13 @@ def get_mock_processor_inputs(array_library, num_tokens=30000):
4357
input_ids = mx.random.randint(
4458
low=0, high=num_tokens, shape=(4, 2048), dtype=mx.int32
4559
)
60+
elif array_library == "jax":
61+
logits = jnp.random.uniform(
62+
key=jax.random.PRNGKey(0), shape=(4, num_tokens), dtype=jnp.float32
63+
)
64+
input_ids = jnp.random.randint(
65+
key=jax.random.PRNGKey(0), low=0, high=num_tokens, shape=(4, 2048)
66+
)
4667
else:
4768
raise ValueError
4869

@@ -67,6 +88,8 @@ class LogitsProcessorPassthroughBenchmark:
6788
params += ["mlx"]
6889
if torch.cuda.is_available():
6990
params += ["torch_cuda"]
91+
if is_jax_allowed():
92+
params += ["jax"]
7093

7194
def setup(self, array_library):
7295
self.logits_processor = HalvingLogitsProcessor()
+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
This recipe demonstrates how to use the `outlines` library to extract structured event details from a text message.
2+
We will extract the title, location, and start date and time from messages like the following:
3+
4+
```plaintext
5+
Hello Kitty, my grandmother will be here, I think it's better to postpone
6+
our appointment to review math lessons to next Monday at 2pm at the same
7+
place, 3 avenue des tanneurs, one hour will be enough see you 😘
8+
```
9+
10+
Let see how to extract the event details from the message with the MLX
11+
library dedicated to Apple Silicon processor (M series).
12+
13+
```python
14+
--8<-- "docs/cookbook/extract_event_details.py"
15+
```
16+
17+
The output will be:
18+
19+
```plaintext
20+
Today: Saturday 16 November 2024 and it's 10:55
21+
```
22+
23+
and the extracted event information will be:
24+
25+
```json
26+
{
27+
"title":"Math Review",
28+
"location":"3 avenue des tanneurs",
29+
"start":"2024-11-22T14:00:00Z"
30+
}
31+
```
32+
33+
34+
To find out more about this use case, we recommend the project developped by [Joseph Rudoler](https://x.com/JRudoler) the [ICS Generator](https://github.com/jrudoler/ics-generator)
+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from datetime import datetime
2+
3+
from pydantic import BaseModel, Field
4+
5+
from outlines import generate, models
6+
7+
# Load the model
8+
model = models.mlxlm("mlx-community/Hermes-3-Llama-3.1-8B-8bit")
9+
10+
11+
# Define the event schema using Pydantic
12+
class Event(BaseModel):
13+
title: str = Field(description="title of the event")
14+
location: str
15+
start: datetime = Field(
16+
default=None, description="date of the event if available in iso format"
17+
)
18+
19+
20+
# Get the current date and time
21+
now = datetime.now().strftime("%A %d %B %Y and it's %H:%M")
22+
23+
# Define the prompt
24+
prompt = f"""
25+
Today's date and time are {now}
26+
Given a user message, extract information of the event like date and time in iso format, location and title.
27+
If the given date is relative, think step by step to find the right date.
28+
Here is the message:
29+
"""
30+
31+
# Sample message
32+
message = """Hello Kitty, my grandmother will be here , I think it's better to postpone our
33+
appointment to review math lessons to next Friday at 2pm at the same place, 3 avenue des tanneurs, I think that one hour will be enough
34+
see you 😘 """
35+
36+
# Create the generator
37+
generator = generate.json(model, Event)
38+
39+
# Extract the event information
40+
event = generator(prompt + message)
41+
42+
# Print the current date and time
43+
print(f"Today: {now}")
44+
45+
# Print the extracted event information in JSON format
46+
print(event.json())

examples/beam-cloud/README.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
## Deploy Outlines on Beam
2+
3+
1. Create an account [here](https://beam.cloud) and install the Beam SDK
4+
2. Download the `app.py` file to your computer
5+
3. Deploy it as a serverless API by running: `beam deploy app.py:predict`

examples/beam-cloud/app.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from beam import Image, endpoint, env
2+
3+
if env.is_remote():
4+
import outlines
5+
6+
7+
# Pre-load models when the container first starts
8+
def load_models():
9+
import outlines
10+
11+
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
12+
return model
13+
14+
15+
@endpoint(
16+
name="outlines-serverless",
17+
gpu="A10G",
18+
cpu=1,
19+
memory="16Gi",
20+
on_start=load_models,
21+
image=Image().add_python_packages(
22+
["outlines", "torch", "transformers", "accelerate"]
23+
),
24+
)
25+
def predict(context, **inputs):
26+
default_prompt = """You are a sentiment-labelling assistant.
27+
Is the following review positive or negative?
28+
29+
Review: This restaurant is just awesome!
30+
"""
31+
32+
prompt = inputs.get("prompt", default_prompt)
33+
34+
# Unpack cached model from context
35+
model = context.on_start_value
36+
# Inference
37+
generator = outlines.generate.choice(model, ["Positive", "Negative"])
38+
answer = generator(prompt)
39+
return {"answer": answer}

mkdocs.yml

+2
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ markdown_extensions:
7474
- pymdownx.emoji:
7575
emoji_index: !!python/name:material.extensions.emoji.twemoji
7676
emoji_generator: !!python/name:material.extensions.emoji.to_svg
77+
- pymdownx.snippets:
7778

7879

7980
extra_css:
@@ -123,6 +124,7 @@ nav:
123124
- Structured Generation from PDFs: cookbook/read-pdfs.md
124125
- Earnings reports to CSV: cookbook/earnings-reports.md
125126
- Digitizing receipts with vision models: cookbook/receipt-digitization.md
127+
- Extract events details from text: cookbook/extract_event_details.md
126128
- Run on the cloud:
127129
- BentoML: cookbook/deploy-using-bentoml.md
128130
- Cerebrium: cookbook/deploy-using-cerebrium.md

outlines/base.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,23 @@
55
from typing import Callable, Optional
66

77
import numpy as np
8-
from numpy.lib.function_base import (
9-
_calculate_shapes,
10-
_parse_gufunc_signature,
11-
_parse_input_dimensions,
12-
_update_dim_sizes,
13-
)
8+
9+
# Import required functions based on NumPy version
10+
np_major_version = int(np.__version__.split(".")[0])
11+
if np_major_version >= 2:
12+
from numpy.lib._function_base_impl import (
13+
_calculate_shapes,
14+
_parse_gufunc_signature,
15+
_parse_input_dimensions,
16+
_update_dim_sizes,
17+
)
18+
else:
19+
from numpy.lib.function_base import (
20+
_calculate_shapes,
21+
_parse_gufunc_signature,
22+
_parse_input_dimensions,
23+
_update_dim_sizes,
24+
)
1425

1526
# Allow nested loops for running in notebook. We don't enable it globally as it
1627
# may interfere with other libraries that use asyncio.

outlines/fsm/json_schema.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import re
44
import warnings
5+
from enum import Enum
56
from typing import Callable, Optional, Tuple, Type, Union
67

78
from jsonschema.protocols import Validator
@@ -306,6 +307,8 @@ def to_regex(
306307
for choice in instance["enum"]:
307308
if type(choice) in [int, float, bool, type(None), str]:
308309
choices.append(re.escape(json.dumps(choice)))
310+
elif isinstance(choice, dict):
311+
choices.append(to_regex(resolver, choice, whitespace_pattern))
309312
else:
310313
raise TypeError(f"Unsupported data type in enum: {type(choice)}")
311314
return f"({'|'.join(choices)})"
@@ -524,7 +527,7 @@ def to_regex(
524527
)
525528

526529

527-
def get_schema_from_signature(fn: Callable) -> str:
530+
def get_schema_from_signature(fn: Callable) -> dict:
528531
"""Turn a function signature into a JSON schema.
529532
530533
Every JSON object valid to the output JSON Schema can be passed
@@ -550,3 +553,16 @@ def get_schema_from_signature(fn: Callable) -> str:
550553
model = create_model(fn_name, **arguments)
551554

552555
return model.model_json_schema()
556+
557+
558+
def get_schema_from_enum(myenum: type[Enum]) -> dict:
559+
if len(myenum) == 0:
560+
raise ValueError(
561+
f"Your enum class {myenum.__name__} has 0 members. If you are working with an enum of functions, do not forget to register them as callable (using `partial` for instance)"
562+
)
563+
choices = [
564+
get_schema_from_signature(elt.value.func) if callable(elt.value) else elt.value
565+
for elt in myenum
566+
]
567+
schema = {"title": myenum.__name__, "enum": choices}
568+
return schema

outlines/generate/json.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import json as pyjson
2+
from enum import Enum
23
from functools import singledispatch
34
from typing import Callable, Optional, Union
45

56
from pydantic import BaseModel
67

7-
from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_signature
8+
from outlines.fsm.json_schema import (
9+
build_regex_from_schema,
10+
get_schema_from_enum,
11+
get_schema_from_signature,
12+
)
813
from outlines.generate.api import SequenceGeneratorAdapter
914
from outlines.models import OpenAI
1015
from outlines.samplers import Sampler, multinomial
@@ -48,6 +53,11 @@ def json(
4853
regex_str = build_regex_from_schema(schema, whitespace_pattern)
4954
generator = regex(model, regex_str, sampler)
5055
generator.format_sequence = lambda x: schema_object.parse_raw(x)
56+
elif isinstance(schema_object, type(Enum)):
57+
schema = pyjson.dumps(get_schema_from_enum(schema_object))
58+
regex_str = build_regex_from_schema(schema, whitespace_pattern)
59+
generator = regex(model, regex_str, sampler)
60+
generator.format_sequence = lambda x: pyjson.loads(x)
5161
elif callable(schema_object):
5262
schema = pyjson.dumps(get_schema_from_signature(schema_object))
5363
regex_str = build_regex_from_schema(schema, whitespace_pattern)

outlines/models/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
from .transformers_vision import TransformersVision, transformers_vision
1717
from .vllm import VLLM, vllm
1818

19-
LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, MLXLM, VLLM]
19+
LogitsGenerator = Union[Transformers, LlamaCpp, OpenAI, ExLlamaV2Model, MLXLM, VLLM]

outlines/models/mlxlm.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -167,12 +167,7 @@ def sample(logits: "mx.array") -> Tuple["mx.array", float]:
167167
prob = softmax_logits[0, token]
168168
return token, prob
169169

170-
kv_heads = (
171-
[self.model.n_kv_heads] * len(self.model.layers)
172-
if isinstance(self.model.n_kv_heads, int)
173-
else self.model.n_kv_heads
174-
)
175-
cache = [mlx_lm.models.base.KVCache(self.model.head_dim, n) for n in kv_heads]
170+
cache = mlx_lm.models.cache.make_prompt_cache(self.model)
176171

177172
# kv cache contains processed input IDs, we pass the unprocessed inputs and cache to model()
178173
unprocessed_input_ids = prompt

outlines/processors/base_logits_processor.py

+21
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,16 @@ def is_mlx_array_type(array_type):
2020
return issubclass(array_type, mx.array)
2121

2222

23+
def is_jax_array_type(array_type):
24+
try:
25+
import jaxlib
26+
except ImportError:
27+
return False
28+
return issubclass(array_type, jaxlib.xla_extension.ArrayImpl) or isinstance(
29+
array_type, jaxlib.xla_extension.ArrayImpl
30+
)
31+
32+
2333
class OutlinesLogitsProcessor(Protocol):
2434
"""
2535
Base class for logits processors which normalizes types of logits:
@@ -101,6 +111,12 @@ def _to_torch(tensor_like: Array) -> torch.Tensor:
101111
# https://ml-explore.github.io/mlx/build/html/usage/numpy.html
102112
return torch.from_dlpack(tensor_like)
103113

114+
elif is_jax_array_type(type(tensor_like)):
115+
import jax
116+
117+
torch_tensor = torch.from_dlpack(jax.dlpack.to_dlpack(tensor_like))
118+
return torch_tensor
119+
104120
else:
105121
raise TypeError(
106122
"LogitsProcessor must be called with either np.NDArray, "
@@ -129,6 +145,11 @@ def _from_torch(tensor: torch.Tensor, target_type: Type) -> Array:
129145
# numpy doesn't support bfloat16, mlx doesn't support direct conversion from torch
130146
return mx.array(tensor.float().numpy())
131147

148+
elif is_jax_array_type(target_type):
149+
import jax
150+
151+
return jax.dlpack.from_dlpack(tensor)
152+
132153
else:
133154
raise TypeError(
134155
f"Failed to convert torch tensors to target_type `{target_type}`"

0 commit comments

Comments
 (0)