diff --git a/tests/utils/test_visualization.py b/tests/utils/test_visualization.py index f92588f..b203246 100644 --- a/tests/utils/test_visualization.py +++ b/tests/utils/test_visualization.py @@ -24,6 +24,15 @@ def sample_spatiotemporal(): U = np.sin(T) * np.cos(2 * np.pi * X) return U +@pytest.fixture +def sample_3d_series(): + # Create a simple 3D time series: shape (Nt=100) + t = np.linspace(0, 10, 100) + x = np.sin(t) + y = np.cos(t) + z = np.sin(t) * np.cos(0.5 * t) + return np.column_stack((x, y, z)) + @patch('matplotlib.pyplot.show') def test_plot_time_series_basic(mock_show, sample_time_series): # Test with basic parameters @@ -73,16 +82,62 @@ def test_imshow_1D_spatiotemp_with_options(mock_show, sample_spatiotemporal): ) mock_show.assert_called_once() +@patch('matplotlib.pyplot.show') +def test_plot_in_3d_state_space_basic(mock_show, sample_3d_series): + # Test with basic parameters + vis.plot_in_3D_state_space(sample_3d_series) + mock_show.assert_called_once() + +@patch('matplotlib.pyplot.show') +def test_plot_in_3d_state_space_with_options(mock_show, sample_3d_series): + # Test with optional parameters + vis.plot_in_3D_state_space( + [sample_3d_series, sample_3d_series], + time_series_labels=["Data 1", "Data 2"], + line_formats=['-', '--'], + state_var_names=["x1", "x2", "x3"], + title="3D Attractor", + linewidth=1.5, + ) + mock_show.assert_called_once() + +@patch('matplotlib.pyplot.show') +def test_plot_in_3d_state_space_with_jax(mock_show): + # Test with JAX arrays + t = jnp.linspace(0, 10, 100) + x = jnp.sin(t) + y = jnp.cos(t) + z = jnp.sin(t) * jnp.cos(0.5 * t) + data = jnp.column_stack((x, y, z)) + vis.plot_in_3D_state_space(data) + mock_show.assert_called_once() + def test_input_validation(): - # Test input validation for both functions + # Test input validation for all three functions with pytest.raises(TypeError): vis.plot_time_series("not an array") with pytest.raises(TypeError): vis.plot_time_series(np.array([1, 2, 3])) # 1D array + with pytest.raises(ValueError): + a = np.zeros((10, 3)) + b = np.zeros((9, 3)) + vis.plot_time_series([a, b]) # mismatched shapes + with pytest.raises(TypeError): vis.imshow_1D_spatiotemp("not an array", 10) with pytest.raises(TypeError): vis.imshow_1D_spatiotemp(np.array([1, 2, 3]), 10) # 1D array + + with pytest.raises(TypeError): + vis.plot_in_3D_state_space("not an array") + + with pytest.raises(TypeError): + vis.plot_in_3D_state_space(np.array([1, 2, 3])) # 1D array + + with pytest.raises(ValueError): + a = np.zeros((10, 3)) + b = np.zeros((9, 3)) + vis.plot_in_3D_state_space([a, b]) # mismatched shapes