diff --git a/scripts/sim_fitness_2x2.py b/scripts/sim_fitness_2x2.py index 8481ab9..c21aa4f 100644 --- a/scripts/sim_fitness_2x2.py +++ b/scripts/sim_fitness_2x2.py @@ -6,8 +6,16 @@ rng = np.random.default_rng() N = 80 GROUPS = ["ProgramA","ProgramB"] +DEFAULT_SEED = 7 + +def simulate(n_per_group=None, seed=DEFAULT_SEED): + global rng, N + # Allow overriding globals via args, primarily for testing + if n_per_group is not None: + N = n_per_group * len(GROUPS) # Adjust N based on n_per_group + + rng = np.random.default_rng(seed) -def simulate(): os.makedirs("data/synthetic", exist_ok=True) subjects = pd.DataFrame({ "id": np.arange(1, N+1), @@ -39,7 +47,7 @@ def simulate(): long.to_csv("data/synthetic/fitness_long.csv", index=False) meta = dict( - seed=int(getattr(rng, "seed_seq", np.random.SeedSequence()).entropy) if hasattr(rng, "seed_seq") else None, + seed=seed, n=int(N), design="2x2 mixed (Group between × Time within)", programs=GROUPS, @@ -50,15 +58,17 @@ def simulate(): json.dump(meta, f, indent=2) print("Wrote fitness_subjects.csv, fitness_long.csv, fitness_meta.json") + # Return for testing convenience + return subjects, long def main(): ap = argparse.ArgumentParser() - ap.add_argument("--seed", type=int, default=7, help="RNG seed") - ap.add_argument("--n-per-group", type=int, default=5, help="(kept for CLI parity; not used here)") + ap.add_argument("--seed", type=int, default=DEFAULT_SEED, help="RNG seed") + ap.add_argument("--n-per-group", type=int, default=40, help="Number of subjects per group") args = ap.parse_args() - global rng - rng = np.random.default_rng(args.seed) - simulate() + + # Pass CLI args to simulate + simulate(n_per_group=args.n_per_group, seed=args.seed) if __name__ == "__main__": main() \ No newline at end of file diff --git a/scripts/sim_stroop.py b/scripts/sim_stroop.py index 6162de9..607b69e 100644 --- a/scripts/sim_stroop.py +++ b/scripts/sim_stroop.py @@ -12,7 +12,7 @@ import pandas as pd RNG_SEED = 42 -rng = np.random.default_rng(RNG_SEED) +rng = np.random.default_rng(RNG_SEED) # Will be reset by simulate() N_SUBJ = 60 N_TRIALS_PER_COND = 100 @@ -37,7 +37,15 @@ OUTLIER_RT_LAPSE_SD = 200 RT_FLOOR = 80 -def simulate(): +def simulate(n_subjects=N_SUBJ, n_trials=N_TRIALS_PER_COND, seed=RNG_SEED): + global rng, N_SUBJ, N_TRIALS_PER_COND, RNG_SEED + + # Set globals based on args, primarily for testing + N_SUBJ = n_subjects + N_TRIALS_PER_COND = n_trials + RNG_SEED = seed + rng = np.random.default_rng(RNG_SEED) + os.makedirs("data/synthetic", exist_ok=True) # subject-level covariates @@ -93,7 +101,10 @@ def simulate(): subjects.to_csv("data/synthetic/psych_stroop_subjects.csv", index=False) trials.to_csv("data/synthetic/psych_stroop_trials.csv", index=False) print("Wrote data/synthetic/psych_stroop_subjects.csv and psych_stroop_trials.csv") - + + # Write meta *after* simulation is done + write_meta(subjects, trials) + return subjects, trials def write_meta(subjects: pd.DataFrame, trials: pd.DataFrame): @@ -138,17 +149,14 @@ def write_meta(subjects: pd.DataFrame, trials: pd.DataFrame): def main(): ap = argparse.ArgumentParser() - ap.add_argument("--seed", type=int, default=42) - ap.add_argument("--n-subjects", type=int, default=60) - ap.add_argument("--n-trials", type=int, default=100) + ap.add_argument("--seed", type=int, default=RNG_SEED) + ap.add_argument("--n-subjects", type=int, default=N_SUBJ) + ap.add_argument("--n-trials", type=int, default=N_TRIALS_PER_COND) args = ap.parse_args() - global N_SUBJ, N_TRIALS_PER_COND, rng, RNG_SEED - N_SUBJ = args.n_subjects - N_TRIALS_PER_COND = args.n_trials - RNG_SEED = args.seed - rng = np.random.default_rng(RNG_SEED) - subjects_df, trials_df = simulate() - write_meta(subjects_df, trials_df) + + # Pass CLI args to simulate + simulate(n_subjects=args.n_subjects, n_trials=args.n_trials, seed=args.seed) + if __name__ == "__main__": main() \ No newline at end of file