Skip to content

Commit 0b460b3

Browse files
Fix BART
1 parent 29031ea commit 0b460b3

File tree

2 files changed

+35
-17
lines changed

2 files changed

+35
-17
lines changed

optimum/graphcore/fx/transformations.py

+33-7
Original file line numberDiff line numberDiff line change
@@ -376,22 +376,38 @@ class VocabEmbeddingToSerializedEmbedding(ReversibleTransformation):
376376
"""
377377

378378
def __init__(self, name_regex: Optional[str] = None):
379-
self.name_regex = re.compile(name_regex) if name_regex else None
379+
self.name_regex_for_module = re.compile(name_regex) if name_regex else None
380+
self.name_regex_for_function = re.compile(name_regex.replace(".", "_")) if name_regex else None
380381

381382
def transform(self, graph_module: "GraphModule") -> "GraphModule":
382383
embedding_nodes = []
383384
for node in graph_module.graph.nodes:
384-
if node.op != "call_module":
385-
continue
386-
match = re.match(self.name_regex, node.target) if self.name_regex is not None else True
387-
if match and isinstance(graph_module.get_submodule(node.target), torch.nn.Embedding):
385+
if node.op == "call_module":
386+
if self.name_regex_for_module is not None and not re.match(self.name_regex_for_module, node.target):
387+
continue
388+
elif not isinstance(graph_module.get_submodule(node.target), torch.nn.Embedding):
389+
continue
390+
embedding_nodes.append(node)
391+
elif node.op == "call_function":
392+
if self.name_regex_for_function is not None and not re.match(self.name_regex_for_function, node.name):
393+
continue
394+
elif node.target is not torch.nn.functional.embedding:
395+
continue
388396
embedding_nodes.append(node)
389397

390398
# We assume the vocab embedding to be the embedding with the maximum number of embeddings.
391399
if not embedding_nodes:
392400
raise RuntimeError("Could not find any embedding node")
393401

394-
embedding_node = max(embedding_nodes, key=lambda node: graph_module.get_submodule(node.target).num_embeddings)
402+
def sort_nodes_function(node):
403+
if node.op == "call_module":
404+
return graph_module.get_submodule(node.target).num_embeddings
405+
return node.args[1].shape[1]
406+
407+
embedding_node = max(embedding_nodes, key=sort_nodes_function)
408+
if embedding_node.op == "call_function":
409+
raise NotImplementedError("VocabEmbeddingToSerializedEmbedding does not support torch.nn.functional.embedding yet.")
410+
395411
split = embedding_node.target.rsplit(".", maxsplit=1)
396412
if len(split) == 1:
397413
split = [None] + split
@@ -504,6 +520,12 @@ def transform(self, graph_module: "GraphModule") -> "GraphModule":
504520

505521

506522
class ShareEmbeddingComputation(Transformation):
523+
def __init__(self, name_regex: Optional[str] = None, allowed_embedding_classes: Union[Tuple[Type], Type] = (torch.nn.Embedding, SerializedEmbedding)):
524+
self.name_regex = re.compile(name_regex) if name_regex else None
525+
self.allowed_embedding_classes = allowed_embedding_classes
526+
if not isinstance(self.allowed_embedding_classes, tuple):
527+
self.allowed_embedding_classes = (self.allowed_embedding_classes,)
528+
507529
def _find_nodes_to_move(self, graph_module, embedding_input_node, shared_embedding_node):
508530
nodes_before_embedding_input_node = set()
509531
for node in graph_module.graph.nodes:
@@ -535,7 +557,11 @@ def transform(self, graph_module: "GraphModule") -> "GraphModule":
535557
candidates = collections.defaultdict(list)
536558
embedding_nodes = collections.defaultdict(list)
537559
for node in graph_module.graph.nodes:
538-
if node.op == "call_module" and isinstance(graph_module.get_submodule(node.target), torch.nn.Embedding):
560+
if node.op == "call_module":
561+
if self.name_regex is not None and not re.match(self.name_regex, node.target):
562+
continue
563+
elif not isinstance(graph_module.get_submodule(node.target), self.allowed_embedding_classes):
564+
continue
539565
candidates[node.target].append(node.args[0])
540566
embedding_nodes[node.target].append(node)
541567

optimum/graphcore/models/bart/modeling_bart.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,6 @@ def get_transformations(self):
252252
layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu)
253253
transformations = [
254254
AddPoptorchBlock("Embedding", 0, "model.shared", log_insertions=log_insertions),
255-
# AddPoptorchBlock("Embedding", 0, "model.encoder.embed_positions"),
256-
# AddPoptorchBlock("Embedding", 0, "model.encoder.layernorm_embedding"),
257-
# AddPoptorchBlock("Embedding", 0, "model.decoder.embed_positions"),
258-
# AddPoptorchBlock("Embedding", 0, "model.decoder.layernorm_embedding"),
259255
AddPoptorchBlocksInSeries(
260256
"Encoder",
261257
layer_ipu[: self.config.encoder_layers],
@@ -351,10 +347,6 @@ def get_transformations(self):
351347
layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu)
352348
transformations = [
353349
AddPoptorchBlock("Embedding", 0, "model.shared", log_insertions=log_insertions),
354-
# AddPoptorchBlock("Embedding", 0, "model.encoder.embed_positions"),
355-
# AddPoptorchBlock("Embedding", 0, "model.encoder.layernorm_embedding"),
356-
# AddPoptorchBlock("Embedding", 0, "model.decoder.embed_positions"),
357-
# AddPoptorchBlock("Embedding", 0, "model.decoder.layernorm_embedding"),
358350
AddPoptorchBlocksInSeries(
359351
"Encoder",
360352
layer_ipu[: self.config.encoder_layers],
@@ -383,8 +375,8 @@ def get_transformations(self):
383375

384376
if not isinstance(self, torch.fx.GraphModule):
385377
if self.ipu_config.embedding_serialization_factor > 1:
386-
transformations.append(VocabEmbeddingToSerializedEmbedding())
387-
transformations += [ShareEmbeddingComputation()]
378+
transformations.append(VocabEmbeddingToSerializedEmbedding("model.shared"))
379+
transformations += [ShareEmbeddingComputation("model.shared")]
388380
return transformations
389381

390382
def forward(

0 commit comments

Comments
 (0)