Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions scripts/sim_fitness_2x2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,16 @@
rng = np.random.default_rng()
N = 80
GROUPS = ["ProgramA","ProgramB"]
DEFAULT_SEED = 7

def simulate(n_per_group=None, seed=DEFAULT_SEED):
global rng, N
# Allow overriding globals via args, primarily for testing
if n_per_group is not None:
N = n_per_group * len(GROUPS) # Adjust N based on n_per_group

rng = np.random.default_rng(seed)

def simulate():
os.makedirs("data/synthetic", exist_ok=True)
subjects = pd.DataFrame({
"id": np.arange(1, N+1),
Expand Down Expand Up @@ -39,7 +47,7 @@ def simulate():
long.to_csv("data/synthetic/fitness_long.csv", index=False)

meta = dict(
seed=int(getattr(rng, "seed_seq", np.random.SeedSequence()).entropy) if hasattr(rng, "seed_seq") else None,
seed=seed,
n=int(N),
design="2x2 mixed (Group between × Time within)",
programs=GROUPS,
Expand All @@ -50,15 +58,17 @@ def simulate():
json.dump(meta, f, indent=2)

print("Wrote fitness_subjects.csv, fitness_long.csv, fitness_meta.json")
# Return for testing convenience
return subjects, long

def main():
ap = argparse.ArgumentParser()
ap.add_argument("--seed", type=int, default=7, help="RNG seed")
ap.add_argument("--n-per-group", type=int, default=5, help="(kept for CLI parity; not used here)")
ap.add_argument("--seed", type=int, default=DEFAULT_SEED, help="RNG seed")
ap.add_argument("--n-per-group", type=int, default=40, help="Number of subjects per group")
args = ap.parse_args()
global rng
rng = np.random.default_rng(args.seed)
simulate()

# Pass CLI args to simulate
simulate(n_per_group=args.n_per_group, seed=args.seed)

if __name__ == "__main__":
main()
34 changes: 21 additions & 13 deletions scripts/sim_stroop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pandas as pd

RNG_SEED = 42
rng = np.random.default_rng(RNG_SEED)
rng = np.random.default_rng(RNG_SEED) # Will be reset by simulate()

N_SUBJ = 60
N_TRIALS_PER_COND = 100
Expand All @@ -37,7 +37,15 @@
OUTLIER_RT_LAPSE_SD = 200
RT_FLOOR = 80

def simulate():
def simulate(n_subjects=N_SUBJ, n_trials=N_TRIALS_PER_COND, seed=RNG_SEED):
global rng, N_SUBJ, N_TRIALS_PER_COND, RNG_SEED

# Set globals based on args, primarily for testing
N_SUBJ = n_subjects
N_TRIALS_PER_COND = n_trials
RNG_SEED = seed
rng = np.random.default_rng(RNG_SEED)

os.makedirs("data/synthetic", exist_ok=True)

# subject-level covariates
Expand Down Expand Up @@ -93,7 +101,10 @@ def simulate():
subjects.to_csv("data/synthetic/psych_stroop_subjects.csv", index=False)
trials.to_csv("data/synthetic/psych_stroop_trials.csv", index=False)
print("Wrote data/synthetic/psych_stroop_subjects.csv and psych_stroop_trials.csv")


# Write meta *after* simulation is done
write_meta(subjects, trials)

return subjects, trials

def write_meta(subjects: pd.DataFrame, trials: pd.DataFrame):
Expand Down Expand Up @@ -138,17 +149,14 @@ def write_meta(subjects: pd.DataFrame, trials: pd.DataFrame):

def main():
ap = argparse.ArgumentParser()
ap.add_argument("--seed", type=int, default=42)
ap.add_argument("--n-subjects", type=int, default=60)
ap.add_argument("--n-trials", type=int, default=100)
ap.add_argument("--seed", type=int, default=RNG_SEED)
ap.add_argument("--n-subjects", type=int, default=N_SUBJ)
ap.add_argument("--n-trials", type=int, default=N_TRIALS_PER_COND)
args = ap.parse_args()
global N_SUBJ, N_TRIALS_PER_COND, rng, RNG_SEED
N_SUBJ = args.n_subjects
N_TRIALS_PER_COND = args.n_trials
RNG_SEED = args.seed
rng = np.random.default_rng(RNG_SEED)
subjects_df, trials_df = simulate()
write_meta(subjects_df, trials_df)

# Pass CLI args to simulate
simulate(n_subjects=args.n_subjects, n_trials=args.n_trials, seed=args.seed)


if __name__ == "__main__":
main()
Loading