diff --git a/tuned_lens/plotting/trajectory_plotting.py b/tuned_lens/plotting/trajectory_plotting.py index 1b89e83..de842a4 100644 --- a/tuned_lens/plotting/trajectory_plotting.py +++ b/tuned_lens/plotting/trajectory_plotting.py @@ -178,8 +178,7 @@ def heatmap( y=self._layer_labels, z=self.stats if not log_scale else np.log10(self.stats), colorbar=dict( - title=f"{self.name} ({self.units})", - titleside="right", + title=f"{self.name} ({self.units})" ), colorscale=colorscale, zmax=max if not log_scale else np.log10(max),