37
37
)
38
38
from transformers .utils .fx import _gen_constructor_wrapper
39
39
40
- from ....fx .optimization import MergeLinears , compose
40
+ from ....fx .optimization import MergeLinears , ReversibleTransformation , compose
41
41
from ....utils import logging
42
42
from ...fx import (
43
43
DEFAULT_TRANSFORMATION_MANAGER ,
46
46
OutlineAttribute ,
47
47
RecomputationCheckpoint ,
48
48
VocabEmbeddingToSerializedEmbedding ,
49
+ LinearToSerializedLinear ,
50
+ TieWeights ,
49
51
symbolic_trace_pipelined_model ,
50
52
)
51
53
from ...modeling_utils import OnehotGather , PipelineMixin , get_layer_ipu , register
@@ -107,7 +109,32 @@ def _get_rel_embedding(self):
107
109
return self .rel_embeddings .weight + 0.0 if self .relative_attention else None
108
110
109
111
110
- gather_last_dim = FastGatherLastDim ()
112
+ def faster_gather_last_dim (input , dim , index , * args , ** kwargs ):
113
+ target = torch .zeros_like (index ).to (input .dtype )
114
+ target .requires_grad_ ()
115
+ o = poptorch .custom_op (
116
+ [input , index ],
117
+ "FastGatherLastDim" ,
118
+ "poptorch.custom_ops" ,
119
+ 1 ,
120
+ example_outputs = [target ],
121
+ attributes = {"axis" : - 1 },
122
+ )
123
+ return o [0 ]
124
+
125
+
126
+ class ChangeTorchGather (ReversibleTransformation ):
127
+ def transform (self , graph_module ):
128
+ for node in graph_module .graph .nodes :
129
+ if node .op == "call_function" and node .target is torch .gather :
130
+ node .target = faster_gather_last_dim
131
+ return graph_module
132
+
133
+ def reverse (self , graph_module ):
134
+ for node in graph_module .graph .nodes :
135
+ if node .op == "call_function" and node .target is faster_gather_last_dim :
136
+ node .target = torch .gather
137
+ return graph_module
111
138
112
139
113
140
class IPUDisentangledSelfAttention (DisentangledSelfAttention ):
@@ -124,8 +151,6 @@ class IPUDisentangledSelfAttention(DisentangledSelfAttention):
124
151
def __init__ (self , config ):
125
152
super ().__init__ (config )
126
153
self .xsoftmax = XSoftmax (- 1 )
127
- # self.gather_last_dim = FastGatherLastDim()
128
- self .gather_last_dim = gather_last_dim
129
154
130
155
def forward (
131
156
self ,
@@ -248,7 +273,8 @@ def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embedd
248
273
index = c2p_pos .expand (
249
274
[query_layer .size (0 ), query_layer .size (1 ), query_layer .size (2 ), relative_pos .size (- 1 )]
250
275
)
251
- c2p_att = self .gather_last_dim (c2p_att , index )
276
+ # c2p_att = gather_last_dim(c2p_att, index)
277
+ c2p_att = torch .gather (c2p_att , - 1 , index )
252
278
score += c2p_att
253
279
254
280
# position->content
@@ -263,12 +289,12 @@ def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embedd
263
289
p2c_pos = torch .clamp (- r_pos + att_span , 0 , att_span * 2 - 1 )
264
290
index = p2c_pos .expand ([query_layer .size (0 ), query_layer .size (1 ), key_layer .size (- 2 ), key_layer .size (- 2 )])
265
291
p2c_att = torch .matmul (key_layer , pos_query_layer .transpose (- 1 , - 2 ))
266
- p2c_att = self . gather_last_dim (p2c_att , index ).transpose (- 1 , - 2 )
292
+ p2c_att = torch . gather (p2c_att , - 1 , index ).transpose (- 1 , - 2 )
267
293
268
294
if query_layer .size (- 2 ) != key_layer .size (- 2 ):
269
295
pos_index = relative_pos [:, :, :, 0 ].unsqueeze (- 1 )
270
296
index = pos_index .expand (pos_index , p2c_att , key_layer )
271
- p2c_att = self . gather_last_dim (p2c_att , index )
297
+ p2c_att = torch . gather (p2c_att , - 1 , index )
272
298
score += p2c_att
273
299
274
300
return score
@@ -283,7 +309,6 @@ def change_modules_for_ipu(self, restore: bool):
283
309
del mod .xsoftmax
284
310
else :
285
311
mod .add_module ("xsoftmax" , XSoftmax (- 1 ))
286
- mod .add_module ("gather_last_dim" , FastGatherLastDim ())
287
312
if restore :
288
313
if isinstance (mod , nn .Dropout ):
289
314
mod .__class__ = StableDropout
@@ -302,10 +327,10 @@ def change_modules_for_ipu(self, restore: bool):
302
327
def get_transformations (self ):
303
328
log_insertions = self .ipu_config .log_insertions
304
329
layer_ipu = get_layer_ipu (self .ipu_config .layers_per_ipu )
330
+ # TODO: handle DebertaForMaskedLM
305
331
transformations = [
306
332
AddPoptorchBlock ("Embedding" , 0 , "deberta.embeddings" , log_insertions = log_insertions ),
307
333
OutlineAttribute ("deberta.embeddings.LayerNorm" , "Embedding" ),
308
- AddPoptorchBlock ("Before Encoder" , 0 , "deberta.encoder" , log_insertions = log_insertions ),
309
334
AddPoptorchBlocksInSeries (
310
335
"Encoder" , layer_ipu , r"deberta.encoder.layer.[0-9]+" , log_insertions = log_insertions
311
336
),
@@ -322,7 +347,13 @@ def get_transformations(self):
322
347
)
323
348
)
324
349
if self .ipu_config .embedding_serialization_factor > 1 :
325
- transformations .append (VocabEmbeddingToSerializedEmbedding ())
350
+ if isinstance (self , DebertaForMaskedLM ):
351
+ transformations += [
352
+ LinearToSerializedLinear ("cls.predictions.decoder" ),
353
+ TieWeights ("deberta.embeddings.word_embeddings" , "cls.predictions.decoder" ),
354
+ ]
355
+ else :
356
+ transformations .append (VocabEmbeddingToSerializedEmbedding ())
326
357
return transformations
327
358
328
359
def parallelize (self ):
@@ -339,6 +370,7 @@ def parallelize(self):
339
370
torch .nn .functional .one_hot = orig
340
371
transformations = self .get_transformations ()
341
372
transformations += TRANSFORMATION_MANAGER .get_reversible_transformations (self .ipu_config .optimization_level )
373
+ transformations .append (ChangeTorchGather ())
342
374
composition = compose (* transformations )
343
375
non_reversible_composition = TRANSFORMATION_MANAGER .compose_non_reversible_transformations (
344
376
self .ipu_config .optimization_level
0 commit comments