Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bolt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ set -e

setup_rust(){
echo "[INFO] Installing build essential"
sudo apt-get update -y
sudo apt-get install -y build-essential

echo "[INFO] Checking Rust installation..."
Expand All @@ -28,6 +29,7 @@ setup_rust(){
if ! cargo nextest --version >/dev/null 2>&1; then
cargo install cargo-nextest
fi
. "$HOME/.cargo/env"
export PATH="$HOME/.cargo/bin:$PATH"
cargo nextest --version
}
Expand Down
10 changes: 6 additions & 4 deletions tachyon/compute/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ impl CodeGen for Expr {
fn to_nvrtc<B: BitBlock>(
&self, schema: &SchemaContext, code_block: &mut CodeBlock,
) -> Result<(), TypeError> {
let result_type = self.infer_type(schema)?;
let res = self.build_nvrtc_code::<B>(schema, code_block)?;
let expr = self.simplify(schema)?;
let result_type = expr.infer_type(schema)?;
let res = expr.build_nvrtc_code::<B>(schema, code_block)?;
code_block.add_store_column::<B>(0, &result_type, &res);
Ok(())
}
Expand All @@ -110,6 +111,7 @@ impl CodeGen for Expr {
&self, schema: &SchemaContext, code_block: &mut CodeBlock,
) -> Result<String, TypeError> {
let result_type = self.infer_type(schema)?;
println!("Result Type: {:?}", result_type);
let error_mode = schema.error_mode() == ErrorMode::Ansi;

let var = match self {
Expand Down Expand Up @@ -208,9 +210,9 @@ impl CodeGen for Expr {
code_block
.add_variable_decl(result_type.kernel_type(), &var)
.add_validity_check(&var, &[&format!("{}.valid", e_var)])
.add_conditional(&var, |block| {
.add_conditional(&format!("{}.valid", var), |block| {
block.add_code(&format!(
"\t{}.value = ({})({}.value)\n",
"\t{}.value = ({})({}.value);\n",
var,
to.c_type(),
e_var,
Expand Down
84 changes: 84 additions & 0 deletions tachyon/compute/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,32 @@ impl Expr {

//Integer Promotion: Promote to the wider integer type. (e.g., I32 + I64 -> I64)
(lt_val, rt_val) if lt_val.is_integer() && rt_val.is_integer() => {
if lt_val.is_signed() != rt_val.is_signed() {
let left_size = lt_val.native_size();
let right_size = rt_val.native_size();

// If the signed type is larger, it can hold the unsigned type's range
// e.g., I16 + U8 -> I16 (I16 range: -32768..32767, U8 range: 0..255)
if lt_val.is_signed() && left_size > right_size {
return Ok(*lt_val);
}
if rt_val.is_signed() && right_size > left_size {
return Ok(*rt_val);
}

// Otherwise, need to promote to next larger signed type
let max_size = left_size.max(right_size);
return match max_size {
1 => Ok(DataType::I16), // I8 + U8 → I16
2 => Ok(DataType::I32), // I16 + U16 → I32
4 => Ok(DataType::I64), // I32 + U32 → I64
_ => Err(TypeError::Unsupported(
"I64/U64 mixing not supported".into(),
)),
};
}

// Same signedness: promote to wider type
if lt_val.native_size() > rt_val.native_size() {
Ok(*lt_val) // e.g., I64 + I32 -> I64
} else {
Expand Down Expand Up @@ -360,6 +386,64 @@ impl Expr {
}
}
}

pub fn simplify(&self, schema: &SchemaContext) -> Result<Expr, TypeError> {
match self {
Expr::Unary { op, expr } => {
let simplified_expr = expr.simplify(schema)?;
Ok(Expr::Unary { op: *op, expr: Box::new(simplified_expr) })
}

Expr::Binary { op, left, right } => {
let left = left.simplify(schema)?;
let right = right.simplify(schema)?;

let lt = left.infer_type(schema)?;
let rt = right.infer_type(schema)?;

let target_type = self.infer_type(schema)?;

match op {
Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => {
let new_left = if lt != target_type {
Expr::Cast { expr: Box::new(left), to: target_type }
} else {
left
};

let new_right = if rt != target_type {
Expr::Cast { expr: Box::new(right), to: target_type }
} else {
right
};

Ok(Expr::Binary {
op: *op,
left: Box::new(new_left),
right: Box::new(new_right),
})
}
_ => Ok(Expr::Binary { op: *op, left: Box::new(left), right: Box::new(right) }),
}
}

Expr::Nary { op, args } => {
let simplified_args: Result<Vec<_>, _> =
args.iter().map(|arg| arg.simplify(schema).map(Box::new)).collect();
Ok(Expr::Nary { op: *op, args: simplified_args? })
}

Expr::Call { name, args } => {
let mut simplified_args = Vec::new();
for arg in args {
simplified_args.push(arg.simplify(schema)?);
}
Ok(Expr::Call { name: name.clone(), args: simplified_args })
}

_ => Ok(self.clone()),
}
}
}

impl fmt::Display for Expr {
Expand Down
57 changes: 36 additions & 21 deletions tachyon/compute/tests/codegen_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,19 @@ fn test_codegen_binary_same_type_cast() {
let _ = expr.to_nvrtc::<u64>(&schema, &mut code_block).expect("codegen");
println!("Code:");
println!("{}", code_block.code());
let expected = r#"Float64 var0 = input[0].load<TypeKind::Float64, uint64_t>(row_idx);
Float32 var1;
var1.valid = true;
var1.value = (float)2.5f;
Float64 var2 = math::mul<false>(ctx, var0, var1);
Float64 var3 = input[1].load<TypeKind::Float64, uint64_t>(row_idx);
Float64 var4 = math::add<false>(ctx, var2, var3);
output[0].store<TypeKind::Float64, uint64_t>(row_idx, var4);"#;
let expected = r#" Float64 var0 = input[0].load<TypeKind::Float64, uint64_t>(row_idx);
Float32 var1;
var1.valid = true;
var1.value = (float)2.5f;
Float64 var2;
var2.valid = var1.valid;
if (var2.valid) {
var2.value = (double)(var1.value);
}
Float64 var3 = math::mul<false>(ctx, var0, var2);
Float64 var4 = input[1].load<TypeKind::Float64, uint64_t>(row_idx);
Float64 var5 = math::add<false>(ctx, var3, var4);
output[0].store<TypeKind::Float64, uint64_t>(row_idx, var5);"#;
assert_eq!(normalize_code(code_block.code()), normalize_code(expected))
}

Expand All @@ -299,19 +304,29 @@ fn test_codegen_binary_different_type_cast() {
let _ = expr.to_nvrtc::<u64>(&schema, &mut code_block).expect("codegen");
println!("Code:");
println!("{}", code_block.code());
let expected = r#" Float64 var0 = input[0].load<TypeKind::Float64, uint64_t>(row_idx);
Float32 var1;
var1.valid = true;
var1.value = (float)2.5f;
Float64 var2 = math::mul<false>(ctx, var0, var1);
Int64 var3 = input[1].load<TypeKind::Int64, uint64_t>(row_idx);
Float32 var4;
var4.valid = var3.valid;
if (var4) {
var4.value = (float)(var3.value)
}
Float64 var5 = math::add<false>(ctx, var2, var4);
output[0].store<TypeKind::Float64, uint64_t>(row_idx, var5);"#;
let expected = r#" Float64 var0 = input[0].load<TypeKind::Float64, uint64_t>(row_idx);
Float32 var1;
var1.valid = true;
var1.value = (float)2.5f;
Float64 var2;
var2.valid = var1.valid;
if (var2.valid) {
var2.value = (double)(var1.value);
}
Float64 var3 = math::mul<false>(ctx, var0, var2);
Int64 var4 = input[1].load<TypeKind::Int64, uint64_t>(row_idx);
Float32 var5;
var5.valid = var4.valid;
if (var5.valid) {
var5.value = (float)(var4.value);
}
Float64 var6;
var6.valid = var5.valid;
if (var6.valid) {
var6.value = (double)(var5.value);
}
Float64 var7 = math::add<false>(ctx, var3, var6);
output[0].store<TypeKind::Float64, uint64_t>(row_idx, var7);"#;
assert_eq!(normalize_code(code_block.code()), normalize_code(expected))
}

Expand Down
Loading