Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions spf/data_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from spf.dataset.v4_data import v4rx_2xf64_keys, v4rx_f64_keys, v4rx_new_dataset
from spf.dataset.v5_data import v5rx_2xf64_keys, v5rx_f64_keys, v5rx_new_dataset
from spf.dataset.wall_array_v2_idxs import v2_column_names
from spf.rf import beamformer_given_steering, get_avg_phase, precompute_steering_vectors
from spf.rf import beamformer_given_steering, get_avg_phase, get_avg_phase_fast, get_avg_phase_fast2, precompute_steering_vectors
from spf.scripts.zarr_utils import zarr_shrink
from spf.sdrpluto.sdr_controller import (
EmitterConfig,
Expand Down Expand Up @@ -266,7 +266,7 @@ def get_data(self):
if sdr_rx is None:
raise ValueError("SDR RX is None, aborting.")
# process the data
signal_matrix = np.vstack(sdr_rx["signal_matrix"])
signal_matrix = np.vstack(sdr_rx["signal_matrix"],dtype=np.complex64)
current_time = time.time() - self.time_offset # timestamp

return data_to_snapshot(
Expand All @@ -285,10 +285,10 @@ def get_data(self):
sdr_rx = self.get_rx()

# process the data
signal_matrix = np.vstack(sdr_rx["signal_matrix"])
signal_matrix = np.vstack(sdr_rx["signal_matrix"]).astype(np.complex64)
current_time = time.time() - self.time_offset # timestamp after sample arrives

avg_phase_diff = get_avg_phase(signal_matrix)
avg_phase_diff =get_avg_phase_fast2(signal_matrix)
assert self.pplus.rx_config.rx_spacing > 0.001
return self.snapshot_class(
signal_matrix=signal_matrix,
Expand Down Expand Up @@ -541,12 +541,9 @@ def write_to_record_matrix(self, thread_idx, record_idx, data):
data.gps_long = current_pos_heading_and_time["gps"][0]
data.gps_lat = current_pos_heading_and_time["gps"][1]
data.gps_timestamp = current_pos_heading_and_time["gps_time"]

if self.realtime_v5inf is not None:
data_dict=asdict(data)
print(data_dict['signal_matrix'].dtype,"XXA")
data_dict['signal_matrix']=data_dict['signal_matrix'].reshape(1,1,*data_dict['signal_matrix'].shape).astype(np.complex64)
print(data_dict['signal_matrix'].dtype,"XXA")
self.realtime_v5inf.write_to_idx(record_idx, thread_idx, data_dict)
if self.data_filename is not None:
z = self.zarr[f"receivers/r{thread_idx}"]
Expand All @@ -556,10 +553,11 @@ def write_to_record_matrix(self, thread_idx, record_idx, data):


def close(self):
self.zarr.store.close()
self.zarr = None
logging.info(f"Trying to shrink... {self.data_filename}")
zarr_shrink(self.data_filename)
if self.data_filename is not None:
self.zarr.store.close()
self.zarr = None
logging.info(f"Trying to shrink... {self.data_filename}")
zarr_shrink(self.data_filename)


# V5 data format
Expand Down
15 changes: 13 additions & 2 deletions spf/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
pi_norm,
reduce_theta_to_positive_y,
windowed_trimmed_circular_mean_and_stddev,
windowed_trimmed_circular_mean_and_stddev_fast,
windowed_trimmed_circular_mean_and_stddev_fast2,
)
from spf.scripts.zarr_utils import (
new_yarr_dataset,
Expand Down Expand Up @@ -375,6 +377,7 @@ def segment_session(
skip_beamformer=False,
skip_detrend=False,
skip_segmentation=False,
fast_beamformer=False,
**kwrgs,
):
"""
Expand Down Expand Up @@ -507,9 +510,13 @@ def segment_session(
.T
)
else:

beamformer_fn=beamformer_given_steering_nomean
if fast_beamformer:
beamformer_fn=beamformer_given_steering_nomean_fast
# CPU version of beamforming (same algorithm but slower)
segmentation_results["windowed_beamformer"] = (
beamformer_given_steering_nomean_fast(
beamformer_fn(
steering_vectors=kwrgs["steering_vectors"],
signal_matrix=v.astype(np.complex64),
)
Expand Down Expand Up @@ -646,9 +653,13 @@ def get_all_windows_stats(
# - Trimmed circular mean of phase differences
# - Trimmed standard deviation of phase differences
# - Median absolute signal amplitude
step_idxs, step_stats = windowed_trimmed_circular_mean_and_stddev(
step_idxs, step_stats = windowed_trimmed_circular_mean_and_stddev_fast(
v, pd, window_size=window_size, stride=stride, trim=trim
)
# step_idxs2, step_stats2 = windowed_trimmed_circular_mean_and_stddev_fast2(
# v, pd, window_size=window_size, stride=stride, trim=trim
# )
#breakpoint()
return step_idxs, step_stats


Expand Down
1 change: 1 addition & 0 deletions spf/dataset/spf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,7 @@ def render_session(self, ridx, data):
skip_beamformer=False,
skip_detrend=self.skip_detrend,
skip_segmentation=self.skip_segmentation,
fast_beamformer=True,
**{
"steering_vectors": self.steering_vectors[ridx],
**DEFAULT_SEGMENT_ARGS,
Expand Down
6 changes: 3 additions & 3 deletions spf/dataset/spf_nn_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def to_absolute_north(self, sample):
sample[ridx]["paired"] = paired_nn_inference_rotated
return sample

@lru_cache
@lru_cache(4)
def get_inference_for_idx(self, idx):
if not self.ds.realtime:
return [
Expand All @@ -103,7 +103,7 @@ def get_inference_for_idx(self, idx):
]
return self.get_and_annotate_entry_at_idx(idx)

@lru_cache
@lru_cache(4)
def get_and_annotate_entry_at_idx(self, idx):
sample = self.ds[idx]
if not self.ds.realtime:
Expand Down Expand Up @@ -135,7 +135,7 @@ def __next__(self):
self.serving_idx += 1
return sample

@lru_cache
@lru_cache(4)
def __getitem__(self, idx):
return self.get_and_annotate_entry_at_idx(idx)

Expand Down
Loading
Loading