Skip to content

Commit 599fed0

Browse files
authored
Merge pull request suno-ai#364 from no2chem/flash_attention
fix: perf: use sdpa if dropout > 0 on fine model
2 parents 7280e4e + b56c8df commit 599fed0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

bark/model_fine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ def __init__(self, config):
2626
self.n_head = config.n_head
2727
self.n_embd = config.n_embd
2828
self.dropout = config.dropout
29-
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
29+
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
3030
self.flash = (
31-
hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
31+
hasattr(torch.nn.functional, "scaled_dot_product_attention")
3232
)
3333

3434
def forward(self, x):

0 commit comments

Comments
 (0)