Skip to content

Conversation

@anadeem2
Copy link
Contributor

@anadeem2 anadeem2 commented Jul 29, 2022

Description

Matrix multiplication Ratex support. Support all combinations of matrix (1d/2d/3d/4d/5d+/etc) as well as permutations (eg. 4x2, 3x1, etc). Batch matmul is not support for half precision (fp16).

Checklist

  • PR's title starts with a category (e.g. [BUGFIX], [MODEL], [TUTORIAL], [FEATURE], [DOC], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

cc @awslabs/raf-reviewer @zachzzc

@anadeem2 anadeem2 requested a review from zachzzc July 29, 2022 22:41
Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM.
Also cc @zachzzc @zhouyuan1119

Comment on lines 1020 to 1021
y = BindSymbol(
raf::ir::Call(Op::Get("raf.op.transpose"), {y, MakeConstant(TupleInt(trans_axes))}));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have all 4 variants (raf.op.batch_matmul_nn, raf.op.batch_matmul_nt, etc) so you should be able to get rid of transpose in all cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will might need to create separate ops for each case. Originally I had a type variable in MatMul node which we can pass nn/nt/tt and concat in node_lowering. But the conversions (reshaping/transposes/etc) do not hold for all.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get the point. What's the issue of creating separate ops for each case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I can create separate ops for each case, but I will make separate PRs for those.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if I miss anything. I'd prefer to correct it in this PR, as the current implementation introduces unnecessary overhead but doesn't improve the readability IMHO.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, the problem is that in PyTorch, matmul accepts all these shapes combination and does the reshape and transpose underneath. Awais is trying to mimic what it does in PyTorch and XLA.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. We did the same thing in Relay in the past so this is totally correct.
I'm just wondering if there's any urgent task that needs this PR. If so then it's fine to merge first and fix later; otherwise, I'd prefer to polish it before merging.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your point, and in fact, originally we had removed removed the trans_axes and used batch_matmul. But while running test cases for bert like/transformer style models we got errors which were resolved by this. Despite this, as we discussed yesterday, even though Pytorch does not have matmul_nt/tn/tt, we should be able to take advantage of RAF capability via Ratex (this is especially useful in auto diff). So I added a custom matmul_xx OP for these cases. The op is only prototype for the meantime (sufficient for bert/transformers) only acting as a bridge, but it will require its own algorithm (reshapes/transforms) to handle all possible cases. I will make an issue for this.

Comment on lines 1054 to 1055
y = BindSymbol(
raf::ir::Call(Op::Get("raf.op.transpose"), {y, MakeConstant(TupleInt(trans_axes))}));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto. Can we get rid of transpose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this is a part of the reshape and how we lower >3 dim down to 3dim. We swap the last 2 dims then do a reshape. This way, even after doing the batch_matmul, when we bring back the finals shapes the results are correct.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try to understand what you meant. Suppose we have (m, k) x (b, k, n), so you will do (b, k, n) -> (b, n, k) -> (-1, k) = (b * n, k), so that you could apply dense((m, k), (b * n, k)).

If the above understanding is correct, you can get rid of the transpose by the following:

  1. Expand a to be (1, m, k).
  2. batch_matmul(a, b).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you understanding is correct, and this solution will in fact work. However, as we discussed yesterday (to clue in Zach), the reshape proceeding this will not hold. Instead we will need to add a squeeze or a special reshape for each case whereas the current implementation is more general and the reshapes work correctly for most condition.


std::string MatMul::ToString() const {
std::stringstream ss;
ss << Node::ToString() << " a_shape= " << a_shape_ << " b_shape " << b_shape_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ss << Node::ToString() << " a_shape= " << a_shape_ << " b_shape " << b_shape_;
ss << Node::ToString() << " a_shape= " << a_shape_ << " b_shape=" << b_shape_;

Comment on lines 112 to 118
for i in range(1, len(shape) + 1):
x_s = shape[::i]
x = torch.randn(x_s)
for j in range(1, len(shape) + 1):
y_s = shape[::j]
y = torch.randn(y_s)
verify_step(Model(), [x, y], jit_script=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, you are trying to cover all cases in different n-dims in this loop. However, this is hard to understand and debug in the future. It would be better to explicitly assign shapes in pytest.mark.parameterize.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I can adjust, but fyi the pytest will have a very long list of shapes. Unless you are thinking of a different approach or there is some macro, we will need to write like ("x_shape" [(3,), (3,), (3,), (3,), (3,3)...]) ("y_shape", [(3,), (3,3)...]) in other words 16 shapes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see the problem with the long list of shapes?
Talking about the total number of tests that make slow down the CI, your current implementation actually has the same number of workloads.
Talking about the format, black should auto-format it to multiple lines.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay sounds good, I just wanted to confirm if there were any macros or more elegant to write the test case.

@comaniac comaniac mentioned this pull request Jul 29, 2022
4 tasks
- sigmoid_backward
- tanh_backward
- ger
- matmul
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean we don't need matmul_backward? I remembered last time you said there was an issue if we don't add it to AutoGrad ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are correct, if we do not add a custom autograd OP the tracing stops. I was planning on creating separate PRs, so that is why I did not include, but I just pushed the commit.

Comment on lines 1020 to 1021
y = BindSymbol(
raf::ir::Call(Op::Get("raf.op.transpose"), {y, MakeConstant(TupleInt(trans_axes))}));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, the problem is that in PyTorch, matmul accepts all these shapes combination and does the reshape and transpose underneath. Awais is trying to mimic what it does in PyTorch and XLA.

@anadeem2
Copy link
Contributor Author

anadeem2 commented Aug 2, 2022

Also for the backward, I did try doing a transpose and calling regular matmul (i.e matmul_nt(a,b) = matmul(a, b.T)). But for some reason even that does not pass the test case. Actually permute was being converted to view, and I even added a custom transpose_axes, but still no luck. The B gradient is correct, but not A. I tried nn/nt/tn/tt etc, no luck.
https://github.com/awslabs/ratex/compare/main...anadeem2:ratex:matmul_dx_test?expand=1

@anadeem2 anadeem2 requested review from comaniac and zachzzc August 2, 2022 23:48
@anadeem2 anadeem2 force-pushed the matmul_full_lowering branch from 5e54de1 to facc61c Compare August 3, 2022 22:42
@anadeem2
Copy link
Contributor Author

anadeem2 commented Aug 3, 2022

Wow what a crazy op. Turns out the reason why everything I was trying was failing due to PyTorch also doing reshapes/transposes/conversion for higher/lower/different dim matrix multiplication. So although my implementations were correct, PyTorch basically undo/inverse, thus making implementation incorrect. Anyways, I removed custom matmul implementations, custom autograd implementation, and custom matmul_xx and instead let PyTorch handle conversion meanwhile I just lower correctly to RAF. Now all test cases are passing. I put the custom matmul_xx code in separate branch in case we need it for future reference. https://github.com/anadeem2/ratex/tree/matmul_xx_ref

@comaniac
Copy link
Contributor

comaniac commented Aug 5, 2022

Is this up-to-date? Please fix the CI error if so.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants