@@ -128,12 +128,12 @@ def shorten_name(name: str, depth: int | None) -> str:
128128
129129 # Set tick labels
130130 if layers1 is not None :
131- shortened = [shorten_name (l , layer_name_depth ) for l in layers1 ]
131+ shortened = [shorten_name (layer , layer_name_depth ) for layer in layers1 ]
132132 ax .set_yticks (range (n_layers1 ))
133133 ax .set_yticklabels (shortened , fontsize = tick_fontsize )
134134
135135 if layers2 is not None :
136- shortened = [shorten_name (l , layer_name_depth ) for l in layers2 ]
136+ shortened = [shorten_name (layer , layer_name_depth ) for layer in layers2 ]
137137 ax .set_xticks (range (n_layers2 ))
138138 ax .set_xticklabels (shortened , fontsize = tick_fontsize , rotation = 45 , ha = "right" )
139139
@@ -197,7 +197,11 @@ def plot_cka_trend(
197197 """
198198 # Normalize input to list of arrays
199199 if isinstance (cka_values , (torch .Tensor , np .ndarray )):
200- arr = cka_values .detach ().cpu ().numpy () if isinstance (cka_values , torch .Tensor ) else cka_values
200+ arr = (
201+ cka_values .detach ().cpu ().numpy ()
202+ if isinstance (cka_values , torch .Tensor )
203+ else cka_values
204+ )
201205 if arr .ndim == 1 :
202206 cka_values = [arr ]
203207 else :
@@ -304,6 +308,7 @@ def plot_cka_trend_with_range(
304308 Returns:
305309 Tuple of (Figure, Axes).
306310 """
311+
307312 # Convert to numpy
308313 def to_numpy (arr ):
309314 if isinstance (arr , torch .Tensor ):
@@ -415,7 +420,7 @@ def plot_cka_comparison(
415420 if figsize is None :
416421 figsize = (5 * ncols , 4 * nrows )
417422
418- fig , axes = plt .subplots (nrows , ncols , figsize = figsize , constrained_layout = True )
423+ fig , axes = plt .subplots (nrows , ncols , figsize = figsize , constrained_layout = share_colorbar )
419424 axes = np .atleast_2d (axes )
420425
421426 # Find global min/max for shared colorbar
@@ -458,6 +463,8 @@ def plot_cka_comparison(
458463 sm = plt .cm .ScalarMappable (norm = norm , cmap = cmap )
459464 sm .set_array ([])
460465 fig .colorbar (sm , ax = axes , fraction = 0.02 , pad = 0.02 , label = "CKA Similarity" )
466+ else :
467+ fig .tight_layout ()
461468
462469 if show :
463470 plt .show ()
0 commit comments