@@ -376,22 +376,38 @@ class VocabEmbeddingToSerializedEmbedding(ReversibleTransformation):
376
376
"""
377
377
378
378
def __init__ (self , name_regex : Optional [str ] = None ):
379
- self .name_regex = re .compile (name_regex ) if name_regex else None
379
+ self .name_regex_for_module = re .compile (name_regex ) if name_regex else None
380
+ self .name_regex_for_function = re .compile (name_regex .replace ("." , "_" )) if name_regex else None
380
381
381
382
def transform (self , graph_module : "GraphModule" ) -> "GraphModule" :
382
383
embedding_nodes = []
383
384
for node in graph_module .graph .nodes :
384
- if node .op != "call_module" :
385
- continue
386
- match = re .match (self .name_regex , node .target ) if self .name_regex is not None else True
387
- if match and isinstance (graph_module .get_submodule (node .target ), torch .nn .Embedding ):
385
+ if node .op == "call_module" :
386
+ if self .name_regex_for_module is not None and not re .match (self .name_regex_for_module , node .target ):
387
+ continue
388
+ elif not isinstance (graph_module .get_submodule (node .target ), torch .nn .Embedding ):
389
+ continue
390
+ embedding_nodes .append (node )
391
+ elif node .op == "call_function" :
392
+ if self .name_regex_for_function is not None and not re .match (self .name_regex_for_function , node .name ):
393
+ continue
394
+ elif node .target is not torch .nn .functional .embedding :
395
+ continue
388
396
embedding_nodes .append (node )
389
397
390
398
# We assume the vocab embedding to be the embedding with the maximum number of embeddings.
391
399
if not embedding_nodes :
392
400
raise RuntimeError ("Could not find any embedding node" )
393
401
394
- embedding_node = max (embedding_nodes , key = lambda node : graph_module .get_submodule (node .target ).num_embeddings )
402
+ def sort_nodes_function (node ):
403
+ if node .op == "call_module" :
404
+ return graph_module .get_submodule (node .target ).num_embeddings
405
+ return node .args [1 ].shape [1 ]
406
+
407
+ embedding_node = max (embedding_nodes , key = sort_nodes_function )
408
+ if embedding_node .op == "call_function" :
409
+ raise NotImplementedError ("VocabEmbeddingToSerializedEmbedding does not support torch.nn.functional.embedding yet." )
410
+
395
411
split = embedding_node .target .rsplit ("." , maxsplit = 1 )
396
412
if len (split ) == 1 :
397
413
split = [None ] + split
@@ -504,6 +520,12 @@ def transform(self, graph_module: "GraphModule") -> "GraphModule":
504
520
505
521
506
522
class ShareEmbeddingComputation (Transformation ):
523
+ def __init__ (self , name_regex : Optional [str ] = None , allowed_embedding_classes : Union [Tuple [Type ], Type ] = (torch .nn .Embedding , SerializedEmbedding )):
524
+ self .name_regex = re .compile (name_regex ) if name_regex else None
525
+ self .allowed_embedding_classes = allowed_embedding_classes
526
+ if not isinstance (self .allowed_embedding_classes , tuple ):
527
+ self .allowed_embedding_classes = (self .allowed_embedding_classes ,)
528
+
507
529
def _find_nodes_to_move (self , graph_module , embedding_input_node , shared_embedding_node ):
508
530
nodes_before_embedding_input_node = set ()
509
531
for node in graph_module .graph .nodes :
@@ -535,7 +557,11 @@ def transform(self, graph_module: "GraphModule") -> "GraphModule":
535
557
candidates = collections .defaultdict (list )
536
558
embedding_nodes = collections .defaultdict (list )
537
559
for node in graph_module .graph .nodes :
538
- if node .op == "call_module" and isinstance (graph_module .get_submodule (node .target ), torch .nn .Embedding ):
560
+ if node .op == "call_module" :
561
+ if self .name_regex is not None and not re .match (self .name_regex , node .target ):
562
+ continue
563
+ elif not isinstance (graph_module .get_submodule (node .target ), self .allowed_embedding_classes ):
564
+ continue
539
565
candidates [node .target ].append (node .args [0 ])
540
566
embedding_nodes [node .target ].append (node )
541
567
0 commit comments