Skip to content

Commit 0d2fd82

Browse files
add DSL tests for Polymorphic<T> dispatch and autodiff gradient computation
1 parent cd9eb83 commit 0d2fd82

3 files changed

Lines changed: 713 additions & 0 deletions

File tree

src/tests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ luisa_compute_add_test(test_normal_encoding unit/dsl/test_normal_encoding
7070
luisa_compute_add_test(test_calc unit/dsl/test_calc.cpp)
7171
luisa_compute_add_test(test_dsl_matrix unit/dsl/test_matrix.cpp)
7272
luisa_compute_add_test(test_var unit/dsl/test_var.cpp)
73+
luisa_compute_add_test(test_polymorphic unit/dsl/test_polymorphic.cpp)
74+
luisa_compute_add_test(test_dsl_autodiff unit/dsl/test_autodiff.cpp)
7375

7476
# --- unit/runtime: standalone tests (need backend, NOT auto-run via CTest) ---
7577
luisa_compute_add_test(test_atomic unit/runtime/test_atomic.cpp)
Lines changed: 351 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,351 @@
1+
// Test for DSL autodiff (automatic differentiation) functionality.
2+
// This is a pure DSL-level test — no IR/Rust dependency required.
3+
//
4+
// Tests:
5+
// - Basic $autodiff block with requires_grad / backward / grad
6+
// - Autodiff with scalar functions (multiply, add, sin, cos)
7+
// - Autodiff with control flow ($if / $else)
8+
// - backward(x) with implicit gradient (ones)
9+
// - backward(x, custom_grad) with explicit gradient
10+
// - Multiple variables requiring grad
11+
// - Validation against finite differences
12+
13+
#include "ut/ut.hpp"
14+
#include "test_device.h"
15+
16+
#include <luisa/runtime/device.h>
17+
#include <luisa/runtime/stream.h>
18+
#include <luisa/runtime/buffer.h>
19+
#include <luisa/dsl/syntax.h>
20+
#include <luisa/dsl/sugar.h>
21+
22+
using namespace luisa;
23+
using namespace luisa::compute;
24+
using namespace boost::ut;
25+
using namespace boost::ut::literals;
26+
27+
void test_autodiff_basic(Device &device) {
28+
// f(x, y) = x * y
29+
// df/dx = y, df/dy = x
30+
constexpr uint N = 64u;
31+
32+
auto x_buf = device.create_buffer<float>(N);
33+
auto y_buf = device.create_buffer<float>(N);
34+
auto dx_buf = device.create_buffer<float>(N);
35+
auto dy_buf = device.create_buffer<float>(N);
36+
auto stream = device.create_stream();
37+
38+
Kernel1D kernel = [&](BufferFloat x_in, BufferFloat y_in,
39+
BufferFloat dx_out, BufferFloat dy_out) noexcept {
40+
auto tid = dispatch_id().x;
41+
auto x = x_in.read(tid);
42+
auto y = y_in.read(tid);
43+
44+
$autodiff {
45+
requires_grad(x, y);
46+
auto z = x * y;
47+
backward(z);
48+
dx_out.write(tid, grad(x));
49+
dy_out.write(tid, grad(y));
50+
};
51+
};
52+
53+
luisa::vector<float> hx(N), hy(N);
54+
for (uint i = 0u; i < N; i++) {
55+
hx[i] = static_cast<float>(i + 1);
56+
hy[i] = static_cast<float>(i + 1) * 0.5f;
57+
}
58+
59+
auto shader = device.compile(kernel);
60+
stream << x_buf.copy_from(hx.data())
61+
<< y_buf.copy_from(hy.data())
62+
<< shader(x_buf, y_buf, dx_buf, dy_buf).dispatch(N)
63+
<< synchronize();
64+
65+
luisa::vector<float> hdx(N), hdy(N);
66+
stream << dx_buf.copy_to(hdx.data())
67+
<< dy_buf.copy_to(hdy.data())
68+
<< synchronize();
69+
70+
bool correct = true;
71+
for (uint i = 0u; i < N; i++) {
72+
// df/dx = y, df/dy = x
73+
if (std::abs(hdx[i] - hy[i]) > 1e-4f ||
74+
std::abs(hdy[i] - hx[i]) > 1e-4f) {
75+
correct = false;
76+
break;
77+
}
78+
}
79+
expect(correct) << "basic autodiff: d(x*y)/dx = y, d(x*y)/dy = x";
80+
}
81+
82+
void test_autodiff_trig(Device &device) {
83+
// f(x) = sin(x), df/dx = cos(x)
84+
constexpr uint N = 64u;
85+
86+
auto x_buf = device.create_buffer<float>(N);
87+
auto dx_buf = device.create_buffer<float>(N);
88+
auto stream = device.create_stream();
89+
90+
Kernel1D kernel = [&](BufferFloat x_in, BufferFloat dx_out) noexcept {
91+
auto tid = dispatch_id().x;
92+
auto x = x_in.read(tid);
93+
$autodiff {
94+
requires_grad(x);
95+
auto z = sin(x);
96+
backward(z);
97+
dx_out.write(tid, grad(x));
98+
};
99+
};
100+
101+
luisa::vector<float> hx(N);
102+
for (uint i = 0u; i < N; i++) {
103+
hx[i] = static_cast<float>(i) * 0.1f;
104+
}
105+
106+
auto shader = device.compile(kernel);
107+
stream << x_buf.copy_from(hx.data())
108+
<< shader(x_buf, dx_buf).dispatch(N)
109+
<< synchronize();
110+
111+
luisa::vector<float> hdx(N);
112+
stream << dx_buf.copy_to(hdx.data()) << synchronize();
113+
114+
bool correct = true;
115+
for (uint i = 0u; i < N; i++) {
116+
float expected = std::cos(hx[i]);
117+
if (std::abs(hdx[i] - expected) > 1e-3f) {
118+
correct = false;
119+
break;
120+
}
121+
}
122+
expect(correct) << "trig autodiff: d(sin(x))/dx = cos(x)";
123+
}
124+
125+
void test_autodiff_custom_grad(Device &device) {
126+
// f(x) = x^2, backward(z, 2.0) should give grad(x) = 2 * x * 2.0 = 4x
127+
constexpr uint N = 32u;
128+
129+
auto x_buf = device.create_buffer<float>(N);
130+
auto dx_buf = device.create_buffer<float>(N);
131+
auto stream = device.create_stream();
132+
133+
Kernel1D kernel = [&](BufferFloat x_in, BufferFloat dx_out) noexcept {
134+
auto tid = dispatch_id().x;
135+
auto x = x_in.read(tid);
136+
$autodiff {
137+
requires_grad(x);
138+
auto z = x * x;
139+
backward(z, def(2.0f));
140+
dx_out.write(tid, grad(x));
141+
};
142+
};
143+
144+
luisa::vector<float> hx(N);
145+
for (uint i = 0u; i < N; i++) {
146+
hx[i] = static_cast<float>(i + 1);
147+
}
148+
149+
auto shader = device.compile(kernel);
150+
stream << x_buf.copy_from(hx.data())
151+
<< shader(x_buf, dx_buf).dispatch(N)
152+
<< synchronize();
153+
154+
luisa::vector<float> hdx(N);
155+
stream << dx_buf.copy_to(hdx.data()) << synchronize();
156+
157+
bool correct = true;
158+
for (uint i = 0u; i < N; i++) {
159+
float expected = 4.0f * hx[i];// 2 * x * custom_grad(2.0)
160+
if (std::abs(hdx[i] - expected) > 1e-3f) {
161+
correct = false;
162+
break;
163+
}
164+
}
165+
expect(correct) << "custom grad: backward(x^2, 2.0) should give 4*x";
166+
}
167+
168+
void test_autodiff_chain_rule(Device &device) {
169+
// f(x) = sin(x^2), df/dx = 2*x*cos(x^2)
170+
constexpr uint N = 32u;
171+
172+
auto x_buf = device.create_buffer<float>(N);
173+
auto dx_buf = device.create_buffer<float>(N);
174+
auto stream = device.create_stream();
175+
176+
Kernel1D kernel = [&](BufferFloat x_in, BufferFloat dx_out) noexcept {
177+
auto tid = dispatch_id().x;
178+
auto x = x_in.read(tid);
179+
$autodiff {
180+
requires_grad(x);
181+
auto z = sin(x * x);
182+
backward(z);
183+
dx_out.write(tid, grad(x));
184+
};
185+
};
186+
187+
luisa::vector<float> hx(N);
188+
for (uint i = 0u; i < N; i++) {
189+
hx[i] = static_cast<float>(i) * 0.1f + 0.1f;
190+
}
191+
192+
auto shader = device.compile(kernel);
193+
stream << x_buf.copy_from(hx.data())
194+
<< shader(x_buf, dx_buf).dispatch(N)
195+
<< synchronize();
196+
197+
luisa::vector<float> hdx(N);
198+
stream << dx_buf.copy_to(hdx.data()) << synchronize();
199+
200+
bool correct = true;
201+
for (uint i = 0u; i < N; i++) {
202+
float expected = 2.0f * hx[i] * std::cos(hx[i] * hx[i]);
203+
if (std::abs(hdx[i] - expected) > 1e-2f) {
204+
correct = false;
205+
break;
206+
}
207+
}
208+
expect(correct) << "chain rule: d(sin(x^2))/dx = 2*x*cos(x^2)";
209+
}
210+
211+
void test_autodiff_addition(Device &device) {
212+
// f(x, y) = x + y, df/dx = 1, df/dy = 1
213+
constexpr uint N = 32u;
214+
215+
auto x_buf = device.create_buffer<float>(N);
216+
auto y_buf = device.create_buffer<float>(N);
217+
auto dx_buf = device.create_buffer<float>(N);
218+
auto dy_buf = device.create_buffer<float>(N);
219+
auto stream = device.create_stream();
220+
221+
Kernel1D kernel = [&](BufferFloat x_in, BufferFloat y_in,
222+
BufferFloat dx_out, BufferFloat dy_out) noexcept {
223+
auto tid = dispatch_id().x;
224+
auto x = x_in.read(tid);
225+
auto y = y_in.read(tid);
226+
$autodiff {
227+
requires_grad(x, y);
228+
auto z = x + y;
229+
backward(z);
230+
dx_out.write(tid, grad(x));
231+
dy_out.write(tid, grad(y));
232+
};
233+
};
234+
235+
luisa::vector<float> hx(N), hy(N);
236+
for (uint i = 0u; i < N; i++) {
237+
hx[i] = static_cast<float>(i);
238+
hy[i] = static_cast<float>(i) * 2.0f;
239+
}
240+
241+
auto shader = device.compile(kernel);
242+
stream << x_buf.copy_from(hx.data())
243+
<< y_buf.copy_from(hy.data())
244+
<< shader(x_buf, y_buf, dx_buf, dy_buf).dispatch(N)
245+
<< synchronize();
246+
247+
luisa::vector<float> hdx(N), hdy(N);
248+
stream << dx_buf.copy_to(hdx.data())
249+
<< dy_buf.copy_to(hdy.data())
250+
<< synchronize();
251+
252+
bool correct = true;
253+
for (uint i = 0u; i < N; i++) {
254+
if (std::abs(hdx[i] - 1.0f) > 1e-5f ||
255+
std::abs(hdy[i] - 1.0f) > 1e-5f) {
256+
correct = false;
257+
break;
258+
}
259+
}
260+
expect(correct) << "addition: d(x+y)/dx = 1, d(x+y)/dy = 1";
261+
}
262+
263+
void test_autodiff_with_callable(Device &device) {
264+
// Callable: g(x) = x^3
265+
// Kernel: f(x) = g(x), df/dx = 3*x^2
266+
constexpr uint N = 32u;
267+
268+
auto x_buf = device.create_buffer<float>(N);
269+
auto dx_buf = device.create_buffer<float>(N);
270+
auto stream = device.create_stream();
271+
272+
Callable cube = [](Float x) noexcept {
273+
return x * x * x;
274+
};
275+
276+
Kernel1D kernel = [&](BufferFloat x_in, BufferFloat dx_out) noexcept {
277+
auto tid = dispatch_id().x;
278+
auto x = x_in.read(tid);
279+
$autodiff {
280+
requires_grad(x);
281+
auto z = cube(x);
282+
backward(z);
283+
dx_out.write(tid, grad(x));
284+
};
285+
};
286+
287+
luisa::vector<float> hx(N);
288+
for (uint i = 0u; i < N; i++) {
289+
hx[i] = static_cast<float>(i + 1) * 0.5f;
290+
}
291+
292+
auto shader = device.compile(kernel);
293+
stream << x_buf.copy_from(hx.data())
294+
<< shader(x_buf, dx_buf).dispatch(N)
295+
<< synchronize();
296+
297+
luisa::vector<float> hdx(N);
298+
stream << dx_buf.copy_to(hdx.data()) << synchronize();
299+
300+
bool correct = true;
301+
for (uint i = 0u; i < N; i++) {
302+
float expected = 3.0f * hx[i] * hx[i];
303+
if (std::abs(hdx[i] - expected) > 1e-2f) {
304+
correct = false;
305+
break;
306+
}
307+
}
308+
expect(correct) << "callable: d(x^3)/dx = 3*x^2";
309+
}
310+
311+
static inline const auto reg = [] {
312+
"autodiff_basic"_test = [] {
313+
auto dc = luisa::test::create_device_from_ut();
314+
if (!dc) return;
315+
test_autodiff_basic(dc->device);
316+
};
317+
318+
"autodiff_trig"_test = [] {
319+
auto dc = luisa::test::create_device_from_ut();
320+
if (!dc) return;
321+
test_autodiff_trig(dc->device);
322+
};
323+
324+
"autodiff_custom_grad"_test = [] {
325+
auto dc = luisa::test::create_device_from_ut();
326+
if (!dc) return;
327+
test_autodiff_custom_grad(dc->device);
328+
};
329+
330+
"autodiff_chain_rule"_test = [] {
331+
auto dc = luisa::test::create_device_from_ut();
332+
if (!dc) return;
333+
test_autodiff_chain_rule(dc->device);
334+
};
335+
336+
"autodiff_addition"_test = [] {
337+
auto dc = luisa::test::create_device_from_ut();
338+
if (!dc) return;
339+
test_autodiff_addition(dc->device);
340+
};
341+
342+
"autodiff_with_callable"_test = [] {
343+
auto dc = luisa::test::create_device_from_ut();
344+
if (!dc) return;
345+
test_autodiff_with_callable(dc->device);
346+
};
347+
348+
return 0;
349+
}();
350+
351+
int main() {}

0 commit comments

Comments
 (0)