Skip to content

Commit 52ace2b

Browse files
committed
fix sign: adhere to mathematical definitions and comply with glsl standards
1 parent ffd0a5b commit 52ace2b

File tree

2 files changed

+43
-34
lines changed

2 files changed

+43
-34
lines changed

include/luisa/core/mathematics.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace luisa {
1212

1313
/**
1414
* @brief Find next 2^n of v
15-
*
15+
*
1616
* @tparam uint32 or uint64
1717
* @param v input number
1818
* @return same as v
@@ -695,7 +695,9 @@ template<size_t N>
695695
return scaling(make_float3(s));
696696
}
697697

698-
[[nodiscard]] constexpr auto sign(float x) noexcept { return x < 0.f ? -1.f : 1.f; }
698+
[[nodiscard]] constexpr auto sign(float x) noexcept {
699+
return static_cast<float>(x > 0.0f) - static_cast<float>(x < 0.0f);
700+
}
699701

700702
[[nodiscard]] constexpr auto sign(float2 v) noexcept {
701703
return make_float2(sign(v.x), sign(v.y));
@@ -755,7 +757,9 @@ template<size_t N>
755757
return scaling(make_double3(s));
756758
}
757759

758-
[[nodiscard]] constexpr auto sign(double x) noexcept { return x < 0. ? -1. : 1.; }
760+
[[nodiscard]] constexpr auto sign(double x) noexcept {
761+
return static_cast<double>(x > 0.0) - static_cast<double>(x < 0.0);
762+
}
759763

760764
[[nodiscard]] constexpr auto sign(double2 v) noexcept {
761765
return make_double2(sign(v.x), sign(v.y));
@@ -769,8 +773,9 @@ template<size_t N>
769773
return make_double4(sign(v.x), sign(v.y), sign(v.z), sign(v.w));
770774
}
771775

772-
773-
[[nodiscard]] constexpr auto sign(int x) noexcept { return x < 0 ? -1 : 1; }
776+
[[nodiscard]] constexpr auto sign(int x) noexcept {
777+
return static_cast<int>(x > 0) - static_cast<int>(x < 0);
778+
}
774779

775780
[[nodiscard]] constexpr auto sign(int2 v) noexcept {
776781
return make_int2(sign(v.x), sign(v.y));

include/luisa/dsl/builtin.h

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,21 @@ inline namespace dsl {
2222
* @tparam Src Source expression type.
2323
* @param s Source expression.
2424
* @return A new DSL expression of type Dest.
25-
*
25+
*
2626
* Performs a value conversion (like static_cast in C++).
2727
* Supported conversions include:
2828
* - Between numeric types (int <-> float)
2929
* - Between vector types of same dimension
30-
*
30+
*
3131
* Example:
3232
* @code
3333
* Float f = 3.7f;
3434
* Int i = cast<int>(f); // i = 3 (truncated)
35-
*
35+
*
3636
* Float3 f3 = make_float3(1.5f, 2.5f, 3.5f);
3737
* Int3 i3 = cast<int>(f3); // (1, 2, 3)
3838
* @endcode
39-
*
39+
*
4040
* @see as() for bitwise reinterpretation
4141
*/
4242
template<typename Dest, typename Src>
@@ -58,20 +58,20 @@ template<typename Dest, typename Src>
5858
* @tparam Src Source expression type.
5959
* @param s Source expression.
6060
* @return A new DSL expression of type Dest with the same bit pattern.
61-
*
61+
*
6262
* Performs a bitwise reinterpretation (like bit_cast/std::bit_cast in C++).
6363
* The source and destination types must have the same size.
64-
*
64+
*
6565
* Example:
6666
* @code
6767
* Float f = 1.0f;
6868
// Reinterpret float bits as uint
6969
* UInt bits = as<uint>(f); // bits = 0x3f800000
70-
*
70+
*
7171
* Float2 f2 = make_float2(1.0f, 2.0f);
7272
* UInt2 u2 = as<uint2>(f2); // Reinterpret as uints
7373
* @endcode
74-
*
74+
*
7575
* @see cast() for value conversion
7676
*/
7777
template<typename Dest, typename Src>
@@ -90,19 +90,19 @@ template<typename Dest, typename Src>
9090
/**
9191
* @brief Provide a boolean assumption hint to the compiler.
9292
* @param pred Boolean expression that is assumed to be true.
93-
*
93+
*
9494
* The assume statement tells the optimizer that the condition is always
9595
* true, allowing it to generate more efficient code. Use with caution -
9696
* if the assumption is violated, undefined behavior occurs.
97-
*
97+
*
9898
* Example:
9999
* @code
100100
* Var<int> index = ...;
101101
assume(index >= 0 && index < buffer_size);
102102
* // Compiler can now optimize knowing index is in bounds
103103
* Float value = buffer.read(index);
104104
* @endcode
105-
*
105+
*
106106
* @see unreachable() for marking unreachable code
107107
*/
108108
inline void assume(Expr<bool> pred) noexcept {
@@ -112,11 +112,11 @@ inline void assume(Expr<bool> pred) noexcept {
112112

113113
/**
114114
* @brief Mark code as unreachable.
115-
*
115+
*
116116
* Tells the compiler that this code path should never be executed.
117117
* Useful after branches that always return/exit or for switch defaults
118118
* that should never be hit.
119-
*
119+
*
120120
* Example:
121121
* @code
122122
* $switch (value) {
@@ -127,7 +127,7 @@ inline void assume(Expr<bool> pred) noexcept {
127127
* };
128128
* };
129129
* @endcode
130-
*
130+
*
131131
* @param msg Optional message for debugging
132132
* @see assume() for providing optimization hints
133133
*/
@@ -160,18 +160,18 @@ inline void device_assert(Expr<bool> pred, luisa::string_view msg) noexcept {
160160
/**
161161
* @brief Get the thread index within its block.
162162
* @return uint3 containing (x, y, z) thread coordinates within the block.
163-
*
163+
*
164164
* The thread_id identifies a thread's position within its thread block.
165165
* It ranges from (0, 0, 0) to block_size() - 1.
166-
*
166+
*
167167
* Example:
168168
* @code
169169
* Kernel1D kernel = [&]() noexcept {
170170
* UInt tid = thread_id().x; // 0 to block_size().x - 1
171171
* // Use tid for shared memory indexing...
172172
* };
173173
* @endcode
174-
*
174+
*
175175
* @see block_id() for block position in the grid
176176
* @see dispatch_id() for global thread position
177177
*/
@@ -197,10 +197,10 @@ inline void device_assert(Expr<bool> pred, luisa::string_view msg) noexcept {
197197
/**
198198
* @brief Get the block index within the dispatch grid.
199199
* @return uint3 containing (x, y, z) block coordinates.
200-
*
200+
*
201201
* The block_id identifies which thread block this thread belongs to.
202202
* It ranges from (0, 0, 0) to (grid_dim - 1).
203-
*
203+
*
204204
* @see thread_id() for position within the block
205205
* @see dispatch_id() for global thread position
206206
*/
@@ -226,13 +226,13 @@ inline void device_assert(Expr<bool> pred, luisa::string_view msg) noexcept {
226226
/**
227227
* @brief Get the global thread index in the dispatch grid.
228228
* @return uint3 containing (x, y, z) global coordinates.
229-
*
229+
*
230230
* The dispatch_id is the global thread identifier, computed as:
231231
* dispatch_id = block_id * block_size + thread_id
232-
*
232+
*
233233
* This is the most commonly used coordinate for indexing into
234234
* buffers and images.
235-
*
235+
*
236236
* Example:
237237
* @code
238238
* Kernel2D process_image = [&](ImageFloat img) noexcept {
@@ -242,7 +242,7 @@ inline void device_assert(Expr<bool> pred, luisa::string_view msg) noexcept {
242242
* };
243243
* stream << shader(image).dispatch(width, height);
244244
* @endcode
245-
*
245+
*
246246
* @see dispatch_size() for total grid dimensions
247247
* @see thread_id() for local thread position
248248
*/
@@ -276,10 +276,10 @@ inline void device_assert(Expr<bool> pred, luisa::string_view msg) noexcept {
276276
/**
277277
* @brief Get the total dispatch grid size.
278278
* @return uint3 containing (width, height, depth) of the dispatch grid.
279-
*
279+
*
280280
* The dispatch_size represents the total number of threads in each dimension.
281281
* Useful for normalizing coordinates or computing global indices.
282-
*
282+
*
283283
* Example:
284284
* @code
285285
* Kernel2D render = [&](ImageFloat image) noexcept {
@@ -289,7 +289,7 @@ inline void device_assert(Expr<bool> pred, luisa::string_view msg) noexcept {
289289
* // uv is now in [0, 1] range...
290290
* };
291291
* @endcode
292-
*
292+
*
293293
* @see dispatch_id() for current thread position
294294
* @see set_block_size() for configuring block dimensions
295295
*/
@@ -315,10 +315,10 @@ inline void device_assert(Expr<bool> pred, luisa::string_view msg) noexcept {
315315
/**
316316
* @brief Get the thread block size.
317317
* @return uint3 containing (x, y, z) dimensions of each thread block.
318-
*
318+
*
319319
* The block_size represents how many threads are in each block.
320320
* Use this for computing local indices or shared memory offsets.
321-
*
321+
*
322322
* @see set_block_size() for configuring block dimensions at compile time
323323
*/
324324
[[nodiscard]] inline const auto block_size() noexcept {
@@ -1722,8 +1722,12 @@ template<typename X, typename Y>
17221722
template<typename X>
17231723
requires is_dsl_v<X> && is_floating_point_or_vector_expr_v<X>
17241724
[[nodiscard]] inline auto sign(X &&x) noexcept {
1725-
return copysign(1.0f, std::forward<X>(x));
1725+
using Scalar = expr_value_t<decltype(x)>;
1726+
auto zero = def<Scalar>(0);
1727+
auto value = std::forward<X>(x);
1728+
return cast<Scalar>(value > zero) - cast<Scalar>(value < zero);
17261729
}
1730+
17271731
template<typename X>
17281732
requires is_dsl_v<X> && (is_scalar_v<expr_value_t<X>> || is_matrix_v<expr_value_t<X>> || is_vector_v<expr_value_t<X>>)
17291733
[[nodiscard]] inline auto ddx(X &&x) noexcept {

0 commit comments

Comments
 (0)