Skip to content

Conversation

@vkuzo
Copy link
Collaborator

@vkuzo vkuzo commented Apr 28, 2023

This is an example of implementing basic fp8 support with a Python tensor subclass.

tl;dr;

  1. FP8Tensor is the Python object which contains raw fp8 data (as torch.bits8), a scale, and a flavor (e4m3/e5m2)
  2. FP8Tensor.__torch__dispatch knows how to add gradients, but converts to fp32 for everything else
  3. FP8Linear is a module which can do stateful delayed scaling. User is expected to manually swap their linears to something like this.

Note: E4M3 support has not been numerically validated, and E5M2 support is not there at all
Note: No testing other than the bare bones at the bottom of the PR has been done.
Note: scaling is not implemented, currently it's just scales of 1.0 everywhere

This is an example of implementing basic fp8 support with a Python tensor subclass.

tl;dr;
1. FP8Tensor is the Python object which contains raw fp8 data (as torch.bits8), a scale, and a flavor (e4m3/e5m2)
2. FP8Tensor.__torch__dispatch knows how to add gradients, but converts to fp32 for everything else
3. FP8Linear is a module which can do stateful delayed scaling. User is expected to manually swap their linears to something like this.

Note: E4M3 support has not been numerically validated, and E5M2 support is not there at all
Note: No testing other than the bare bones at the bottom of the PR has been done.
Note: scaling is not implemented, currently it's just scales of 1.0 everywhere
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants