diff --git a/commons/stan_files/bart_ewmv.stan b/commons/stan_files/bart_ewmv.stan index 4d521150..577fd85f 100644 --- a/commons/stan_files/bart_ewmv.stan +++ b/commons/stan_files/bart_ewmv.stan @@ -71,17 +71,21 @@ model { real p_burst = phi[j]; for (k in 1:Tsubj[j]) { - real u_gain = 1; + // real u_gain = 1; + // Reward values based on pump number + row_vector[11] u_gain_array = [0.05, 0.10, 0.10, 0.30, 0.40, .50, .60, .70, .70, .80, .90]; + row_vector[11] u_lose_array = [0, 0.05, 0.15, 0.25, 0.55, 0.95, 1.45, 2.05, 2.75, 3.45, 4.25, 5.15]; real u_loss; real u_pump; real u_stop = 0; real delta_u; for (l in 1:(pumps[j, k] + 1 - explosion[j, k])) { - u_loss = (l - 1); - - u_pump = (1 - p_burst) * u_gain - lambda[j] * p_burst * u_loss + - rho[j] * p_burst * (1 - p_burst) * (u_gain + lambda[j] * u_loss)^2; + // Because our reward gain is not constant across the trail, we need to use the summation of the previous rewards. + // This is pre-populated in the u_lose_array + // u_loss = (l - 1); + + u_pump = (1 - p_burst) * u_gain_array[l] - lambda[j] * p_burst * u_lose_array[l] + rho[j] * p_burst * (1 - p_burst) * (u_gain_array[l] + lambda[j] * u_lose_array[l])^2; // u_stop always equals 0. delta_u = u_pump - u_stop; @@ -131,7 +135,10 @@ generated quantities { log_lik[j] = 0; for (k in 1:Tsubj[j]) { - real u_gain = 1; + // real u_gain = 1; + // Reward values based on pump number + row_vector[11] u_gain_array = [0.05, 0.10, 0.10, 0.30, 0.40, .50, .60, .70, .70, .80, .90]; + row_vector[11] u_lose_array = [0, 0.05, 0.15, 0.25, 0.55, 0.95, 1.45, 2.05, 2.75, 3.45, 4.25, 5.15]; real u_loss; real u_pump; real u_stop = 0; @@ -139,10 +146,13 @@ generated quantities { for (l in 1:(pumps[j, k] + 1 - explosion[j, k])) { // u_gain always equals r ^ rho. - u_loss = (l - 1); + + // Because our reward gain is not constant across the trail, we need to use the summation of the previous rewards. + // This is pre-populated in the u_lose_array + // u_loss = (l - 1); - u_pump = (1 - p_burst) * u_gain - lambda[j] * p_burst * u_loss + - rho[j] * p_burst * (1 - p_burst) * (u_gain + lambda[j] * u_loss)^2; + // Updated to use u_gain_array values + u_pump = (1 - p_burst) * u_gain_array[l] - lambda[j] * p_burst * u_lose_array[l] + rho[j] * p_burst * (1 - p_burst) * (u_gain_array[l] + lambda[j] * u_lose_array[l])^2; // u_stop always equals 0. delta_u = u_pump - u_stop;