Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/linux/tnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ struct tnum tnum_or(struct tnum a, struct tnum b);
struct tnum tnum_xor(struct tnum a, struct tnum b);
/* Multiply two tnums, return @a * @b */
struct tnum tnum_mul(struct tnum a, struct tnum b);
/* Unsigned division, return @a / @b */
struct tnum tnum_udiv(struct tnum a, struct tnum b);
/* Signed division, return @a / @b */
struct tnum tnum_sdiv(struct tnum a, struct tnum b, bool alu32);

/* Return true if the known bits of both tnums have the same value */
bool tnum_overlap(struct tnum a, struct tnum b);
Expand Down
159 changes: 158 additions & 1 deletion kernel/bpf/tnum.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
#define TNUM(_v, _m) (struct tnum){.value = _v, .mask = _m}
/* A completely unknown value */
const struct tnum tnum_unknown = { .value = 0, .mask = -1 };
/* Tnum bottom */
const struct tnum tnum_bottom = { .value = -1, .mask = -1 };

static bool __tnum_eqb(struct tnum a, struct tnum b)
{
return a.value == b.value && a.mask == b.mask;
}

struct tnum tnum_const(u64 value)
{
Expand Down Expand Up @@ -83,9 +90,23 @@ struct tnum tnum_sub(struct tnum a, struct tnum b)
return TNUM(dv & ~mu, mu);
}

/* __tnum_neg_width: tnum negation with given bit width.
* @a: the tnum to be negated.
* @width: the bit width to perform negation, 32 or 64.
*/
static struct tnum __tnum_neg_width(struct tnum a, int width)
{
if (width == 32)
return tnum_sub(TNUM(U32_MAX, 0), a);
else if (width == 64)
return tnum_sub(TNUM(0, 0), a);
else
return tnum_unknown;
}

struct tnum tnum_neg(struct tnum a)
{
return tnum_sub(TNUM(0, 0), a);
return __tnum_neg_width(a, 64);
}

struct tnum tnum_and(struct tnum a, struct tnum b)
Expand Down Expand Up @@ -167,6 +188,138 @@ bool tnum_overlap(struct tnum a, struct tnum b)
return (a.value & mu) == (b.value & mu);
}

/* __get_mask: get a mask that covers all bits up to the highest set bit in x.
* For example:
* x = 0b0000...0000 -> return 0b0000...0000
* x = 0b0000...0001 -> return 0b0000...0001
* x = 0b0000...1001 -> return 0b0000...1111
* x = 0b1111...1111 -> return 0b1111...1111
*/
static u64 __get_mask(u64 x)
{
int width = 0;

if (x > 0)
width = 64 - __builtin_clzll(x);
if (width == 0)
return 0;
else if (width == 64)
return U64_MAX;
else
return (1ULL << width) - 1;
}

struct tnum tnum_udiv(struct tnum a, struct tnum b)
{
if (tnum_is_const(b)) {
/* BPF div specification: x / 0 = 0 */
if (b.value == 0)
return TNUM(0, 0);
if (tnum_is_const(a))
return TNUM(a.value / b.value, 0);
}

if (b.value == 0)
return tnum_unknown;

u64 a_max = a.value + a.mask;
u64 b_min = b.value;
u64 max_res = a_max / b_min;
return TNUM(0, __get_mask(max_res));
}

static u64 __msb(u64 x, int width)
{
return x & (1ULL << (width - 1));
}

static struct tnum __tnum_get_positive(struct tnum x, int width)
{
if (__msb(x.value, width))
return tnum_bottom;
if (__msb(x.mask, width))
return TNUM(x.value, x.mask & ~(1ULL << (width - 1)));
return x;
}

static struct tnum __tnum_get_negative(struct tnum x, int width)
{
if (__msb(x.value, width))
return x;
if (__msb(x.mask, width))
return TNUM(x.value | (1ULL << (width - 1)), x.mask & ~(1ULL << (width - 1)));
return tnum_bottom;
}

static struct tnum __tnum_abs(struct tnum x, int width)
{
if (__msb(x.value, width))
return __tnum_neg_width(x, width);
else
return x;
}

/* __tnum_sdiv, a helper for tnum_sdiv.
* @a: tnum a, a's sign is fixed, __msb(a.mask) == 0
* @b: tnum b, b's sign is fixed, __msb(b.mask) == 0
*
* This function reuses tnum_udiv by operating on the absolute values of a and b,
* and then adjusting the sign of the result based on C's division rules.
* Here we don't need to specially handle the case of [S64_MIN / -1], because
* after __tnum_abs, S64_MIN becomes (S64_MAX + 1), and the behavior of
* unsigned [(S64_MAX + 1) / 1] is normal.
*/
static struct tnum __tnum_sdiv(struct tnum a, struct tnum b, int width)
{
if (__tnum_eqb(a, tnum_bottom) || __tnum_eqb(b, tnum_bottom))
return tnum_bottom;

struct tnum a_abs = __tnum_abs(a, width);
struct tnum b_abs = __tnum_abs(b, width);
struct tnum res_abs = tnum_udiv(a_abs, b_abs);

if (__msb(a.value, width) == __msb(b.value, width))
return res_abs;
else
return __tnum_neg_width(res_abs, width);
}

struct tnum tnum_sdiv(struct tnum a, struct tnum b, bool alu32)
{
if (tnum_is_const(b)) {
/* BPF div specification: x / 0 = 0 */
if (b.value == 0)
return TNUM(0, 0);
if (tnum_is_const(a)) {
/* BPF div specification: S32_MIN / -1 = S32_MIN */
if (alu32 && (u32)a.value == (u32)S32_MIN && (u32)b.value == (u32)-1)
return TNUM((u32)S32_MIN, 0);
/* BPF div specification: S64_MIN / -1 = S64_MIN */
if (!alu32 && a.value == S64_MIN && b.value == (u64)-1)
return TNUM((u64)S64_MIN, 0);
s64 sval = (s64)a.value / (s64)b.value;
return TNUM((u64)sval, 0);
}
}

if (b.value == 0)
return tnum_unknown;

int width = alu32 ? 32 : 64;
struct tnum a_pos = __tnum_get_positive(a, width);
struct tnum a_neg = __tnum_get_negative(a, width);
struct tnum b_pos = __tnum_get_positive(b, width);
struct tnum b_neg = __tnum_get_negative(b, width);

struct tnum res_pos = __tnum_sdiv(a_pos, b_pos, width);
struct tnum res_neg = __tnum_sdiv(a_neg, b_neg, width);
struct tnum res_mix1 = __tnum_sdiv(a_pos, b_neg, width);
struct tnum res_mix2 = __tnum_sdiv(a_neg, b_pos, width);

return tnum_union(tnum_union(res_pos, res_neg),
tnum_union(res_mix1, res_mix2));
}

/* Note that if a and b disagree - i.e. one has a 'known 1' where the other has
* a 'known 0' - this will return a 'known 1' for that bit.
*/
Expand All @@ -186,6 +339,10 @@ struct tnum tnum_intersect(struct tnum a, struct tnum b)
*/
struct tnum tnum_union(struct tnum a, struct tnum b)
{
if (__tnum_eqb(a, tnum_bottom))
return b;
if (__tnum_eqb(b, tnum_bottom))
return a;
u64 v = a.value & b.value;
u64 mu = (a.value ^ b.value) | a.mask | b.mask;

Expand Down
Loading
Loading