Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ctrader/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
ctrader
test_market_sim
test_binance_rest
test_policy
libtorch/
52 changes: 43 additions & 9 deletions ctrader/Makefile
Original file line number Diff line number Diff line change
@@ -1,31 +1,49 @@
CC = gcc
CXX = g++
CFLAGS = -O3 -march=native -Wall -Wextra -Wpedantic -std=c11
CXXFLAGS = -O3 -march=native -Wall -Wextra -std=c++17
LDFLAGS = -lm

SRCS = main.c trade_loop.c market_sim.c binance_rest.c policy_infer.c
OBJS = $(SRCS:.c=.o)
TORCH_DIR ?=
LIBTORCH_URL = https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.6.0%2Bcpu.zip

C_SRCS = main.c trade_loop.c market_sim.c binance_rest.c

# libtorch (optional)
ifdef TORCH_DIR
CFLAGS += -I$(TORCH_DIR)/include -I$(TORCH_DIR)/include/torch/csrc/api/include
LDFLAGS += -L$(TORCH_DIR)/lib -ltorch -lc10 -Wl,-rpath,$(TORCH_DIR)/lib
TORCH_CXXFLAGS = -I$(TORCH_DIR)/include -I$(TORCH_DIR)/include/torch/csrc/api/include -D_GLIBCXX_USE_CXX11_ABI=1
TORCH_LDFLAGS = -L$(TORCH_DIR)/lib -ltorch -ltorch_cpu -lc10 -Wl,-rpath,$(TORCH_DIR)/lib -lstdc++ -lpthread
POLICY_SRC = policy_infer.cpp
POLICY_OBJ = policy_infer_cpp.o
LINK = $(CXX)
else
TORCH_CXXFLAGS =
TORCH_LDFLAGS =
POLICY_SRC = policy_infer.c
POLICY_OBJ = policy_infer.o
LINK = $(CC)
endif

C_OBJS = $(C_SRCS:.c=.o)
ALL_OBJS = $(C_OBJS) $(POLICY_OBJ)

# libcurl (optional, for real binance_rest)
ifdef USE_CURL
LDFLAGS += -lcurl -lssl -lcrypto
endif

.PHONY: all clean test test_valgrind
.PHONY: all clean test test_valgrind test_policy download_libtorch

all: ctrader

ctrader: $(OBJS)
$(CC) $(CFLAGS) -o $@ $^ $(LDFLAGS)
ctrader: $(ALL_OBJS)
$(LINK) $(CFLAGS) -o $@ $^ $(LDFLAGS) $(TORCH_LDFLAGS)

%.o: %.c
$(CC) $(CFLAGS) -c -o $@ $<

policy_infer_cpp.o: policy_infer.cpp policy_infer.h
$(CXX) $(CXXFLAGS) $(TORCH_CXXFLAGS) -c -o $@ $<

test: test_market_sim test_binance_rest
./test_market_sim
./test_binance_rest
Expand All @@ -36,9 +54,25 @@ test_market_sim: tests/test_market_sim.c market_sim.c
test_binance_rest: tests/test_binance_rest.c binance_rest.c
$(CC) $(CFLAGS) -o $@ $^ -lm

ifdef TORCH_DIR
test_policy: tests/test_policy_infer.cpp policy_infer.cpp policy_infer.h
$(CXX) $(CXXFLAGS) $(TORCH_CXXFLAGS) -o $@ tests/test_policy_infer.cpp policy_infer.cpp $(LDFLAGS) $(TORCH_LDFLAGS)
./test_policy
else
test_policy:
@echo "TORCH_DIR not set, skipping policy_infer tests"
endif

test_valgrind: test_market_sim test_binance_rest
valgrind --leak-check=full --error-exitcode=1 ./test_market_sim
valgrind --leak-check=full --error-exitcode=1 ./test_binance_rest

download_libtorch:
@if [ -d "libtorch" ]; then echo "libtorch already downloaded"; \
else wget -q $(LIBTORCH_URL) -O /tmp/libtorch.zip && \
unzip -qo /tmp/libtorch.zip -d . && \
rm -f /tmp/libtorch.zip && \
echo "libtorch downloaded to ctrader/libtorch/"; fi

clean:
rm -f ctrader test_market_sim test_binance_rest $(OBJS)
rm -f ctrader test_market_sim test_binance_rest test_policy $(ALL_OBJS) policy_infer.o policy_infer_cpp.o
92 changes: 92 additions & 0 deletions ctrader/policy_infer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
extern "C" {
#include "policy_infer.h"
}

#include <torch/script.h>
#include <cstring>
#include <cstdio>

int policy_load(Policy *policy, const char *model_path) {
memset(policy, 0, sizeof(*policy));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Unload the previous module before clearing Policy

If policy_load() is called on an already-loaded Policy (for example during an in-process model refresh), the memset here wipes out policy->model before policy_unload() can free it. That leaks the original torch::jit::Module, and if the second load fails we also lose the only handle to the previously working model.

Useful? React with 👍 / 👎.

try {
auto *mod = new torch::jit::Module(torch::jit::load(model_path));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Load TorchScript modules onto CPU explicitly

Checked ctrader/Makefile: download_libtorch fetches the CPU-only libtorch build, but policy_load() uses bare torch::jit::load(model_path) here. When a checkpoint was exported from CUDA (which is a common training path in this repo), TorchScript will try to restore tensors back to that saved device, so loading the model on the CPU trading bot fails before inference ever starts. Pass an explicit CPU device/map_location during load so GPU-trained artifacts remain deployable.

Useful? React with 👍 / 👎.

mod->eval();
policy->model = static_cast<void *>(mod);
policy->loaded = 1;
return 0;
} catch (const c10::Error &e) {
fprintf(stderr, "policy_load(%s): %s\n", model_path, e.what());
return -1;
} catch (...) {
fprintf(stderr, "policy_load(%s): unknown error\n", model_path);
return -1;
}
}

void policy_unload(Policy *policy) {
if (policy->model) {
auto *mod = static_cast<torch::jit::Module *>(policy->model);
delete mod;
policy->model = NULL;
}
policy->loaded = 0;
}

int policy_forward(
Policy *policy,
const double *obs,
int obs_len,
PolicyAction *out_actions,
int n_symbols
) {
if (!policy->loaded || !policy->model) {
fprintf(stderr, "policy_forward: model not loaded\n");
return -1;
}

auto *mod = static_cast<torch::jit::Module *>(policy->model);

std::vector<float> fobs(obs_len);
for (int i = 0; i < obs_len; i++)
fobs[i] = static_cast<float>(obs[i]);

try {
auto input = torch::from_blob(fobs.data(), {1, obs_len}, torch::kFloat32);
std::vector<torch::jit::IValue> inputs;
inputs.push_back(input);

auto output = mod->forward(inputs).toTensor();
output = output.contiguous().to(torch::kFloat64);
int n_out = static_cast<int>(output.numel());

const int fields_per_sym = 4;
const double *data = output.data_ptr<double>();
for (int i = 0; i < n_symbols; i++) {
if (i * fields_per_sym + 3 < n_out) {
out_actions[i].buy_price = data[i * fields_per_sym + 0];
out_actions[i].sell_price = data[i * fields_per_sym + 1];
out_actions[i].buy_amount = data[i * fields_per_sym + 2];
out_actions[i].sell_amount = data[i * fields_per_sym + 3];
} else {
out_actions[i].buy_price = 0.0;
out_actions[i].sell_price = 0.0;
out_actions[i].buy_amount = 0.0;
out_actions[i].sell_amount = 0.0;
}
}
return 0;
} catch (const c10::Error &e) {
fprintf(stderr, "policy_forward: %s\n", e.what());
return -1;
} catch (...) {
fprintf(stderr, "policy_forward: unknown error\n");
return -1;
}
}

int policy_export_torchscript(const char *checkpoint_path, const char *output_path) {
(void)checkpoint_path;
(void)output_path;
fprintf(stderr, "policy_export_torchscript: use Python torch.jit.trace to export\n");
return -1;
}
202 changes: 202 additions & 0 deletions ctrader/tests/test_policy_infer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
extern "C" {
#include "../policy_infer.h"
}

#include <torch/script.h>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <string>
#include <sstream>

static int g_pass = 0, g_fail = 0;

#define ASSERT_EQ_INT(a, b, msg) do { \
int _a = (a), _b = (b); \
if (_a != _b) { \
fprintf(stderr, "FAIL %s: %d != %d\n", msg, _a, _b); \
g_fail++; \
} else { g_pass++; } \
} while(0)

#define ASSERT_NEAR(a, b, tol, msg) do { \
double _a = (a), _b = (b), _t = (tol); \
if (fabs(_a - _b) > _t) { \
fprintf(stderr, "FAIL %s: %.10f != %.10f (tol=%.10f)\n", msg, _a, _b, _t); \
g_fail++; \
} else { g_pass++; } \
} while(0)

static std::string create_test_model(int in_dim, int out_dim) {
torch::jit::Module mod("TestModel");

std::ostringstream src;
src << "def forward(self, x: Tensor) -> Tensor:\n"
<< " return torch.mm(x, torch.ones(" << in_dim << ", " << out_dim << ")) + 0.5\n";
mod.define(src.str());

std::string path = "/tmp/test_policy_model.pt";
mod.save(path);
return path;
}

static void test_load_unload() {
std::string path = create_test_model(8, 4);

Policy p;
int rc = policy_load(&p, path.c_str());
ASSERT_EQ_INT(rc, 0, "load: returns 0");
ASSERT_EQ_INT(p.loaded, 1, "load: loaded flag");

policy_unload(&p);
ASSERT_EQ_INT(p.loaded, 0, "unload: loaded flag");
if (p.model != NULL) {
fprintf(stderr, "FAIL unload: model not null\n");
g_fail++;
} else {
g_pass++;
}
}

static void test_load_bad_path() {
Policy p;
int rc = policy_load(&p, "/tmp/nonexistent_model_xyz.pt");
ASSERT_EQ_INT(rc, -1, "bad_path: returns -1");
ASSERT_EQ_INT(p.loaded, 0, "bad_path: not loaded");
}

static void test_forward_not_loaded() {
Policy p;
memset(&p, 0, sizeof(p));
double obs[4] = {1.0, 2.0, 3.0, 4.0};
PolicyAction actions[1];
int rc = policy_forward(&p, obs, 4, actions, 1);
ASSERT_EQ_INT(rc, -1, "not_loaded: returns -1");
}

static void test_forward_basic() {
int in_dim = 8, out_dim = 4;
std::string path = create_test_model(in_dim, out_dim);

Policy p;
int rc = policy_load(&p, path.c_str());
ASSERT_EQ_INT(rc, 0, "forward_basic: load ok");

double obs[8] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
PolicyAction actions[1];
memset(actions, 0, sizeof(actions));

rc = policy_forward(&p, obs, in_dim, actions, 1);
ASSERT_EQ_INT(rc, 0, "forward_basic: returns 0");

// model is x @ ones(8,4) + 0.5, sum of obs = 36, so each output = 36.5
ASSERT_NEAR(actions[0].buy_price, 36.5, 0.01, "forward_basic: buy_price");
ASSERT_NEAR(actions[0].sell_price, 36.5, 0.01, "forward_basic: sell_price");
ASSERT_NEAR(actions[0].buy_amount, 36.5, 0.01, "forward_basic: buy_amount");
ASSERT_NEAR(actions[0].sell_amount, 36.5, 0.01, "forward_basic: sell_amount");

policy_unload(&p);
}

static void test_forward_multi_symbol() {
int in_dim = 6, n_symbols = 3, out_dim = n_symbols * 4;
std::string path = create_test_model(in_dim, out_dim);

Policy p;
int rc = policy_load(&p, path.c_str());
ASSERT_EQ_INT(rc, 0, "multi_sym: load ok");

double obs[6] = {1, 2, 3, 4, 5, 6};
PolicyAction actions[3];
memset(actions, 0, sizeof(actions));

rc = policy_forward(&p, obs, in_dim, actions, n_symbols);
ASSERT_EQ_INT(rc, 0, "multi_sym: returns 0");

// sum = 21, each output = 21.5
for (int i = 0; i < n_symbols; i++) {
ASSERT_NEAR(actions[i].buy_price, 21.5, 0.01, "multi_sym: output value");
}

policy_unload(&p);
}

static void test_forward_deterministic() {
int in_dim = 4, out_dim = 4;
std::string path = create_test_model(in_dim, out_dim);

Policy p;
policy_load(&p, path.c_str());

double obs[4] = {1.0, 2.0, 3.0, 4.0};
PolicyAction a1, a2;

policy_forward(&p, obs, in_dim, &a1, 1);
policy_forward(&p, obs, in_dim, &a2, 1);

ASSERT_NEAR(a1.buy_price, a2.buy_price, 1e-6, "deterministic: buy_price");
ASSERT_NEAR(a1.sell_price, a2.sell_price, 1e-6, "deterministic: sell_price");
ASSERT_NEAR(a1.buy_amount, a2.buy_amount, 1e-6, "deterministic: buy_amount");
ASSERT_NEAR(a1.sell_amount, a2.sell_amount, 1e-6, "deterministic: sell_amount");

policy_unload(&p);
}

static void test_load_reload() {
std::string path = create_test_model(4, 4);

Policy p;
policy_load(&p, path.c_str());
ASSERT_EQ_INT(p.loaded, 1, "reload: first load");

policy_unload(&p);
ASSERT_EQ_INT(p.loaded, 0, "reload: after unload");

int rc = policy_load(&p, path.c_str());
ASSERT_EQ_INT(rc, 0, "reload: second load");
ASSERT_EQ_INT(p.loaded, 1, "reload: loaded after second load");

policy_unload(&p);
}

static void test_output_fewer_than_requested() {
int in_dim = 4, out_dim = 4;
std::string path = create_test_model(in_dim, out_dim);

Policy p;
policy_load(&p, path.c_str());

double obs[4] = {1.0, 2.0, 3.0, 4.0};
PolicyAction actions[3];
memset(actions, 0xff, sizeof(actions));

int rc = policy_forward(&p, obs, in_dim, actions, 3);
ASSERT_EQ_INT(rc, 0, "fewer_out: returns 0");

ASSERT_NEAR(actions[1].buy_price, 0.0, 1e-10, "fewer_out: sym1 zeroed");
ASSERT_NEAR(actions[2].buy_price, 0.0, 1e-10, "fewer_out: sym2 zeroed");

policy_unload(&p);
}

static void test_export_returns_error() {
int rc = policy_export_torchscript("foo", "bar");
ASSERT_EQ_INT(rc, -1, "export: returns -1");
}

int main() {
fprintf(stderr, "=== policy_infer tests (libtorch) ===\n");

test_load_unload();
test_load_bad_path();
test_forward_not_loaded();
test_forward_basic();
test_forward_multi_symbol();
test_forward_deterministic();
test_load_reload();
test_output_fewer_than_requested();
test_export_returns_error();

fprintf(stderr, "\n%d passed, %d failed\n", g_pass, g_fail);
return g_fail > 0 ? 1 : 0;
}
Loading