37
37
symbolic_trace_pipelined_model ,
38
38
)
39
39
from ...modeling_utils import PipelineMixin , get_layer_ipu , register
40
+ from .optimized_gpt2_attn import OptimizedGPT2Attention
40
41
41
42
42
43
logger = logging .get_logger (__name__ )
@@ -69,7 +70,7 @@ def get_transformations(self):
69
70
layer_ipu = get_layer_ipu (self .ipu_config .layers_per_ipu )
70
71
transformations = [
71
72
AddPoptorchBlock ("Token Embedding" , 0 , "transformer.wte" , log_insertions = log_insertions ),
72
- AddPoptorchBlock ("Position Embedding" , 1 , "transformer.wtp " , log_insertions = log_insertions ),
73
+ AddPoptorchBlock ("Position Embedding" , 0 , "transformer.wpe " , log_insertions = log_insertions ),
73
74
OutlineAttribute ("transformer.ln_f" , "LayerNorm" ),
74
75
AddPoptorchBlocksInSeries ("Layer" , layer_ipu , r"transformer.h.[0-9]+" , log_insertions = log_insertions ),
75
76
# Only one of the following AddPoptorchBlock, will actually add a block.
@@ -84,7 +85,7 @@ def get_transformations(self):
84
85
)
85
86
)
86
87
if self .ipu_config .embedding_serialization_factor > 1 :
87
- transformations .append (VocabEmbeddingToSerializedEmbedding ())
88
+ transformations .append (VocabEmbeddingToSerializedEmbedding ("transformer.wte" ))
88
89
89
90
return transformations
90
91
@@ -96,6 +97,9 @@ def parallelize(self):
96
97
- Adds recomputation checkpoints
97
98
"""
98
99
PipelineMixin .parallelize (self )
100
+ if not isinstance (self , torch .fx .GraphModule ):
101
+ for layer in self .transformer .h :
102
+ layer .attn .__class__ = OptimizedGPT2Attention
99
103
if self .ipu_config .embedding_serialization_factor > 1 :
100
104
self .resize_vocab (False )
101
105
traced = symbolic_trace_pipelined_model (self )
@@ -137,7 +141,7 @@ def get_transformations(self):
137
141
layer_ipu = get_layer_ipu (self .ipu_config .layers_per_ipu )
138
142
transformations = [
139
143
AddPoptorchBlock ("Token Embedding" , 0 , "transformer.wte" , log_insertions = log_insertions ),
140
- AddPoptorchBlock ("Position Embedding" , 1 , "transformer.wtp " , log_insertions = log_insertions ),
144
+ AddPoptorchBlock ("Position Embedding" , 0 , "transformer.wpe " , log_insertions = log_insertions ),
141
145
OutlineAttribute ("transformer.ln_f" , "LayerNorm" ),
142
146
AddPoptorchBlocksInSeries ("Layer" , layer_ipu , r"transformer.h.[0-9]+" , log_insertions = log_insertions ),
143
147
AddPoptorchBlock ("LM Head" , 0 , "lm_head" , log_insertions = log_insertions ),
0 commit comments