Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
313 changes: 313 additions & 0 deletions ZQ003/scripts/make_simulation.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from scipy.signal import lfilter, butter \n",
"from scipy.stats import poisson\n",
"import matplotlib.pyplot as plt\n",
"from scipy.interpolate import make_interp_spline\n",
"import random\n",
"import pandas as pd\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Parameters\n",
"sample_rate = 1017.25 # (Hz) - based on normal data recording rate\n",
"t = 1800 # (s) - based on std experiment\n",
"cutoff = 0.1 # based on OG simulation paper\n",
"n_dtpts = int(sample_rate*t) # Number of data points\n",
"movement_attenuation = 50 # Example attenuation percentage as per OG sim paper\n",
"noise_factor = 2 # as per OG sim paper\n",
"time_pts = np.linspace(0,t,n_dtpts)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def calculate_movement_component(cutoff = 0.1, sample_rate = sample_rate, movement_attenuation = 50):\n",
" '''\n",
" Calculate the movement component of the signal, \n",
" based on a lowpass filtered random data and movement attenuation parameter\n",
" '''\n",
"\n",
" b, a = butter(N=4, Wn=cutoff / (sample_rate / 2), btype='low') # check cutofffffffff\n",
"\n",
" # Apply the filter\n",
" lowpass_values = lfilter(b, a, np.random.rand(n_dtpts))\n",
"\n",
" movement_component = 1 - (lowpass_values * (movement_attenuation / 100))\n",
" return movement_component\n",
"\n",
"def calculate_decay_component(time_pts, decay_rate1 = 0.02, decay_rate2 = 0.002, decay_base = 40):\n",
" '''\n",
" Make a double exponatial decaying curve, sampled at every time_pts\n",
" '''\n",
" decay_rate = ((1 - decay_rate1) ** time_pts + (1 - decay_rate2) ** time_pts) / 2\n",
" print(np.shape(decay_rate))\n",
" decay = decay_rate*(decay_base/100)+(1-decay_base/100)\n",
" \n",
" return decay\n",
" \n",
"\n",
"def calculate_ERT(lambda_val = 2, peak = 1, scale = 5, vis = False):\n",
" '''\n",
" Makes a Poisson distribution, \n",
" with mean = lambda_val, range = t, max value = peak\n",
" '''\n",
" # evaluate lambda over a duration 5 times longer to capture the whole distribution\n",
" t = lambda_val*5\n",
"\n",
" # Generate discrete values of the theoretical Poisson probability mass\n",
" # function (pmf) from 0 to t\n",
" x = np.arange(0, t)\n",
" pmf = poisson.pmf(x, lambda_val)\n",
" # Rescale x axis. The lowest reasonable value of lambda is 2,\n",
" # corresponding to t = 10, our response timescale is >50ms\n",
" x = x * scale\n",
" # Rescale y axis\n",
" pmf = pmf/max(pmf)*peak\n",
" # print(max(pmf))\n",
"\n",
" # Interpolate pmf\n",
" b = make_interp_spline(x, pmf, k=2) # b spline interpolation\n",
" x = np.arange(0, t * scale)\n",
" pmf = b(x)\n",
" # print(max(pmf))\n",
"\n",
" # Reindex where pmf values are >= 0.01\n",
" indices = np.where(pmf >= 0.01)[0]\n",
" pmf = pmf[indices]\n",
" # x = x[indices]\n",
" # x = np.arange(len(x))\n",
" # print(max(pmf))\n",
" \n",
" if vis:\n",
" plt.plot(x, pmf)\n",
" plt.xlabel('time (ms; 1017.25Hz)')\n",
" plt.show()\n",
" \n",
" return pmf\n",
"\n",
"\n",
"def calculate_noise_component(n_dtpts, sample_rate, noise_factor=8):\n",
" '''\n",
" Make a vector of length n_dtpts with random noised scaled by noise_factor\n",
" '''\n",
" noise_component = np.random.randn(n_dtpts) * noise_factor\n",
"\n",
" # b = sig.firwin(noise_component, cutoff=[1], fs=data.attrs['fs'],\n",
" # pass_zero=False)\n",
" # noise_component = detrend.filter_b = b\n",
" # b, a = butter(N=6, Wn=0.99, btype='low')\n",
"\n",
" # # Apply the filter\n",
" # noise_component = lfilter(b, a, np.random.rand(n_dtpts))\n",
" \n",
" return noise_component\n",
"\n",
"\n",
"\n",
"\n",
" \n",
"\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# %%\n",
"def make_event(n_dtpts = 1000, n_events = False, lambda_val = False, peak_m = False, vis = False, delay_og = 0):\n",
"\n",
" true_signal = np.zeros(n_dtpts)\n",
" if not n_events: n_events = random.randint(2,3)\n",
" events = np.zeros(n_dtpts)\n",
" if not lambda_val: lambda_val = random.randint(2, 5)\n",
" if not peak_m: peak_m = random.uniform(5,15)\n",
"\n",
" for i in range(n_events):\n",
" delay = delay_og + random.randint(0,5)\n",
" peak = peak_m + random.uniform(-2, 2)\n",
" print(peak)\n",
" ert = calculate_ERT(lambda_val, peak, scale=10)\n",
" event_duration = len(ert)\n",
"\n",
" initial_response = random.randint(delay, len(true_signal)-event_duration)\n",
" events[initial_response] = 1\n",
"\n",
" true_signal[initial_response:initial_response+event_duration] += ert\n",
"\n",
" if vis: plt.plot(true_signal)\n",
" \n",
" return events, true_signal"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ert = calculate_ERT(20, 7, scale=10)\n",
"max(ert)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Put it all together to make the simmulated signal made of noise, underlying true signal, photobleaching decay and movement \n",
"\n",
"events1, true_signal1 = make_event(n_dtpts=n_dtpts, n_events=20, delay_og=1, peak_m = 8, vis = False, lambda_val=2)\n",
"events2, true_signal2 = make_event(n_dtpts=n_dtpts, n_events=15, delay_og=2, peak_m = 10, vis = False, lambda_val=20)\n",
"events3, true_signal3 = make_event(n_dtpts=n_dtpts, n_events=21, delay_og=0, peak_m = 12, vis = False, lambda_val=50)\n",
"\n",
"true_signal = true_signal1 + true_signal2 + true_signal3\n",
"\n",
"movement_component = calculate_movement_component(cutoff, sample_rate, movement_attenuation)\n",
"\n",
"noise_component = calculate_noise_component(n_dtpts, sample_rate)\n",
"noise_component_iso = calculate_noise_component(n_dtpts, sample_rate)\n",
"\n",
"decay_component = calculate_decay_component(time_pts)\n",
"\n",
"data = (true_signal + 200) * movement_component * decay_component + noise_component\n",
"\n",
"isob = 100 * movement_component * decay_component + noise_component_iso\n",
"\n",
"# np.save('C:\\Users\\levip\\Desktop\\NSB\\BrainHack\\behapy\\SIM\\rawdata\\sub-test1\\ses-TEST1\\sub-test1_ses-TEST.2_task-TEST_run-1_label-LNAc_channel-iso.npy', isob)\n",
"\n",
"p, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2,2)\n",
"ax1.plot(true_signal)\n",
"ax2 = plt.subplot(2,2, 2)\n",
"ax2.plot(decay_component)\n",
"ax3 = plt.subplot(2,2, 3)\n",
"ax3.plot(movement_component)\n",
"ax4 = plt.subplot(2,2, 4)\n",
"ax4.plot(noise_component)\n",
"plt.show()\n",
"\n",
"# plt.plot(decay_component)\n",
"# plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.plot(true_signal1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Rescale onset of each event from index to seconds\n",
"aa = np.where(events1 == 1)[0] / sample_rate\n",
"bb = np.where(events2 == 1)[0] / sample_rate\n",
"cc = np.where(events3 == 1)[0] / sample_rate\n",
"\n",
"# Combine all event times and labels\n",
"onsets = np.concatenate([aa, bb, cc])\n",
"duration = [0.1] * len(onsets)\n",
"event_ids = ['event1'] * len(aa) + ['event2'] * len(bb) + ['event3'] * len(cc)\n",
"\n",
"# Create the DataFrame and sort by time\n",
"df = pd.DataFrame({'onset': onsets, 'duration': duration, 'event_id': event_ids}).sort_values(by='onset').reset_index(drop=True)\n",
"df = df.set_index('onset')\n",
"\n",
"\n",
"df.to_csv(r'\\Users\\levip\\Desktop\\NSB\\BrainHack\\behapy\\SIM\\rawdata\\sub-test1\\ses-TEST1\\sub-test1_ses-TEST1_task-TEST_run-1_events.csv')\n",
"\n",
"# print(df)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"np.save('/Users/levip/Desktop/NSB/BrainHack/behapy/SIM/rawdata/sub-test1/ses-TEST1/fp/sub-test1_ses-TEST1_task-TEST_run-1_label-LNAc_channel-ACh.npy', data)\n",
"np.save('/Users/levip/Desktop/NSB/BrainHack/behapy/SIM/rawdata/sub-test1/ses-TEST1/fp/sub-test1_ses-TEST1_task-TEST_run-1_label-LNAc_channel-iso.npy', isob)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.plot(data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#%%\n",
"import numpy as np\n",
"import holoviews as hv\n",
"import datashader as ds\n",
"from holoviews.operation.datashader import datashade\n",
"from bokeh.plotting import output_notebook\n",
"\n",
"# Enable Bokeh and Holoviews support in the notebook\n",
"hv.extension('bokeh')\n",
"# output_notebook()\n",
"\n",
"# Convert data to a Holoviews Curve\n",
"curve = hv.Curve((np.arange(len(true_signal3)), true_signal3))\n",
"shaded_curve = datashade(curve).opts(width=800)\n",
"\n",
"shaded_curve"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "behapy",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}