diff --git a/ctrader/.gitignore b/ctrader/.gitignore index 9dbbb71c..f60d1236 100644 --- a/ctrader/.gitignore +++ b/ctrader/.gitignore @@ -2,3 +2,5 @@ ctrader test_market_sim test_binance_rest +test_policy +libtorch/ diff --git a/ctrader/Makefile b/ctrader/Makefile index 4caaba42..3e087194 100644 --- a/ctrader/Makefile +++ b/ctrader/Makefile @@ -1,31 +1,49 @@ CC = gcc +CXX = g++ CFLAGS = -O3 -march=native -Wall -Wextra -Wpedantic -std=c11 +CXXFLAGS = -O3 -march=native -Wall -Wextra -std=c++17 LDFLAGS = -lm -SRCS = main.c trade_loop.c market_sim.c binance_rest.c policy_infer.c -OBJS = $(SRCS:.c=.o) +TORCH_DIR ?= +LIBTORCH_URL = https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.6.0%2Bcpu.zip + +C_SRCS = main.c trade_loop.c market_sim.c binance_rest.c -# libtorch (optional) ifdef TORCH_DIR - CFLAGS += -I$(TORCH_DIR)/include -I$(TORCH_DIR)/include/torch/csrc/api/include - LDFLAGS += -L$(TORCH_DIR)/lib -ltorch -lc10 -Wl,-rpath,$(TORCH_DIR)/lib + TORCH_CXXFLAGS = -I$(TORCH_DIR)/include -I$(TORCH_DIR)/include/torch/csrc/api/include -D_GLIBCXX_USE_CXX11_ABI=1 + TORCH_LDFLAGS = -L$(TORCH_DIR)/lib -ltorch -ltorch_cpu -lc10 -Wl,-rpath,$(TORCH_DIR)/lib -lstdc++ -lpthread + POLICY_SRC = policy_infer.cpp + POLICY_OBJ = policy_infer_cpp.o + LINK = $(CXX) +else + TORCH_CXXFLAGS = + TORCH_LDFLAGS = + POLICY_SRC = policy_infer.c + POLICY_OBJ = policy_infer.o + LINK = $(CC) endif +C_OBJS = $(C_SRCS:.c=.o) +ALL_OBJS = $(C_OBJS) $(POLICY_OBJ) + # libcurl (optional, for real binance_rest) ifdef USE_CURL LDFLAGS += -lcurl -lssl -lcrypto endif -.PHONY: all clean test test_valgrind +.PHONY: all clean test test_valgrind test_policy download_libtorch all: ctrader -ctrader: $(OBJS) - $(CC) $(CFLAGS) -o $@ $^ $(LDFLAGS) +ctrader: $(ALL_OBJS) + $(LINK) $(CFLAGS) -o $@ $^ $(LDFLAGS) $(TORCH_LDFLAGS) %.o: %.c $(CC) $(CFLAGS) -c -o $@ $< +policy_infer_cpp.o: policy_infer.cpp policy_infer.h + $(CXX) $(CXXFLAGS) $(TORCH_CXXFLAGS) -c -o $@ $< + test: test_market_sim test_binance_rest ./test_market_sim ./test_binance_rest @@ -36,9 +54,25 @@ test_market_sim: tests/test_market_sim.c market_sim.c test_binance_rest: tests/test_binance_rest.c binance_rest.c $(CC) $(CFLAGS) -o $@ $^ -lm +ifdef TORCH_DIR +test_policy: tests/test_policy_infer.cpp policy_infer.cpp policy_infer.h + $(CXX) $(CXXFLAGS) $(TORCH_CXXFLAGS) -o $@ tests/test_policy_infer.cpp policy_infer.cpp $(LDFLAGS) $(TORCH_LDFLAGS) + ./test_policy +else +test_policy: + @echo "TORCH_DIR not set, skipping policy_infer tests" +endif + test_valgrind: test_market_sim test_binance_rest valgrind --leak-check=full --error-exitcode=1 ./test_market_sim valgrind --leak-check=full --error-exitcode=1 ./test_binance_rest +download_libtorch: + @if [ -d "libtorch" ]; then echo "libtorch already downloaded"; \ + else wget -q $(LIBTORCH_URL) -O /tmp/libtorch.zip && \ + unzip -qo /tmp/libtorch.zip -d . && \ + rm -f /tmp/libtorch.zip && \ + echo "libtorch downloaded to ctrader/libtorch/"; fi + clean: - rm -f ctrader test_market_sim test_binance_rest $(OBJS) + rm -f ctrader test_market_sim test_binance_rest test_policy $(ALL_OBJS) policy_infer.o policy_infer_cpp.o diff --git a/ctrader/policy_infer.cpp b/ctrader/policy_infer.cpp new file mode 100644 index 00000000..18d597e8 --- /dev/null +++ b/ctrader/policy_infer.cpp @@ -0,0 +1,92 @@ +extern "C" { +#include "policy_infer.h" +} + +#include +#include +#include + +int policy_load(Policy *policy, const char *model_path) { + memset(policy, 0, sizeof(*policy)); + try { + auto *mod = new torch::jit::Module(torch::jit::load(model_path)); + mod->eval(); + policy->model = static_cast(mod); + policy->loaded = 1; + return 0; + } catch (const c10::Error &e) { + fprintf(stderr, "policy_load(%s): %s\n", model_path, e.what()); + return -1; + } catch (...) { + fprintf(stderr, "policy_load(%s): unknown error\n", model_path); + return -1; + } +} + +void policy_unload(Policy *policy) { + if (policy->model) { + auto *mod = static_cast(policy->model); + delete mod; + policy->model = NULL; + } + policy->loaded = 0; +} + +int policy_forward( + Policy *policy, + const double *obs, + int obs_len, + PolicyAction *out_actions, + int n_symbols +) { + if (!policy->loaded || !policy->model) { + fprintf(stderr, "policy_forward: model not loaded\n"); + return -1; + } + + auto *mod = static_cast(policy->model); + + std::vector fobs(obs_len); + for (int i = 0; i < obs_len; i++) + fobs[i] = static_cast(obs[i]); + + try { + auto input = torch::from_blob(fobs.data(), {1, obs_len}, torch::kFloat32); + std::vector inputs; + inputs.push_back(input); + + auto output = mod->forward(inputs).toTensor(); + output = output.contiguous().to(torch::kFloat64); + int n_out = static_cast(output.numel()); + + const int fields_per_sym = 4; + const double *data = output.data_ptr(); + for (int i = 0; i < n_symbols; i++) { + if (i * fields_per_sym + 3 < n_out) { + out_actions[i].buy_price = data[i * fields_per_sym + 0]; + out_actions[i].sell_price = data[i * fields_per_sym + 1]; + out_actions[i].buy_amount = data[i * fields_per_sym + 2]; + out_actions[i].sell_amount = data[i * fields_per_sym + 3]; + } else { + out_actions[i].buy_price = 0.0; + out_actions[i].sell_price = 0.0; + out_actions[i].buy_amount = 0.0; + out_actions[i].sell_amount = 0.0; + } + } + return 0; + } catch (const c10::Error &e) { + fprintf(stderr, "policy_forward: %s\n", e.what()); + return -1; + } catch (...) { + fprintf(stderr, "policy_forward: unknown error\n"); + return -1; + } +} + +int policy_export_torchscript(const char *checkpoint_path, const char *output_path) { + (void)checkpoint_path; + (void)output_path; + fprintf(stderr, "policy_export_torchscript: use Python torch.jit.trace to export\n"); + return -1; +} diff --git a/ctrader/tests/test_policy_infer.cpp b/ctrader/tests/test_policy_infer.cpp new file mode 100644 index 00000000..f62b0d55 --- /dev/null +++ b/ctrader/tests/test_policy_infer.cpp @@ -0,0 +1,202 @@ +extern "C" { +#include "../policy_infer.h" +} + +#include +#include +#include +#include +#include +#include + +static int g_pass = 0, g_fail = 0; + +#define ASSERT_EQ_INT(a, b, msg) do { \ + int _a = (a), _b = (b); \ + if (_a != _b) { \ + fprintf(stderr, "FAIL %s: %d != %d\n", msg, _a, _b); \ + g_fail++; \ + } else { g_pass++; } \ +} while(0) + +#define ASSERT_NEAR(a, b, tol, msg) do { \ + double _a = (a), _b = (b), _t = (tol); \ + if (fabs(_a - _b) > _t) { \ + fprintf(stderr, "FAIL %s: %.10f != %.10f (tol=%.10f)\n", msg, _a, _b, _t); \ + g_fail++; \ + } else { g_pass++; } \ +} while(0) + +static std::string create_test_model(int in_dim, int out_dim) { + torch::jit::Module mod("TestModel"); + + std::ostringstream src; + src << "def forward(self, x: Tensor) -> Tensor:\n" + << " return torch.mm(x, torch.ones(" << in_dim << ", " << out_dim << ")) + 0.5\n"; + mod.define(src.str()); + + std::string path = "/tmp/test_policy_model.pt"; + mod.save(path); + return path; +} + +static void test_load_unload() { + std::string path = create_test_model(8, 4); + + Policy p; + int rc = policy_load(&p, path.c_str()); + ASSERT_EQ_INT(rc, 0, "load: returns 0"); + ASSERT_EQ_INT(p.loaded, 1, "load: loaded flag"); + + policy_unload(&p); + ASSERT_EQ_INT(p.loaded, 0, "unload: loaded flag"); + if (p.model != NULL) { + fprintf(stderr, "FAIL unload: model not null\n"); + g_fail++; + } else { + g_pass++; + } +} + +static void test_load_bad_path() { + Policy p; + int rc = policy_load(&p, "/tmp/nonexistent_model_xyz.pt"); + ASSERT_EQ_INT(rc, -1, "bad_path: returns -1"); + ASSERT_EQ_INT(p.loaded, 0, "bad_path: not loaded"); +} + +static void test_forward_not_loaded() { + Policy p; + memset(&p, 0, sizeof(p)); + double obs[4] = {1.0, 2.0, 3.0, 4.0}; + PolicyAction actions[1]; + int rc = policy_forward(&p, obs, 4, actions, 1); + ASSERT_EQ_INT(rc, -1, "not_loaded: returns -1"); +} + +static void test_forward_basic() { + int in_dim = 8, out_dim = 4; + std::string path = create_test_model(in_dim, out_dim); + + Policy p; + int rc = policy_load(&p, path.c_str()); + ASSERT_EQ_INT(rc, 0, "forward_basic: load ok"); + + double obs[8] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; + PolicyAction actions[1]; + memset(actions, 0, sizeof(actions)); + + rc = policy_forward(&p, obs, in_dim, actions, 1); + ASSERT_EQ_INT(rc, 0, "forward_basic: returns 0"); + + // model is x @ ones(8,4) + 0.5, sum of obs = 36, so each output = 36.5 + ASSERT_NEAR(actions[0].buy_price, 36.5, 0.01, "forward_basic: buy_price"); + ASSERT_NEAR(actions[0].sell_price, 36.5, 0.01, "forward_basic: sell_price"); + ASSERT_NEAR(actions[0].buy_amount, 36.5, 0.01, "forward_basic: buy_amount"); + ASSERT_NEAR(actions[0].sell_amount, 36.5, 0.01, "forward_basic: sell_amount"); + + policy_unload(&p); +} + +static void test_forward_multi_symbol() { + int in_dim = 6, n_symbols = 3, out_dim = n_symbols * 4; + std::string path = create_test_model(in_dim, out_dim); + + Policy p; + int rc = policy_load(&p, path.c_str()); + ASSERT_EQ_INT(rc, 0, "multi_sym: load ok"); + + double obs[6] = {1, 2, 3, 4, 5, 6}; + PolicyAction actions[3]; + memset(actions, 0, sizeof(actions)); + + rc = policy_forward(&p, obs, in_dim, actions, n_symbols); + ASSERT_EQ_INT(rc, 0, "multi_sym: returns 0"); + + // sum = 21, each output = 21.5 + for (int i = 0; i < n_symbols; i++) { + ASSERT_NEAR(actions[i].buy_price, 21.5, 0.01, "multi_sym: output value"); + } + + policy_unload(&p); +} + +static void test_forward_deterministic() { + int in_dim = 4, out_dim = 4; + std::string path = create_test_model(in_dim, out_dim); + + Policy p; + policy_load(&p, path.c_str()); + + double obs[4] = {1.0, 2.0, 3.0, 4.0}; + PolicyAction a1, a2; + + policy_forward(&p, obs, in_dim, &a1, 1); + policy_forward(&p, obs, in_dim, &a2, 1); + + ASSERT_NEAR(a1.buy_price, a2.buy_price, 1e-6, "deterministic: buy_price"); + ASSERT_NEAR(a1.sell_price, a2.sell_price, 1e-6, "deterministic: sell_price"); + ASSERT_NEAR(a1.buy_amount, a2.buy_amount, 1e-6, "deterministic: buy_amount"); + ASSERT_NEAR(a1.sell_amount, a2.sell_amount, 1e-6, "deterministic: sell_amount"); + + policy_unload(&p); +} + +static void test_load_reload() { + std::string path = create_test_model(4, 4); + + Policy p; + policy_load(&p, path.c_str()); + ASSERT_EQ_INT(p.loaded, 1, "reload: first load"); + + policy_unload(&p); + ASSERT_EQ_INT(p.loaded, 0, "reload: after unload"); + + int rc = policy_load(&p, path.c_str()); + ASSERT_EQ_INT(rc, 0, "reload: second load"); + ASSERT_EQ_INT(p.loaded, 1, "reload: loaded after second load"); + + policy_unload(&p); +} + +static void test_output_fewer_than_requested() { + int in_dim = 4, out_dim = 4; + std::string path = create_test_model(in_dim, out_dim); + + Policy p; + policy_load(&p, path.c_str()); + + double obs[4] = {1.0, 2.0, 3.0, 4.0}; + PolicyAction actions[3]; + memset(actions, 0xff, sizeof(actions)); + + int rc = policy_forward(&p, obs, in_dim, actions, 3); + ASSERT_EQ_INT(rc, 0, "fewer_out: returns 0"); + + ASSERT_NEAR(actions[1].buy_price, 0.0, 1e-10, "fewer_out: sym1 zeroed"); + ASSERT_NEAR(actions[2].buy_price, 0.0, 1e-10, "fewer_out: sym2 zeroed"); + + policy_unload(&p); +} + +static void test_export_returns_error() { + int rc = policy_export_torchscript("foo", "bar"); + ASSERT_EQ_INT(rc, -1, "export: returns -1"); +} + +int main() { + fprintf(stderr, "=== policy_infer tests (libtorch) ===\n"); + + test_load_unload(); + test_load_bad_path(); + test_forward_not_loaded(); + test_forward_basic(); + test_forward_multi_symbol(); + test_forward_deterministic(); + test_load_reload(); + test_output_fewer_than_requested(); + test_export_returns_error(); + + fprintf(stderr, "\n%d passed, %d failed\n", g_pass, g_fail); + return g_fail > 0 ? 1 : 0; +}