diff --git a/.gitignore b/.gitignore index 9f738b7..f01c1b9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,12 @@ *.pickle -*.tar -*.tar.* +*.tar *.tar.* *.zip # Ignore dataset for vector database **/gist/ **/miniset/ **/hotpot15/ +**/msmarco/ # Ignore VS Code settings and extensions .vscode/ @@ -36,4 +36,4 @@ build-* *.jpg *.arrow *inat/* -*google-landmark/* \ No newline at end of file +*google-landmark/* diff --git a/benchmark/run_encoder_example.cpp b/benchmark/run_encoder_example.cpp index b9cfc12..a25c2e9 100644 --- a/benchmark/run_encoder_example.cpp +++ b/benchmark/run_encoder_example.cpp @@ -13,7 +13,7 @@ using namespace derecho::cascade; -#define UDL1_PATH "/rag/emb/encode_search" +#define UDL1_PATH "/rag/emb/encode" #define UDL2_PATH "/rag/emb/centroids_search" #define UDL3_PATH "/rag/emb/clusters_search" #define UDL4_PATH "/rag/generate/agg" @@ -37,56 +37,22 @@ const int ID = 0; ServiceClientAPI& capi = ServiceClientAPI::get_service_client(); int main() { - - std::vector>> queries = { - {0 ,std::make_shared("hello this is query 1")}, - {1 ,std::make_shared("and this is query 2")}, - {2 ,std::make_shared("What's the weather today?")}, - {4 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {5 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {6 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {7 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {8 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {9 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {10 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {11 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {12 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {13 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {14 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {15 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {16 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {17 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {18 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {19 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {20 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {21 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {22 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {23 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {24 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {25 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {26 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {27 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {28 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {29 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - {30 ,std::make_shared("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")}, - }; - - EncoderQueryBatcher batcher(384, queries.size()); - - for(const auto& query : queries) { - batcher.add_query( - query.first, - ID, - query.second - ); + const size_t number_iterations = 128; + const size_t number_queries = 64; + for (size_t i = 0; i < number_iterations; i++) { + + EncoderQueryBatcher batcher(number_queries); + for(size_t j = 0; j < number_queries; j++) { + batcher.add_query(j + number_queries * i, ID, std::make_shared("What is the weather today?")); + } + + batcher.serialize(); + ObjectWithStringKey obj; + obj.key = UDL1_PATH "/" + std::string("batch") + std::to_string(number_iterations); + obj.blob = std::move(*batcher.get_blob()); + capi.trigger_put(obj); } - batcher.serialize(); - ObjectWithStringKey obj; - obj.key = UDL1_PATH "/" + std::string("batch1"); - obj.blob = std::move(*batcher.get_blob()); - capi.trigger_put(obj); - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); return 0; } \ No newline at end of file diff --git a/benchmark/setup/perf_test_setup.py b/benchmark/setup/perf_test_setup.py index 3ddbe27..c4748f7 100644 --- a/benchmark/setup/perf_test_setup.py +++ b/benchmark/setup/perf_test_setup.py @@ -104,6 +104,8 @@ def put_initial_embeddings_docs(capi, basepath, put_docs=True, embed_dim=1024): chunk_idx = break_into_chunks(len(table_dict), NUM_KEY_PER_MAP_OBJ) table_key_list = list(table_dict.keys()) for j, (start_idx, end_idx) in enumerate(chunk_idx): + if j > 2: + break key = f"/rag/doc/emb_doc_map/cluster{cluster_id}/{j}" table_dict_chunk = {k: table_dict[k] for k in table_key_list[start_idx:end_idx]} table_json = json.dumps(table_dict_chunk) @@ -119,6 +121,8 @@ def put_initial_embeddings_docs(capi, basepath, put_docs=True, embed_dim=1024): centroids_chunk_idx = break_into_chunks(centroids_embs.shape[0], NUM_EMB_PER_OBJ) print(f"Initilizing: put {centroids_embs.shape[0]} centroids embeddings to cascade") for i, (start_idx, end_idx) in enumerate(centroids_chunk_idx): + if i > 2: + break key = f"/rag/emb/centroids_obj/{i}" centroids_embs_chunk = centroids_embs[start_idx:end_idx] res = capi.put(key, centroids_embs_chunk.tobytes()) @@ -142,7 +146,12 @@ def put_initial_embeddings_docs(capi, basepath, put_docs=True, embed_dim=1024): cluster_embs = get_embeddings(basepath, cluster_file_name, embed_dim) num_embeddings = cluster_embs.shape[0] cluster_chunk_idx = break_into_chunks(num_embeddings, NUM_EMB_PER_OBJ) + if cluster_id > 2: + break for i, (start_idx, end_idx) in enumerate(cluster_chunk_idx): + if i > 2: + break + key = f"/rag/emb/clusters/cluster{cluster_id}/{i}" cluster_embs_chunk = cluster_embs[start_idx:end_idx] res = capi.put(key, cluster_embs_chunk.tobytes()) diff --git a/cfg/dfgs.json.tmp b/cfg/dfgs.json.tmp index f944064..249db48 100644 --- a/cfg/dfgs.json.tmp +++ b/cfg/dfgs.json.tmp @@ -13,9 +13,12 @@ "python_path":["python_udls"], "module":"encode_udl", "entry_class":"EncodeUDL", + "max_queued_entries": 64, + "max_batch_size": 32, + "batch_time_us": 1000, "encoder_config": { "model": "BAAI/bge-small-en-v1.5", - "device": "cuda" + "device": "cuda:0" } }], "destinations": [{"/rag/emb/centroids_search":"put"}] diff --git a/vortex_udls/python/python_udls/encode_udl.py b/vortex_udls/python/python_udls/encode_udl.py index 1365d61..04511e7 100644 --- a/vortex_udls/python/python_udls/encode_udl.py +++ b/vortex_udls/python/python_udls/encode_udl.py @@ -2,55 +2,182 @@ import time import json import warnings - +import threading import numpy as np +from queue import Queue, Empty from typing import Any from FlagEmbedding import FlagModel import cascade_context #type: ignore from derecho.cascade.udl import UserDefinedLogic from derecho.cascade.member_client import TimestampLogger +from derecho.cascade.member_client import ServiceClientAPI warnings.filterwarnings("ignore") from pyudl_serialize_utils import Batch +# batching queue +# encoding queue +# sending queue + +class Batcher(threading.Thread): + def __init__(self, output: Queue[Batch], max_batch_size: int, capacity: int = 32, batch_time_us: int = 10_000): + """_summary_ + + Args: + max_batch_size (int): _description_ + capacity (int, optional): _description_. Defaults to 32. + batch_time_us (int, optional): _description_. Defaults to 10_000. + """ + super().__init__() + self._max_batch_size = max_batch_size + self._batch_time_s = batch_time_us / 1_000_000 + self._output_queue: Queue[Batch] = output + self._blob_queue: Queue[Batch] = Queue(capacity) + + self._query_ids: list[int] = [] + self._client_ids: list[int] = [] + self._text: list[str] = [] + + def push_batch(self, batch: Batch): + # verified shallow copy + self._blob_queue.put(batch) + + def run(self): + def send(): + self._output_queue.put(Batch(self._text, self._query_ids, self._client_ids)) + + # since it's shallow copy, clear() will also + # clear out the entries in the EncoderBatch + self._query_ids = [] + self._client_ids = [] + self._text = [] + + last_reset = time.time() + while True: + if len(self._query_ids) == 0: + batch = self._blob_queue.get() + last_reset = time.time() + else: + now = time.time() + target_time = last_reset + self._batch_time_s + if target_time <= now: + send() + continue + + try: + batch = self._blob_queue.get(timeout=target_time-now) + except Empty: + send() + continue + + for i in range(batch.size): + self._query_ids.append(batch.query_id_list[i]) + self._client_ids.append(batch.client_id_list[i]) + self._text.append(batch.query_list[i]) + + if len(self._query_ids) == self._max_batch_size: + send() + +class Worker(threading.Thread): + def __init__(self, encode_queue: Queue[Batch], send_queue: Queue[Batch], model_name: str, device: str): + super().__init__() + self._encode_queue = encode_queue + self._send_queue = send_queue + self._model: FlagModel | None = None + self._model_name = model_name + self._device = device + + def run(self): + while True: + batch = self._encode_queue.get() + + if self._model is None: + # load encoder when we need it to prevent overloading + # the hardware during startup + + # NOTE: use fp16 should be set to false because later udls assume f32 + # for deserialization + self._model = FlagModel(self._model_name, device=self._device, use_fp16 = False) + + timeit_start = time.time() + embeddings: np.ndarray = self._model.encode( # type: ignore + batch.query_list, + convert_to_numpy=True + ) + timeit_stop = time.time() + print(f"Prediction time: {(timeit_stop - timeit_start)*1_000_000} us") + + batch.add_embeddings(embeddings) + self._send_queue.put(batch) + +class Sender(threading.Thread): + def __init__(self, send_queue: Queue[Batch]): + super().__init__() + self._send_queue = send_queue + self._batch_id = 0 + self._my_id = id + self._capi = ServiceClientAPI() + self._my_id = self._capi.get_my_id() + + def run(self): + while True: + batch = self._send_queue.get() + timeit_start = time.time() + output_bytes = batch.serialize() + timeit_stop = time.time() + print(f"Serialization time: {(timeit_stop - timeit_start)*1_000_000} us") + # format should be {client}_{batch_id} + key_str = f"/rag/emb/centroids_search/{self._my_id}_{self._batch_id}" + self._capi.put(key_str, output_bytes.tobytes()) + self._batch_id += 1 class EncodeUDL(UserDefinedLogic): def __init__(self, conf_str: str): + # TODO: parse in conf self._conf: dict[str, Any] = json.loads(conf_str) + self._tl = TimestampLogger() - self._encoder = None - self._batch = Batch() - self._batch_id = 0 + self._encode_queue: Queue[Batch] = Queue() + self._send_queue: Queue[Batch] = Queue() + + self._batcher = Batcher( + self._encode_queue, + self._conf["max_batch_size"], + self._conf["max_queued_entries"], + self._conf["batch_time_us"] + ) + self._batcher.start() + + self._worker = Worker( + self._encode_queue, + self._send_queue, + self._conf["encoder_config"]["model"], + device=self._conf["encoder_config"]["device"] + ) + self._worker.start() + + self._sender = Sender(self._send_queue) + self._sender.start() def ocdpo_handler(self, **kwargs): + # move data onto queue for processing # self._tl.log("EncodeUDL: ocdpo_handler") - if self._encoder is None: - # load encoder when we need it to prevent overloading - # the hardware during startup - self._encoder = FlagModel('BAAI/bge-small-en-v1.5', devices="cuda:0") message_id = kwargs["message_id"] + # TODO: this logging only works for batch of 1 self._tl.log(10001, message_id, 0, 0) - data = kwargs["blob"] - - - self._batch.deserialize(data) - - query_embeddings: np.ndarray = self._encoder.encode( - self._batch.query_list - ) - - self._batch_id += 1 - - # format should be {client}_{batch_id} - key_str = kwargs["key"] - output_bytes = self._batch.serialize(query_embeddings) - cascade_context.emit(key_str, output_bytes, message_id=kwargs["message_id"]) - return None - + batch = Batch() + + timeit_start = time.time() + batch.deserialize(kwargs["blob"]) + timeit_stop = time.time() + print(f"Deserialization time: {(timeit_stop - timeit_start)*1_000_000} us") + + self._batcher.push_batch(batch) + def __del__(self): pass diff --git a/vortex_udls/python/python_udls/pyudl_serialize_utils.py b/vortex_udls/python/python_udls/pyudl_serialize_utils.py index f4bedc1..d84e888 100644 --- a/vortex_udls/python/python_udls/pyudl_serialize_utils.py +++ b/vortex_udls/python/python_udls/pyudl_serialize_utils.py @@ -11,61 +11,135 @@ def utf8_length(s: str) -> int: class Batch: - def __init__(self): + def __init__(self, + strings: list[str] | None = None, + query_id: list[int] | None = None, + client_id: list[int] | None = None + ): self._bytes: np.ndarray = np.ndarray(shape=(0, ), dtype=np.uint8) - self._strings: list[str] = [] - self._client_id: int = 0 + self._strings: list[str] = strings if strings else [] + self._query_id: list[int] = query_id if query_id else [] + self._client_id: list[int] = client_id if client_id else [] + self._embeddings: np.ndarray | None = None @property def query_list(self): return self._strings + + @property + def query_id_list(self): + return self._query_id @property - def client_id(self): + def client_id_list(self): return self._client_id + + @property + def size(self): + return len(self._strings) def deserialize(self, data: np.ndarray): self._bytes = data # structured dtype - header_type = np.dtype([ - ('count', np.uint32), - ('embeddings_start', np.uint32) - ]) metadata_type = np.dtype([ ('query_id', np.uint64), ('client_id', np.uint32), ('text_position', np.uint32), ('text_length', np.uint32), - ('embeddings_position', np.uint32), - ('embeddings_dim', np.uint32), ]) header_start = 0 - header_end = header_start + header_type.itemsize - (count, _) = data[header_start:header_end].view(header_type)[0] + header_end = header_start + 4 + count = data[header_start:header_end].view(np.uint32)[0] - metadata_start = 8 + metadata_start = 4 metadata_end = metadata_type.itemsize * count + metadata_start self._strings = [""] * count - - # get one record to grab client id - # saves on bne in loop - metadata_record = data[metadata_start:metadata_start + metadata_type.itemsize].view(metadata_type)[0] - self._client_id = metadata_record[1] + self._query_id = [0] * count + self._client_id = [0] * count for idx, m in enumerate(data[metadata_start:metadata_end].view(metadata_type)): string_start = m[2] string_length = m[3] string = data[string_start:string_start+string_length].tobytes().decode("utf-8") + self._strings[idx] = string + self._query_id[idx] = m[0] + self._client_id[idx] = m[1] + + def add_embeddings(self, embeddings: np.ndarray): + self._embeddings = embeddings + def serialize(self) -> np.ndarray: + assert self._embeddings is not None, "please add embeddings before calling serialize" + assert self._embeddings.dtype == np.float32, "embedding type is not float32" + assert self._embeddings.shape[0] == self.size, "mismatched number of embeddings" + + num_emb, emb_dim = self._embeddings.shape + + header_type = np.dtype([ + ('count', np.uint32), + ('embeddings_position', np.uint32) + ]) + + metadata_type = np.dtype([ + ('query_id', np.uint64), + ('client_id', np.uint32), + ('text_position', np.uint32), + ('text_length', np.uint32), + ('embeddings_position', np.uint32), + ('query_emb_size', np.uint32), + ]) + + encoded_texts = [s.encode("utf-8") for s in self._strings] + total_text_size = sum(len(t) for t in encoded_texts) + total_emb_size = self._embeddings.itemsize * emb_dim * num_emb + emb_bytes = self._embeddings.itemsize * emb_dim + + + header_size = header_type.itemsize + metadata_size = self.size * metadata_type.itemsize + total_size = header_size + metadata_size + total_text_size + total_emb_size + + metadata_position = header_size + text_position = metadata_position + metadata_size + embedding_position = text_position + total_text_size + + # Allocate buffer + + buffer = np.zeros(total_size, dtype=np.uint8) + + # **Step 1: Write header** + np.frombuffer(buffer[:header_size], dtype=header_type)[0] = (self.size, embedding_position) + + # **Step 2: Write responses directly into the buffer while encoding** + metadata_view = np.frombuffer(buffer[metadata_position:metadata_position + metadata_size], dtype=metadata_type) + text_ptr_offset = text_position + embeddings_ptr_offset = embedding_position + + for i in range(self.size): + text_len = len(encoded_texts[i]) + + metadata_view[i] = ( + self._query_id[i], + self._client_id[i], + text_ptr_offset, + text_len, + embeddings_ptr_offset, + emb_bytes + ) + + buffer[text_ptr_offset:text_ptr_offset+text_len] = np.frombuffer(encoded_texts[i], dtype=np.uint8) + + text_ptr_offset += text_len + embeddings_ptr_offset += emb_bytes - def serialize(self, embeddings: np.ndarray) -> np.ndarray: # lesson learned # do not use .astype as that is a cast on each element of the array # use .view, which is simular to C++'s reinterpret_cast - return np.concatenate((self._bytes, embeddings.flatten().view(np.uint8))) + buffer[embedding_position:] = self._embeddings.flatten().view(np.uint8) + return buffer class AggregateResultBatch: diff --git a/vortex_udls/serialize_utils.cpp b/vortex_udls/serialize_utils.cpp index 49ae9a7..4a5f838 100644 --- a/vortex_udls/serialize_utils.cpp +++ b/vortex_udls/serialize_utils.cpp @@ -850,7 +850,7 @@ std::priority_queue, CompareObjKey> filter /// /// -EncoderQueryBatcher::EncoderQueryBatcher(uint32_t emb_dim, uint64_t size_hint): _emb_size(static_cast(emb_dim * sizeof(float))) { +EncoderQueryBatcher::EncoderQueryBatcher(uint64_t size_hint) { _queries.reserve(size_hint); } @@ -897,19 +897,15 @@ void EncoderQueryBatcher::serialize() { const uint32_t num_queries = _queries.size(); const uint32_t metadata_position = EncoderQueryBatcher::HEADER_SIZE; const uint32_t text_position = metadata_position + (num_queries * EncoderQueryBatcher::METADATA_SIZE); - const uint32_t embeddings_position = text_position + _total_text_size; - - const uint32_t header[2] = {num_queries, embeddings_position}; + const uint32_t header = num_queries; static_assert(EncoderQueryBatcher::HEADER_SIZE == sizeof(header)); // write the header - std::memcpy(buffer, header, EncoderQueryBatcher::HEADER_SIZE); + std::memcpy(buffer, &header, EncoderQueryBatcher::HEADER_SIZE); // write each query to the buffer uint32_t metadata_ptr_offset = metadata_position; uint32_t text_ptr_offset = text_position; - uint32_t embedding_ptr_offset = embeddings_position; - for(const auto& query : _queries) { const query_id_t& query_id = std::get<0>(query); @@ -918,8 +914,7 @@ void EncoderQueryBatcher::serialize() { const uint32_t& text_len = _text_size_mapping[query_id]; // write metadata: query_id, {client_id, query_text_position, query_text_size, embeddings_position, query_emb_size} - std::cout << client_id << " " << text_ptr_offset << " " << text_len << " " << embedding_ptr_offset << " " << _emb_size << std::endl; - uint32_t metadata_array[5] = {client_id, text_ptr_offset, text_len, embedding_ptr_offset, _emb_size}; + uint32_t metadata_array[3] = {client_id, text_ptr_offset, text_len}; static_assert(EncoderQueryBatcher::METADATA_SIZE == sizeof(metadata_array) + sizeof(query_id_t)); std::memcpy(buffer + metadata_ptr_offset, &query_id, sizeof(query_id_t)); @@ -931,7 +926,6 @@ void EncoderQueryBatcher::serialize() { // update offsets metadata_ptr_offset += EncoderQueryBatcher::METADATA_SIZE; text_ptr_offset += text_len; - embedding_ptr_offset += _emb_size; } return size; }, _total_obj_size); diff --git a/vortex_udls/serialize_utils.hpp b/vortex_udls/serialize_utils.hpp index 9970b75..0e0fec6 100644 --- a/vortex_udls/serialize_utils.hpp +++ b/vortex_udls/serialize_utils.hpp @@ -108,15 +108,14 @@ using encoder_query_t = std::tuple _queries; private: - // number of queries, embeddings position (EMPTY SLOT) - static constexpr uint32_t HEADER_SIZE = sizeof(uint32_t) * 2; + // number of queries + static constexpr uint32_t HEADER_SIZE = sizeof(uint32_t); - // query_id, client_id, query_text_position, query_text_size, embeddings_position, query_emb_size - static constexpr uint32_t METADATA_SIZE = sizeof(uint32_t) * 5 + sizeof(query_id_t); + // query_id, client_id, query_text_position, query_text_size + static constexpr uint32_t METADATA_SIZE = sizeof(uint32_t) * 3 + sizeof(query_id_t); // maps from query_id to number of bytes used to encode query string std::unordered_map _text_size_mapping; @@ -130,7 +129,7 @@ class EncoderQueryBatcher { public: // emb_dim: specify the desired number of dims for the encoded text // size_hint: used to reserve space in the underlying vector for performance optimization - EncoderQueryBatcher(uint32_t emb_dim, uint64_t size_hint = 1000); + EncoderQueryBatcher(uint64_t size_hint = 1000); void add_query(const encoder_query_t &query); void add_query(query_id_t query_id, uint32_t node_id, std::shared_ptr query_text);