Skip to content

Conversation

@anadeem2
Copy link
Contributor

@anadeem2 anadeem2 commented Jul 7, 2022

Description

Matmul lowering for Bert like Transformer models native trace. Matmul OP is currently only around 60% complete (sufficient). It does not natively support conversion of higher dim matrix (4D+), so we fallback to CPU to ensure 100% coverage. The problem is we cannot use view as it modifies same memory location and values have not materialized in aten_raf_type. Potential fix are:

  1. implement matrix folding (similar to pytorch)
  2. lower & utilize einsum
  3. In raf_node_lowering try reshape + transpositions to convert higher dim matrix.

Checklist

  • PR's title starts with a category (e.g. [BUGFIX], [MODEL], [TUTORIAL], [FEATURE], [DOC], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

cc @awslabs/raf-reviewer

@anadeem2 anadeem2 requested a review from zachzzc July 7, 2022 18:34
verify_step(Model(), [x])


@pytest.mark.parametrize("shape", [(3, 3, 3)])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 dimension case should work right? Can we add a test here?
And also can you add a fp16 dtype test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test case permutes over the shapes. So basically it will do (1x1, 1x2, 1x3, 2x1,2x2...3x3). I tried running fp16, but bmm does not support it so it fails.

@zachzzc
Copy link
Contributor

zachzzc commented Jul 7, 2022

I checked XLA implementation https://github.com/pytorch/xla/blob/cc19c3abcbb3f702d5f468ee08549edd926ef549/torch_xla/csrc/xla_lower_util.cpp#L386
We can revise it to support dim >= 4 in the future referring to this

@anadeem2
Copy link
Contributor Author

I did reference aten::matmul, not sure if we can implement the matrix folding like them, but here it is for reference. https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/LinearAlgebra.cpp

@comaniac
Copy link
Contributor

Should this be PR closed due to #38?

@anadeem2
Copy link
Contributor Author

anadeem2 commented Aug 6, 2022

We can close this PR now. This was actually the implementation I used to get the full graph first. I had to add autograd piece which is in my debug_branch. However, we now have a better implementation so its no longer needed.

@anadeem2 anadeem2 closed this Aug 6, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants