Skip to content

Commit a582a46

Browse files
GPT-2 Fix
1 parent c94d930 commit a582a46

File tree

5 files changed

+22
-10
lines changed

5 files changed

+22
-10
lines changed

optimum/graphcore/fx/transformation_manager.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import copy
1818
import functools
19+
import operator
1920
from typing import Iterator, List, Tuple, Union
2021

2122
import torch
@@ -123,6 +124,6 @@ def compose_reversible_transformations(self, optimization_level: int) -> Reversi
123124
(1, MergeLinears()),
124125
# (1, FuseBiasInLinear()),
125126
# Those change the computation, but are actually needed for fp16 stability.
126-
(0, ClipValuesSymmetric(1e4, exclude_targets=("view",))),
127+
(0, ClipValuesSymmetric(1e4, include_targets=(torch.add, torch.mul, operator.add, operator.mul))),
127128
(0, ClipValues(1e-4, float("inf"), include_targets=(torch.nn.LayerNorm,))),
128129
)

optimum/graphcore/fx/transformations.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,9 @@ def __init__(
204204
):
205205
if clip_value < 0:
206206
raise ValueError(f"The provided clip value must be equal or greater than 0, but here {clip_value}.")
207-
return super().__init__(-clip_value, clip_value, exclude_targets=exclude_targets)
207+
return super().__init__(
208+
-clip_value, clip_value, include_targets=include_targets, exclude_targets=exclude_targets
209+
)
208210

209211

210212
class OutlineAttribute(ReversibleTransformation):
@@ -406,7 +408,9 @@ def sort_nodes_function(node):
406408

407409
embedding_node = max(embedding_nodes, key=sort_nodes_function)
408410
if embedding_node.op == "call_function":
409-
raise NotImplementedError("VocabEmbeddingToSerializedEmbedding does not support torch.nn.functional.embedding yet.")
411+
raise NotImplementedError(
412+
"VocabEmbeddingToSerializedEmbedding does not support torch.nn.functional.embedding yet."
413+
)
410414

411415
split = embedding_node.target.rsplit(".", maxsplit=1)
412416
if len(split) == 1:
@@ -520,7 +524,11 @@ def transform(self, graph_module: "GraphModule") -> "GraphModule":
520524

521525

522526
class ShareEmbeddingComputation(Transformation):
523-
def __init__(self, name_regex: Optional[str] = None, allowed_embedding_classes: Union[Tuple[Type], Type] = (torch.nn.Embedding, SerializedEmbedding)):
527+
def __init__(
528+
self,
529+
name_regex: Optional[str] = None,
530+
allowed_embedding_classes: Union[Tuple[Type], Type] = (torch.nn.Embedding, SerializedEmbedding),
531+
):
524532
self.name_regex = re.compile(name_regex) if name_regex else None
525533
self.allowed_embedding_classes = allowed_embedding_classes
526534
if not isinstance(self.allowed_embedding_classes, tuple):

optimum/graphcore/fx/utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636

3737
# TODO: keep this until transformers >= 4.23.2
3838
class GCProxy(HFProxy):
39-
4039
@property
4140
def dtype(self):
4241
return self.__getattr__("dtype")

optimum/graphcore/models/deberta/modeling_deberta.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@
4343
DEFAULT_TRANSFORMATION_MANAGER,
4444
AddPoptorchBlock,
4545
AddPoptorchBlocksInSeries,
46+
LinearToSerializedLinear,
4647
OutlineAttribute,
4748
RecomputationCheckpoint,
48-
VocabEmbeddingToSerializedEmbedding,
49-
LinearToSerializedLinear,
5049
TieWeights,
50+
VocabEmbeddingToSerializedEmbedding,
5151
symbolic_trace_pipelined_model,
5252
)
5353
from ...modeling_utils import OnehotGather, PipelineMixin, get_layer_ipu, register

optimum/graphcore/models/gpt2/modeling_gpt2.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
symbolic_trace_pipelined_model,
3838
)
3939
from ...modeling_utils import PipelineMixin, get_layer_ipu, register
40+
from .optimized_gpt2_attn import OptimizedGPT2Attention
4041

4142

4243
logger = logging.get_logger(__name__)
@@ -69,7 +70,7 @@ def get_transformations(self):
6970
layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu)
7071
transformations = [
7172
AddPoptorchBlock("Token Embedding", 0, "transformer.wte", log_insertions=log_insertions),
72-
AddPoptorchBlock("Position Embedding", 1, "transformer.wtp", log_insertions=log_insertions),
73+
AddPoptorchBlock("Position Embedding", 0, "transformer.wpe", log_insertions=log_insertions),
7374
OutlineAttribute("transformer.ln_f", "LayerNorm"),
7475
AddPoptorchBlocksInSeries("Layer", layer_ipu, r"transformer.h.[0-9]+", log_insertions=log_insertions),
7576
# Only one of the following AddPoptorchBlock, will actually add a block.
@@ -84,7 +85,7 @@ def get_transformations(self):
8485
)
8586
)
8687
if self.ipu_config.embedding_serialization_factor > 1:
87-
transformations.append(VocabEmbeddingToSerializedEmbedding())
88+
transformations.append(VocabEmbeddingToSerializedEmbedding("transformer.wte"))
8889

8990
return transformations
9091

@@ -96,6 +97,9 @@ def parallelize(self):
9697
- Adds recomputation checkpoints
9798
"""
9899
PipelineMixin.parallelize(self)
100+
if not isinstance(self, torch.fx.GraphModule):
101+
for layer in self.transformer.h:
102+
layer.attn.__class__ = OptimizedGPT2Attention
99103
if self.ipu_config.embedding_serialization_factor > 1:
100104
self.resize_vocab(False)
101105
traced = symbolic_trace_pipelined_model(self)
@@ -137,7 +141,7 @@ def get_transformations(self):
137141
layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu)
138142
transformations = [
139143
AddPoptorchBlock("Token Embedding", 0, "transformer.wte", log_insertions=log_insertions),
140-
AddPoptorchBlock("Position Embedding", 1, "transformer.wtp", log_insertions=log_insertions),
144+
AddPoptorchBlock("Position Embedding", 0, "transformer.wpe", log_insertions=log_insertions),
141145
OutlineAttribute("transformer.ln_f", "LayerNorm"),
142146
AddPoptorchBlocksInSeries("Layer", layer_ipu, r"transformer.h.[0-9]+", log_insertions=log_insertions),
143147
AddPoptorchBlock("LM Head", 0, "lm_head", log_insertions=log_insertions),

0 commit comments

Comments
 (0)