Skip to content

Commit da372bf

Browse files
committed
fix(ehgf): invalid eHGF update code
1 parent c653d6c commit da372bf

6 files changed

Lines changed: 159 additions & 123 deletions

File tree

commons/stan_files/ehgf_ibrb.stan

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ model {
218218
for (i in 1:N) {
219219
row_vector[L-1] mu = mu0[i]; // prediction (2 ~ L)
220220
real sa[L] = sigma_base; // uncertainty of prediction (2 ~ L)
221+
real sa_prev[L]; // uncertainty of prediction from previous trial (2 ~ L)
221222
real mu_hat[L] = rep_array(0.0, L); // prior prediction (1 ~ L)
222223
real sa_hat[L] = rep_array(0.0, L); // prior uncertainty of prediction (1 ~ L)
223224
real m = -1; // predictive probability that the next response will be 1 (0~1)
@@ -244,10 +245,12 @@ model {
244245
for (l in 2:(L-1)) {
245246
real ka = kappa[i,l-1];
246247
real om = omega[i,l-1];
247-
sa_hat[l] = sa[l] + exp(ka * mu[l] + om);
248+
sa_hat[l] = sa[l] + exp(ka * mu[(l+1)-1] + om);
248249
}
249250
sa_hat[L] = sa[L] + exp(omega[i,L-1]);
250-
251+
252+
sa_prev = sa;
253+
251254
// Level 2
252255
mu_prev = mu_hat[2];
253256
pe = u[i,t] - mu_hat[1]; // prediction error
@@ -258,26 +261,26 @@ model {
258261
for (l in 3:L) {
259262
real ka = kappa[i,(l-1)-1];
260263
real om = omega[i,(l-1)-1];
261-
real mu_lower = mu[l-1-1];
264+
real mu_lower = mu[(l-1)-1];
262265
real mu_prev_ = mu_hat[l];
263266
real mu_prev_lower = mu_hat[l-1];
264267
real sa_lower = sa[l-1];
265-
real sa_prev_lower = sa_hat[l-1];
268+
real sa_prev_lower = sa_prev[l-1];
266269
real vv, pimhat, ww, rr, dd;
267270

268271
real v = exp(ka * mu_prev_ + om); // volatility
269-
real w = v / sa_prev_lower; // weighting factor (level: l-1)
272+
real w = v / sa_hat[l-1]; // weighting factor (level: l-1)
270273

271-
real vpe = ((sa_lower + pow(mu_lower - mu_prev_lower, 2)) / sa_prev_lower) - 1; // volatility prediction error
274+
real vpe = ((sa_lower + pow(mu_lower - mu_prev_lower, 2)) / sa_hat[l-1]) - 1.0; // volatility prediction error
272275
real lr = 0.5 * sa_hat[l] * ka * w; // learning rate
273276
real pwpe = lr * vpe; // precision-weighted prediction error
274277
mu[l-1] = mu_prev_ + pwpe; // posterior prediction
275278

276-
vv = exp(ka * mu[l-1] +om);
277-
pimhat = 1.0 / (sa_prev_lower +vv);
279+
vv = exp(ka * mu[l-1] + om);
280+
pimhat = 1.0 / (sa_prev_lower + vv);
278281
ww = vv * pimhat;
279282
rr = (vv - sa_prev_lower) * pimhat;
280-
dd = (sa_lower + square(mu_lower - mu_prev_lower)) *pimhat - 1;
283+
dd = (sa_lower + square(mu_lower - mu_prev_lower)) * pimhat - 1.0;
281284

282285
sa[l] = 1.0/((1.0/sa_hat[l]) + fmax(0.0, 0.5 * square(ka) * ww * (ww + rr * dd)));
283286
}
@@ -302,11 +305,12 @@ generated quantities {
302305
vector[L-2] mu_kappa;
303306
vector[L-1] mu_omega;
304307
real<lower=0> mu_zeta;
305-
306-
// Subject loop
308+
309+
// Subject loop
307310
for (i in 1:N) {
308311
row_vector[L-1] mu = mu0[i]; // prediction (2 ~ L)
309312
real sa[L] = sigma_base; // uncertainty of prediction (2 ~ L)
313+
real sa_prev[L]; // uncertainty of prediction from previous trial (2 ~ L)
310314
real mu_hat[L] = rep_array(0.0, L); // prior prediction (1 ~ L)
311315
real sa_hat[L] = rep_array(0.0, L); // prior uncertainty of prediction (1 ~ L)
312316
real m = -1; // predictive probability that the next response will be 1 (0~1)
@@ -333,10 +337,12 @@ generated quantities {
333337
for (l in 2:(L-1)) {
334338
real ka = kappa[i,l-1];
335339
real om = omega[i,l-1];
336-
sa_hat[l] = sa[l] + exp(ka * mu[l+1-1] + om);
340+
sa_hat[l] = sa[l] + exp(ka * mu[(l+1)-1] + om);
337341
}
338342
sa_hat[L] = sa[L] + exp(omega[i,L-1]);
339-
343+
344+
sa_prev = sa;
345+
340346
// Level 2
341347
mu_prev = mu_hat[2];
342348
pe = u[i,t] - mu_hat[1]; // prediction error
@@ -347,26 +353,26 @@ generated quantities {
347353
for (l in 3:L) {
348354
real ka = kappa[i,(l-1)-1];
349355
real om = omega[i,(l-1)-1];
350-
real mu_lower = mu[l-1-1];
356+
real mu_lower = mu[(l-1)-1];
351357
real mu_prev_ = mu_hat[l];
352358
real mu_prev_lower = mu_hat[l-1];
353359
real sa_lower = sa[l-1];
354-
real sa_prev_lower = sa_hat[l-1];
360+
real sa_prev_lower = sa_prev[l-1];
355361
real vv, pimhat, ww, rr, dd;
356362

357-
real v = exp(ka * mu_prev_ + om);// volatility
358-
real w = v / sa_prev_lower; // weighting factor (level: l-1)
363+
real v = exp(ka * mu_prev_ + om); // volatility
364+
real w = v / sa_hat[l-1]; // weighting factor (level: l-1)
359365

360-
real vpe = ((sa_lower + pow(mu_lower - mu_prev_lower, 2)) / sa_prev_lower) - 1; // volatility prediction error
366+
real vpe = ((sa_lower + pow(mu_lower - mu_prev_lower, 2)) / sa_hat[l-1]) - 1.0; // volatility prediction error
361367
real lr = 0.5 * sa_hat[l] * ka * w; // learning rate
362368
real pwpe = lr * vpe; // precision-weighted prediction error
363369
mu[l-1] = mu_prev_ + pwpe; // posterior prediction
364370

365-
vv = exp(ka * mu[l-1] +om);
366-
pimhat = 1.0 / (sa_prev_lower +vv);
371+
vv = exp(ka * mu[l-1] + om);
372+
pimhat = 1.0 / (sa_prev_lower + vv);
367373
ww = vv * pimhat;
368374
rr = (vv - sa_prev_lower) * pimhat;
369-
dd = (sa_lower + square(mu_lower - mu_prev_lower)) *pimhat - 1;
375+
dd = (sa_lower + square(mu_lower - mu_prev_lower)) * pimhat - 1.0;
370376

371377
sa[l] = 1.0/((1.0/sa_hat[l]) + fmax(0.0, 0.5 * square(ka) * ww * (ww + rr * dd)));
372378
}

commons/stan_files/ehgf_ibrb_multipleB.stan

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ model {
225225
for (bIdx in 1:Bsubj[i]) {
226226
row_vector[L-1] mu = mu0[i]; // prediction (2 ~ L)
227227
real sa[L] = sigma_base; // uncertainty of prediction (2 ~ L)
228+
real sa_prev[L];
228229
real mu_hat[L] = rep_array(0.0, L); // prior prediction (1 ~ L)
229230
real sa_hat[L] = rep_array(0.0, L); // prior uncertainty of prediction (1 ~ L)
230231
real m = -1; // predictive probability that the next response will be 1 (0~1)
@@ -251,40 +252,42 @@ model {
251252
for (l in 2:(L-1)) {
252253
real ka = kappa[i,l-1];
253254
real om = omega[i,l-1];
254-
sa_hat[l] = sa[l] + exp(ka * mu[l] + om);
255+
sa_hat[l] = sa[l] + exp(ka * mu[(l+1)-1] + om);
255256
}
256257
sa_hat[L] = sa[L] + exp(omega[i,L-1]);
257-
258+
259+
sa_prev = sa;
260+
258261
// Level 2
259262
mu_prev = mu_hat[2];
260-
pe = u[i,bIdx,t] - mu_hat[1]; // prediction error
263+
pe = u[i,bIdx,t] - mu_hat[1]; // prediction error
261264
sa[2] = 1.0 / ((1.0/sa_hat[2]) + sa_hat[1]); // learning rate
262265
mu[2-1] = mu_prev + sa[2] * pe; // posterior prediction
263266

264267
// Level 3 ~ L
265268
for (l in 3:L) {
266269
real ka = kappa[i,(l-1)-1];
267270
real om = omega[i,(l-1)-1];
268-
real mu_lower = mu[l-1-1];
271+
real mu_lower = mu[(l-1)-1];
269272
real mu_prev_ = mu_hat[l];
270273
real mu_prev_lower = mu_hat[l-1];
271274
real sa_lower = sa[l-1];
272-
real sa_prev_lower = sa_hat[l-1];
275+
real sa_prev_lower = sa_prev[l-1];
273276
real vv, pimhat, ww, rr, dd;
274277

275278
real v = exp(ka * mu_prev_ + om); // volatility
276-
real w = v / sa_prev_lower; // weighting factor (level: l-1)
279+
real w = v / sa_hat[l-1]; // weighting factor (level: l-1)
277280

278-
real vpe = ((sa_lower + pow(mu_lower - mu_prev_lower, 2)) / sa_prev_lower) - 1; // volatility prediction error
281+
real vpe = ((sa_lower + pow(mu_lower - mu_prev_lower, 2)) / sa_hat[l-1]) - 1.0; // volatility prediction error
279282
real lr = 0.5 * sa_hat[l] * ka * w; // learning rate
280283
real pwpe = lr * vpe; // precision-weighted prediction error
281284
mu[l-1] = mu_prev_ + pwpe; // posterior prediction
282285

283-
vv = exp(ka * mu[l-1] +om);
284-
pimhat = 1.0 / (sa_prev_lower +vv);
286+
vv = exp(ka * mu[l-1] + om);
287+
pimhat = 1.0 / (sa_prev_lower + vv);
285288
ww = vv * pimhat;
286289
rr = (vv - sa_prev_lower) * pimhat;
287-
dd = (sa_lower + square(mu_lower - mu_prev_lower)) *pimhat - 1;
290+
dd = (sa_lower + square(mu_lower - mu_prev_lower)) * pimhat - 1.0;
288291

289292
sa[l] = 1.0/((1.0/sa_hat[l]) + fmax(0.0, 0.5 * square(ka) * ww * (ww + rr * dd)));
290293
}
@@ -311,11 +314,12 @@ generated quantities {
311314
vector[L-1] mu_omega;
312315
real<lower=0> mu_zeta;
313316

314-
// Subject loop
317+
// Subject loop
315318
for (i in 1:N) {
316319
for (bIdx in 1:Bsubj[i]) {
317320
row_vector[L-1] mu = mu0[i]; // prediction (2 ~ L)
318321
real sa[L] = sigma_base; // uncertainty of prediction (2 ~ L)
322+
real sa_prev[L];
319323
real mu_hat[L] = rep_array(0.0, L); // prior prediction (1 ~ L)
320324
real sa_hat[L] = rep_array(0.0, L); // prior uncertainty of prediction (1 ~ L)
321325
real m = -1; // predictive probability that the next response will be 1 (0~1)
@@ -342,40 +346,42 @@ generated quantities {
342346
for (l in 2:(L-1)) {
343347
real ka = kappa[i,l-1];
344348
real om = omega[i,l-1];
345-
sa_hat[l] = sa[l] + exp(ka * mu[l+1-1] + om);
349+
sa_hat[l] = sa[l] + exp(ka * mu[(l+1)-1] + om);
346350
}
347351
sa_hat[L] = sa[L] + exp(omega[i,L-1]);
348-
352+
353+
sa_prev = sa;
354+
349355
// Level 2
350356
mu_prev = mu_hat[2];
351-
pe = u[i,bIdx,t] - mu_hat[1]; // prediction error
357+
pe = u[i,bIdx,t] - mu_hat[1]; // prediction error
352358
sa[2] = 1.0 / ((1.0/sa_hat[2]) + sa_hat[1]); // learning rate
353359
mu[2-1] = mu_prev + sa[2] * pe; // posterior prediction
354360

355361
// Level 3 ~ L
356362
for (l in 3:L) {
357363
real ka = kappa[i,(l-1)-1];
358364
real om = omega[i,(l-1)-1];
359-
real mu_lower = mu[l-1-1];
365+
real mu_lower = mu[(l-1)-1];
360366
real mu_prev_ = mu_hat[l];
361367
real mu_prev_lower = mu_hat[l-1];
362368
real sa_lower = sa[l-1];
363-
real sa_prev_lower = sa_hat[l-1];
369+
real sa_prev_lower = sa_prev[l-1];
364370
real vv, pimhat, ww, rr, dd;
365371

366-
real v = exp(ka * mu_prev_ + om);// volatility
367-
real w = v / sa_prev_lower; // weighting factor (level: l-1)
368-
369-
real vpe = ((sa_lower + pow(mu_lower - mu_prev_lower, 2)) / sa_prev_lower) - 1; // volatility prediction error
372+
real v = exp(ka * mu_prev_ + om); // volatility
373+
real w = v / sa_hat[l-1]; // weighting factor (level: l-1)
374+
375+
real vpe = ((sa_lower + pow(mu_lower - mu_prev_lower, 2)) / sa_hat[l-1]) - 1.0; // volatility prediction error
370376
real lr = 0.5 * sa_hat[l] * ka * w; // learning rate
371377
real pwpe = lr * vpe; // precision-weighted prediction error
372378
mu[l-1] = mu_prev_ + pwpe; // posterior prediction
373379

374-
vv = exp(ka * mu[l-1] +om);
375-
pimhat = 1.0 / (sa_prev_lower +vv);
380+
vv = exp(ka * mu[l-1] + om);
381+
pimhat = 1.0 / (sa_prev_lower + vv);
376382
ww = vv * pimhat;
377383
rr = (vv - sa_prev_lower) * pimhat;
378-
dd = (sa_lower + square(mu_lower - mu_prev_lower)) *pimhat - 1;
384+
dd = (sa_lower + square(mu_lower - mu_prev_lower)) * pimhat - 1.0;
379385

380386
sa[l] = 1.0/((1.0/sa_hat[l]) + fmax(0.0, 0.5 * square(ka) * ww * (ww + rr * dd)));
381387
}

0 commit comments

Comments
 (0)