-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_refined_target_decoding.py
More file actions
56 lines (45 loc) · 2.09 KB
/
test_refined_target_decoding.py
File metadata and controls
56 lines (45 loc) · 2.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import polars as pl
import numpy as np
import random
# === Load your refined file ===
df = pl.read_csv("datasets/refined/1h/ADAUSDT_1h_refined.csv")
# === Constants ===
VOLUME_K = 0.2
NUM_ROWS = df.height
WINDOW_SIZE = 30
# === Choose a random starting index safely (must be > 0 and leave 10 rows ahead) ===
start_index = random.randint(1, NUM_ROWS - WINDOW_SIZE - 1)
end_index = start_index + WINDOW_SIZE
print(f"\n🔀 Decoding rows {start_index} to {end_index - 1}:")
# === Iterate over the selected range ===
for i in range(start_index, end_index):
row_i = df[i]
row_prev = df[i - 1]
timestamp = row_i["timestamp"].item()
prev_close = row_prev["original_close"].item()
prev_vals = {
"volume": row_prev["original_volume"].item(),
"quote_asset_volume": row_prev["original_quote_asset_volume"].item(),
"number_of_trades": row_prev["original_number_of_trades"].item(),
"taker_buy_base_asset_volume": row_prev["original_taker_buy_base_asset_volume"].item(),
"taker_buy_quote_asset_volume": row_prev["original_taker_buy_quote_asset_volume"].item(),
}
# === Decode target values from row i ===
decoded = {}
# OHLC decoding
for field in ["open", "high", "low", "close"]:
target_val = row_i[f"target_{field}"].item()
decoded[field] = prev_close * np.exp(target_val)
# Volume-related decoding
for field in ["volume", "quote_asset_volume", "number_of_trades",
"taker_buy_base_asset_volume", "taker_buy_quote_asset_volume"]:
target_val = row_i[f"target_{field}"].item()
log_ret = np.arctanh(np.clip(target_val, -0.999999, 0.999999)) / VOLUME_K
decoded[field] = prev_vals[field] * np.exp(log_ret)
# === Print decoded vs original comparison ===
print(f"\n🧪 Decoding check for row {i} @ {timestamp}:")
for field in decoded:
original_val = row_i[f"original_{field}"].item()
decoded_val = decoded[field]
delta = abs(original_val - decoded_val)
print(f"{field:<35} original: {original_val:>15.8f} decoded: {decoded_val:>15.8f} Δ = {delta:.8f}")