diff --git a/custom_type_handling.jai b/custom_type_handling.jai new file mode 100644 index 0000000..f95a1e7 --- /dev/null +++ b/custom_type_handling.jai @@ -0,0 +1,42 @@ +Pg_Type_Handler :: #type (name: string, info: *Type_Info, slot: *u8, col_type: Pq_Type, len: int, data: *u8, row: int, col: int) -> bool; + +register_pg_type_handler :: (oid: Oid, name: string, handler: Pg_Type_Handler) -> (success: bool) { + + for pg_custom_type_handlers { + + if it.oid == oid { + return false; + } + } + + array_add(*pg_custom_type_handlers, .{ + oid = oid, + name = copy_string(name), + handler = handler, + }, + ); + + return true; +} + +remove_pg_type_handler :: (oid: Oid) -> (success: bool) { + + for pg_custom_type_handlers { + + if it.oid == oid { + free(it.name); + remove it; + return true; + } + } + + return false; +} + +Pg_Type_Handler_Record :: struct { + oid: Oid; + name: string; + handler: Pg_Type_Handler; +} + +pg_custom_type_handlers: [..]Pg_Type_Handler_Record; diff --git a/examples/Jaipgvector.jai b/examples/Jaipgvector.jai new file mode 100755 index 0000000..e74fee2 --- /dev/null +++ b/examples/Jaipgvector.jai @@ -0,0 +1,170 @@ +// Based on: https://github.com/overlord-systems/jai-pgvector/tree/bfa620f4e004fe35a93d2ab0ca2bb2c02fddc500 +#import "Basic"; +#import,file "../module.jai"; + +Register_Pgvector_Err :: enum { + OK; + QUERY_ERR; + NOT_INSTALLED; + REGISTER_ERR; +} + +// This does a synchronous DB query +register_pgvector_handler :: (pg_conn: *PGconn) -> err: Register_Pgvector_Err { + + Pgvector_Oids :: struct { + vector_oid: Oid; + } + + // @TODO: We currently only handle the vector type. Note that '_vector' means vector[]. + // + // SELECT to_regtype('vector')::oid, to_regtype('_vector')::oid, to_regtype('halfvec')::oid, to_regtype('_halfvec')::oid, to_regtype('sparsevec')::oid, to_regtype('_sparsevec')::oid + + QUERY :: #string EOF +SELECT to_regtype('vector')::oid AS "vector_oid" +EOF + + rows: []Pgvector_Oids; + rows=, success := execute(pg_conn, Pgvector_Oids, QUERY); + if !success return .QUERY_ERR; + + if rows.count == 0 return .NOT_INSTALLED; + + result := rows[0]; + if result.vector_oid == 0 return .NOT_INSTALLED; + + success = register_pg_type_handler(result.vector_oid, "vector", pgvector_vector_handler); + if !success return .REGISTER_ERR; + + return .OK; +} + +/* +By default this will round to 6 digits after the dot, but float32 has a maximum of 7 digits after the dot. + +If you want to show 7 dots, you can do: + old_float_formatter_trailing_width := context.print_style.default_format_float.trailing_width; + context.print_style.default_format_float.trailing_width = 7; + + floats_to_pgvector_string(.[1, 0.1234567]); + + defer context.print_style.default_format_float.trailing_width = old_float_formatter_trailing_width; + +Note that the downside to this approach is that the returned string will be: + [1 ,0.1234567] + +Without this adjustment, the returned string is: + [1,0.123457] // Note the rounding of ..567->..57 +*/ +floats_to_pgvector_string :: (floats: []float) -> string { + + // We set FormatArray explicitly so we don't break if the default array formatter changes + return sprint("%", FormatArray.{ + value = floats, + separator = ",", + begin_string = "[", + end_string = "]", + draw_separator_after_last_element = false, + stop_printing_after_this_many_elements = -1, + }); +} + +#scope_module + +pgvector_vector_handler :: (name: string, info: *Type_Info, slot: *u8, col_type: Pq_Type, len: int, data: *u8, row: int, col: int) -> bool { + + if info.type != .ARRAY { + log_error("Error: Trying to write column % of type % into member field \"%\" of type %. Vectors can only be stored in member fields that are arrays of floats or views of floats", col, col_type, name, info.type); + return false; + } + + array_info := cast(*Type_Info_Array) info; + if array_info.element_type.type != Type_Info_Tag.FLOAT { + log_error("Error: Trying to write vector column % into member field \"%\" of type %. Member field must be a float array or view", col, name, info.type); + return false; + } + + vector_bytes: []u8 = {count=len, data=data}; + element_size := array_info.element_type.runtime_size; + dims := pgvector_get_dims_from_vector_bytes(vector_bytes); + if element_size != 4 { + log_error("Error: Trying to write column % of type % with array count % into member field \"%\" with array element size of % instead of 4. Member field must be a float array or view.", col, col_type, dims, name, element_size); + } + + if array_info.array_type == .VIEW { + + view := (cast(*Array_View_64) slot); + + floats := pgvector_vector_bytes_to_floats(vector_bytes); + + view.count = floats.count; + view.data = floats.data.(*u8); + + } else if array_info.array_count == -1 { + + resizable_array := cast(*Resizable_Array) slot; + + array_reserve(resizable_array, dims, element_size); + resizable_array.count = dims; + + pgvector_vector_bytes_to_floats_buf(vector_bytes, {count=dims, data=resizable_array.data.(*float)}); + + } else { + + if array_info.array_count != dims { + log_error("Error: Trying to write column % of type % with array count % into member field \"%\", which is a fixed-size array with count %. For fixed-size arrays, the size must match!", col, col_type, dims, name, array_info.array_count); + return false; + } + + pgvector_vector_bytes_to_floats_buf(vector_bytes, {count=dims, data=slot.(*float)}); + } + + return true; +} + +pgvector_get_dims_from_vector_bytes :: (vector: []u8) -> u16 { + assert(vector.count >= 2); + dims := ntoh(vector.data.(*u16).*); + return dims; +} + +pgvector_vector_bytes_to_floats :: (vector: []u8) -> []float { + + dims := pgvector_get_dims_from_vector_bytes(vector); + + out: [..]float; + array_reserve(*out, dims); + out.count = dims; + + pgvector_vector_bytes_to_floats_buf(vector, out); + return out; +} + +// This parses the input bytes and writes to the 'out' buffer. +// Note that its required that: out.count == pgvector_get_dims_from_vector_bytes(vector) +pgvector_vector_bytes_to_floats_buf :: (vector: []u8, out: []float) { + + // At least as of pgvector 0.8.0, when responding in binary mode, pgvector prefixes + // the float array with two int16, the first being the number of dimensions, and the second unused (always zero). + // + // Therefore, to parse it 8 bytes are the minimum (4 bytes dim+unused, 4 bytes first element). + // + // https://github.com/pgvector/pgvector/blob/fef635c9e5512597621e5669dce845c744170822/src/vector.c#L402 + assert(vector.count >= 4); + + dims := pgvector_get_dims_from_vector_bytes(vector); + assert(out.count == dims, "'out' view used in pgvector_vector_bytes_to_floats must match vector dimensions"); + + float_index := 0; + float_ptr := vector.data.(*float) + 1; // +1 to skip the dims+unused at the beginning + while float_index < dims { + + defer float_index += 1; + + // Incoming bytes are usually big endian, so this (from Jaipq) takes care of swapping bytes if our machine + // is little endian. + val := ntoh((float_ptr+float_index).*); + + out[float_index] = val; + } +} diff --git a/examples/example.jai b/examples/example.jai index ccfc56d..cf9a203 100644 --- a/examples/example.jai +++ b/examples/example.jai @@ -98,16 +98,39 @@ main :: () { assert(result.varchar == ""); assert(result.text == "Something"); + test_register_handlers(conn); + log("ALL OK"); } +test_register_handlers :: (conn: *PGconn) { + + err_msg := register_pgvector_handler(conn); + if #complete err_msg == { + + // OK means you can now run and parse queries returning the 'vector' type from pgvector + case .OK; + return; + + case .QUERY_ERR; + assert(false, "Error querying for the pgvector.vector oid"); + + case .NOT_INSTALLED; + log("Didn't find the pgvector.vector type in your DB. Seems you don't have pgvector installed on your postgres DB", flags=.WARNING); + + case .REGISTER_ERR; + assert(false, "Found the pgvector.vector OID in your postgres DB, but failed to register it with jai-postgres!"); + } +} + UUID :: [16]u8; // UUIDs are just two s64 values, so equality is simply comparing each // of the two s64 values to each other from u1 and u2 -uuid_equal :: (u1: UUID, u2: UUID) -> bool { - a, b := u1.([2]s64,force), u2.([2]s64,force); - return a[0] == b[0] && a[1] == b[1]; +uuid_equal :: inline (u1: Uuid, u2: Uuid) -> bool { + u1_s64 := u1.data.(*s64); + u2_s64 := u2.data.(*s64); + return u1_s64.* == u2_s64.* && (u1_s64 + 1).* == (u2_s64 + 1).*; } uuid_to_string :: (uuid: UUID) -> string { @@ -161,4 +184,6 @@ Test :: struct { } #import,file "../module.jai"; +#import,file "Jaipgvector.jai"; #import "Basic"; + diff --git a/module.jai b/module.jai index 5b31fa0..0dd72d4 100644 --- a/module.jai +++ b/module.jai @@ -456,6 +456,13 @@ assign_member :: (name: string, info: *Type_Info, slot: *u8, col_type: Pq_Type, } case; + + // See if we have a handler that knows how to deal with this type + for pg_custom_type_handlers { + if it.oid != col_type.(Oid) continue; + return it.handler(name, info, slot, col_type, len, data, row, col); + } + if cast(s64) col_type > enum_highest_value(Pq_Type) { // Seems to be a custom type. Try to interpret it as a string val: string; @@ -617,6 +624,7 @@ Reflection :: #import "Reflection"; #load "byte_order.jai"; #load "pq_types.jai"; +#load "custom_type_handling.jai"; #if OS == .WINDOWS { Windows :: #import "Windows";