-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathdx2_rnn.patch
More file actions
41 lines (34 loc) · 1.42 KB
/
dx2_rnn.patch
File metadata and controls
41 lines (34 loc) · 1.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
diff --git a/torch/nn/_functions/rnn.py b/torch/nn/_functions/rnn.py
index 477c551..c034e3b 100644
--- a/torch/nn/_functions/rnn.py
+++ b/torch/nn/_functions/rnn.py
@@ -11,9 +11,6 @@ except ImportError:
pass
-force_unfused = False
-
-
def RNNReLUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
hy = F.relu(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh, b_hh))
return hy
@@ -25,7 +22,7 @@ def RNNTanhCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
- if input.is_cuda and not force_unfused:
+ if input.is_cuda:
igates = F.linear(input, w_ih)
hgates = F.linear(hidden[0], w_hh)
state = fusedBackend.LSTMFused.apply
@@ -49,7 +46,7 @@ def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
def GRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
- if input.is_cuda and not force_unfused:
+ if input.is_cuda:
gi = F.linear(input, w_ih)
gh = F.linear(hidden, w_hh)
state = fusedBackend.GRUFused.apply
@@ -373,7 +370,7 @@ def hack_onnx_rnn(fargs, output, args, kwargs):
def RNN(*args, **kwargs):
def forward(input, *fargs, **fkwargs):
- if not force_unfused and cudnn.is_acceptable(input.data):
+ if cudnn.is_acceptable(input.data):
func = CudnnRNN(*args, **kwargs)
else:
func = AutogradRNN(*args, **kwargs)