Skip to content

Commit f9df326

Browse files
committed
Feature: Improve model weight plot
1 parent 64caed2 commit f9df326

1 file changed

Lines changed: 7 additions & 7 deletions

File tree

train/viz/model_weights_plot.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
FeatureWeights = dict[str, dict[str, list[float]]]
5858

5959

60-
def plot(features: FeatureWeights, output: pathlib.Path) -> str:
60+
def plot(features: FeatureWeights, output: pathlib.Path) -> pathlib.Path:
6161
"""Plot figure showing distribution of weights for each label for each feature.
6262
6363
Parameters
@@ -69,7 +69,7 @@ def plot(features: FeatureWeights, output: pathlib.Path) -> str:
6969
7070
Returns
7171
-------
72-
str
72+
pathlib.Path
7373
Path to saved file.
7474
"""
7575
fig, ax = plt.subplots(1, 1, figsize=(10, 100))
@@ -88,8 +88,8 @@ def plot(features: FeatureWeights, output: pathlib.Path) -> str:
8888
"SIZE",
8989
"PUNC",
9090
}
91-
for i, (feat, weights) in enumerate(features.items()):
92-
for label, weights in weights.items():
91+
for i, (_, feat_weights) in enumerate(sorted(features.items(), key=lambda x: x[0])):
92+
for label, weights in feat_weights.items():
9393
y = [i + LABEL_OFFSET[label]] * len(weights)
9494

9595
label_prefix = "" if label in unlabelled_lines else "_"
@@ -102,7 +102,7 @@ def plot(features: FeatureWeights, output: pathlib.Path) -> str:
102102
ax.hlines(i + 0.5, x_min, x_max, color="#7c6f64")
103103

104104
ax.set_yticks(list(range(len(features))))
105-
ax.set_yticklabels(list(features.keys()))
105+
ax.set_yticklabels(list(sorted(features.keys())))
106106
ax.set_ylim((-0.5, len(features) - 0.5))
107107
ax.set_xlim((x_min, x_max))
108108
ax.grid(True, axis="x")
@@ -128,11 +128,11 @@ def load_model_features(model_path: str) -> FeatureWeights:
128128
FeatureWeights
129129
Weights for each feature.
130130
"""
131-
tagger = pycrfsuite.Tagger()
131+
tagger = pycrfsuite.Tagger() # type: ignore
132132
tagger.open(str(model_path))
133133

134134
tagger_features = tagger.info()
135-
features = defaultdict(lambda: defaultdict(list))
135+
features: FeatureWeights = defaultdict(lambda: defaultdict(list))
136136
for (feature, label), weight in tagger_features.state_features.items():
137137
feature_name = feature.split(":", 1)[0]
138138
features[feature_name][label].append(weight)

0 commit comments

Comments
 (0)