Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 187 additions & 12 deletions tachyon/compute/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ impl CodeGen for Expr {
var.to_string()
}
Expr::Literal(l) => {
if matches!(l, Literal::Str(_)) {
return Err(TypeError::Unsupported(
"String literals are not yet supported in GPU codegen".to_string(),
));
}
let value = match l {
Literal::I8(i) => format!("{}", i),
Literal::I16(i) => format!("{}", i),
Expand All @@ -151,7 +156,7 @@ impl CodeGen for Expr {
Literal::F32(f) => format!("{}f", float_literal_to_str(*f)),
Literal::F64(f) => float_literal_to_str(*f).to_string(),
Literal::Bool(b) => (if *b { "true" } else { "false" }).to_string(),
Literal::Str(s) => format!("\"{}\"", escape_c_string(s)),
Literal::Str(_) => unreachable!(),
};
let var = code_block.next_var();
let ty_c = result_type.c_type();
Expand Down Expand Up @@ -203,15 +208,188 @@ impl CodeGen for Expr {
}
}
Expr::Nary { op: _, args: _ } => unimplemented!(),
Expr::Call { name, args } => {
let mut arg_strs = Vec::with_capacity(args.len());
for a in args {
arg_strs.push(a.build_nvrtc_code::<B>(schema, code_block)?);
Expr::Call { name, args } => match name.as_str() {
"length" => {
if args.len() != 1 {
return Err(TypeError::Unsupported("length arity".into()));
}
let (col_idx, col_type, col_name) = match &args[0] {
Expr::Column(col_name) => {
let (idx, dt) = schema
.lookup(col_name)
.copied()
.ok_or_else(|| TypeError::UnknownColumn(col_name.clone()))?;
(idx, dt, col_name.clone())
}
_ => {
return Err(TypeError::Unsupported(
"length currently requires a string column argument".into(),
));
}
};
if col_type != DataType::Str {
return Err(TypeError::Unsupported("length expects string column".into()));
}
let arg_var =
code_block.add_load_column::<B>(&col_name, col_idx, &col_type).to_string();
let var = code_block.next_var();
code_block.add_code(&format!(
"\tUInt32 {} = string_ops::length({}, input[{}]);\n",
var, arg_var, col_idx
));
var
}
let var = code_block.next_var();
code_block.add_code(&format!("{}({})", name, arg_strs.join(", ")));
var
}
"lower" | "lower_case" => {
if args.len() != 1 {
return Err(TypeError::Unsupported("lower arity".into()));
}
let (col_idx, col_type, col_name) = match &args[0] {
Expr::Column(col_name) => {
let (idx, dt) = schema
.lookup(col_name)
.copied()
.ok_or_else(|| TypeError::UnknownColumn(col_name.clone()))?;
(idx, dt, col_name.clone())
}
_ => {
return Err(TypeError::Unsupported(
"lower currently requires a string column argument".into(),
));
}
};
if col_type != DataType::Str {
return Err(TypeError::Unsupported("lower expects string column".into()));
}
let arg_var =
code_block.add_load_column::<B>(&col_name, col_idx, &col_type).to_string();
let var = code_block.next_var();
code_block.add_code(&format!(
"\tString {} = string_ops::lower({}, input[{}], output[0], row_idx);\n",
var, arg_var, col_idx
));
var
}
"upper" | "upper_case" => {
if args.len() != 1 {
return Err(TypeError::Unsupported("upper arity".into()));
}
let (col_idx, col_type, col_name) = match &args[0] {
Expr::Column(col_name) => {
let (idx, dt) = schema
.lookup(col_name)
.copied()
.ok_or_else(|| TypeError::UnknownColumn(col_name.clone()))?;
(idx, dt, col_name.clone())
}
_ => {
return Err(TypeError::Unsupported(
"upper currently requires a string column argument".into(),
));
}
};
if col_type != DataType::Str {
return Err(TypeError::Unsupported("upper expects string column".into()));
}
let arg_var =
code_block.add_load_column::<B>(&col_name, col_idx, &col_type).to_string();
let var = code_block.next_var();
code_block.add_code(&format!(
"\tString {} = string_ops::upper({}, input[{}], output[0], row_idx);\n",
var, arg_var, col_idx
));
var
}
"substring" => {
if args.len() != 3 {
return Err(TypeError::Unsupported("substring arity".into()));
}
let (col_idx, col_type, col_name) = match &args[0] {
Expr::Column(col_name) => {
let (idx, dt) = schema
.lookup(col_name)
.copied()
.ok_or_else(|| TypeError::UnknownColumn(col_name.clone()))?;
(idx, dt, col_name.clone())
}
_ => {
return Err(TypeError::Unsupported(
"substring currently requires first argument as string column"
.into(),
));
}
};
if col_type != DataType::Str {
return Err(TypeError::Unsupported(
"substring expects first argument string column".into(),
));
}
let start_var = args[1].build_nvrtc_code::<B>(schema, code_block)?;
let len_var = args[2].build_nvrtc_code::<B>(schema, code_block)?;
let arg_var =
code_block.add_load_column::<B>(&col_name, col_idx, &col_type).to_string();
let var = code_block.next_var();
code_block.add_code(&format!(
"\tString {};\n\t{}.valid = {}.valid & {}.valid & {}.valid;\n\tif ({}.valid) {{\n\t\t{}.value = string_ops::substring({}, (int32_t)({}.value), (int32_t)({}.value), input[{}], output[0], row_idx).value;\n\t}}\n",
var, var, arg_var, start_var, len_var, var, var, arg_var, start_var, len_var, col_idx
));
var
}
"concat" => {
if args.len() != 2 {
return Err(TypeError::Unsupported("concat arity".into()));
}
let (l_idx, l_type, l_name) = match &args[0] {
Expr::Column(col_name) => {
let (idx, dt) = schema
.lookup(col_name)
.copied()
.ok_or_else(|| TypeError::UnknownColumn(col_name.clone()))?;
(idx, dt, col_name.clone())
}
_ => {
return Err(TypeError::Unsupported(
"concat currently requires string column arguments".into(),
));
}
};
let (r_idx, r_type, r_name) = match &args[1] {
Expr::Column(col_name) => {
let (idx, dt) = schema
.lookup(col_name)
.copied()
.ok_or_else(|| TypeError::UnknownColumn(col_name.clone()))?;
(idx, dt, col_name.clone())
}
_ => {
return Err(TypeError::Unsupported(
"concat currently requires string column arguments".into(),
));
}
};
if l_type != DataType::Str || r_type != DataType::Str {
return Err(TypeError::Unsupported("concat expects string columns".into()));
}
let l_var =
code_block.add_load_column::<B>(&l_name, l_idx, &l_type).to_string();
let r_var =
code_block.add_load_column::<B>(&r_name, r_idx, &r_type).to_string();
let var = code_block.next_var();
code_block.add_code(&format!(
"\tString {} = string_ops::concat({}, input[{}], {}, input[{}], output[0], row_idx);\n",
var, l_var, l_idx, r_var, r_idx
));
var
}
_ => {
let mut arg_strs = Vec::with_capacity(args.len());
for a in args {
arg_strs.push(a.build_nvrtc_code::<B>(schema, code_block)?);
}
let var = code_block.next_var();
code_block.add_code(&format!("{}({})", name, arg_strs.join(", ")));
var
}
},
Expr::Cast { expr, to } => {
let e_var = expr.build_nvrtc_code::<B>(schema, code_block)?;
let from = expr.infer_type(schema)?;
Expand Down Expand Up @@ -272,9 +450,6 @@ fn op_kernel_fn(op: Operator) -> String {
kernel_fn.to_string()
}

fn escape_c_string(s: &str) -> String {
s.replace('"', "\\\"")
}
/// Formats floating-point literals for CUDA code generation.
pub(crate) fn float_literal_to_str<T: Into<f64> + Copy + PartialEq>(f: T) -> String {
let f64_val = f.into();
Expand Down
Loading