Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 52 additions & 38 deletions ml/guides/text_classification/explore_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

Contains functions to help study, visualize and understand datasets.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Expand All @@ -18,8 +19,8 @@ def get_num_classes(labels):

# Arguments
labels: list, label values.
There should be at lease one sample for values in the
range (0, num_classes -1)
There should be at least one sample for values in the
range(0, num_classes - 1)

# Returns
int, total number of classes.
Expand All @@ -31,17 +32,21 @@ def get_num_classes(labels):
num_classes = max(labels) + 1
missing_classes = [i for i in range(num_classes) if i not in labels]
if len(missing_classes):
raise ValueError('Missing samples with label value(s) '
'{missing_classes}. Please make sure you have '
'at least one sample for every label value '
'in the range(0, {max_class})'.format(
missing_classes=missing_classes,
max_class=num_classes - 1))
raise ValueError(
"Missing samples with label value(s) "
"{missing_classes}. Please make sure you have "
"at least one sample for every label value "
"in the range(0, {max_class})".format(
missing_classes=missing_classes, max_class=num_classes - 1
)
)

if num_classes <= 1:
raise ValueError('Invalid number of labels: {num_classes}.'
'Please make sure there are at least two classes '
'of samples'.format(num_classes=num_classes))
raise ValueError(
"Invalid number of labels: {num_classes}."
"Please make sure there are at least two classes "
"of samples".format(num_classes=num_classes)
)
return num_classes


Expand All @@ -58,53 +63,62 @@ def get_num_words_per_sample(sample_texts):
return np.median(num_words)


def plot_frequency_distribution_of_ngrams(sample_texts,
ngram_range=(1, 2),
num_ngrams=50):
def plot_frequency_distribution_of_ngrams(
sample_texts, ngram_range=(1, 1), num_ngrams=50
):
"""Plots the frequency distribution of n-grams.

# Arguments
samples_texts: list, sample texts.
ngram_range: tuple (min, mplt), The range of n-gram values to consider.
Min and mplt are the lower and upper bound values for the range.
ngram_range: tuple (min, max), The range of n-gram values to consider.
min and max are the lower and the upper bound values for the range.
num_ngrams: int, number of n-grams to plot.
Top `num_ngrams` frequent n-grams will be plotted.
"""
# Create args required for vectorizing.
kwargs = {
'ngram_range': (1, 1),
'dtype': 'int32',
'strip_accents': 'unicode',
'decode_error': 'replace',
'analyzer': 'word', # Split text into word tokens.
"ngram_range": ngram_range,
"dtype": "int32",
"strip_accents": "unicode",
"decode_error": "replace",
"analyzer": "word", # Split text into word tokens.
}
vectorizer = CountVectorizer(**kwargs)

# This creates a vocabulary (dict, where keys are n-grams and values are
# idxices). This also converts every text to an array the length of
# vocabulary, where every element idxicates the count of the n-gram
# corresponding at that idxex in vocabulary.
# indices). This also converts every text to an array the length of
# vocabulary, where every element represents the count of the n-gram
# corresponding at that index in vocabulary.
vectorized_texts = vectorizer.fit_transform(sample_texts)

# This is the list of all n-grams in the index order from the vocabulary.
all_ngrams = list(vectorizer.get_feature_names())
all_ngrams = list(vectorizer.get_feature_names_out())
num_ngrams = min(num_ngrams, len(all_ngrams))
# ngrams = all_ngrams[:num_ngrams]

# Add up the counts per n-gram ie. column-wise
all_counts = vectorized_texts.sum(axis=0).tolist()[0]

# Sort n-grams and counts by frequency and get top `num_ngrams` ngrams.
all_counts, all_ngrams = zip(*[(c, n) for c, n in sorted(
zip(all_counts, all_ngrams), reverse=True)])
all_counts, all_ngrams = zip(
*[(c, n) for c, n in sorted(zip(all_counts, all_ngrams), reverse=True)]
)
ngrams = list(all_ngrams)[:num_ngrams]
counts = list(all_counts)[:num_ngrams]

idx = np.arange(num_ngrams)
plt.bar(idx, counts, width=0.8, color='b')
plt.xlabel('N-grams')
plt.ylabel('Frequencies')
plt.title('Frequency distribution of n-grams')

f, ax = plt.subplots(
figsize=(12, 5)
)
plt.bar(idx, counts, width=0.8, color="b")
plt.xlabel("Top {num_ngrams} N-grams".format(num_ngrams=num_ngrams))
plt.ylabel("Frequencies")
plt.title(
"Frequency distribution of n-grams with range={ngram_range}".format(
ngram_range=ngram_range
)
)
plt.xticks(idx, ngrams, rotation=45)
plt.show()

Expand All @@ -116,9 +130,9 @@ def plot_sample_length_distribution(sample_texts):
samples_texts: list, sample texts.
"""
plt.hist([len(s) for s in sample_texts], 50)
plt.xlabel('Length of a sample')
plt.ylabel('Number of samples')
plt.title('Sample length distribution')
plt.xlabel("Length of a sample")
plt.ylabel("Number of samples")
plt.title("Sample length distribution")
plt.show()


Expand All @@ -134,9 +148,9 @@ def plot_class_distribution(labels):
count_map = Counter(labels)
counts = [count_map[i] for i in range(num_classes)]
idx = np.arange(num_classes)
plt.bar(idx, counts, width=0.8, color='b')
plt.xlabel('Class')
plt.ylabel('Number of samples')
plt.title('Class distribution')
plt.bar(idx, counts, width=0.8, color="b")
plt.xlabel("Class")
plt.ylabel("Number of samples")
plt.title("Class distribution")
plt.xticks(idx, idx)
plt.show()