Skip to content

Commit 62ce47c

Browse files
changes
1 parent 0b460b3 commit 62ce47c

File tree

5 files changed

+88
-37
lines changed

5 files changed

+88
-37
lines changed

optimum/graphcore/fx/utils.py

+39-20
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
2626
MODEL_FOR_CTC_MAPPING_NAMES,
2727
)
28-
from transformers.utils.fx import HFTracer, get_concrete_args
28+
from transformers.utils.fx import HFAttribute, HFProxy, HFTracer, get_concrete_args
2929

3030
from ..modeling_utils import PipelineMixin
3131

@@ -34,11 +34,30 @@
3434
from transformers import PreTrainedModel
3535

3636

37+
# TODO: keep this until transformers >= 4.23.2
38+
class GCProxy(HFProxy):
39+
40+
@property
41+
def dtype(self):
42+
return self.__getattr__("dtype")
43+
44+
def __getattr__(self, k):
45+
if k == "_metadata":
46+
return self.__getattribute__(k)
47+
# note: not added to the graph yet, if this is a method call
48+
# we peephole optimize to the method invocation
49+
hf_attribute = HFAttribute(self, k)
50+
if hasattr(self, "_metadata"):
51+
hf_attribute.install_metadata(getattr(self._metadata, k))
52+
return hf_attribute
53+
54+
3755
class PipelinedTracer(HFTracer):
3856
# TODO: keep this until transformers >= 4.23.2
3957
_TORCH_METHODS_TO_PATCH = list(HFTracer._TORCH_METHODS_TO_PATCH)
4058
_TORCH_METHODS_TO_PATCH.append("clamp")
4159
_TORCH_METHODS_TO_PATCH.append("rand")
60+
_TORCH_METHODS_TO_PATCH.append("finfo")
4261
"""
4362
Tracer that enables tracing and transforming models to run them on IPUs.
4463
Compared to the HFTracer, this one adds the following features:
@@ -79,8 +98,9 @@ def proxy(self, node):
7998
# it is easier to use this one, and equivalent.
8099
node.parent_module_qualified_name = self.current_module_qualified_name[-1]
81100
node.parent_module_type = self.current_module_type[-1]
82-
proxy = super().proxy(node)
83-
return proxy
101+
return GCProxy(node, self)
102+
# return gc_proxy
103+
return super().proxy(node)
84104

85105
def call_module(self, m, forward, args, kwargs):
86106
# Could be done in a "cleaner" fashion by inlining the content of Tracer.call_module.
@@ -98,22 +118,22 @@ def call_module(self, m, forward, args, kwargs):
98118
return proxy
99119

100120
def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None):
101-
if self.root_is_in_half_precision:
102-
float32_dtype_in_args = any(a is torch.float32 for a in args)
103-
float32_dtype_in_kwargs = kwargs.get("dtype", None) is torch.float32
104-
node_types_to_inspect = [
105-
("call_method", "to"),
106-
("call_function", torch.full),
107-
]
108-
torch_methods_to_patched_version = {
109-
orig: wrapped for (orig, wrapped) in self.patched_torch_methods.values()
110-
}
111-
for (k, t) in node_types_to_inspect:
112-
if kind == k and target == torch_methods_to_patched_version.get(t, t):
113-
if float32_dtype_in_args:
114-
args = tuple(a if a is not torch.float32 else torch.float16 for a in args)
115-
if float32_dtype_in_kwargs:
116-
kwargs["dtype"] = torch.float16
121+
# if self.root_is_in_half_precision:
122+
# float32_dtype_in_args = any(a is torch.float32 for a in args)
123+
# float32_dtype_in_kwargs = kwargs.get("dtype", None) is torch.float32
124+
# node_types_to_inspect = [
125+
# ("call_method", "to"),
126+
# ("call_function", torch.full),
127+
# ]
128+
# torch_methods_to_patched_version = {
129+
# orig: wrapped for (orig, wrapped) in self.patched_torch_methods.values()
130+
# }
131+
# for (k, t) in node_types_to_inspect:
132+
# if kind == k and target == torch_methods_to_patched_version.get(t, t):
133+
# if float32_dtype_in_args:
134+
# args = tuple(a if a is not torch.float32 else torch.float16 for a in args)
135+
# if float32_dtype_in_kwargs:
136+
# kwargs["dtype"] = torch.float16
117137
return super().create_proxy(
118138
kind, target, args, kwargs, name=name, type_expr=type_expr, proxy_factory_fn=proxy_factory_fn
119139
)
@@ -149,7 +169,6 @@ def symbolic_trace_with_pipelined_tracer(
149169
model: PipelineMixin,
150170
input_names: Optional[List[str]] = None,
151171
) -> torch.fx.GraphModule:
152-
153172
"""
154173
Performs symbolic tracing on the model.
155174

optimum/graphcore/models/deberta/modeling_deberta.py

+42-10
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
)
3838
from transformers.utils.fx import _gen_constructor_wrapper
3939

40-
from ....fx.optimization import MergeLinears, compose
40+
from ....fx.optimization import MergeLinears, ReversibleTransformation, compose
4141
from ....utils import logging
4242
from ...fx import (
4343
DEFAULT_TRANSFORMATION_MANAGER,
@@ -46,6 +46,8 @@
4646
OutlineAttribute,
4747
RecomputationCheckpoint,
4848
VocabEmbeddingToSerializedEmbedding,
49+
LinearToSerializedLinear,
50+
TieWeights,
4951
symbolic_trace_pipelined_model,
5052
)
5153
from ...modeling_utils import OnehotGather, PipelineMixin, get_layer_ipu, register
@@ -107,7 +109,32 @@ def _get_rel_embedding(self):
107109
return self.rel_embeddings.weight + 0.0 if self.relative_attention else None
108110

109111

110-
gather_last_dim = FastGatherLastDim()
112+
def faster_gather_last_dim(input, dim, index, *args, **kwargs):
113+
target = torch.zeros_like(index).to(input.dtype)
114+
target.requires_grad_()
115+
o = poptorch.custom_op(
116+
[input, index],
117+
"FastGatherLastDim",
118+
"poptorch.custom_ops",
119+
1,
120+
example_outputs=[target],
121+
attributes={"axis": -1},
122+
)
123+
return o[0]
124+
125+
126+
class ChangeTorchGather(ReversibleTransformation):
127+
def transform(self, graph_module):
128+
for node in graph_module.graph.nodes:
129+
if node.op == "call_function" and node.target is torch.gather:
130+
node.target = faster_gather_last_dim
131+
return graph_module
132+
133+
def reverse(self, graph_module):
134+
for node in graph_module.graph.nodes:
135+
if node.op == "call_function" and node.target is faster_gather_last_dim:
136+
node.target = torch.gather
137+
return graph_module
111138

112139

113140
class IPUDisentangledSelfAttention(DisentangledSelfAttention):
@@ -124,8 +151,6 @@ class IPUDisentangledSelfAttention(DisentangledSelfAttention):
124151
def __init__(self, config):
125152
super().__init__(config)
126153
self.xsoftmax = XSoftmax(-1)
127-
# self.gather_last_dim = FastGatherLastDim()
128-
self.gather_last_dim = gather_last_dim
129154

130155
def forward(
131156
self,
@@ -248,7 +273,8 @@ def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embedd
248273
index = c2p_pos.expand(
249274
[query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]
250275
)
251-
c2p_att = self.gather_last_dim(c2p_att, index)
276+
# c2p_att = gather_last_dim(c2p_att, index)
277+
c2p_att = torch.gather(c2p_att, -1, index)
252278
score += c2p_att
253279

254280
# position->content
@@ -263,12 +289,12 @@ def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embedd
263289
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
264290
index = p2c_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
265291
p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2))
266-
p2c_att = self.gather_last_dim(p2c_att, index).transpose(-1, -2)
292+
p2c_att = torch.gather(p2c_att, -1, index).transpose(-1, -2)
267293

268294
if query_layer.size(-2) != key_layer.size(-2):
269295
pos_index = relative_pos[:, :, :, 0].unsqueeze(-1)
270296
index = pos_index.expand(pos_index, p2c_att, key_layer)
271-
p2c_att = self.gather_last_dim(p2c_att, index)
297+
p2c_att = torch.gather(p2c_att, -1, index)
272298
score += p2c_att
273299

274300
return score
@@ -283,7 +309,6 @@ def change_modules_for_ipu(self, restore: bool):
283309
del mod.xsoftmax
284310
else:
285311
mod.add_module("xsoftmax", XSoftmax(-1))
286-
mod.add_module("gather_last_dim", FastGatherLastDim())
287312
if restore:
288313
if isinstance(mod, nn.Dropout):
289314
mod.__class__ = StableDropout
@@ -302,10 +327,10 @@ def change_modules_for_ipu(self, restore: bool):
302327
def get_transformations(self):
303328
log_insertions = self.ipu_config.log_insertions
304329
layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu)
330+
# TODO: handle DebertaForMaskedLM
305331
transformations = [
306332
AddPoptorchBlock("Embedding", 0, "deberta.embeddings", log_insertions=log_insertions),
307333
OutlineAttribute("deberta.embeddings.LayerNorm", "Embedding"),
308-
AddPoptorchBlock("Before Encoder", 0, "deberta.encoder", log_insertions=log_insertions),
309334
AddPoptorchBlocksInSeries(
310335
"Encoder", layer_ipu, r"deberta.encoder.layer.[0-9]+", log_insertions=log_insertions
311336
),
@@ -322,7 +347,13 @@ def get_transformations(self):
322347
)
323348
)
324349
if self.ipu_config.embedding_serialization_factor > 1:
325-
transformations.append(VocabEmbeddingToSerializedEmbedding())
350+
if isinstance(self, DebertaForMaskedLM):
351+
transformations += [
352+
LinearToSerializedLinear("cls.predictions.decoder"),
353+
TieWeights("deberta.embeddings.word_embeddings", "cls.predictions.decoder"),
354+
]
355+
else:
356+
transformations.append(VocabEmbeddingToSerializedEmbedding())
326357
return transformations
327358

328359
def parallelize(self):
@@ -339,6 +370,7 @@ def parallelize(self):
339370
torch.nn.functional.one_hot = orig
340371
transformations = self.get_transformations()
341372
transformations += TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level)
373+
transformations.append(ChangeTorchGather())
342374
composition = compose(*transformations)
343375
non_reversible_composition = TRANSFORMATION_MANAGER.compose_non_reversible_transformations(
344376
self.ipu_config.optimization_level

optimum/graphcore/models/gpt2/modeling_gpt2.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def deparallelize(self):
131131

132132

133133
@register(GPT2LMHeadModel)
134-
class PipelinedGPT2LMHeadModel(GPT2LMHeadModel, GPT2PipelineMixin):
134+
class PipelinedGPT2LMHeadModel(GPT2PipelineMixin, GPT2LMHeadModel):
135135
def get_transformations(self):
136136
log_insertions = self.ipu_config.log_insertions
137137
layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu)
@@ -241,7 +241,7 @@ def forward(
241241

242242

243243
@register(GPT2ForSequenceClassification)
244-
class PipelinedGPT2ForSequenceClassification(GPT2ForSequenceClassification, GPT2PipelineMixin):
244+
class PipelinedGPT2ForSequenceClassification(GPT2PipelineMixin, GPT2ForSequenceClassification):
245245
def forward(
246246
self,
247247
input_ids: Optional[torch.LongTensor] = None,
@@ -290,5 +290,5 @@ def forward(
290290

291291

292292
@register(GPT2ForTokenClassification)
293-
class PipelinedGPT2ForTokenClassification(GPT2ForTokenClassification, GPT2PipelineMixin):
293+
class PipelinedGPT2ForTokenClassification(GPT2PipelineMixin, GPT2ForTokenClassification):
294294
pass

optimum/graphcore/trainer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,8 @@ def __init__(
282282
if args.ipu_config_overrides:
283283
logger.info(f"Overriding IPU config: {args.ipu_config_overrides}")
284284
self.ipu_config.update_from_string(args.ipu_config_overrides)
285+
if self.args.gradient_accumulation_steps is None:
286+
self.args.gradient_accumulation_steps = self.ipu_config.gradient_accumulation_steps
285287
self.ipu_config.seed = self.args.seed
286288
self.opts = self.ipu_config.to_options(compile_only=args.compile_only)
287289
self.eval_opts = self.ipu_config.to_options(for_inference=True, compile_only=args.compile_only)
@@ -1116,7 +1118,7 @@ def _inner_training_loop(
11161118
logger.info(f" Num Epochs = {num_train_epochs}")
11171119
logger.info(f" Instantaneous batch size per device = {batch_size}")
11181120
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
1119-
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1121+
logger.info(f" Gradient Accumulation steps = {self.ipu_config.gradient_accumulation_steps}")
11201122
logger.info(f" Total optimization steps = {max_steps}")
11211123

11221124
self.state.epoch = 0
@@ -1208,7 +1210,7 @@ def _inner_training_loop(
12081210
steps_in_epoch = (
12091211
len(epoch_iterator)
12101212
if has_length(train_dataloader)
1211-
else args.max_steps * args.gradient_accumulation_steps
1213+
else args.max_steps * self.ipu_config.gradient_accumulation_steps
12121214
)
12131215

12141216
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)

optimum/graphcore/training_args.py

-2
Original file line numberDiff line numberDiff line change
@@ -750,8 +750,6 @@ def __post_init__(self):
750750
override_str = []
751751
if self.gradient_accumulation_steps is not None:
752752
override_str.append(f"gradient_accumulation_steps={self.gradient_accumulation_steps}")
753-
else:
754-
self.gradient_accumulation_steps = 1
755753

756754
if self.auto_loss_scaling:
757755
override_str.append(f"auto_loss_scaling={self.auto_loss_scaling}")

0 commit comments

Comments
 (0)