Skip to content

Commit 143e1ce

Browse files
committed
Making DAC decoder torch.compileable
The original torch.nn.utils.weight_norm() function is deprecated and it prevents the DAC decoder model from being torch.compiled. The newer torch.nn.utils.parametrizations.weight_norm() works the same way as the old one and plus it is compatible with torch.compile (with CUDA graphs).
1 parent c7cfc5d commit 143e1ce

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

dac/nn/layers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.nn as nn
44
import torch.nn.functional as F
55
from einops import rearrange
6-
from torch.nn.utils import weight_norm
6+
from torch.nn.utils.parametrizations import weight_norm
77

88

99
def WNConv1d(*args, **kwargs):

0 commit comments

Comments
 (0)