Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions tachyon/compute/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,19 +205,29 @@ impl CodeGen for Expr {
}
Expr::Cast { expr, to } => {
let e_var = expr.build_nvrtc_code::<B>(schema, code_block)?;
if to.kernel_type() == expr.infer_type(schema)?.kernel_type() {
let from = expr.infer_type(schema)?;
if *to == from {
return Ok(e_var);
}
let var = code_block.next_var();
let cast_fn = match (from, to) {
//(DataType::I8, DataType::F16) => "__ushort2half_rn",
(DataType::I16, DataType::F16) => "__short2half_rn",
(DataType::I32, DataType::F16) => "__int2half_rn",
(DataType::I64, DataType::F16) => "__ll2half_rn",
//(DataType::U8, DataType::F16) => "__ushort2half_rn",
(DataType::U16, DataType::F16) => "__ushort2half_rn",
(DataType::U32, DataType::F16) => "__uint2half_rn",
(DataType::U64, DataType::F16) => "__ull2half_rn",
_ => &format!("({})", to.c_type()),
};
code_block
.add_variable_decl(result_type.kernel_type(), &var)
.add_validity_check(&var, &[&format!("{}.valid", e_var)])
.add_conditional(&format!("{}.valid", var), |block| {
block.add_code(&format!(
"\t{}.value = ({})({}.value);\n",
var,
to.c_type(),
e_var,
"\t{}.value = {}({}.value);\n",
var, cast_fn, e_var,
));
});
var
Expand Down
183 changes: 111 additions & 72 deletions tachyon/compute/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,78 +260,7 @@ impl Expr {

match op {
Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => {
match (&lt, &rt) {
(lt_val, rt_val) if lt_val == rt_val => Ok(*lt_val),

//Integer Promotion: Promote to the wider integer type. (e.g., I32 + I64 -> I64)
(lt_val, rt_val) if lt_val.is_integer() && rt_val.is_integer() => {
if lt_val.is_signed() != rt_val.is_signed() {
let left_size = lt_val.native_size();
let right_size = rt_val.native_size();

// If the signed type is larger, it can hold the unsigned type's range
// e.g., I16 + U8 -> I16 (I16 range: -32768..32767, U8 range: 0..255)
if lt_val.is_signed() && left_size > right_size {
return Ok(*lt_val);
}
if rt_val.is_signed() && right_size > left_size {
return Ok(*rt_val);
}

// Otherwise, need to promote to next larger signed type
let max_size = left_size.max(right_size);
return match max_size {
1 => Ok(DataType::I16), // I8 + U8 → I16
2 => Ok(DataType::I32), // I16 + U16 → I32
4 => Ok(DataType::I64), // I32 + U32 → I64
_ => Err(TypeError::Unsupported(
"I64/U64 mixing not supported".into(),
)),
};
}

// Same signedness: promote to wider type
if lt_val.native_size() > rt_val.native_size() {
Ok(*lt_val) // e.g., I64 + I32 -> I64
} else {
Ok(*rt_val) // e.g., I32 + I64 -> I64
}
}
// Float Promotion: Promote to the wider float type. (e.g., F32 + F64 -> F64)
(lt_val, rt_val) if lt_val.is_float() && rt_val.is_float() => {
if lt_val.native_size() > rt_val.native_size() {
Ok(*lt_val) // e.g., F64 + F32 -> F64
} else {
Ok(*rt_val) // e.g., F32 + F64 -> F64
}
}

// Integer/Float Promotion (Left is Float): Promote to the wider float type. (e.g., F32 + I64 -> F64)
(lt_val, rt_val) if lt_val.is_float() && rt_val.is_integer() => {
match (lt_val, rt_val.native_size()) {
(DataType::F64, _) => Ok(DataType::F64), // F64 is always the widest
(DataType::F32, size) if size > DataType::F32.native_size() => {
Ok(DataType::F64)
} // I64 is larger than F32
_ => Ok(*lt_val), // F32 + smaller integer => F32
}
}

// 5. Integer/Float Promotion (Right is Float): Same logic, reversed. (e.g., I64 + F32 -> F64)
(lt_val, rt_val) if lt_val.is_integer() && rt_val.is_float() => {
match (rt_val, lt_val.native_size()) {
(DataType::F64, _) => Ok(DataType::F64), // F64 is always the widest
(DataType::F32, size) if size > DataType::F32.native_size() => {
Ok(DataType::F64)
} // I64 is larger than F32
_ => Ok(*rt_val), // F32 + smaller integer => F32
}
}

(lt_val, rt_val) => {
Err(TypeError::TypeMismatch { expected: *lt_val, got: *rt_val })
}
}
infer_binary_op_type(lt, rt)
}
Operator::Eq
| Operator::NotEq
Expand Down Expand Up @@ -423,6 +352,49 @@ impl Expr {
right: Box::new(new_right),
})
}
Operator::Eq
| Operator::NotEq
| Operator::Lt
| Operator::LtEq
| Operator::Gt
| Operator::GtEq => {
if lt != rt {
let (l_promoted, r_promoted) = match infer_binary_op_type(lt, rt) {
Ok(promoted_type) => (promoted_type, promoted_type),
Err(TypeError::Unsupported(_)) => {
debug_assert!(lt.native_size() == 8 || rt.native_size() == 8);
let left_type =
if lt.is_signed() { DataType::I64 } else { DataType::U64 };
let right_type =
if rt.is_signed() { DataType::I64 } else { DataType::U64 };
(left_type, right_type)
}
Err(err) => Err(err)?,
};

let new_left = if lt != l_promoted {
Expr::Cast { expr: Box::new(left), to: l_promoted }
} else {
left
};
let new_right = if rt != r_promoted {
Expr::Cast { expr: Box::new(right), to: r_promoted }
} else {
right
};
Ok(Expr::Binary {
op: *op,
left: Box::new(new_left),
right: Box::new(new_right),
})
} else {
Ok(Expr::Binary {
op: *op,
left: Box::new(left),
right: Box::new(right),
})
}
}
_ => Ok(Expr::Binary { op: *op, left: Box::new(left), right: Box::new(right) }),
}
}
Expand All @@ -446,6 +418,73 @@ impl Expr {
}
}

fn infer_binary_op_type(lt: DataType, rt: DataType) -> Result<DataType, TypeError> {
match (&lt, &rt) {
(lt_val, rt_val) if lt_val == rt_val => Ok(*lt_val),

//Integer Promotion: Promote to the wider integer type. (e.g., I32 + I64 -> I64)
(lt_val, rt_val) if lt_val.is_integer() && rt_val.is_integer() => {
if lt_val.is_signed() != rt_val.is_signed() {
let left_size = lt_val.native_size();
let right_size = rt_val.native_size();

// If the signed type is larger, it can hold the unsigned type's range
// e.g., I16 + U8 -> I16 (I16 range: -32768..32767, U8 range: 0..255)
if lt_val.is_signed() && left_size > right_size {
return Ok(*lt_val);
}
if rt_val.is_signed() && right_size > left_size {
return Ok(*rt_val);
}

// Otherwise, need to promote to next larger signed type
let max_size = left_size.max(right_size);
return match max_size {
1 => Ok(DataType::I16), // I8 + U8 → I16
2 => Ok(DataType::I32), // I16 + U16 → I32
4 => Ok(DataType::I64), // I32 + U32 → I64
_ => Err(TypeError::Unsupported("I64/U64 mixing not supported".into())),
};
}

// Same signedness: promote to wider type
if lt_val.native_size() > rt_val.native_size() {
Ok(*lt_val) // e.g., I64 + I32 -> I64
} else {
Ok(*rt_val) // e.g., I32 + I64 -> I64
}
}
// Float Promotion: Promote to the wider float type. (e.g., F32 + F64 -> F64)
(lt_val, rt_val) if lt_val.is_float() && rt_val.is_float() => {
if lt_val.native_size() > rt_val.native_size() {
Ok(*lt_val) // e.g., F64 + F32 -> F64
} else {
Ok(*rt_val) // e.g., F32 + F64 -> F64
}
}

// Integer/Float Promotion (Left is Float): Promote to the wider float type. (e.g., F32 + I64 -> F64)
(lt_val, rt_val) if lt_val.is_float() && rt_val.is_integer() => {
match (lt_val, rt_val.native_size()) {
(DataType::F64, _) => Ok(DataType::F64), // F64 is always the widest
(DataType::F32, size) if size > DataType::F32.native_size() => Ok(DataType::F64), // I64 is larger than F32
_ => Ok(*lt_val), // F32 + smaller integer => F32
}
}

// 5. Integer/Float Promotion (Right is Float): Same logic, reversed. (e.g., I64 + F32 -> F64)
(lt_val, rt_val) if lt_val.is_integer() && rt_val.is_float() => {
match (rt_val, lt_val.native_size()) {
(DataType::F64, _) => Ok(DataType::F64), // F64 is always the widest
(DataType::F32, size) if size > DataType::F32.native_size() => Ok(DataType::F64), // I64 is larger than F32
_ => Ok(*rt_val), // F32 + smaller integer => F32
}
}

(lt_val, rt_val) => Err(TypeError::TypeMismatch { expected: *lt_val, got: *rt_val }),
}
}

impl fmt::Display for Expr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Expand Down
Loading