fix: Gate empty result in GATConv#638
Conversation
|
How would this work if Also, we need the same patch on gatv2. Simply adding a new conditional in |
| Wxj = l.dense_x(xj) | ||
| Wxj = reshape(Wxj, chout, heads, :) | ||
| Wxi = l.dense_x(xi) | ||
| Wxi = reshape(Wxi, chout, heads, :) |
There was a problem hiding this comment.
| Wxj = l.dense_x(xj) | |
| Wxj = reshape(Wxj, chout, heads, :) | |
| Wxi = l.dense_x(xi) | |
| Wxi = reshape(Wxi, chout, heads, :) | |
| Wxj = l.dense_x(xj) | |
| Wxj = reshape(Wxj, chout, heads, :) | |
| if xi !== xj | |
| Wxi = l.dense_x(xi) | |
| Wxi = reshape(Wxi, chout, heads, :) | |
| else | |
| Wxi = Wxj | |
| end |
would work?
There was a problem hiding this comment.
This doesn't work. It seems like the above example triggers both branches and Zygote gets one branch confused for another. I have seen this kind of behaviour before with Zygote.
|
regarding the wrong gradient shape for empty array, do you have any clue why it is happening? ideally the chainrule's projection shouldn't be patched, it should just receive a dx in the correct shape. |
I have absolutely no idea. A hint there is the error only occurs for 0-sized arrays. I think it has something to do with computing gradients and is outside the scope of this package. Edit: Maybe its related to FluxML/Flux.jl#2648 where processing zero-sized array doesn't override an existing cache. I searched Flux.jl, ChainRules.jl (and core), and Zygote.jl and I could not find any suspicious and there is a specific bypass used for |
Resolves #637
Fixes 3 problems:
gat_convduring backpropagation, "DimensionMismatch: arrays could not be broadcast to a common size". This could be solved by removing the conditional used to calculateWxi, Wxj:reshape(x, :, size(x, 3))ingat_convcreates incompatible sizes whenxis empty. This PR determines the size using the first two axes ofxinstead.DimensionMismatch: variable with size(x) == (1, 1, 0) cannot have a gradient with size(dx) == (4, 1, 0). This could be fixed by a patch forChainRulesCore.jlfix: Allow arbitrary reshape in projection if array is zero sized JuliaDiff/ChainRulesCore.jl#702Test script: