We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 7280e4e + b56c8df commit 599fed0Copy full SHA for 599fed0
bark/model_fine.py
@@ -26,9 +26,9 @@ def __init__(self, config):
26
self.n_head = config.n_head
27
self.n_embd = config.n_embd
28
self.dropout = config.dropout
29
- # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
30
self.flash = (
31
- hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
+ hasattr(torch.nn.functional, "scaled_dot_product_attention")
32
)
33
34
def forward(self, x):
0 commit comments