Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,22 +347,23 @@ def optimize_model(self):
# Unzip (6-element tuples: state, action, reward, next_state, done, gamma)
batch_state, batch_action, batch_reward, batch_next, batch_done, batch_gamma = zip(*transitions)

action_batch = torch.tensor(batch_action, dtype=torch.long).unsqueeze(1).to(self.device)
reward_batch = torch.tensor(batch_reward, dtype=torch.float32).to(self.device)
done_batch = torch.tensor(batch_done, dtype=torch.float32).to(self.device)
weights_batch = torch.tensor(is_weights, dtype=torch.float32).to(self.device)
gamma_batch = torch.tensor(batch_gamma, dtype=torch.float32).to(self.device)
action_batch = torch.as_tensor(batch_action, dtype=torch.long, device=self.device).unsqueeze(1)
reward_batch = torch.as_tensor(batch_reward, dtype=torch.float32, device=self.device)
done_batch = torch.as_tensor(batch_done, dtype=torch.float32, device=self.device)
weights_batch = torch.as_tensor(is_weights, dtype=torch.float32, device=self.device)
gamma_batch = torch.as_tensor(batch_gamma, dtype=torch.float32, device=self.device)

# Reward scaling (scale=1.0 preserves signal, clamp asymmetric: deaths must hurt)
reward_scale = max(self.config.opt.reward_scale, 1.0)
norm_rewards = torch.clamp(reward_batch / reward_scale, -50.0, 30.0)

if self.use_hybrid:
# Unpack tuples: (matrix_u8, sectors_f32)
s_matrices = torch.tensor(np.array([s[0] for s in batch_state]), dtype=torch.float32).to(self.device) / 255.0
s_sectors = torch.tensor(np.array([s[1] for s in batch_state]), dtype=torch.float32).to(self.device)
n_matrices = torch.tensor(np.array([s[0] for s in batch_next]), dtype=torch.float32).to(self.device) / 255.0
n_sectors = torch.tensor(np.array([s[1] for s in batch_next]), dtype=torch.float32).to(self.device)
# Optimize: create tensors directly on device
s_matrices = torch.as_tensor(np.array([s[0] for s in batch_state]), dtype=torch.float32, device=self.device) / 255.0
s_sectors = torch.as_tensor(np.array([s[1] for s in batch_state]), dtype=torch.float32, device=self.device)
n_matrices = torch.as_tensor(np.array([s[0] for s in batch_next]), dtype=torch.float32, device=self.device) / 255.0
n_sectors = torch.as_tensor(np.array([s[1] for s in batch_next]), dtype=torch.float32, device=self.device)

q_values = self.policy_net(s_matrices, s_sectors).gather(1, action_batch)

Expand All @@ -374,8 +375,8 @@ def optimize_model(self):
expected_q_values = (next_q_values * gamma_n * (1 - done_batch)) + norm_rewards
else:
# Legacy: plain uint8 arrays
state_batch = torch.tensor(np.array(batch_state), dtype=torch.float32).to(self.device) / 255.0
next_batch = torch.tensor(np.array(batch_next), dtype=torch.float32).to(self.device) / 255.0
state_batch = torch.as_tensor(np.array(batch_state), dtype=torch.float32, device=self.device) / 255.0
next_batch = torch.as_tensor(np.array(batch_next), dtype=torch.float32, device=self.device) / 255.0

q_values = self.policy_net(state_batch).gather(1, action_batch)

Expand Down
73 changes: 73 additions & 0 deletions benchmark_tensor_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@

import torch
import numpy as np
import time
import timeit

def benchmark():
batch_size = 64
# Simulate data
batch_action = tuple(np.random.randint(0, 10, size=batch_size).tolist())
batch_reward = tuple(np.random.randn(batch_size).tolist())
batch_done = tuple(np.random.randint(0, 2, size=batch_size).astype(float).tolist())
batch_gamma = tuple(np.random.uniform(0.9, 0.99, size=batch_size).tolist())

# is_weights is a numpy array
is_weights = np.random.uniform(0, 1, size=batch_size).astype(np.float32)

device = torch.device("cpu") # Testing on CPU as sandbox has no GPU

print(f"Benchmarking with batch_size={batch_size} on {device}")

def method_original():
action_batch = torch.tensor(batch_action, dtype=torch.long).unsqueeze(1).to(device)
reward_batch = torch.tensor(batch_reward, dtype=torch.float32).to(device)
done_batch = torch.tensor(batch_done, dtype=torch.float32).to(device)
weights_batch = torch.tensor(is_weights, dtype=torch.float32).to(device)
gamma_batch = torch.tensor(batch_gamma, dtype=torch.float32).to(device)
return action_batch, reward_batch, done_batch, weights_batch, gamma_batch

def method_optimized_device_arg():
action_batch = torch.tensor(batch_action, dtype=torch.long, device=device).unsqueeze(1)
reward_batch = torch.tensor(batch_reward, dtype=torch.float32, device=device)
done_batch = torch.tensor(batch_done, dtype=torch.float32, device=device)
weights_batch = torch.tensor(is_weights, dtype=torch.float32, device=device)
gamma_batch = torch.tensor(batch_gamma, dtype=torch.float32, device=device)
return action_batch, reward_batch, done_batch, weights_batch, gamma_batch

def method_as_tensor_numpy():
# Convert tuple to numpy first?
# Note: converting tuple to numpy adds overhead.
# But for is_weights which is already numpy:
weights_batch = torch.as_tensor(is_weights, dtype=torch.float32, device=device)

# For tuples, maybe sticking to torch.tensor is fine, or converting to numpy first?
# Let's try converting tuples to numpy arrays first
action_batch = torch.as_tensor(batch_action, dtype=torch.long, device=device).unsqueeze(1)
reward_batch = torch.as_tensor(batch_reward, dtype=torch.float32, device=device)
done_batch = torch.as_tensor(batch_done, dtype=torch.float32, device=device)
gamma_batch = torch.as_tensor(batch_gamma, dtype=torch.float32, device=device)
return action_batch, reward_batch, done_batch, weights_batch, gamma_batch

# Warmup
for _ in range(100):
method_original()
method_optimized_device_arg()
method_as_tensor_numpy()

# Measure
n_iter = 10000

t0 = timeit.timeit(method_original, number=n_iter)
t1 = timeit.timeit(method_optimized_device_arg, number=n_iter)
t2 = timeit.timeit(method_as_tensor_numpy, number=n_iter)

print(f"Original: {t0:.4f} s")
print(f"Optimized (device arg): {t1:.4f} s")
print(f"Optimized (as_tensor): {t2:.4f} s")

print(f"Speedup (device arg): {t0/t1:.2f}x")
print(f"Speedup (as_tensor): {t0/t2:.2f}x")

if __name__ == "__main__":
benchmark()
2 changes: 1 addition & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
from gen2.model import DuelingDQN
from model import DuelingDQN

def test_dueling_dqn_initialization():
"""Test that DuelingDQN initializes with different parameters."""
Expand Down