@@ -22,21 +22,21 @@ inline namespace dsl {
2222 * @tparam Src Source expression type.
2323 * @param s Source expression.
2424 * @return A new DSL expression of type Dest.
25- *
25+ *
2626 * Performs a value conversion (like static_cast in C++).
2727 * Supported conversions include:
2828 * - Between numeric types (int <-> float)
2929 * - Between vector types of same dimension
30- *
30+ *
3131 * Example:
3232 * @code
3333 * Float f = 3.7f;
3434 * Int i = cast<int>(f); // i = 3 (truncated)
35- *
35+ *
3636 * Float3 f3 = make_float3(1.5f, 2.5f, 3.5f);
3737 * Int3 i3 = cast<int>(f3); // (1, 2, 3)
3838 * @endcode
39- *
39+ *
4040 * @see as() for bitwise reinterpretation
4141 */
4242template <typename Dest, typename Src>
@@ -58,20 +58,20 @@ template<typename Dest, typename Src>
5858 * @tparam Src Source expression type.
5959 * @param s Source expression.
6060 * @return A new DSL expression of type Dest with the same bit pattern.
61- *
61+ *
6262 * Performs a bitwise reinterpretation (like bit_cast/std::bit_cast in C++).
6363 * The source and destination types must have the same size.
64- *
64+ *
6565 * Example:
6666 * @code
6767 * Float f = 1.0f;
6868 // Reinterpret float bits as uint
6969 * UInt bits = as<uint>(f); // bits = 0x3f800000
70- *
70+ *
7171 * Float2 f2 = make_float2(1.0f, 2.0f);
7272 * UInt2 u2 = as<uint2>(f2); // Reinterpret as uints
7373 * @endcode
74- *
74+ *
7575 * @see cast() for value conversion
7676 */
7777template <typename Dest, typename Src>
@@ -90,19 +90,19 @@ template<typename Dest, typename Src>
9090/* *
9191 * @brief Provide a boolean assumption hint to the compiler.
9292 * @param pred Boolean expression that is assumed to be true.
93- *
93+ *
9494 * The assume statement tells the optimizer that the condition is always
9595 * true, allowing it to generate more efficient code. Use with caution -
9696 * if the assumption is violated, undefined behavior occurs.
97- *
97+ *
9898 * Example:
9999 * @code
100100 * Var<int> index = ...;
101101 assume(index >= 0 && index < buffer_size);
102102 * // Compiler can now optimize knowing index is in bounds
103103 * Float value = buffer.read(index);
104104 * @endcode
105- *
105+ *
106106 * @see unreachable() for marking unreachable code
107107 */
108108inline void assume (Expr<bool > pred) noexcept {
@@ -112,11 +112,11 @@ inline void assume(Expr<bool> pred) noexcept {
112112
113113/* *
114114 * @brief Mark code as unreachable.
115- *
115+ *
116116 * Tells the compiler that this code path should never be executed.
117117 * Useful after branches that always return/exit or for switch defaults
118118 * that should never be hit.
119- *
119+ *
120120 * Example:
121121 * @code
122122 * $switch (value) {
@@ -127,7 +127,7 @@ inline void assume(Expr<bool> pred) noexcept {
127127 * };
128128 * };
129129 * @endcode
130- *
130+ *
131131 * @param msg Optional message for debugging
132132 * @see assume() for providing optimization hints
133133 */
@@ -160,18 +160,18 @@ inline void device_assert(Expr<bool> pred, luisa::string_view msg) noexcept {
160160/* *
161161 * @brief Get the thread index within its block.
162162 * @return uint3 containing (x, y, z) thread coordinates within the block.
163- *
163+ *
164164 * The thread_id identifies a thread's position within its thread block.
165165 * It ranges from (0, 0, 0) to block_size() - 1.
166- *
166+ *
167167 * Example:
168168 * @code
169169 * Kernel1D kernel = [&]() noexcept {
170170 * UInt tid = thread_id().x; // 0 to block_size().x - 1
171171 * // Use tid for shared memory indexing...
172172 * };
173173 * @endcode
174- *
174+ *
175175 * @see block_id() for block position in the grid
176176 * @see dispatch_id() for global thread position
177177 */
@@ -197,10 +197,10 @@ inline void device_assert(Expr<bool> pred, luisa::string_view msg) noexcept {
197197/* *
198198 * @brief Get the block index within the dispatch grid.
199199 * @return uint3 containing (x, y, z) block coordinates.
200- *
200+ *
201201 * The block_id identifies which thread block this thread belongs to.
202202 * It ranges from (0, 0, 0) to (grid_dim - 1).
203- *
203+ *
204204 * @see thread_id() for position within the block
205205 * @see dispatch_id() for global thread position
206206 */
@@ -226,13 +226,13 @@ inline void device_assert(Expr<bool> pred, luisa::string_view msg) noexcept {
226226/* *
227227 * @brief Get the global thread index in the dispatch grid.
228228 * @return uint3 containing (x, y, z) global coordinates.
229- *
229+ *
230230 * The dispatch_id is the global thread identifier, computed as:
231231 * dispatch_id = block_id * block_size + thread_id
232- *
232+ *
233233 * This is the most commonly used coordinate for indexing into
234234 * buffers and images.
235- *
235+ *
236236 * Example:
237237 * @code
238238 * Kernel2D process_image = [&](ImageFloat img) noexcept {
@@ -242,7 +242,7 @@ inline void device_assert(Expr<bool> pred, luisa::string_view msg) noexcept {
242242 * };
243243 * stream << shader(image).dispatch(width, height);
244244 * @endcode
245- *
245+ *
246246 * @see dispatch_size() for total grid dimensions
247247 * @see thread_id() for local thread position
248248 */
@@ -276,10 +276,10 @@ inline void device_assert(Expr<bool> pred, luisa::string_view msg) noexcept {
276276/* *
277277 * @brief Get the total dispatch grid size.
278278 * @return uint3 containing (width, height, depth) of the dispatch grid.
279- *
279+ *
280280 * The dispatch_size represents the total number of threads in each dimension.
281281 * Useful for normalizing coordinates or computing global indices.
282- *
282+ *
283283 * Example:
284284 * @code
285285 * Kernel2D render = [&](ImageFloat image) noexcept {
@@ -289,7 +289,7 @@ inline void device_assert(Expr<bool> pred, luisa::string_view msg) noexcept {
289289 * // uv is now in [0, 1] range...
290290 * };
291291 * @endcode
292- *
292+ *
293293 * @see dispatch_id() for current thread position
294294 * @see set_block_size() for configuring block dimensions
295295 */
@@ -315,10 +315,10 @@ inline void device_assert(Expr<bool> pred, luisa::string_view msg) noexcept {
315315/* *
316316 * @brief Get the thread block size.
317317 * @return uint3 containing (x, y, z) dimensions of each thread block.
318- *
318+ *
319319 * The block_size represents how many threads are in each block.
320320 * Use this for computing local indices or shared memory offsets.
321- *
321+ *
322322 * @see set_block_size() for configuring block dimensions at compile time
323323 */
324324[[nodiscard]] inline const auto block_size () noexcept {
@@ -1722,8 +1722,12 @@ template<typename X, typename Y>
17221722template <typename X>
17231723 requires is_dsl_v<X> && is_floating_point_or_vector_expr_v<X>
17241724[[nodiscard]] inline auto sign (X &&x) noexcept {
1725- return copysign (1 .0f , std::forward<X>(x));
1725+ using Scalar = expr_value_t <decltype (x)>;
1726+ auto zero = def<Scalar>(0 );
1727+ auto value = std::forward<X>(x);
1728+ return cast<Scalar>(value > zero) - cast<Scalar>(value < zero);
17261729}
1730+
17271731template <typename X>
17281732 requires is_dsl_v<X> && (is_scalar_v<expr_value_t <X>> || is_matrix_v<expr_value_t <X>> || is_vector_v<expr_value_t <X>>)
17291733[[nodiscard]] inline auto ddx (X &&x) noexcept {
0 commit comments