Skip to content

Commit ce949f1

Browse files
authored
Enhanced cmp eval for different types (#10)
1 parent e9fbc1a commit ce949f1

File tree

6 files changed

+1047
-125
lines changed

6 files changed

+1047
-125
lines changed

tachyon/compute/src/codegen.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -205,19 +205,29 @@ impl CodeGen for Expr {
205205
}
206206
Expr::Cast { expr, to } => {
207207
let e_var = expr.build_nvrtc_code::<B>(schema, code_block)?;
208-
if to.kernel_type() == expr.infer_type(schema)?.kernel_type() {
208+
let from = expr.infer_type(schema)?;
209+
if *to == from {
209210
return Ok(e_var);
210211
}
211212
let var = code_block.next_var();
213+
let cast_fn = match (from, to) {
214+
//(DataType::I8, DataType::F16) => "__ushort2half_rn",
215+
(DataType::I16, DataType::F16) => "__short2half_rn",
216+
(DataType::I32, DataType::F16) => "__int2half_rn",
217+
(DataType::I64, DataType::F16) => "__ll2half_rn",
218+
//(DataType::U8, DataType::F16) => "__ushort2half_rn",
219+
(DataType::U16, DataType::F16) => "__ushort2half_rn",
220+
(DataType::U32, DataType::F16) => "__uint2half_rn",
221+
(DataType::U64, DataType::F16) => "__ull2half_rn",
222+
_ => &format!("({})", to.c_type()),
223+
};
212224
code_block
213225
.add_variable_decl(result_type.kernel_type(), &var)
214226
.add_validity_check(&var, &[&format!("{}.valid", e_var)])
215227
.add_conditional(&format!("{}.valid", var), |block| {
216228
block.add_code(&format!(
217-
"\t{}.value = ({})({}.value);\n",
218-
var,
219-
to.c_type(),
220-
e_var,
229+
"\t{}.value = {}({}.value);\n",
230+
var, cast_fn, e_var,
221231
));
222232
});
223233
var

tachyon/compute/src/expr.rs

Lines changed: 111 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -260,78 +260,7 @@ impl Expr {
260260

261261
match op {
262262
Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => {
263-
match (&lt, &rt) {
264-
(lt_val, rt_val) if lt_val == rt_val => Ok(*lt_val),
265-
266-
//Integer Promotion: Promote to the wider integer type. (e.g., I32 + I64 -> I64)
267-
(lt_val, rt_val) if lt_val.is_integer() && rt_val.is_integer() => {
268-
if lt_val.is_signed() != rt_val.is_signed() {
269-
let left_size = lt_val.native_size();
270-
let right_size = rt_val.native_size();
271-
272-
// If the signed type is larger, it can hold the unsigned type's range
273-
// e.g., I16 + U8 -> I16 (I16 range: -32768..32767, U8 range: 0..255)
274-
if lt_val.is_signed() && left_size > right_size {
275-
return Ok(*lt_val);
276-
}
277-
if rt_val.is_signed() && right_size > left_size {
278-
return Ok(*rt_val);
279-
}
280-
281-
// Otherwise, need to promote to next larger signed type
282-
let max_size = left_size.max(right_size);
283-
return match max_size {
284-
1 => Ok(DataType::I16), // I8 + U8 → I16
285-
2 => Ok(DataType::I32), // I16 + U16 → I32
286-
4 => Ok(DataType::I64), // I32 + U32 → I64
287-
_ => Err(TypeError::Unsupported(
288-
"I64/U64 mixing not supported".into(),
289-
)),
290-
};
291-
}
292-
293-
// Same signedness: promote to wider type
294-
if lt_val.native_size() > rt_val.native_size() {
295-
Ok(*lt_val) // e.g., I64 + I32 -> I64
296-
} else {
297-
Ok(*rt_val) // e.g., I32 + I64 -> I64
298-
}
299-
}
300-
// Float Promotion: Promote to the wider float type. (e.g., F32 + F64 -> F64)
301-
(lt_val, rt_val) if lt_val.is_float() && rt_val.is_float() => {
302-
if lt_val.native_size() > rt_val.native_size() {
303-
Ok(*lt_val) // e.g., F64 + F32 -> F64
304-
} else {
305-
Ok(*rt_val) // e.g., F32 + F64 -> F64
306-
}
307-
}
308-
309-
// Integer/Float Promotion (Left is Float): Promote to the wider float type. (e.g., F32 + I64 -> F64)
310-
(lt_val, rt_val) if lt_val.is_float() && rt_val.is_integer() => {
311-
match (lt_val, rt_val.native_size()) {
312-
(DataType::F64, _) => Ok(DataType::F64), // F64 is always the widest
313-
(DataType::F32, size) if size > DataType::F32.native_size() => {
314-
Ok(DataType::F64)
315-
} // I64 is larger than F32
316-
_ => Ok(*lt_val), // F32 + smaller integer => F32
317-
}
318-
}
319-
320-
// 5. Integer/Float Promotion (Right is Float): Same logic, reversed. (e.g., I64 + F32 -> F64)
321-
(lt_val, rt_val) if lt_val.is_integer() && rt_val.is_float() => {
322-
match (rt_val, lt_val.native_size()) {
323-
(DataType::F64, _) => Ok(DataType::F64), // F64 is always the widest
324-
(DataType::F32, size) if size > DataType::F32.native_size() => {
325-
Ok(DataType::F64)
326-
} // I64 is larger than F32
327-
_ => Ok(*rt_val), // F32 + smaller integer => F32
328-
}
329-
}
330-
331-
(lt_val, rt_val) => {
332-
Err(TypeError::TypeMismatch { expected: *lt_val, got: *rt_val })
333-
}
334-
}
263+
infer_binary_op_type(lt, rt)
335264
}
336265
Operator::Eq
337266
| Operator::NotEq
@@ -423,6 +352,49 @@ impl Expr {
423352
right: Box::new(new_right),
424353
})
425354
}
355+
Operator::Eq
356+
| Operator::NotEq
357+
| Operator::Lt
358+
| Operator::LtEq
359+
| Operator::Gt
360+
| Operator::GtEq => {
361+
if lt != rt {
362+
let (l_promoted, r_promoted) = match infer_binary_op_type(lt, rt) {
363+
Ok(promoted_type) => (promoted_type, promoted_type),
364+
Err(TypeError::Unsupported(_)) => {
365+
debug_assert!(lt.native_size() == 8 || rt.native_size() == 8);
366+
let left_type =
367+
if lt.is_signed() { DataType::I64 } else { DataType::U64 };
368+
let right_type =
369+
if rt.is_signed() { DataType::I64 } else { DataType::U64 };
370+
(left_type, right_type)
371+
}
372+
Err(err) => Err(err)?,
373+
};
374+
375+
let new_left = if lt != l_promoted {
376+
Expr::Cast { expr: Box::new(left), to: l_promoted }
377+
} else {
378+
left
379+
};
380+
let new_right = if rt != r_promoted {
381+
Expr::Cast { expr: Box::new(right), to: r_promoted }
382+
} else {
383+
right
384+
};
385+
Ok(Expr::Binary {
386+
op: *op,
387+
left: Box::new(new_left),
388+
right: Box::new(new_right),
389+
})
390+
} else {
391+
Ok(Expr::Binary {
392+
op: *op,
393+
left: Box::new(left),
394+
right: Box::new(right),
395+
})
396+
}
397+
}
426398
_ => Ok(Expr::Binary { op: *op, left: Box::new(left), right: Box::new(right) }),
427399
}
428400
}
@@ -446,6 +418,73 @@ impl Expr {
446418
}
447419
}
448420

421+
fn infer_binary_op_type(lt: DataType, rt: DataType) -> Result<DataType, TypeError> {
422+
match (&lt, &rt) {
423+
(lt_val, rt_val) if lt_val == rt_val => Ok(*lt_val),
424+
425+
//Integer Promotion: Promote to the wider integer type. (e.g., I32 + I64 -> I64)
426+
(lt_val, rt_val) if lt_val.is_integer() && rt_val.is_integer() => {
427+
if lt_val.is_signed() != rt_val.is_signed() {
428+
let left_size = lt_val.native_size();
429+
let right_size = rt_val.native_size();
430+
431+
// If the signed type is larger, it can hold the unsigned type's range
432+
// e.g., I16 + U8 -> I16 (I16 range: -32768..32767, U8 range: 0..255)
433+
if lt_val.is_signed() && left_size > right_size {
434+
return Ok(*lt_val);
435+
}
436+
if rt_val.is_signed() && right_size > left_size {
437+
return Ok(*rt_val);
438+
}
439+
440+
// Otherwise, need to promote to next larger signed type
441+
let max_size = left_size.max(right_size);
442+
return match max_size {
443+
1 => Ok(DataType::I16), // I8 + U8 → I16
444+
2 => Ok(DataType::I32), // I16 + U16 → I32
445+
4 => Ok(DataType::I64), // I32 + U32 → I64
446+
_ => Err(TypeError::Unsupported("I64/U64 mixing not supported".into())),
447+
};
448+
}
449+
450+
// Same signedness: promote to wider type
451+
if lt_val.native_size() > rt_val.native_size() {
452+
Ok(*lt_val) // e.g., I64 + I32 -> I64
453+
} else {
454+
Ok(*rt_val) // e.g., I32 + I64 -> I64
455+
}
456+
}
457+
// Float Promotion: Promote to the wider float type. (e.g., F32 + F64 -> F64)
458+
(lt_val, rt_val) if lt_val.is_float() && rt_val.is_float() => {
459+
if lt_val.native_size() > rt_val.native_size() {
460+
Ok(*lt_val) // e.g., F64 + F32 -> F64
461+
} else {
462+
Ok(*rt_val) // e.g., F32 + F64 -> F64
463+
}
464+
}
465+
466+
// Integer/Float Promotion (Left is Float): Promote to the wider float type. (e.g., F32 + I64 -> F64)
467+
(lt_val, rt_val) if lt_val.is_float() && rt_val.is_integer() => {
468+
match (lt_val, rt_val.native_size()) {
469+
(DataType::F64, _) => Ok(DataType::F64), // F64 is always the widest
470+
(DataType::F32, size) if size > DataType::F32.native_size() => Ok(DataType::F64), // I64 is larger than F32
471+
_ => Ok(*lt_val), // F32 + smaller integer => F32
472+
}
473+
}
474+
475+
// 5. Integer/Float Promotion (Right is Float): Same logic, reversed. (e.g., I64 + F32 -> F64)
476+
(lt_val, rt_val) if lt_val.is_integer() && rt_val.is_float() => {
477+
match (rt_val, lt_val.native_size()) {
478+
(DataType::F64, _) => Ok(DataType::F64), // F64 is always the widest
479+
(DataType::F32, size) if size > DataType::F32.native_size() => Ok(DataType::F64), // I64 is larger than F32
480+
_ => Ok(*rt_val), // F32 + smaller integer => F32
481+
}
482+
}
483+
484+
(lt_val, rt_val) => Err(TypeError::TypeMismatch { expected: *lt_val, got: *rt_val }),
485+
}
486+
}
487+
449488
impl fmt::Display for Expr {
450489
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
451490
match self {

0 commit comments

Comments
 (0)