Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
*.pickle
*.tar
*.tar.*
*.tar *.tar.*
*.zip

# Ignore dataset for vector database
**/gist/
**/miniset/
**/hotpot15/
**/msmarco/

# Ignore VS Code settings and extensions
.vscode/
Expand Down Expand Up @@ -36,4 +36,4 @@ build-*
*.jpg
*.arrow
*inat/*
*google-landmark/*
*google-landmark/*
64 changes: 15 additions & 49 deletions benchmark/run_encoder_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

using namespace derecho::cascade;

#define UDL1_PATH "/rag/emb/encode_search"
#define UDL1_PATH "/rag/emb/encode"
#define UDL2_PATH "/rag/emb/centroids_search"
#define UDL3_PATH "/rag/emb/clusters_search"
#define UDL4_PATH "/rag/generate/agg"
Expand All @@ -37,56 +37,22 @@ const int ID = 0;
ServiceClientAPI& capi = ServiceClientAPI::get_service_client();

int main() {

std::vector<std::pair<query_id_t, std::shared_ptr<std::string>>> queries = {
{0 ,std::make_shared<std::string>("hello this is query 1")},
{1 ,std::make_shared<std::string>("and this is query 2")},
{2 ,std::make_shared<std::string>("What's the weather today?")},
{4 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{5 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{6 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{7 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{8 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{9 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{10 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{11 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{12 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{13 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{14 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{15 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{16 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{17 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{18 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{19 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{20 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{21 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{22 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{23 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{24 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{25 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{26 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{27 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{28 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{29 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
{30 ,std::make_shared<std::string>("SUPER LONG MESSAGE kjsdjflkasfj;asdkljf;klasdfj;aosdklfj;asdlkfjl;asdkfjl;asdfjk;asldkfj;asldkfjasl;dfjkl;asdkfjl;asdkfjl;askdfjl;aksdjfl;adksjfl;asdkjf;alskdfj;lakdjf;")},
};

EncoderQueryBatcher batcher(384, queries.size());

for(const auto& query : queries) {
batcher.add_query(
query.first,
ID,
query.second
);
const size_t number_iterations = 128;
const size_t number_queries = 64;
for (size_t i = 0; i < number_iterations; i++) {

EncoderQueryBatcher batcher(number_queries);
for(size_t j = 0; j < number_queries; j++) {
batcher.add_query(j + number_queries * i, ID, std::make_shared<std::string>("What is the weather today?"));
}

batcher.serialize();
ObjectWithStringKey obj;
obj.key = UDL1_PATH "/" + std::string("batch") + std::to_string(number_iterations);
obj.blob = std::move(*batcher.get_blob());
capi.trigger_put(obj);
}

batcher.serialize();
ObjectWithStringKey obj;
obj.key = UDL1_PATH "/" + std::string("batch1");
obj.blob = std::move(*batcher.get_blob());
capi.trigger_put(obj);

std::this_thread::sleep_for(std::chrono::milliseconds(1000));
return 0;
}
9 changes: 9 additions & 0 deletions benchmark/setup/perf_test_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def put_initial_embeddings_docs(capi, basepath, put_docs=True, embed_dim=1024):
chunk_idx = break_into_chunks(len(table_dict), NUM_KEY_PER_MAP_OBJ)
table_key_list = list(table_dict.keys())
for j, (start_idx, end_idx) in enumerate(chunk_idx):
if j > 2:
break
key = f"/rag/doc/emb_doc_map/cluster{cluster_id}/{j}"
table_dict_chunk = {k: table_dict[k] for k in table_key_list[start_idx:end_idx]}
table_json = json.dumps(table_dict_chunk)
Expand All @@ -119,6 +121,8 @@ def put_initial_embeddings_docs(capi, basepath, put_docs=True, embed_dim=1024):
centroids_chunk_idx = break_into_chunks(centroids_embs.shape[0], NUM_EMB_PER_OBJ)
print(f"Initilizing: put {centroids_embs.shape[0]} centroids embeddings to cascade")
for i, (start_idx, end_idx) in enumerate(centroids_chunk_idx):
if i > 2:
break
key = f"/rag/emb/centroids_obj/{i}"
centroids_embs_chunk = centroids_embs[start_idx:end_idx]
res = capi.put(key, centroids_embs_chunk.tobytes())
Expand All @@ -142,7 +146,12 @@ def put_initial_embeddings_docs(capi, basepath, put_docs=True, embed_dim=1024):
cluster_embs = get_embeddings(basepath, cluster_file_name, embed_dim)
num_embeddings = cluster_embs.shape[0]
cluster_chunk_idx = break_into_chunks(num_embeddings, NUM_EMB_PER_OBJ)
if cluster_id > 2:
break
for i, (start_idx, end_idx) in enumerate(cluster_chunk_idx):
if i > 2:
break

key = f"/rag/emb/clusters/cluster{cluster_id}/{i}"
cluster_embs_chunk = cluster_embs[start_idx:end_idx]
res = capi.put(key, cluster_embs_chunk.tobytes())
Expand Down
5 changes: 4 additions & 1 deletion cfg/dfgs.json.tmp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
"python_path":["python_udls"],
"module":"encode_udl",
"entry_class":"EncodeUDL",
"max_queued_entries": 64,
"max_batch_size": 32,
"batch_time_us": 1000,
"encoder_config": {
"model": "BAAI/bge-small-en-v1.5",
"device": "cuda"
"device": "cuda:0"
}
}],
"destinations": [{"/rag/emb/centroids_search":"put"}]
Expand Down
177 changes: 152 additions & 25 deletions vortex_udls/python/python_udls/encode_udl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,182 @@
import time
import json
import warnings

import threading
import numpy as np

from queue import Queue, Empty
from typing import Any
from FlagEmbedding import FlagModel

import cascade_context #type: ignore
from derecho.cascade.udl import UserDefinedLogic
from derecho.cascade.member_client import TimestampLogger
from derecho.cascade.member_client import ServiceClientAPI

warnings.filterwarnings("ignore")
from pyudl_serialize_utils import Batch

# batching queue
# encoding queue
# sending queue

class Batcher(threading.Thread):
def __init__(self, output: Queue[Batch], max_batch_size: int, capacity: int = 32, batch_time_us: int = 10_000):
"""_summary_

Args:
max_batch_size (int): _description_
capacity (int, optional): _description_. Defaults to 32.
batch_time_us (int, optional): _description_. Defaults to 10_000.
"""
super().__init__()
self._max_batch_size = max_batch_size
self._batch_time_s = batch_time_us / 1_000_000
self._output_queue: Queue[Batch] = output
self._blob_queue: Queue[Batch] = Queue(capacity)

self._query_ids: list[int] = []
self._client_ids: list[int] = []
self._text: list[str] = []

def push_batch(self, batch: Batch):
# verified shallow copy
self._blob_queue.put(batch)

def run(self):
def send():
self._output_queue.put(Batch(self._text, self._query_ids, self._client_ids))

# since it's shallow copy, clear() will also
# clear out the entries in the EncoderBatch
self._query_ids = []
self._client_ids = []
self._text = []

last_reset = time.time()
while True:
if len(self._query_ids) == 0:
batch = self._blob_queue.get()
last_reset = time.time()
else:
now = time.time()
target_time = last_reset + self._batch_time_s
if target_time <= now:
send()
continue

try:
batch = self._blob_queue.get(timeout=target_time-now)
except Empty:
send()
continue

for i in range(batch.size):
self._query_ids.append(batch.query_id_list[i])
self._client_ids.append(batch.client_id_list[i])
self._text.append(batch.query_list[i])

if len(self._query_ids) == self._max_batch_size:
send()

class Worker(threading.Thread):
def __init__(self, encode_queue: Queue[Batch], send_queue: Queue[Batch], model_name: str, device: str):
super().__init__()
self._encode_queue = encode_queue
self._send_queue = send_queue
self._model: FlagModel | None = None
self._model_name = model_name
self._device = device

def run(self):
while True:
batch = self._encode_queue.get()

if self._model is None:
# load encoder when we need it to prevent overloading
# the hardware during startup

# NOTE: use fp16 should be set to false because later udls assume f32
# for deserialization
self._model = FlagModel(self._model_name, device=self._device, use_fp16 = False)

timeit_start = time.time()
embeddings: np.ndarray = self._model.encode( # type: ignore
batch.query_list,
convert_to_numpy=True
)
timeit_stop = time.time()
print(f"Prediction time: {(timeit_stop - timeit_start)*1_000_000} us")

batch.add_embeddings(embeddings)
self._send_queue.put(batch)

class Sender(threading.Thread):
def __init__(self, send_queue: Queue[Batch]):
super().__init__()
self._send_queue = send_queue
self._batch_id = 0
self._my_id = id
self._capi = ServiceClientAPI()
self._my_id = self._capi.get_my_id()

def run(self):
while True:
batch = self._send_queue.get()
timeit_start = time.time()
output_bytes = batch.serialize()
timeit_stop = time.time()
print(f"Serialization time: {(timeit_stop - timeit_start)*1_000_000} us")

# format should be {client}_{batch_id}
key_str = f"/rag/emb/centroids_search/{self._my_id}_{self._batch_id}"
self._capi.put(key_str, output_bytes.tobytes())
self._batch_id += 1

class EncodeUDL(UserDefinedLogic):
def __init__(self, conf_str: str):
# TODO: parse in conf
self._conf: dict[str, Any] = json.loads(conf_str)

self._tl = TimestampLogger()
self._encoder = None
self._batch = Batch()
self._batch_id = 0
self._encode_queue: Queue[Batch] = Queue()
self._send_queue: Queue[Batch] = Queue()

self._batcher = Batcher(
self._encode_queue,
self._conf["max_batch_size"],
self._conf["max_queued_entries"],
self._conf["batch_time_us"]
)
self._batcher.start()

self._worker = Worker(
self._encode_queue,
self._send_queue,
self._conf["encoder_config"]["model"],
device=self._conf["encoder_config"]["device"]
)
self._worker.start()

self._sender = Sender(self._send_queue)
self._sender.start()

def ocdpo_handler(self, **kwargs):
# move data onto queue for processing
# self._tl.log("EncodeUDL: ocdpo_handler")
if self._encoder is None:
# load encoder when we need it to prevent overloading
# the hardware during startup
self._encoder = FlagModel('BAAI/bge-small-en-v1.5', devices="cuda:0")
message_id = kwargs["message_id"]

# TODO: this logging only works for batch of 1
self._tl.log(10001, message_id, 0, 0)
data = kwargs["blob"]


self._batch.deserialize(data)

query_embeddings: np.ndarray = self._encoder.encode(
self._batch.query_list
)

self._batch_id += 1


# format should be {client}_{batch_id}
key_str = kwargs["key"]
output_bytes = self._batch.serialize(query_embeddings)
cascade_context.emit(key_str, output_bytes, message_id=kwargs["message_id"])
return None

batch = Batch()

timeit_start = time.time()
batch.deserialize(kwargs["blob"])
timeit_stop = time.time()
print(f"Deserialization time: {(timeit_stop - timeit_start)*1_000_000} us")

self._batcher.push_batch(batch)

def __del__(self):
pass
Loading