Skip to content

Conversation

@AntonOresten
Copy link
Contributor

This is intended to reduce boilerplate reshaping by making trailing singletons implicit, i.e.:

q = ct.load(Q, (1, bid_x, head_idx, batch_idx), (D_K[], TILE_M[], 1, 1))
q = reshape(q, (D_K[], TILE_M[]))
...
acc = reshape(acc, (D_V[], TILE_M[], 1, 1))
ct.store(Out, (1, bid_x, head_idx, batch_idx), acc)

can become:

q = ct.load(Q, (1, bid_x, head_idx, batch_idx), (D_K[], TILE_M[]))
...
ct.store(Out, (1, bid_x, head_idx, batch_idx), acc)

There is precedence for this in that ct.store would already remove (and in specific cases add?) singletons. As per the current README:

Store reshaping

ct.store automatically reshapes the tile to match the target array's rank by dropping singleton dimensions (e.g., storing a (1, N) tile into a 1D array reshapes it to (N,)). Scalar () tiles are reshaped to (1,).

@maleadt
Copy link
Member

maleadt commented Feb 6, 2026

I'm not sure I like this, at least not a generic shape-matching implementation. Dropping trailing singletons is probably fine, and we do this often in Julia. It's incompatible with the existing store shape matching which drops leading singletons, which is fishy. That said, for most of these ops (ct. prefixed one) I'd like to stay as close as possible to Python semantics.

@maleadt
Copy link
Member

maleadt commented Feb 6, 2026

Pushed a commit that restricts the matching to consecutive trailing singletons, which matches Julia better. Thoughts?

@AntonOresten
Copy link
Contributor Author

AntonOresten commented Feb 6, 2026

Yes, I think that's much better.

Did dropping leading singletons on store come from Python?

@maleadt
Copy link
Member

maleadt commented Feb 7, 2026

Did dropping leading singletons on store come from Python?

No, I added it after making reductions preserve singleton dimensions (as Julia does), resulting in lots of mismatches between. But by allowing dropping singleton trailing dimensions we can still get rid of the dropdims previously needed by switching layouts around.

AntonOresten and others added 6 commits February 8, 2026 07:35
Switch layernorm backward's DW/DB partial buffers from (GROUP_SIZE_M, N) to
(N, GROUP_SIZE_M) so that sum(; dims=2) produces trailing singletons that
auto-squeeze, removing the need for explicit dropdims calls.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Codegen tests verify dropdims emits a reshape op for both dim=1 and dim=2
on tiles with singleton dimensions. Execution test verifies correctness
of sum + dropdims pattern on GPU.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@maleadt maleadt merged commit 468065e into JuliaGPU:main Feb 8, 2026
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants