Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions custom_type_handling.jai
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
Pg_Type_Handler :: #type (name: string, info: *Type_Info, slot: *u8, col_type: Pq_Type, len: int, data: *u8, row: int, col: int) -> bool;

register_pg_type_handler :: (oid: Oid, name: string, handler: Pg_Type_Handler) -> (success: bool) {

for pg_custom_type_handlers {

if it.oid == oid {
return false;
}
}

array_add(*pg_custom_type_handlers, .{
oid = oid,
name = copy_string(name),
handler = handler,
},
);

return true;
}

remove_pg_type_handler :: (oid: Oid) -> (success: bool) {

for pg_custom_type_handlers {

if it.oid == oid {
free(it.name);
remove it;
return true;
}
}

return false;
}

Pg_Type_Handler_Record :: struct {
oid: Oid;
name: string;
handler: Pg_Type_Handler;
}

pg_custom_type_handlers: [..]Pg_Type_Handler_Record;
170 changes: 170 additions & 0 deletions examples/Jaipgvector.jai
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Based on: https://github.com/overlord-systems/jai-pgvector/tree/bfa620f4e004fe35a93d2ab0ca2bb2c02fddc500
#import "Basic";
#import,file "../module.jai";

Register_Pgvector_Err :: enum {
OK;
QUERY_ERR;
NOT_INSTALLED;
REGISTER_ERR;
}

// This does a synchronous DB query
register_pgvector_handler :: (pg_conn: *PGconn) -> err: Register_Pgvector_Err {

Pgvector_Oids :: struct {
vector_oid: Oid;
}

// @TODO: We currently only handle the vector type. Note that '_vector' means vector[].
//
// SELECT to_regtype('vector')::oid, to_regtype('_vector')::oid, to_regtype('halfvec')::oid, to_regtype('_halfvec')::oid, to_regtype('sparsevec')::oid, to_regtype('_sparsevec')::oid

QUERY :: #string EOF
SELECT to_regtype('vector')::oid AS "vector_oid"
EOF

rows: []Pgvector_Oids;
rows=, success := execute(pg_conn, Pgvector_Oids, QUERY);
if !success return .QUERY_ERR;

if rows.count == 0 return .NOT_INSTALLED;

result := rows[0];
if result.vector_oid == 0 return .NOT_INSTALLED;

success = register_pg_type_handler(result.vector_oid, "vector", pgvector_vector_handler);
if !success return .REGISTER_ERR;

return .OK;
}

/*
By default this will round to 6 digits after the dot, but float32 has a maximum of 7 digits after the dot.

If you want to show 7 dots, you can do:
old_float_formatter_trailing_width := context.print_style.default_format_float.trailing_width;
context.print_style.default_format_float.trailing_width = 7;

floats_to_pgvector_string(.[1, 0.1234567]);

defer context.print_style.default_format_float.trailing_width = old_float_formatter_trailing_width;

Note that the downside to this approach is that the returned string will be:
[1 ,0.1234567]

Without this adjustment, the returned string is:
[1,0.123457] // Note the rounding of ..567->..57
*/
floats_to_pgvector_string :: (floats: []float) -> string {

// We set FormatArray explicitly so we don't break if the default array formatter changes
return sprint("%", FormatArray.{
value = floats,
separator = ",",
begin_string = "[",
end_string = "]",
draw_separator_after_last_element = false,
stop_printing_after_this_many_elements = -1,
});
}

#scope_module

pgvector_vector_handler :: (name: string, info: *Type_Info, slot: *u8, col_type: Pq_Type, len: int, data: *u8, row: int, col: int) -> bool {

if info.type != .ARRAY {
log_error("Error: Trying to write column % of type % into member field \"%\" of type %. Vectors can only be stored in member fields that are arrays of floats or views of floats", col, col_type, name, info.type);
return false;
}

array_info := cast(*Type_Info_Array) info;
if array_info.element_type.type != Type_Info_Tag.FLOAT {
log_error("Error: Trying to write vector column % into member field \"%\" of type %. Member field must be a float array or view", col, name, info.type);
return false;
}

vector_bytes: []u8 = {count=len, data=data};
element_size := array_info.element_type.runtime_size;
dims := pgvector_get_dims_from_vector_bytes(vector_bytes);
if element_size != 4 {
log_error("Error: Trying to write column % of type % with array count % into member field \"%\" with array element size of % instead of 4. Member field must be a float array or view.", col, col_type, dims, name, element_size);
}

if array_info.array_type == .VIEW {

view := (cast(*Array_View_64) slot);

floats := pgvector_vector_bytes_to_floats(vector_bytes);

view.count = floats.count;
view.data = floats.data.(*u8);

} else if array_info.array_count == -1 {

resizable_array := cast(*Resizable_Array) slot;

array_reserve(resizable_array, dims, element_size);
resizable_array.count = dims;

pgvector_vector_bytes_to_floats_buf(vector_bytes, {count=dims, data=resizable_array.data.(*float)});

} else {

if array_info.array_count != dims {
log_error("Error: Trying to write column % of type % with array count % into member field \"%\", which is a fixed-size array with count %. For fixed-size arrays, the size must match!", col, col_type, dims, name, array_info.array_count);
return false;
}

pgvector_vector_bytes_to_floats_buf(vector_bytes, {count=dims, data=slot.(*float)});
}

return true;
}

pgvector_get_dims_from_vector_bytes :: (vector: []u8) -> u16 {
assert(vector.count >= 2);
dims := ntoh(vector.data.(*u16).*);
return dims;
}

pgvector_vector_bytes_to_floats :: (vector: []u8) -> []float {

dims := pgvector_get_dims_from_vector_bytes(vector);

out: [..]float;
array_reserve(*out, dims);
out.count = dims;

pgvector_vector_bytes_to_floats_buf(vector, out);
return out;
}

// This parses the input bytes and writes to the 'out' buffer.
// Note that its required that: out.count == pgvector_get_dims_from_vector_bytes(vector)
pgvector_vector_bytes_to_floats_buf :: (vector: []u8, out: []float) {

// At least as of pgvector 0.8.0, when responding in binary mode, pgvector prefixes
// the float array with two int16, the first being the number of dimensions, and the second unused (always zero).
//
// Therefore, to parse it 8 bytes are the minimum (4 bytes dim+unused, 4 bytes first element).
//
// https://github.com/pgvector/pgvector/blob/fef635c9e5512597621e5669dce845c744170822/src/vector.c#L402
assert(vector.count >= 4);

dims := pgvector_get_dims_from_vector_bytes(vector);
assert(out.count == dims, "'out' view used in pgvector_vector_bytes_to_floats must match vector dimensions");

float_index := 0;
float_ptr := vector.data.(*float) + 1; // +1 to skip the dims+unused at the beginning
while float_index < dims {

defer float_index += 1;

// Incoming bytes are usually big endian, so this (from Jaipq) takes care of swapping bytes if our machine
// is little endian.
val := ntoh((float_ptr+float_index).*);

out[float_index] = val;
}
}
31 changes: 28 additions & 3 deletions examples/example.jai
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,39 @@ main :: () {
assert(result.varchar == "");
assert(result.text == "Something");

test_register_handlers(conn);

log("ALL OK");
}

test_register_handlers :: (conn: *PGconn) {

err_msg := register_pgvector_handler(conn);
if #complete err_msg == {

// OK means you can now run and parse queries returning the 'vector' type from pgvector
case .OK;
return;

case .QUERY_ERR;
assert(false, "Error querying for the pgvector.vector oid");

case .NOT_INSTALLED;
log("Didn't find the pgvector.vector type in your DB. Seems you don't have pgvector installed on your postgres DB", flags=.WARNING);

case .REGISTER_ERR;
assert(false, "Found the pgvector.vector OID in your postgres DB, but failed to register it with jai-postgres!");
}
}

UUID :: [16]u8;

// UUIDs are just two s64 values, so equality is simply comparing each
// of the two s64 values to each other from u1 and u2
uuid_equal :: (u1: UUID, u2: UUID) -> bool {
a, b := u1.([2]s64,force), u2.([2]s64,force);
return a[0] == b[0] && a[1] == b[1];
uuid_equal :: inline (u1: Uuid, u2: Uuid) -> bool {
u1_s64 := u1.data.(*s64);
u2_s64 := u2.data.(*s64);
return u1_s64.* == u2_s64.* && (u1_s64 + 1).* == (u2_s64 + 1).*;
}

uuid_to_string :: (uuid: UUID) -> string {
Expand Down Expand Up @@ -161,4 +184,6 @@ Test :: struct {
}

#import,file "../module.jai";
#import,file "Jaipgvector.jai";
#import "Basic";

8 changes: 8 additions & 0 deletions module.jai
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,13 @@ assign_member :: (name: string, info: *Type_Info, slot: *u8, col_type: Pq_Type,
}

case;

// See if we have a handler that knows how to deal with this type
for pg_custom_type_handlers {
if it.oid != col_type.(Oid) continue;
return it.handler(name, info, slot, col_type, len, data, row, col);
}

if cast(s64) col_type > enum_highest_value(Pq_Type) {
// Seems to be a custom type. Try to interpret it as a string
val: string;
Expand Down Expand Up @@ -617,6 +624,7 @@ Reflection :: #import "Reflection";

#load "byte_order.jai";
#load "pq_types.jai";
#load "custom_type_handling.jai";

#if OS == .WINDOWS {
Windows :: #import "Windows";
Expand Down