diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..a1b2739 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,21 @@ +name: CI + +on: + push: + branches: [master] + pull_request: + branches: [master] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + node-version: [14, 16, 18, 20] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-node@v4 + with: + node-version: ${{ matrix.node-version }} + - run: npm install + - run: npm test diff --git a/README.md b/README.md index b5a2a95..b280ffa 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,19 @@ # classificator +[![CI](https://github.com/Wozacosta/classificator/actions/workflows/ci.yml/badge.svg)](https://github.com/Wozacosta/classificator/actions/workflows/ci.yml) [![NPM Licence shield](https://img.shields.io/github/license/Wozacosta/classificator.svg)](https://github.com/Wozacosta/classificator/blob/master/LICENSE) [![NPM release version shield](https://img.shields.io/npm/v/classificator.svg)](https://www.npmjs.com/package/classificator) -Naive Bayes classifier for node.js +A fast, lightweight Naive Bayes classifier for Node.js with explainable predictions. -`bayes` takes a document (piece of text), and tells you what category that document belongs to. +``` + +-----------------+ + "great movie" -->| classificator |--> { predictedCategory: "positive", proba: 0.83 } + +-----------------+ + | trained on | + | your data | + +--------------+ +``` ## What can I use this for? @@ -16,83 +24,254 @@ You can use this for categorizing any text content into any arbitrary set of **c - is a news article about **technology**, **politics**, or **sports** ? - is a piece of text expressing **positive** emotions, or **negative** emotions? +``` + +----------+ + +--->| positive | 0.72 + "awesome movie" | +----------+ + | | +----------+ + v +-->| negative | 0.18 + [ tokenize ] | +----------+ + | | +----------+ + v +-->| neutral | 0.10 + [ calculate ]------+ +----------+ + [ probability ] +``` + More here: https://en.wikipedia.org/wiki/Naive_Bayes_classifier ## Installing -Recommended: Node v6.0.0 + +Recommended: Node v14.0.0 + ``` npm install --save classificator ``` -## Usage +## Quick Start -``` +```js const bayes = require('classificator') const classifier = bayes() + +// Train +classifier.learn('amazing, awesome movie!', 'positive') +classifier.learn('terrible, boring film', 'negative') + +// Classify +const result = classifier.categorize('awesome film') +console.log(result.predictedCategory) // => 'positive' ``` -### Teach your classifier +## How It Works + +Classificator uses the [Naive Bayes](https://en.wikipedia.org/wiki/Naive_Bayes_classifier) algorithm with Laplace smoothing. Here's the pipeline: + +``` + Input Text + | + v ++-------------+ +------------------+ +-------------------+ +| Tokenizer |---->| Preprocessor |---->| Frequency Table | +| split words | | stopwords/stem | | count each word | ++-------------+ +------------------+ +-------------------+ + | + +---------------------------------------+ + | + v ++---------------------------+ +------------------+ +| For each category: | | Normalize with | +| P(cat) * P(w1|cat) * |---->| logsumexp for | +| P(w2|cat) * ... | | final proba | ++---------------------------+ +------------------+ + | + v + +------------------+ + | Return sorted | + | likelihoods + | + | predictedCategory| + +------------------+ ``` + +**Laplace smoothing** prevents zero-probability issues — even words never seen in a category get a small probability instead of zeroing everything out. + + +## Usage + +### Teach your classifier + +```js classifier.learn('amazing, awesome movie! Had a good time', 'positive') classifier.learn('Buy my free viagra pill and get rich!', 'spam') classifier.learn('I really hate dust and annoying cats', 'negative') classifier.learn('LOL this sucks so hard', 'troll') ``` -### Make your classifier unlearn +### Batch learning +```js +classifier.learnBatch([ + { text: 'amazing, awesome movie!', category: 'positive' }, + { text: 'Buy my free viagra pill', category: 'spam' }, + { text: 'I really hate dust', category: 'negative' } +]) ``` + +### Make your classifier unlearn + +```js classifier.learn('i hate mornings', 'positive'); -// uh oh, that was mistake. Time to unlearn +// uh oh, that was a mistake. Time to unlearn classifier.unlearn('i hate mornings', 'positive'); ``` +If the last document in a category is unlearned, the category is automatically removed. + ### Remove a category -``` +```js classifier.removeCategory('troll'); ``` -### categorization +### Categorization -``` +```js classifier.categorize("I've always hated Martians"); // => { - likelihoods: [ - { - category: 'negative', - logLikelihood: -17.241944258040537, - logProba: -0.6196197927020783, - proba: 0.538149006882628 - }, { - category: 'positive', - logLikelihood: -17.93509143860048, - logProba: -1.312766973262022, - proba: 0.26907450344131445 - }, { - category: 'spam', - logLikelihood: -18.26854831109384, - logProba: -1.646223845755383, - proba: 0.19277648967605832 } - ], - predictedCategory: 'negative' - } +// likelihoods: [ +// { category: 'negative', proba: 0.538, logLikelihood: -17.24, logProba: -0.62 }, +// { category: 'positive', proba: 0.269, logLikelihood: -17.94, logProba: -1.31 }, +// { category: 'spam', proba: 0.193, logLikelihood: -18.27, logProba: -1.65 } +// ], +// predictedCategory: 'negative' +// } ``` -### serialize the classifier's state as a JSON string. +### Categorize with confidence threshold -`let stateJson = classifier.toJson()` +Reject low-confidence predictions instead of guessing: -### load the classifier back from its JSON representation. +```js +classifier.categorizeWithConfidence('some ambiguous text', 0.7); +// => predictedCategory is null if the top probability is below 0.7 +// likelihoods array is always returned in full +``` -`let revivedClassifier = bayes.fromJson(stateJson)` +``` + "ambiguous text" + | + v + [ categorize ] + | + proba = 0.42 + | + 0.42 < 0.70 ? --yes--> predictedCategory: null (rejected) + | + no + | + v + predictedCategory: "spam" (accepted) +``` + +### Get top N categories + +```js +classifier.categorizeTopN("I've always hated Martians", 2); +// => same as categorize(), but likelihoods array has at most 2 entries +``` + +### Understand why a prediction was made -note: `stateJson` can either be a JSON string (obtained from `classifier.toJson()`), or an object +```js +classifier.topInfluentialTokens("I've always hated Martians", 3); +// => [ +// { token: 'hated', probability: 0.42, frequency: 1 }, +// { token: 'always', probability: 0.21, frequency: 1 }, +// { token: 'Martians', probability: 0.12, frequency: 1 } +// ] +``` + +``` + "I've always hated Martians" --> predicted: negative + | + Why? v + +----------------------------------------------------+ + | Token | P(token|negative) | Influence | + |-----------|-------------------|--------------------| + | hated | 0.42 | ################## | + | always | 0.21 | ######### | + | Martians | 0.12 | ##### | + +----------------------------------------------------+ +``` + +### Serialize / Deserialize + +```js +// Save +let stateJson = classifier.toJson() + +// Restore +let revivedClassifier = bayes.fromJson(stateJson) +``` + +`stateJson` can be a JSON string or a plain object. + +**Important:** Functions (`tokenizer`, `tokenPreprocessor`) can't be serialized to JSON. Pass them back when restoring: + +```js +let revivedClassifier = bayes.fromJson(stateJson, { + tokenizer: myTokenizer, + tokenPreprocessor: myPreprocessor +}) +``` + +``` + Classifier JSON String Classifier + (in memory) (on disk) (restored) + | | | + +--- toJson() --------------->| | + | +--- fromJson(json, opts) --->| + | | ^ | + | tokenizer: fn - LOST | | | + | alpha: 0.5 - KEPT | pass functions | + | fitPrior: true - KEPT | back in opts | + | | | +``` + +### Inspect your classifier + +```js +classifier.getCategories() +// => ['positive', 'spam', 'negative', 'troll'] + +classifier.getCategoryStats() +// => { +// positive: { docCount: 1, wordCount: 7, vocabularySize: 7 }, +// spam: { docCount: 1, wordCount: 8, vocabularySize: 8 }, +// ... +// _total: { docCount: 4, wordCount: 25, vocabularySize: 20 } +// } +``` + +### Reset the classifier + +```js +classifier.reset() +// clears all learned data but preserves options (tokenizer, alpha, fitPrior) +``` + +### Method chaining + +Most methods return `this`, so you can chain calls: + +```js +const result = bayes() + .learn('happy fun', 'positive') + .learn('sad bad', 'negative') + .categorize('happy') +``` -------- @@ -104,71 +283,263 @@ note: `stateJson` can either be a JSON string (obtained from `classifier.toJson( Returns an instance of a Naive-Bayes Classifier. -Pass in an optional `options` object to configure the instance. +| Option | Type | Default | Description | +|----------------------|------------|----------------------------|-------------------------------------------------------------------------------------------------| +| `tokenizer` | `Function` | Splits on whitespace/punct | Custom tokenization function. Receives `text` (string), must return an array of string tokens. | +| `tokenPreprocessor` | `Function` | none | Transform tokens after tokenization (e.g. stopword removal, stemming, lowercasing). Receives and returns an array of tokens. | +| `alpha` | `number` | `1` | Additive (Laplace) smoothing parameter. Higher values = more conservative predictions. `0` disables smoothing (can cause zero-probability issues). | +| `fitPrior` | `boolean` | `true` | If `true`, prior probability is proportional to learned document frequencies (categories with more training docs are favored). If `false`, uses uniform prior (all categories equally likely before seeing the text). | -If you specify a `tokenizer` function in `options`, it will be used as the instance's tokenizer. It receives a (string) `text` argument - this is the string value that is passed in by you when you call `.learn()` or `.categorize()`. It must return an array of tokens. The default tokenizer removes punctuation and splits on spaces. +```js +let classifier = bayes({ + tokenizer: function (text) { return text.split(' ') }, + tokenPreprocessor: function (tokens) { + var stopwords = new Set(['the', 'a', 'is', 'in']) + return tokens + .map(function (t) { return t.toLowerCase() }) + .filter(function (t) { return !stopwords.has(t) }) + }, + alpha: 0.5, + fitPrior: false +}) +``` -Eg. +#### Understanding `alpha` (Laplace smoothing) ``` -let classifier = bayes({ - tokenizer: function (text) { return text.split(' ') } -}) + alpha controls how much probability "leaks" to unseen words: + + alpha = 0 Unseen words get 0 probability. Risky. + alpha = 0.5 Lidstone smoothing. Less aggressive. + alpha = 1 Standard Laplace. Good default. <-- default + alpha = 10 Very conservative. Small datasets. + + Effect on P(word|category): + + P(word|cat) = (count + alpha) / (total + alpha * vocabSize) + ────────────── ───────────────────────────── + numerator gets denominator grows with alpha + a boost spreading probability to all + possible words ``` -You can specify the `alpha` parameter of the [additive smoothing operation](https://en.wikipedia.org/wiki/Additive_smoothing). -This is an integer. -The default value is 1 +#### Understanding `fitPrior` -You can also specify the `fitPrior` parameter. -Defines how the [prior probablity](https://en.wikipedia.org/wiki/Prior_probability) is calculated. -If set to `false`, the classifier will use an uniform prior rather than a learnt one. -The default value is `true`. +``` + fitPrior: true (default) fitPrior: false + ───────────────────────── ──────────────────────── + P(cat) = docCount / total P(cat) = 1 (uniform) + + 900 positive docs + 100 negative Same data, but: + P(positive) = 0.9 P(positive) = P(negative) + P(negative) = 0.1 Only word content matters + + Good when training data Good when training data + reflects real-world is imbalanced but you want + distribution fair comparison +``` ### `classifier.learn(text, category)` -Teach your classifier what `category` should be associated with an array `text` of words. +Teach your classifier what `category` should be associated with a `text` string. -### `classifier.unlearn(text, category)` +Returns `this` for chaining. Throws `TypeError` if text or category is not a string. -The classifier will unlearn the `text` that was associated with `category`. +### `classifier.learnBatch(items)` -### `classifier.removeCategory(category)` +Learn from multiple text/category pairs at once. `items` is an array of `{ text, category }` objects. -The category is removed and the classifier data are updated accordingly. +Returns `this` for chaining. Throws `TypeError` if items is not an array. -### `classifier.categorize(text)` +### `classifier.unlearn(text, category)` -*Parameters* +The classifier will unlearn the `text` that was associated with `category`. If the last document in a category is unlearned, the category is automatically removed. -`text {String}` +Returns `this` for chaining. Throws `Error` if the category does not exist. -*Returns* +### `classifier.removeCategory(category)` -`{Object}` An object with the `predictedCategory` and an array of the categories -ordered by likelihood (most likely first). +The category is removed and the classifier data are updated accordingly. Vocabulary is cleaned up: tokens only present in the removed category are removed from the global vocabulary. No-op if the category does not exist. -``` +Returns `this` for chaining. + +### `classifier.categorize(text)` + +Returns `{Object}` with `predictedCategory` and `likelihoods` array sorted by probability (highest first). Returns `{ predictedCategory: null, likelihoods: [] }` if no categories have been learned. + +```js { - likelihoods : [ - ... - { - category: 'positive', - logLikelihood: -17.93509143860048, - logProba: -1.312766973262022, - proba: 0.26907450344131445 - }, + likelihoods: [ + { category: 'positive', logLikelihood: -17.94, logProba: -1.31, proba: 0.27 }, ... ], - predictedCategory : 'negative' //--> the main category bayes thinks text - belongs to. As a string + predictedCategory: 'negative' } ``` +### `classifier.categorizeWithConfidence(text, threshold)` + +Like `categorize()`, but sets `predictedCategory` to `null` if the top category's probability is below `threshold` (a number between 0 and 1). The `likelihoods` array is always returned in full. Throws `TypeError` if threshold is invalid. + +### `classifier.categorizeTopN(text, n)` + +Like `categorize()`, but returns only the top `n` most likely categories in the likelihoods array. + +### `classifier.topInfluentialTokens(text[, n])` + +Returns the top `n` (default 5) tokens that most influenced the predicted category, sorted by probability. Each entry has `{ token, probability, frequency }`. + +### `classifier.getCategories()` + +Returns an array of all category names the classifier has learned. + +### `classifier.getCategoryStats()` + +Returns an object with per-category stats (`docCount`, `wordCount`, `vocabularySize`) and a `_total` key with aggregate stats including total `wordCount`. + +### `classifier.reset()` + +Resets the classifier to its initial untrained state, preserving configuration options. + +Returns `this` for chaining. + ### `classifier.toJson()` Returns the JSON representation of a classifier. -### `let classifier = bayes.fromJson(jsonStr)` +### `let classifier = bayes.fromJson(jsonStr[, options])` + +Returns a classifier instance from the JSON representation. Use this with `classifier.toJson()`. + +`jsonStr` can be a JSON string or a plain object. + +`options` is an optional object for runtime-only options (e.g. `{ tokenizer: fn, tokenPreprocessor: fn }`) that cannot be serialized to JSON. + + +-------- + + +## Typical Workflows + +### Spam Filter + +``` + +-----------+ +-----------+ +-------------+ +--------+ + | Collect |---->| Train |---->| Serialize |---->| Deploy | + | emails | | classifier| | to JSON | | in app | + +-----------+ +-----------+ +-------------+ +--------+ + | | + learn('buy now fromJson(saved) then + free!!!', 'spam') categorize(newEmail) + learn('meeting at + 3pm', 'ham') +``` + +### Sentiment Analysis with Preprocessing + +```js +const classifier = bayes({ + tokenPreprocessor: (tokens) => { + const stops = new Set(['the', 'a', 'is', 'it', 'and', 'of', 'to']) + return tokens + .map(t => t.toLowerCase()) + .filter(t => !stops.has(t) && t.length > 2) + } +}) + +// Train on labeled reviews +reviews.forEach(r => classifier.learn(r.text, r.sentiment)) + +// Classify new review +const result = classifier.categorize('This product is absolutely amazing!') +if (result.likelihoods[0].proba > 0.7) { + console.log(`Confident: ${result.predictedCategory}`) +} else { + console.log('Uncertain, needs human review') +} +``` + +### Model Persistence + +```js +const fs = require('fs') + +// Save trained model +fs.writeFileSync('model.json', classifier.toJson()) + +// Load later +const saved = fs.readFileSync('model.json', 'utf8') +const classifier = bayes.fromJson(saved, { tokenizer: myTokenizer }) +``` + + +-------- + + +## Test Suite + +The library includes a comprehensive test suite with **109 tests**: + +``` + Unit tests (82) - Individual method correctness, edge cases, + parameter validation, numerical stability + + Integration tests (7) - Feature combinations: serialize/restore pipelines, + learn/unlearn/relearn cycles, preprocessor + consistency, method chaining workflows + + E2E tests (20) - Real-world scenarios: spam detection, sentiment + analysis, multi-category topic classification, + incremental learning, mistake correction, + imbalanced dataset handling +``` + +Run with: + +``` +npm test +``` + + +-------- + + +## Changelog + +### 0.5.0 + +**New features:** +- `tokenPreprocessor` option for stopword removal, stemming, and custom token transforms +- `categorizeWithConfidence(text, threshold)` for rejecting low-confidence predictions +- `topInfluentialTokens(text, n)` for explainable classification +- `getCategories()`, `categorizeTopN()`, `learnBatch()`, `reset()`, `getCategoryStats()` +- Input validation on all public methods (throws TypeError for non-string inputs) + +**Bug fixes:** +- Fixed `alpha: 0` being silently overridden to `1` +- Fixed `fromJson(null)` crash +- Fixed `unlearn()` not cleaning up categories when last document is removed +- Fixed `unlearn()` crash on non-existent category +- Fixed `categorize()` crash on empty classifier (now returns `predictedCategory: null`) +- Fixed default tokenizer returning empty tokens for empty strings +- Fixed `removeCategory()` not guarding against negative vocabulary counts +- Fixed `wordCount` going negative in `unlearn()` edge cases +- Fixed logsumexp numerical instability (now uses max-subtraction trick) +- Fixed `fromJson()` losing runtime options after state restoration +- Fixed error message typo and inconsistent capitalization + +**Improvements:** +- Numerically stable logsumexp prevents underflow on large documents +- Tokenizer and tokenPreprocessor validation at construction time +- `getCategoryStats()` now includes `wordCount` in `_total` +- GitHub Actions CI for Node 14/16/18/20 +- Comprehensive test suite (109 tests: unit + integration + E2E) +- Improved JSDoc and README documentation with diagrams + +### 0.4.0 + +- Allow custom tokenizer to be passed to `fromJson()` + +### 0.3.4 -Returns a classifier instance from the JSON representation. Use this with the JSON representation obtained from `classifier.toJson()` +- Initial tracked version diff --git a/lib/classificator.js b/lib/classificator.js index 73e21ba..af650a2 100644 --- a/lib/classificator.js +++ b/lib/classificator.js @@ -1,4 +1,4 @@ -const Decimal = require('decimal.js').default; // handles arbitrary-precision arithmetics. +const Decimal = require('decimal.js').default; /* Expose our naive-bayes generator function @@ -29,6 +29,7 @@ const DEFAULT_FIT_PRIOR = true; * @param {String|Object} jsonStrOrObject state representation obtained by classifier.toJson() * @param {Object} [options] optional options object (e.g. { tokenizer: fn }) * @return {NaiveBayes} Classifier + * @throws {Error} If input is not a valid JSON string or object, or if required state keys are missing. */ module.exports.fromJson = (jsonStrOrObject, options) => { let parameters; @@ -40,6 +41,9 @@ module.exports.fromJson = (jsonStrOrObject, options) => { break; case 'object': + if (jsonStrOrObject === null) { + throw new Error(''); + } parameters = jsonStrOrObject; break; @@ -47,8 +51,7 @@ module.exports.fromJson = (jsonStrOrObject, options) => { throw new Error(''); } } catch (e) { - console.error(e); - throw new Error('Naivebays.fromJson expects a valid JSON string or an object.'); + throw new Error('NaiveBayes.fromJson expects a valid JSON string or an object.'); } // merge any runtime-only options (e.g. tokenizer) into the restored options @@ -61,12 +64,16 @@ module.exports.fromJson = (jsonStrOrObject, options) => { STATE_KEYS.forEach((k) => { if (typeof parameters[k] === 'undefined') { throw new Error( - `Naivebayes.fromJson: JSON string is missing an expected property: [${k}].` + `NaiveBayes.fromJson: JSON string is missing an expected property: [${k}].` ); } classifier[k] = parameters[k]; }); + // restore the merged options (STATE_KEYS includes 'options' which overwrites + // with the saved state, losing runtime-only options like tokenizer/tokenPreprocessor) + classifier.options = restoredOptions; + return classifier; }; @@ -78,15 +85,9 @@ module.exports.fromJson = (jsonStrOrObject, options) => { * @return {Array} */ const defaultTokenizer = (text) => { - // remove punctuation from text - remove anything that isn't a word char or a space const rgxPunctuation = /[^(a-zA-ZA-Яa-я0-9_)+\s]/g; - const sanitized = text.replace(rgxPunctuation, ' '); - // tokens = tokens.filter(function(token) { - // return token.length >= _that.config.minimumLength; - // }); - - return sanitized.split(/\s+/); + return sanitized.split(/\s+/).filter(token => token.length > 0); }; /** @@ -94,52 +95,71 @@ const defaultTokenizer = (text) => { * * This is a naive-bayes classifier that uses Laplace Smoothing. * - * Takes an (optional) options object containing: - * - `tokenizer` => custom tokenization function - * + * @param {Object} [options] Configuration options + * @param {Function} [options.tokenizer] Custom tokenization function. Receives a string, + * must return an array of string tokens. + * @param {Function} [options.tokenPreprocessor] Optional function to transform tokens after + * tokenization (e.g. stopword removal, stemming). + * Receives an array of tokens, must return an array. + * @param {number} [options.alpha=1] Additive (Laplace) smoothing parameter. + * Higher values = more conservative predictions. + * Set to 0 to disable smoothing. + * @param {boolean} [options.fitPrior=true] Whether to use learned prior probabilities. + * When false, uses uniform prior (all categories + * equally likely before seeing the text). + * @throws {TypeError} If options is truthy but not a plain object, or tokenizer/tokenPreprocessor is not a function. */ function Naivebayes(options) { - // set options object this.options = {}; if (typeof options !== 'undefined') { if (!options || typeof options !== 'object' || Array.isArray(options)) { throw TypeError( - `NaiveBayes got invalid 'options': ${options}'. Pass in an object.` + `NaiveBayes got invalid 'options': '${options}'. Pass in an object.` ); } this.options = options; } + if (this.options.tokenizer && typeof this.options.tokenizer !== 'function') { + throw TypeError('NaiveBayes: tokenizer must be a function.'); + } + if (this.options.tokenPreprocessor && typeof this.options.tokenPreprocessor !== 'function') { + throw TypeError('NaiveBayes: tokenPreprocessor must be a function.'); + } + this.tokenizer = this.options.tokenizer || defaultTokenizer; - this.alpha = this.options.alpha || DEFAULT_ALPHA; + this.tokenPreprocessor = this.options.tokenPreprocessor || null; + this.alpha = this.options.alpha === undefined ? DEFAULT_ALPHA : this.options.alpha; this.fitPrior = this.options.fitPrior === undefined ? DEFAULT_FIT_PRIOR : this.options.fitPrior; - // initialize our vocabulary and its size + this.vocabulary = {}; this.vocabularySize = 0; - - // number of documents we have learned from this.totalDocuments = 0; - - // document frequency table for each of our categories - //= > for each category, how often were documents mapped to it this.docCount = {}; - - // for each category, how many words total were mapped to it this.wordCount = {}; - - // word frequency table for each category - //= > for each category, how frequent was a given word mapped to it this.wordFrequencyCount = {}; - - // hashmap of our category names this.categories = {}; } +/** + * Tokenize text and optionally apply the preprocessor. + * + * @param {String} text + * @return {Array} tokens + */ +Naivebayes.prototype.tokenize = function(text) { + const tokens = this.tokenizer(text); + if (this.tokenPreprocessor) { + return this.tokenPreprocessor(tokens); + } + return tokens; +}; /** - * Initialize each of our data structure entries for this new category + * Initialize each of our data structure entries for this new category. * * @param {String} categoryName + * @return {Naivebayes} this */ Naivebayes.prototype.initializeCategory = function(categoryName) { if (!this.categories[categoryName]) { @@ -155,17 +175,19 @@ Naivebayes.prototype.initializeCategory = function(categoryName) { * Properly remove a category, unlearning all words that were associated to it. * * @param {String} categoryName + * @return {Naivebayes} this */ Naivebayes.prototype.removeCategory = function(categoryName) { if (!this.categories[categoryName]) { return this; } - // update the total number of documents we have learned from this.totalDocuments -= this.docCount[categoryName]; Object.keys(this.wordFrequencyCount[categoryName]).forEach((token) => { - this.vocabulary[token]--; - if (this.vocabulary[token] === 0) this.vocabularySize--; + if (this.vocabulary[token] && this.vocabulary[token] > 0) { + this.vocabulary[token]--; + if (this.vocabulary[token] === 0) this.vocabularySize--; + } }); delete this.docCount[categoryName]; @@ -177,48 +199,44 @@ Naivebayes.prototype.removeCategory = function(categoryName) { }; /** - * train our naive-bayes classifier by telling it what `category` + * Train our naive-bayes classifier by telling it what `category` * the `text` corresponds to. * * @param {String} text * @param {String} category Category to learn as being text + * @return {Naivebayes} this + * @throws {TypeError} If text or category is not a string. */ Naivebayes.prototype.learn = function(text, category) { - // initialize category data structures if we've never seen this category + if (typeof text !== 'string') { + throw new TypeError(`NaiveBayes: text must be a string, got ${typeof text}.`); + } + if (typeof category !== 'string') { + throw new TypeError(`NaiveBayes: category must be a string, got ${typeof category}.`); + } + this.initializeCategory(category); - // update our count of how many documents mapped to this category this.docCount[category]++; - - // update the total number of documents we have learned from this.totalDocuments++; - // normalize the text into a word array - const tokens = this.tokenizer(text); - - // get a frequency count for each token in the text + const tokens = this.tokenize(text); const frequencyTable = this.frequencyTable(tokens); Object.keys(frequencyTable).forEach((token) => { const frequencyInText = frequencyTable[token]; - // add this word to our vocabulary if not already existing if (!this.vocabulary[token] || this.vocabulary[token] === 0) { this.vocabularySize++; this.vocabulary[token] = 1; - // this.vocabulary[token] = frequencyInText; - } else if (this.vocabulary[token] > 0) { + } else { this.vocabulary[token]++; - // this.vocabulary[token] += frequencyInText; } - - // update the frequency information for this word in this category if (!this.wordFrequencyCount[category][token]) { this.wordFrequencyCount[category][token] = frequencyInText; } else this.wordFrequencyCount[category][token] += frequencyInText; - // update the count of all words we have seen mapped to this category this.wordCount[category] += frequencyInText; }); @@ -226,77 +244,93 @@ Naivebayes.prototype.learn = function(text, category) { }; /** - * untrain our naive-bayes classifier by telling it what `category` + * Untrain our naive-bayes classifier by telling it what `category` * the `text` to remove corresponds to. * * @param {String} text * @param {String} category Category to unlearn as being text + * @return {Naivebayes} this + * @throws {TypeError} If text or category is not a string. + * @throws {Error} If category does not exist. */ Naivebayes.prototype.unlearn = function(text, category) { - // update our count of how many documents mapped to this category + if (typeof text !== 'string') { + throw new TypeError(`NaiveBayes: text must be a string, got ${typeof text}.`); + } + if (typeof category !== 'string') { + throw new TypeError(`NaiveBayes: category must be a string, got ${typeof category}.`); + } + if (!this.categories[category]) { + throw new Error(`NaiveBayes: cannot unlearn from non-existent category: '${category}'.`); + } + this.docCount[category]--; if (this.docCount[category] === 0) { delete this.docCount[category]; } - // update the total number of documents we have learned from this.totalDocuments--; - // normalize the text into a word array - const tokens = this.tokenizer(text); - - // get a frequency count for each token in the text + const tokens = this.tokenize(text); const frequencyTable = this.frequencyTable(tokens); - /* - Update our vocabulary and our word frequency count for this category - */ - Object.keys(frequencyTable).forEach((token) => { const frequencyInText = frequencyTable[token]; - // add this word to our vocabulary if not already existing if (this.vocabulary[token] && this.vocabulary[token] > 0) { - this.vocabulary[token] -= frequencyInText; + this.vocabulary[token]--; if (this.vocabulary[token] === 0) this.vocabularySize--; } - - this.wordFrequencyCount[category][token] -= frequencyInText; - if (this.wordFrequencyCount[category][token] === 0) { - delete this.wordFrequencyCount[category][token]; + if (this.wordFrequencyCount[category] && this.wordFrequencyCount[category][token]) { + this.wordFrequencyCount[category][token] -= frequencyInText; + if (this.wordFrequencyCount[category][token] <= 0) { + delete this.wordFrequencyCount[category][token]; + } } - // update the count of all words we have seen mapped to this category - this.wordCount[category] -= frequencyInText; - if (this.wordCount[category] === 0) { - delete this.wordCount[category]; - delete this.wordFrequencyCount[category]; + if (this.wordCount[category] !== undefined) { + this.wordCount[category] -= frequencyInText; + if (this.wordCount[category] <= 0) { + delete this.wordCount[category]; + delete this.wordFrequencyCount[category]; + } } }); + // clean up category if no documents remain + if (!this.docCount[category]) { + delete this.categories[category]; + } + return this; }; - /** * Determine what category `text` belongs to. * * @param {String} text - * * @return {Object} The predicted category, and the likelihoods stats. + * @throws {TypeError} If text is not a string. */ Naivebayes.prototype.categorize = function(text) { - const tokens = this.tokenizer(text); + if (typeof text !== 'string') { + throw new TypeError(`NaiveBayes: text must be a string, got ${typeof text}.`); + } + + const tokens = this.tokenize(text); const frequencyTable = this.frequencyTable(tokens); const categories = Object.keys(this.categories); const likelihoods = []; - // iterate through our categories to find the one with max probability for this text + if (categories.length === 0) { + return { + likelihoods: [], + predictedCategory: null + }; + } + categories.forEach((category) => { - // start by calculating the overall probability of this category - //= > out of all documents we've ever looked at, how many were - // mapped to this category let categoryLikelihood; if (this.fitPrior) { categoryLikelihood = this.docCount[category] / this.totalDocuments; @@ -304,40 +338,39 @@ Naivebayes.prototype.categorize = function(text) { categoryLikelihood = 1; } - // take the log to avoid underflow - // let logLikelihood = Math.log(categoryLikelihood); let logLikelihood = Decimal(categoryLikelihood); logLikelihood = logLikelihood.naturalLogarithm(); - // now determine P( w | c ) for each word `w` in the text Object.keys(frequencyTable).forEach((token) => { if (this.vocabulary[token] && this.vocabulary[token] > 0) { const termFrequencyInText = frequencyTable[token]; const tokenProbability = this.tokenProbability(token, category); - // determine the log of the P( w | c ) for this word - // logLikelihood += termFrequencyInText * Math.log(tokenProbability); let logTokenProbability = Decimal(tokenProbability); logTokenProbability = logTokenProbability.naturalLogarithm(); - logLikelihood = logLikelihood.plus(termFrequencyInText * logTokenProbability); + logLikelihood = logLikelihood.plus(logTokenProbability.times(termFrequencyInText)); } }); - if (logLikelihood == Number.NEGATIVE_INFINITY) { - console.warn(`[Classificator] category ${category} had -Infinity odds`); - } likelihoods.push({ category, logLikelihood }); }); + // Numerically stable logsumexp: subtract max to prevent overflow/underflow const logsumexp = (likelihoods) => { + if (likelihoods.length === 0) return new Decimal(0); + + const maxLog = likelihoods.reduce((max, l) => { + const val = l.logLikelihood; + return val.greaterThan(max) ? val : max; + }, likelihoods[0].logLikelihood); + let sum = new Decimal(0); likelihoods.forEach((likelihood) => { - const x = Decimal(likelihood.logLikelihood); - const a = Decimal.exp(x); - sum = sum.plus(a); + const shifted = likelihood.logLikelihood.minus(maxLog); + sum = sum.plus(Decimal.exp(shifted)); }); - return sum.naturalLogarithm(); + return maxLog.plus(sum.naturalLogarithm()); }; const logProbX = logsumexp(likelihoods); @@ -349,7 +382,6 @@ Naivebayes.prototype.categorize = function(text) { likelihood.logLikelihood = likelihood.logLikelihood.toNumber(); }); - // sort to have first element with biggest probability likelihoods.sort((a, b) => b.proba - a.proba); return { @@ -359,20 +391,82 @@ Naivebayes.prototype.categorize = function(text) { }; /** - * Calculate probability that a `token` belongs to a `category` + * Like categorize(), but returns only the top N most likely categories. + * + * @param {String} text The text to categorize. + * @param {number} n Maximum number of categories to return. + * @return {Object} Same shape as categorize(), but with truncated likelihoods. + */ +Naivebayes.prototype.categorizeTopN = function(text, n) { + const result = this.categorize(text); + if (result.likelihoods.length > n) { + result.likelihoods = result.likelihoods.slice(0, n); + } + return result; +}; + +/** + * Categorize with a confidence threshold. Returns null predictedCategory + * if the top category's probability is below the threshold. + * + * @param {String} text + * @param {number} threshold Minimum probability (0 to 1) for a confident prediction. + * @return {Object} Same shape as categorize(), but predictedCategory is null if below threshold. + */ +Naivebayes.prototype.categorizeWithConfidence = function(text, threshold) { + if (typeof threshold !== 'number' || threshold < 0 || threshold > 1) { + throw new TypeError('NaiveBayes: threshold must be a number between 0 and 1.'); + } + const result = this.categorize(text); + if (result.predictedCategory === null) return result; + + if (result.likelihoods[0].proba < threshold) { + result.predictedCategory = null; + } + return result; +}; + +/** + * Get the top N most influential tokens for a given text's classification. + * Shows which words most contributed to the predicted category. + * + * @param {String} text + * @param {number} [n=5] Number of top tokens to return. + * @return {Array<{token: string, probability: number, frequency: number}>} + */ +Naivebayes.prototype.topInfluentialTokens = function(text, n) { + n = (n === undefined || n === null) ? 5 : Math.max(0, Math.floor(n)); + const tokens = this.tokenize(text); + const frequencyTable = this.frequencyTable(tokens); + const result = this.categorize(text); + const topCategory = result.predictedCategory; + + if (!topCategory) return []; + + return Object.keys(frequencyTable) + .filter(token => this.vocabulary[token] && this.vocabulary[token] > 0) + .map(token => ({ + token, + probability: this.tokenProbability(token, topCategory), + frequency: frequencyTable[token] + })) + .sort((a, b) => b.probability - a.probability) + .slice(0, n); +}; + +/** + * Calculate probability that a `token` belongs to a `category`. + * Uses Laplace smoothing. If the token was never seen in the category, + * still returns a non-zero probability due to smoothing. * * @param {String} token * @param {String} category - * @return {Number} probability + * @return {Number} probability (0 < p <= 1, depending on alpha) */ Naivebayes.prototype.tokenProbability = function(token, category) { - // how many times this word has occurred in documents mapped to this category const wordFrequencyCount = this.wordFrequencyCount[category][token] || 0; - - // what is the count of all words that have ever been mapped to this category const wordCount = this.wordCount[category]; - // use laplace Add-1 Smoothing equation return (wordFrequencyCount + this.alpha) / (wordCount + this.alpha * this.vocabularySize); }; @@ -397,12 +491,79 @@ Naivebayes.prototype.frequencyTable = function(tokens) { /** * Dump the classifier's state as a JSON string. + * * @return {String} Representation of the classifier. */ Naivebayes.prototype.toJson = function() { const state = {}; - STATE_KEYS.forEach(k => (state[k] = this[k])); - return JSON.stringify(state); }; + +/** + * Get an array of all category names the classifier has learned. + * + * @return {String[]} Array of category name strings. + */ +Naivebayes.prototype.getCategories = function() { + return Object.keys(this.categories); +}; + +/** + * Learn from multiple text/category pairs at once. + * + * @param {Array<{text: string, category: string}>} items Array of training items. + * @return {Naivebayes} this + * @throws {TypeError} If items is not an array. + */ +Naivebayes.prototype.learnBatch = function(items) { + if (!Array.isArray(items)) { + throw new TypeError('NaiveBayes: learnBatch expects an array of { text, category } objects.'); + } + items.forEach(item => { + this.learn(item.text, item.category); + }); + return this; +}; + +/** + * Reset the classifier to its initial (untrained) state, preserving configuration options. + * + * @return {Naivebayes} this + */ +Naivebayes.prototype.reset = function() { + this.vocabulary = {}; + this.vocabularySize = 0; + this.totalDocuments = 0; + this.docCount = {}; + this.wordCount = {}; + this.wordFrequencyCount = {}; + this.categories = {}; + return this; +}; + +/** + * Get statistics about each category's training data. + * + * @return {Object} Map of category names to { docCount, wordCount, vocabularySize }, + * plus a _total key with aggregate stats. + */ +Naivebayes.prototype.getCategoryStats = function() { + const stats = {}; + Object.keys(this.categories).forEach(category => { + stats[category] = { + docCount: this.docCount[category] || 0, + wordCount: this.wordCount[category] || 0, + vocabularySize: Object.keys(this.wordFrequencyCount[category] || {}).length + }; + }); + const totalWordCount = Object.keys(this.categories).reduce((sum, cat) => { + return sum + (this.wordCount[cat] || 0); + }, 0); + stats._total = { + docCount: this.totalDocuments, + wordCount: totalWordCount, + vocabularySize: this.vocabularySize + }; + return stats; +}; diff --git a/package.json b/package.json index ea12ea5..ac30e61 100644 --- a/package.json +++ b/package.json @@ -1,7 +1,7 @@ { "name": "classificator", "description": "Naive Bayes classifier with verbose informations for node.js", - "version": "0.4.0", + "version": "0.5.0", "author": "Wozacosta", "keywords": [ "naive", @@ -23,7 +23,7 @@ "mocha": "^9.0.2" }, "engines": { - "node": ">=5.0.0" + "node": ">=14.0.0" }, "main": "./lib/classificator", "repository": { @@ -31,6 +31,6 @@ "url": "https://github.com/Wozacosta/classificator.git" }, "scripts": { - "test": "mocha -t 30000 -R spec" + "test": "mocha -t 30000 -R spec test/*.js" } } diff --git a/test/classificator.js b/test/classificator.js index 8efd319..64ba1d9 100644 --- a/test/classificator.js +++ b/test/classificator.js @@ -1,6 +1,4 @@ var assert = require('assert') - , fs = require('fs') - , path = require('path') , bayes = require('../lib/classificator') describe('bayes() init', function () { @@ -18,10 +16,18 @@ describe('bayes() init', function () { invalidOptionsCases.forEach(function (invalidOptions) { assert.throws(function () { bayes(invalidOptions) }, Error) - // check that it's a TypeError assert.throws(function () { bayes(invalidOptions) }, TypeError) }) }) + + it('throws TypeError when tokenizer is not a function', function () { + assert.throws(function () { bayes({ tokenizer: 'not a function' }) }, TypeError) + assert.throws(function () { bayes({ tokenizer: 42 }) }, TypeError) + }) + + it('throws TypeError when tokenPreprocessor is not a function', function () { + assert.throws(function () { bayes({ tokenPreprocessor: 'bad' }) }, TypeError) + }) }) describe('bayes using custom tokenizer', function () { @@ -34,7 +40,6 @@ describe('bayes using custom tokenizer', function () { classifier.learn('abcd', 'happy') - // check classifier's state is as expected assert.equal(classifier.totalDocuments, 1) assert.equal(classifier.docCount.happy, 1) assert.deepEqual(classifier.vocabulary, { a: 1, b: 1, c: 1, d: 1 }) @@ -44,36 +49,105 @@ describe('bayes using custom tokenizer', function () { assert.equal(classifier.wordFrequencyCount.happy.b, 1) assert.equal(classifier.wordFrequencyCount.happy.c, 1) assert.equal(classifier.wordFrequencyCount.happy.d, 1) - assert.deepEqual(classifier.categories, { happy: 1 }) + assert.deepStrictEqual(classifier.categories, { happy: true }) }) }) -describe('bayes serializing/deserializing its state', function () { - it('serializes/deserializes its state as JSON correctly.', function (done) { - var classifier = bayes() +describe('bayes using tokenPreprocessor', function () { + it('applies tokenPreprocessor after tokenizer', function () { + var stopwords = new Set(['the', 'a', 'is', 'in']) + var classifier = bayes({ + tokenPreprocessor: function (tokens) { + return tokens + .map(function (t) { return t.toLowerCase() }) + .filter(function (t) { return !stopwords.has(t) }) + } + }) + + classifier.learn('The cat is in a hat', 'animals') + + // stopwords should be removed + assert.equal(classifier.wordFrequencyCount.animals['the'], undefined) + assert.equal(classifier.wordFrequencyCount.animals['a'], undefined) + assert.equal(classifier.wordFrequencyCount.animals['is'], undefined) + assert.equal(classifier.wordFrequencyCount.animals['in'], undefined) + + // content words should remain (lowercased) + assert.equal(classifier.wordFrequencyCount.animals['cat'], 1) + assert.equal(classifier.wordFrequencyCount.animals['hat'], 1) + }) - classifier.learn('Fun times were had by all', 'positive') - classifier.learn('sad dark rainy day in the cave', 'negative') + it('works with stemming-style preprocessor', function () { + var classifier = bayes({ + tokenPreprocessor: function (tokens) { + return tokens.map(function (t) { + // crude stemming: strip trailing 's', 'ing', 'ed' + return t.replace(/(ing|ed|s)$/i, '').toLowerCase() + }) + } + }) + + classifier.learn('running dogs played', 'active') + classifier.learn('sleeping cats rested', 'passive') - var jsonRepr = classifier.toJson() + var result = classifier.categorize('dogs playing') + assert.equal(result.predictedCategory, 'active') + }) + + it('preprocessor is preserved through fromJson with options', function () { + var preprocessor = function (tokens) { + return tokens.map(function (t) { return t.toLowerCase() }) + } - // check serialized values - var state = JSON.parse(jsonRepr) + var classifier = bayes({ tokenPreprocessor: preprocessor }) + classifier.learn('HELLO', 'greetings') - // ensure classifier's state values are all in the json representation - bayes.STATE_KEYS.forEach(function (k) { - assert.deepEqual(state[k], classifier[k]) - }) + var revived = bayes.fromJson(classifier.toJson(), { tokenPreprocessor: preprocessor }) + var result = revived.categorize('HELLO') + assert.equal(result.predictedCategory, 'greetings') + }) - var revivedClassifier = bayes.fromJson(jsonRepr) + it('classifier.options preserves runtime options after fromJson', function () { + var preprocessor = function (tokens) { + return tokens.map(function (t) { return t.toLowerCase() }) + } + var tokenizer = function (text) { return text.split(' ') } - // ensure the revived classifier's state is same as original state - bayes.STATE_KEYS.forEach(function (k) { - assert.deepEqual(revivedClassifier[k], classifier[k]) - }) + var classifier = bayes({ tokenPreprocessor: preprocessor, tokenizer: tokenizer }) + classifier.learn('HELLO WORLD', 'greetings') - done() + var revived = bayes.fromJson(classifier.toJson(), { + tokenPreprocessor: preprocessor, + tokenizer: tokenizer }) + + assert.strictEqual(revived.options.tokenPreprocessor, preprocessor) + assert.strictEqual(revived.options.tokenizer, tokenizer) + assert.strictEqual(revived.tokenPreprocessor, preprocessor) + assert.strictEqual(revived.tokenizer, tokenizer) + }) +}) + +describe('bayes serializing/deserializing its state', function () { + it('serializes/deserializes its state as JSON correctly.', function () { + var classifier = bayes() + + classifier.learn('Fun times were had by all', 'positive') + classifier.learn('sad dark rainy day in the cave', 'negative') + + var jsonRepr = classifier.toJson() + var state = JSON.parse(jsonRepr) + + bayes.STATE_KEYS.forEach(function (k) { + assert.deepEqual(state[k], classifier[k]) + }) + + var revivedClassifier = bayes.fromJson(jsonRepr) + + bayes.STATE_KEYS.forEach(function (k) { + assert.deepEqual(revivedClassifier[k], classifier[k]) + }) + }) }) describe('bayes using custom tokenizer with fromJson', function () { @@ -90,47 +164,58 @@ describe('bayes using custom tokenizer with fromJson', function () { var jsonRepr = classifier.toJson() var revivedClassifier = bayes.fromJson(jsonRepr, { tokenizer: splitOnChar }) - // the revived classifier should use the custom tokenizer var result = revivedClassifier.categorize('abcd') assert.equal(result.predictedCategory, 'happy') }) }) -describe('bayes .learn() correctness', function () { - //sentiment analysis test - it('categorizes correctly for `positive` and `negative` categories', function (done) { +describe('bayes .fromJson() edge cases', function () { + it('throws on null input', function () { + assert.throws(function () { bayes.fromJson(null) }, Error) + }) + + it('throws on numeric input', function () { + assert.throws(function () { bayes.fromJson(42) }, Error) + }) + + it('throws on invalid JSON string', function () { + assert.throws(function () { bayes.fromJson('not valid json') }, Error) + }) + + it('throws when JSON is missing required state keys', function () { + assert.throws(function () { bayes.fromJson('{"options":{}}') }, Error) + }) + + it('preserves options through serialization round-trip', function () { + var classifier = bayes({ alpha: 2, fitPrior: false }) + classifier.learn('hello world', 'greetings') + + var revived = bayes.fromJson(classifier.toJson()) + assert.equal(revived.alpha, 2) + assert.equal(revived.fitPrior, false) + }) +}) +describe('bayes .learn() correctness', function () { + it('categorizes correctly for `positive` and `negative` categories', function () { let classifier = bayes(); - //teach it positive phrases classifier.learn('amazing, awesome movie!! Yeah!!', 'positive') classifier.learn('Sweet, this is incredibly, amazing, perfect, great!!', 'positive') - - //teach it a negative phrase classifier.learn('terrible, shitty thing. Damn. Sucks!!', 'negative') - - //teach it a neutral phrase classifier.learn('I dont really know what to make of this.', 'neutral') - //now test it to see that it correctly categorizes a new document assert.deepEqual(classifier.categorize('awesome, cool, amazing!! Yay.').predictedCategory, 'positive') - done() }) - //topic analysis test - it('categorizes correctly for `chinese` and `japanese` categories', function (done) { - + it('categorizes correctly for `chinese` and `japanese` categories', function () { var classifier = bayes() - //teach it how to identify the `chinese` category classifier.learn('Chinese Beijing Chinese', 'chinese') classifier.learn('Chinese Chinese Shanghai', 'chinese') classifier.learn('Chinese Macao', 'chinese') - - //teach it how to identify the `japanese` category classifier.learn('Tokyo Japan Chinese', 'japanese') - //make sure it learned the `chinese` category correctly var chineseFrequencyCount = classifier.wordFrequencyCount.chinese assert.equal(chineseFrequencyCount['Chinese'], 5) @@ -138,20 +223,16 @@ describe('bayes .learn() correctness', function () { assert.equal(chineseFrequencyCount['Shanghai'], 1) assert.equal(chineseFrequencyCount['Macao'], 1) - //make sure it learned the `japanese` category correctly var japaneseFrequencyCount = classifier.wordFrequencyCount.japanese assert.equal(japaneseFrequencyCount['Tokyo'], 1) assert.equal(japaneseFrequencyCount['Japan'], 1) assert.equal(japaneseFrequencyCount['Chinese'], 1) - //now test it to see that it correctly categorizes a new document assert.deepEqual(classifier.categorize('Chinese Chinese Chinese Tokyo Japan').predictedCategory,'chinese') - - done() }) - it('correctly tokenizes cyrlic characters', function (done) { + it('correctly tokenizes cyrlic characters', function () { var classifier = bayes() classifier.learn('Надежда за', 'a') @@ -169,42 +250,649 @@ describe('bayes .learn() correctness', function () { assert.equal(bFreqCount['еп'], 2) assert.equal(bFreqCount['36'], 2) assert.equal(bFreqCount['Тест'], 2) - - done() }) - it('correctly computes probabilities without prior', function (done) { + it('correctly computes probabilities without prior', function () { var classifier = bayes({ fitPrior: false}) - // learn on a very unbalanced dataset classifier.learn('aa', '1') classifier.learn('aa', '1') classifier.learn('aa', '1') classifier.learn('bb', '2') - // test the likelihoods obtained on test strings assert.equal(classifier.categorize('cc').likelihoods[0].proba, 0.5) assert.equal(Number(classifier.categorize('bb').likelihoods[0].proba).toFixed(6), Number(0.76923077).toFixed(6)) assert.equal(Number(classifier.categorize('aa').likelihoods[0].proba).toFixed(6), Number(0.70588235).toFixed(6)) - - done() }) - it('correctly computes probabilities with prior', function (done) { + it('correctly computes probabilities with prior', function () { var classifier = bayes() - // learn on a very unbalanced dataset classifier.learn('aa', '1') classifier.learn('aa', '1') classifier.learn('aa', '1') classifier.learn('bb', '2') - // test the likelihoods obtained on test strings assert.equal(classifier.categorize('cc').likelihoods[0].proba, 0.75) assert.equal(Number(classifier.categorize('bb').likelihoods[0].proba).toFixed(6), Number(0.52631579).toFixed(6)) assert.equal(Number(classifier.categorize('aa').likelihoods[0].proba).toFixed(6), Number(0.87804878).toFixed(6)) + }) + + it('throws TypeError when text is not a string', function () { + var classifier = bayes() + assert.throws(function () { classifier.learn(123, 'cat') }, TypeError) + assert.throws(function () { classifier.learn(null, 'cat') }, TypeError) + }) + + it('throws TypeError when category is not a string', function () { + var classifier = bayes() + assert.throws(function () { classifier.learn('hello', 123) }, TypeError) + assert.throws(function () { classifier.learn('hello', null) }, TypeError) + }) +}) + +describe('bayes .unlearn() correctness', function () { + it('reverses the effect of a single learn call', function () { + var classifier = bayes() + + classifier.learn('fun times', 'positive') + classifier.learn('bad times', 'negative') + classifier.learn('great day', 'positive') + + var docsBefore = classifier.totalDocuments + classifier.unlearn('fun times', 'positive') + + assert.equal(classifier.totalDocuments, docsBefore - 1) + }) + + it('throws when unlearning from a non-existent category', function () { + var classifier = bayes() + classifier.learn('hello', 'greetings') + + assert.throws(function () { + classifier.unlearn('hello', 'nonexistent') + }, Error) + }) + + it('removes category from categories when last doc is unlearned', function () { + var classifier = bayes() + + classifier.learn('hello', 'greetings') + assert.ok(classifier.categories['greetings']) + + classifier.unlearn('hello', 'greetings') + assert.equal(classifier.categories['greetings'], undefined) + }) + + it('classifier still works correctly after unlearn', function () { + var classifier = bayes() + + classifier.learn('amazing great', 'positive') + classifier.learn('terrible awful', 'negative') + classifier.learn('bad horrible', 'negative') - done() + classifier.unlearn('bad horrible', 'negative') + + var result = classifier.categorize('terrible') + assert.equal(result.predictedCategory, 'negative') + }) + + it('returns this for method chaining', function () { + var classifier = bayes() + classifier.learn('hello', 'greetings') + var result = classifier.unlearn('hello', 'greetings') + assert.strictEqual(result, classifier) + }) + + it('throws TypeError when text is not a string', function () { + var classifier = bayes() + classifier.learn('hello', 'greetings') + assert.throws(function () { classifier.unlearn(123, 'greetings') }, TypeError) + }) + + it('does not leave negative wordCount', function () { + var classifier = bayes() + classifier.learn('hello', 'greetings') + classifier.unlearn('hello', 'greetings') + + assert.equal(classifier.wordCount['greetings'], undefined) + }) +}) + +describe('bayes .removeCategory()', function () { + it('removes a category and its associated data', function () { + var classifier = bayes() + + classifier.learn('hello world', 'greetings') + classifier.learn('bad stuff', 'negative') + + classifier.removeCategory('greetings') + + assert.equal(classifier.categories['greetings'], undefined) + assert.equal(classifier.docCount['greetings'], undefined) + assert.equal(classifier.wordCount['greetings'], undefined) + assert.equal(classifier.wordFrequencyCount['greetings'], undefined) + }) + + it('returns this for chaining', function () { + var classifier = bayes() + classifier.learn('hello', 'greetings') + var result = classifier.removeCategory('greetings') + assert.strictEqual(result, classifier) + }) + + it('returns this when removing non-existent category (no-op)', function () { + var classifier = bayes() + var result = classifier.removeCategory('nonexistent') + assert.strictEqual(result, classifier) + }) + + it('classifier still categorizes correctly after removing a category', function () { + var classifier = bayes() + + classifier.learn('amazing great', 'positive') + classifier.learn('terrible bad', 'negative') + classifier.learn('meh ok', 'neutral') + + classifier.removeCategory('neutral') + + var result = classifier.categorize('amazing') + assert.equal(result.predictedCategory, 'positive') + assert.equal(result.likelihoods.length, 2) + }) + + it('updates vocabulary and vocabularySize correctly', function () { + var classifier = bayes() + + classifier.learn('unique', 'only') + var sizeBefore = classifier.vocabularySize + + classifier.removeCategory('only') + assert.ok(classifier.vocabularySize < sizeBefore) + }) + + it('does not produce negative vocabulary counts', function () { + var classifier = bayes() + + classifier.learn('shared word', 'a') + classifier.learn('shared word', 'b') + + classifier.removeCategory('a') + + // 'shared' and 'word' should still have count >= 0 + Object.keys(classifier.vocabulary).forEach(function (token) { + assert.ok(classifier.vocabulary[token] >= 0, + 'vocabulary[' + token + '] should be >= 0, got ' + classifier.vocabulary[token]) + }) }) }) +describe('bayes .categorize() return structure', function () { + it('returns an object with predictedCategory and likelihoods', function () { + var classifier = bayes() + classifier.learn('hello', 'greetings') + + var result = classifier.categorize('hello') + assert.ok(result.hasOwnProperty('predictedCategory')) + assert.ok(result.hasOwnProperty('likelihoods')) + assert.ok(Array.isArray(result.likelihoods)) + }) + + it('likelihoods contain category, logLikelihood, logProba, proba', function () { + var classifier = bayes() + classifier.learn('hello', 'greetings') + + var result = classifier.categorize('hello') + var likelihood = result.likelihoods[0] + + assert.ok(likelihood.hasOwnProperty('category')) + assert.ok(likelihood.hasOwnProperty('logLikelihood')) + assert.ok(likelihood.hasOwnProperty('logProba')) + assert.ok(likelihood.hasOwnProperty('proba')) + }) + + it('likelihoods are sorted by proba descending', function () { + var classifier = bayes() + classifier.learn('aa', 'a') + classifier.learn('bb', 'b') + classifier.learn('cc', 'c') + + var result = classifier.categorize('aa') + for (var i = 1; i < result.likelihoods.length; i++) { + assert.ok(result.likelihoods[i - 1].proba >= result.likelihoods[i].proba) + } + }) + + it('probabilities sum to approximately 1.0', function () { + var classifier = bayes() + classifier.learn('happy fun', 'positive') + classifier.learn('sad bad', 'negative') + + var result = classifier.categorize('happy') + var sum = result.likelihoods.reduce(function (acc, l) { return acc + l.proba }, 0) + assert.ok(Math.abs(sum - 1.0) < 0.001) + }) + + it('returns predictedCategory null for empty classifier', function () { + var classifier = bayes() + var result = classifier.categorize('hello') + + assert.equal(result.predictedCategory, null) + assert.deepEqual(result.likelihoods, []) + }) + + it('throws TypeError when text is not a string', function () { + var classifier = bayes() + assert.throws(function () { classifier.categorize(123) }, TypeError) + assert.throws(function () { classifier.categorize(null) }, TypeError) + }) +}) + +describe('bayes edge cases', function () { + it('handles empty string input to categorize()', function () { + var classifier = bayes() + classifier.learn('hello world', 'greetings') + + var result = classifier.categorize('') + assert.ok(result.hasOwnProperty('predictedCategory')) + }) + + it('handles text with only punctuation', function () { + var classifier = bayes() + classifier.learn('hello', 'greetings') + + var result = classifier.categorize('!@#$%') + assert.ok(result.hasOwnProperty('predictedCategory')) + }) + + it('handles unknown tokens gracefully', function () { + var classifier = bayes() + classifier.learn('hello world', 'greetings') + + var result = classifier.categorize('xyzzy foobar') + assert.ok(result.hasOwnProperty('predictedCategory')) + }) +}) + +describe('bayes alpha parameter', function () { + it('uses default alpha of 1 when not specified', function () { + var classifier = bayes() + assert.equal(classifier.alpha, 1) + }) + + it('accepts alpha: 0 without overriding to default', function () { + var classifier = bayes({ alpha: 0 }) + assert.strictEqual(classifier.alpha, 0) + }) + + it('custom alpha affects token probability calculation', function () { + var classifier1 = bayes({ alpha: 1 }) + var classifier2 = bayes({ alpha: 10 }) + + classifier1.learn('hello world', 'greetings') + classifier1.learn('goodbye world', 'farewells') + classifier2.learn('hello world', 'greetings') + classifier2.learn('goodbye world', 'farewells') + + var prob1 = classifier1.tokenProbability('hello', 'greetings') + var prob2 = classifier2.tokenProbability('hello', 'greetings') + + assert.notEqual(prob1, prob2) + }) + + it('alpha: 0 categorization works but unseen tokens zero-out a category', function () { + var classifier = bayes({ alpha: 0 }) + + classifier.learn('happy fun', 'positive') + classifier.learn('sad bad', 'negative') + + // 'happy' was only seen in positive, so negative gets zero probability for it + var result = classifier.categorize('happy') + assert.equal(result.predictedCategory, 'positive') + + // result should not contain NaN + result.likelihoods.forEach(function (l) { + assert.ok(!isNaN(l.proba), 'proba should not be NaN') + }) + }) +}) + +describe('bayes method chaining', function () { + it('learn() returns this', function () { + var classifier = bayes() + assert.strictEqual(classifier.learn('hello', 'greetings'), classifier) + }) + + it('initializeCategory() returns this', function () { + var classifier = bayes() + assert.strictEqual(classifier.initializeCategory('test'), classifier) + }) + + it('removeCategory() returns this', function () { + var classifier = bayes() + assert.strictEqual(classifier.removeCategory('test'), classifier) + }) + + it('supports fluent chain: learn().learn().categorize()', function () { + var result = bayes() + .learn('happy fun', 'positive') + .learn('sad bad', 'negative') + .categorize('happy') + + assert.equal(result.predictedCategory, 'positive') + }) +}) + +describe('bayes .getCategories()', function () { + it('returns empty array for new classifier', function () { + var classifier = bayes() + assert.deepEqual(classifier.getCategories(), []) + }) + + it('returns array of learned category names', function () { + var classifier = bayes() + classifier.learn('hello', 'greetings') + classifier.learn('bye', 'farewells') + + var categories = classifier.getCategories() + assert.ok(categories.indexOf('greetings') !== -1) + assert.ok(categories.indexOf('farewells') !== -1) + assert.equal(categories.length, 2) + }) + + it('reflects category removal', function () { + var classifier = bayes() + classifier.learn('hello', 'greetings') + classifier.learn('bye', 'farewells') + classifier.removeCategory('greetings') + + var categories = classifier.getCategories() + assert.ok(categories.indexOf('greetings') === -1) + assert.equal(categories.length, 1) + }) +}) + +describe('bayes .categorizeTopN()', function () { + it('returns only top N categories', function () { + var classifier = bayes() + classifier.learn('aa', 'a') + classifier.learn('bb', 'b') + classifier.learn('cc', 'c') + + var result = classifier.categorizeTopN('aa', 2) + assert.equal(result.likelihoods.length, 2) + }) + + it('returns all categories if N >= total categories', function () { + var classifier = bayes() + classifier.learn('aa', 'a') + classifier.learn('bb', 'b') + + var result = classifier.categorizeTopN('aa', 10) + assert.equal(result.likelihoods.length, 2) + }) + + it('predictedCategory is the most likely', function () { + var classifier = bayes() + classifier.learn('happy fun great', 'positive') + classifier.learn('sad bad terrible', 'negative') + classifier.learn('ok meh whatever', 'neutral') + + var result = classifier.categorizeTopN('happy fun', 1) + assert.equal(result.predictedCategory, 'positive') + assert.equal(result.likelihoods.length, 1) + }) +}) + +describe('bayes .categorizeWithConfidence()', function () { + it('returns predictedCategory when above threshold', function () { + var classifier = bayes() + classifier.learn('happy fun great amazing', 'positive') + classifier.learn('sad bad terrible awful', 'negative') + + var result = classifier.categorizeWithConfidence('happy fun great', 0.5) + assert.equal(result.predictedCategory, 'positive') + }) + + it('returns null predictedCategory when below threshold', function () { + var classifier = bayes() + classifier.learn('aa', 'a') + classifier.learn('bb', 'b') + + // with a very high threshold, prediction should be null + var result = classifier.categorizeWithConfidence('cc', 0.99) + assert.equal(result.predictedCategory, null) + }) + + it('returns null for empty classifier', function () { + var classifier = bayes() + var result = classifier.categorizeWithConfidence('hello', 0.5) + assert.equal(result.predictedCategory, null) + }) + + it('still returns full likelihoods array', function () { + var classifier = bayes() + classifier.learn('aa', 'a') + classifier.learn('bb', 'b') + + var result = classifier.categorizeWithConfidence('cc', 0.99) + assert.ok(Array.isArray(result.likelihoods)) + assert.ok(result.likelihoods.length > 0) + }) + + it('throws TypeError for invalid threshold', function () { + var classifier = bayes() + classifier.learn('hello', 'greetings') + + assert.throws(function () { classifier.categorizeWithConfidence('hello', -1) }, TypeError) + assert.throws(function () { classifier.categorizeWithConfidence('hello', 2) }, TypeError) + assert.throws(function () { classifier.categorizeWithConfidence('hello', 'bad') }, TypeError) + }) +}) + +describe('bayes .topInfluentialTokens()', function () { + it('returns top tokens for a classification', function () { + var classifier = bayes() + classifier.learn('happy fun great joy', 'positive') + classifier.learn('sad bad terrible gloom', 'negative') + + var tokens = classifier.topInfluentialTokens('happy fun great', 3) + assert.ok(Array.isArray(tokens)) + assert.ok(tokens.length <= 3) + assert.ok(tokens.length > 0) + + // each token should have the right shape + tokens.forEach(function (t) { + assert.ok(t.hasOwnProperty('token')) + assert.ok(t.hasOwnProperty('probability')) + assert.ok(t.hasOwnProperty('frequency')) + }) + }) + + it('returns empty array for empty classifier', function () { + var classifier = bayes() + var tokens = classifier.topInfluentialTokens('hello') + assert.deepEqual(tokens, []) + }) + + it('tokens are sorted by probability descending', function () { + var classifier = bayes() + classifier.learn('apple banana cherry', 'fruit') + classifier.learn('dog cat bird', 'animal') + + var tokens = classifier.topInfluentialTokens('apple banana cherry', 5) + for (var i = 1; i < tokens.length; i++) { + assert.ok(tokens[i - 1].probability >= tokens[i].probability) + } + }) + + it('defaults to 5 tokens', function () { + var classifier = bayes() + classifier.learn('a b c d e f g h', 'letters') + classifier.learn('1 2 3', 'numbers') + + var tokens = classifier.topInfluentialTokens('a b c d e f g h') + assert.ok(tokens.length <= 5) + }) + + it('returns empty array when n is 0', function () { + var classifier = bayes() + classifier.learn('hello world', 'greetings') + + var tokens = classifier.topInfluentialTokens('hello world', 0) + assert.deepEqual(tokens, []) + }) + + it('handles negative n by returning empty array', function () { + var classifier = bayes() + classifier.learn('hello world', 'greetings') + + var tokens = classifier.topInfluentialTokens('hello world', -3) + assert.deepEqual(tokens, []) + }) +}) + +describe('bayes .learnBatch()', function () { + it('learns multiple items at once', function () { + var classifier = bayes() + classifier.learnBatch([ + { text: 'happy fun', category: 'positive' }, + { text: 'sad bad', category: 'negative' } + ]) + + assert.equal(classifier.totalDocuments, 2) + assert.ok(classifier.categories['positive']) + assert.ok(classifier.categories['negative']) + }) + + it('produces same result as individual learn calls', function () { + var classifier1 = bayes() + classifier1.learn('happy fun', 'positive') + classifier1.learn('sad bad', 'negative') + + var classifier2 = bayes() + classifier2.learnBatch([ + { text: 'happy fun', category: 'positive' }, + { text: 'sad bad', category: 'negative' } + ]) + + assert.deepEqual(classifier1.vocabulary, classifier2.vocabulary) + assert.equal(classifier1.totalDocuments, classifier2.totalDocuments) + assert.deepEqual(classifier1.docCount, classifier2.docCount) + }) + + it('throws TypeError on non-array input', function () { + var classifier = bayes() + assert.throws(function () { classifier.learnBatch('not an array') }, TypeError) + assert.throws(function () { classifier.learnBatch(42) }, TypeError) + }) + + it('returns this for chaining', function () { + var classifier = bayes() + var result = classifier.learnBatch([{ text: 'hello', category: 'greetings' }]) + assert.strictEqual(result, classifier) + }) +}) + +describe('bayes .reset()', function () { + it('clears all learned data', function () { + var classifier = bayes() + classifier.learn('hello world', 'greetings') + classifier.learn('goodbye world', 'farewells') + + classifier.reset() + + assert.equal(classifier.totalDocuments, 0) + assert.equal(classifier.vocabularySize, 0) + assert.deepEqual(classifier.categories, {}) + assert.deepEqual(classifier.vocabulary, {}) + assert.deepEqual(classifier.docCount, {}) + assert.deepEqual(classifier.wordCount, {}) + assert.deepEqual(classifier.wordFrequencyCount, {}) + }) + + it('preserves options (tokenizer, alpha, fitPrior)', function () { + var customTokenizer = function (text) { return text.split('') } + var classifier = bayes({ tokenizer: customTokenizer, alpha: 2, fitPrior: false }) + classifier.learn('abc', 'letters') + + classifier.reset() + + assert.strictEqual(classifier.tokenizer, customTokenizer) + assert.equal(classifier.alpha, 2) + assert.equal(classifier.fitPrior, false) + }) + + it('classifier can be retrained after reset', function () { + var classifier = bayes() + classifier.learn('hello', 'greetings') + classifier.reset() + classifier.learn('goodbye', 'farewells') + + assert.equal(classifier.totalDocuments, 1) + assert.deepEqual(classifier.getCategories(), ['farewells']) + }) + + it('returns this for chaining', function () { + var classifier = bayes() + assert.strictEqual(classifier.reset(), classifier) + }) +}) + +describe('bayes .getCategoryStats()', function () { + it('returns correct doc and word counts per category', function () { + var classifier = bayes() + classifier.learn('hello world', 'greetings') + classifier.learn('goodbye world', 'farewells') + classifier.learn('hi there', 'greetings') + + var stats = classifier.getCategoryStats() + + assert.equal(stats.greetings.docCount, 2) + assert.equal(stats.farewells.docCount, 1) + assert.ok(stats.greetings.wordCount > 0) + assert.ok(stats.greetings.vocabularySize > 0) + }) + + it('includes _total aggregate stats with wordCount', function () { + var classifier = bayes() + classifier.learn('hello world', 'greetings') + classifier.learn('bye now', 'farewells') + + var stats = classifier.getCategoryStats() + + assert.equal(stats._total.docCount, 2) + assert.ok(stats._total.vocabularySize > 0) + assert.equal(stats._total.wordCount, stats.greetings.wordCount + stats.farewells.wordCount) + }) +}) + +describe('bayes numerical stability (logsumexp)', function () { + it('probabilities still sum to ~1.0 with many categories', function () { + var classifier = bayes() + for (var i = 0; i < 20; i++) { + classifier.learn('word' + i + ' common shared text', 'cat' + i) + } + + var result = classifier.categorize('word0 common shared') + var sum = result.likelihoods.reduce(function (acc, l) { return acc + l.proba }, 0) + assert.ok(Math.abs(sum - 1.0) < 0.01, 'probabilities should sum to ~1.0, got ' + sum) + }) + + it('handles long documents without NaN', function () { + var classifier = bayes() + classifier.learn('good great amazing wonderful fantastic', 'positive') + classifier.learn('bad terrible awful horrible dreadful', 'negative') + + // create a long document + var longText = '' + for (var i = 0; i < 100; i++) { + longText += 'good great amazing ' + } + + var result = classifier.categorize(longText) + assert.ok(!isNaN(result.likelihoods[0].proba), 'proba should not be NaN') + assert.ok(result.likelihoods[0].proba > 0, 'proba should be positive') + assert.equal(result.predictedCategory, 'positive') + }) +}) diff --git a/test/integration.js b/test/integration.js new file mode 100644 index 0000000..75a2a62 --- /dev/null +++ b/test/integration.js @@ -0,0 +1,556 @@ +var assert = require('assert') + , bayes = require('../lib/classificator') + +// ============================================================================= +// Integration Tests +// Test multiple features working together in realistic combinations +// ============================================================================= + +describe('[Integration] full train → serialize → restore → classify pipeline', function () { + it('classifier survives a full round-trip with custom options', function () { + var tokenizer = function (text) { return text.toLowerCase().split(/\s+/) } + var preprocessor = function (tokens) { + var stops = new Set(['the', 'a', 'is', 'it', 'and', 'of', 'to', 'in']) + return tokens.filter(function (t) { return !stops.has(t) }) + } + + // 1. Create with custom options + var classifier = bayes({ + tokenizer: tokenizer, + tokenPreprocessor: preprocessor, + alpha: 0.5, + fitPrior: true + }) + + // 2. Train + classifier.learn('The movie was amazing and wonderful', 'positive') + classifier.learn('It is a great film to watch', 'positive') + classifier.learn('The movie was terrible and boring', 'negative') + classifier.learn('It is a bad film, awful acting', 'negative') + + // 3. Verify pre-serialization + var before = classifier.categorize('amazing film') + assert.equal(before.predictedCategory, 'positive') + + // 4. Serialize + var json = classifier.toJson() + + // 5. Restore with runtime options + var restored = bayes.fromJson(json, { + tokenizer: tokenizer, + tokenPreprocessor: preprocessor + }) + + // 6. Verify post-restoration + var after = restored.categorize('amazing film') + assert.equal(after.predictedCategory, 'positive') + assert.equal(after.likelihoods.length, before.likelihoods.length) + + // 7. Probabilities should match + assert.equal( + before.likelihoods[0].proba.toFixed(8), + after.likelihoods[0].proba.toFixed(8) + ) + + // 8. Options preserved + assert.equal(restored.alpha, 0.5) + assert.equal(restored.fitPrior, true) + assert.strictEqual(restored.tokenizer, tokenizer) + assert.strictEqual(restored.tokenPreprocessor, preprocessor) + }) +}) + +describe('[Integration] learn → unlearn → relearn cycle', function () { + it('classifier state is consistent after learn/unlearn/relearn', function () { + var classifier = bayes() + + // learn initial data + classifier.learn('good morning sunshine', 'positive') + classifier.learn('terrible horrible day', 'negative') + classifier.learn('wonderful great time', 'positive') + + var stats1 = classifier.getCategoryStats() + assert.equal(stats1._total.docCount, 3) + + // unlearn a mistake + classifier.unlearn('wonderful great time', 'positive') + assert.equal(classifier.totalDocuments, 2) + + // re-learn corrected data + classifier.learn('wonderful great time', 'neutral') + assert.equal(classifier.totalDocuments, 3) + + // classifier should now have 3 categories + var categories = classifier.getCategories() + assert.ok(categories.indexOf('positive') !== -1) + assert.ok(categories.indexOf('negative') !== -1) + assert.ok(categories.indexOf('neutral') !== -1) + + // classification still works + var result = classifier.categorize('terrible') + assert.equal(result.predictedCategory, 'negative') + }) +}) + +describe('[Integration] batch learning + stats + reset + retrain', function () { + it('full lifecycle: batch train, inspect, reset, retrain differently', function () { + var classifier = bayes() + + // batch train + classifier.learnBatch([ + { text: 'buy cheap viagra now', category: 'spam' }, + { text: 'limited offer free pills', category: 'spam' }, + { text: 'hello how are you doing', category: 'ham' }, + { text: 'meeting at 3pm tomorrow', category: 'ham' }, + { text: 'project update attached', category: 'ham' } + ]) + + // inspect + var stats = classifier.getCategoryStats() + assert.equal(stats.spam.docCount, 2) + assert.equal(stats.ham.docCount, 3) + assert.equal(stats._total.docCount, 5) + + // classify + assert.equal(classifier.categorize('free offer').predictedCategory, 'spam') + assert.equal(classifier.categorize('meeting tomorrow').predictedCategory, 'ham') + + // reset + classifier.reset() + assert.equal(classifier.totalDocuments, 0) + assert.deepEqual(classifier.getCategories(), []) + + // retrain with different categories + classifier.learn('breaking news politics', 'news') + classifier.learn('football scores today', 'sports') + + assert.equal(classifier.categorize('political news').predictedCategory, 'news') + assert.equal(classifier.categorize('football game').predictedCategory, 'sports') + }) +}) + +describe('[Integration] removeCategory + reclassification', function () { + it('removing a dominant category shifts predictions correctly', function () { + var classifier = bayes() + + classifier.learn('python code function', 'programming') + classifier.learn('java class object', 'programming') + classifier.learn('javascript react component', 'programming') + classifier.learn('recipe bake flour sugar', 'cooking') + classifier.learn('stock market investment', 'finance') + + // programming dominates with fitPrior + var before = classifier.categorize('new class today') + assert.equal(before.predictedCategory, 'programming') + assert.equal(before.likelihoods.length, 3) + + // remove programming + classifier.removeCategory('programming') + + // now should choose between cooking and finance + var after = classifier.categorize('new class today') + assert.equal(after.likelihoods.length, 2) + assert.ok(after.predictedCategory === 'cooking' || after.predictedCategory === 'finance') + + // probabilities still sum to ~1 + var sum = after.likelihoods.reduce(function (acc, l) { return acc + l.proba }, 0) + assert.ok(Math.abs(sum - 1.0) < 0.01) + }) +}) + +describe('[Integration] tokenPreprocessor affects all operations consistently', function () { + it('preprocessor is applied in learn, unlearn, categorize, and topInfluentialTokens', function () { + var lowered = [] + var preprocessor = function (tokens) { + var result = tokens.map(function (t) { return t.toLowerCase() }) + lowered.push(result) + return result + } + + var classifier = bayes({ tokenPreprocessor: preprocessor }) + + // learn — preprocessor called + lowered = [] + classifier.learn('HELLO WORLD', 'greetings') + assert.ok(lowered.length > 0) + assert.deepEqual(lowered[0], ['hello', 'world']) + + // should have lowercase tokens in vocabulary + assert.ok(classifier.vocabulary['hello']) + assert.equal(classifier.vocabulary['HELLO'], undefined) + + // categorize — preprocessor called + lowered = [] + var result = classifier.categorize('HELLO') + assert.ok(lowered.length > 0) + assert.equal(result.predictedCategory, 'greetings') + + // topInfluentialTokens — preprocessor called + lowered = [] + var tokens = classifier.topInfluentialTokens('HELLO WORLD', 2) + assert.ok(lowered.length > 0) + assert.ok(tokens.length > 0) + + // unlearn — preprocessor called (lowercase matches original learn) + lowered = [] + classifier.learn('BYE NOW', 'farewells') + classifier.unlearn('HELLO WORLD', 'greetings') + assert.ok(lowered.length > 0) + assert.equal(classifier.categories['greetings'], undefined) + }) +}) + +describe('[Integration] confidence threshold + topN combined workflow', function () { + it('uses confidence to filter then topN to limit results', function () { + var classifier = bayes() + + classifier.learnBatch([ + { text: 'cat dog hamster', category: 'pets' }, + { text: 'cat dog hamster', category: 'pets' }, + { text: 'car truck bus', category: 'vehicles' }, + { text: 'apple banana orange', category: 'fruit' }, + { text: 'table chair desk', category: 'furniture' } + ]) + + // high confidence prediction + var confident = classifier.categorizeWithConfidence('cat dog', 0.3) + assert.equal(confident.predictedCategory, 'pets') + + // low confidence for ambiguous text + var unsure = classifier.categorizeWithConfidence('xyz unknown', 0.5) + assert.equal(unsure.predictedCategory, null) + + // topN limits output + var top2 = classifier.categorizeTopN('cat dog', 2) + assert.equal(top2.likelihoods.length, 2) + assert.equal(top2.predictedCategory, 'pets') + }) +}) + +describe('[Integration] method chaining complex workflow', function () { + it('chains learn, unlearn, removeCategory, and ends with categorize', function () { + var result = bayes() + .learn('happy joy love', 'positive') + .learn('sad hate anger', 'negative') + .learn('meh whatever ok', 'neutral') + .learn('oops wrong category', 'neutral') + .unlearn('oops wrong category', 'neutral') + .removeCategory('neutral') + .learn('wonderful amazing', 'positive') + .categorize('love and joy') + + assert.equal(result.predictedCategory, 'positive') + assert.equal(result.likelihoods.length, 2) + }) +}) + +// ============================================================================= +// End-to-End Tests +// Simulate real-world classification scenarios from start to finish +// ============================================================================= + +describe('[E2E] email spam detection', function () { + var classifier + + beforeEach(function () { + classifier = bayes() + + // train spam + classifier.learnBatch([ + { text: 'Buy cheap viagra online now discount', category: 'spam' }, + { text: 'You won a free prize click here to claim', category: 'spam' }, + { text: 'Limited time offer buy one get one free', category: 'spam' }, + { text: 'Earn money fast work from home guaranteed', category: 'spam' }, + { text: 'Congratulations you have been selected winner', category: 'spam' } + ]) + + // train ham + classifier.learnBatch([ + { text: 'Hey can we meet for lunch tomorrow', category: 'ham' }, + { text: 'Please review the attached quarterly report', category: 'ham' }, + { text: 'The meeting has been moved to 3pm', category: 'ham' }, + { text: 'Here are the notes from today discussion', category: 'ham' }, + { text: 'Can you send me the project update please', category: 'ham' } + ]) + }) + + it('correctly classifies obvious spam', function () { + assert.equal( + classifier.categorize('Buy now free offer limited time').predictedCategory, + 'spam' + ) + }) + + it('correctly classifies legitimate email', function () { + assert.equal( + classifier.categorize('Can we discuss the project tomorrow').predictedCategory, + 'ham' + ) + }) + + it('handles ambiguous text with confidence check', function () { + var result = classifier.categorizeWithConfidence('please review this', 0.9) + // ambiguous text should either be confident ham or rejected + assert.ok( + result.predictedCategory === 'ham' || result.predictedCategory === null + ) + }) + + it('explains predictions with influential tokens', function () { + var tokens = classifier.topInfluentialTokens('Buy now free offer', 3) + assert.ok(tokens.length > 0) + // all tokens should have valid probabilities + tokens.forEach(function (t) { + assert.ok(t.probability > 0) + assert.ok(t.probability <= 1) + }) + }) + + it('survives serialization round-trip and still classifies', function () { + var json = classifier.toJson() + var restored = bayes.fromJson(json) + + assert.equal( + restored.categorize('Buy now free offer limited time').predictedCategory, + 'spam' + ) + assert.equal( + restored.categorize('Can we discuss the project tomorrow').predictedCategory, + 'ham' + ) + }) + + it('stats reflect training data', function () { + var stats = classifier.getCategoryStats() + assert.equal(stats.spam.docCount, 5) + assert.equal(stats.ham.docCount, 5) + assert.equal(stats._total.docCount, 10) + assert.ok(stats._total.wordCount > 0) + assert.ok(stats._total.vocabularySize > 0) + }) +}) + +describe('[E2E] sentiment analysis', function () { + var classifier + + beforeEach(function () { + classifier = bayes({ + tokenPreprocessor: function (tokens) { + return tokens.map(function (t) { return t.toLowerCase() }) + } + }) + + var trainingData = [ + { text: 'I love this product amazing quality', category: 'positive' }, + { text: 'Great experience wonderful service', category: 'positive' }, + { text: 'Excellent work very satisfied happy', category: 'positive' }, + { text: 'Best purchase ever highly recommend', category: 'positive' }, + { text: 'Terrible quality waste of money', category: 'negative' }, + { text: 'Horrible experience worst service ever', category: 'negative' }, + { text: 'Broken on arrival very disappointed', category: 'negative' }, + { text: 'Would not recommend awful product', category: 'negative' } + ] + + classifier.learnBatch(trainingData) + }) + + it('classifies positive review correctly', function () { + var result = classifier.categorize('Amazing product love the quality') + assert.equal(result.predictedCategory, 'positive') + assert.ok(result.likelihoods[0].proba > 0.5) + }) + + it('classifies negative review correctly', function () { + var result = classifier.categorize('Terrible waste would not buy again') + assert.equal(result.predictedCategory, 'negative') + assert.ok(result.likelihoods[0].proba > 0.5) + }) + + it('probabilities always sum to 1', function () { + var texts = [ + 'Amazing product love it', + 'Terrible broken waste', + 'Some random unrelated words', + '' + ] + + texts.forEach(function (text) { + var result = classifier.categorize(text) + if (result.likelihoods.length > 0) { + var sum = result.likelihoods.reduce(function (a, l) { return a + l.proba }, 0) + assert.ok(Math.abs(sum - 1.0) < 0.01, 'proba sum should be ~1, got ' + sum) + } + }) + }) + + it('top influential tokens make semantic sense', function () { + var tokens = classifier.topInfluentialTokens('Amazing product love the quality', 3) + var tokenNames = tokens.map(function (t) { return t.token }) + + // at least one positive word should be influential + var positiveWords = ['amazing', 'love', 'quality', 'product'] + var hasPositive = tokenNames.some(function (t) { return positiveWords.indexOf(t) !== -1 }) + assert.ok(hasPositive, 'should have at least one positive word in influential tokens') + }) +}) + +describe('[E2E] multi-category topic classification', function () { + var classifier + + beforeEach(function () { + classifier = bayes() + + classifier.learnBatch([ + { text: 'The stock market rallied today as investors showed confidence', category: 'finance' }, + { text: 'Federal reserve announces interest rate decision', category: 'finance' }, + { text: 'Bitcoin cryptocurrency prices surged this week', category: 'finance' }, + + { text: 'Scientists discover new species in the Amazon rainforest', category: 'science' }, + { text: 'NASA launches new telescope to study distant galaxies', category: 'science' }, + { text: 'Research shows promising results for new cancer treatment', category: 'science' }, + + { text: 'Team wins championship in overtime thriller', category: 'sports' }, + { text: 'Player breaks scoring record in historic game', category: 'sports' }, + { text: 'Coach announces retirement after successful season', category: 'sports' }, + + { text: 'New smartphone features revolutionary camera technology', category: 'technology' }, + { text: 'AI startup raises funding for machine learning platform', category: 'technology' }, + { text: 'Software update fixes critical security vulnerability', category: 'technology' } + ]) + }) + + it('correctly classifies finance text', function () { + assert.equal( + classifier.categorize('investors concerned about market volatility').predictedCategory, + 'finance' + ) + }) + + it('correctly classifies science text', function () { + assert.equal( + classifier.categorize('researchers study new galaxy formation').predictedCategory, + 'science' + ) + }) + + it('correctly classifies sports text', function () { + assert.equal( + classifier.categorize('team wins game scoring record').predictedCategory, + 'sports' + ) + }) + + it('correctly classifies technology text', function () { + assert.equal( + classifier.categorize('new AI software update released').predictedCategory, + 'technology' + ) + }) + + it('topN returns correct number of categories', function () { + var result = classifier.categorizeTopN('new technology update', 2) + assert.equal(result.likelihoods.length, 2) + // top result should be technology + assert.equal(result.predictedCategory, 'technology') + }) + + it('all 4 categories exist', function () { + var cats = classifier.getCategories() + assert.equal(cats.length, 4) + assert.ok(cats.indexOf('finance') !== -1) + assert.ok(cats.indexOf('science') !== -1) + assert.ok(cats.indexOf('sports') !== -1) + assert.ok(cats.indexOf('technology') !== -1) + }) + + it('survives full serialize → restore → classify cycle', function () { + var json = classifier.toJson() + var restored = bayes.fromJson(json) + + assert.equal( + restored.categorize('stock market investors').predictedCategory, + 'finance' + ) + assert.equal(restored.getCategories().length, 4) + assert.equal(restored.getCategoryStats()._total.docCount, 12) + }) +}) + +describe('[E2E] incremental learning over time', function () { + it('classifier improves as more data is added', function () { + var classifier = bayes() + + // start with minimal data + classifier.learn('good', 'positive') + classifier.learn('bad', 'negative') + + // ambiguous initially + var initial = classifier.categorize('good bad ugly') + + // add more training data incrementally + classifier.learnBatch([ + { text: 'good great wonderful amazing', category: 'positive' }, + { text: 'good fantastic brilliant', category: 'positive' }, + { text: 'bad terrible awful', category: 'negative' }, + { text: 'bad horrible dreadful', category: 'negative' } + ]) + + // should now clearly classify 'good' as positive + var improved = classifier.categorize('good') + assert.equal(improved.predictedCategory, 'positive') + assert.ok(improved.likelihoods[0].proba > 0.5) + + // stats reflect all training + assert.equal(classifier.getCategoryStats()._total.docCount, 6) + }) +}) + +describe('[E2E] correcting classification mistakes', function () { + it('unlearn wrong data, re-learn correct data, verify improvement', function () { + var classifier = bayes() + + // initial correct training + classifier.learn('happy joy smile', 'positive') + classifier.learn('sad cry tears', 'negative') + + // oops, accidentally trained wrong + classifier.learn('happy celebration party', 'negative') // mistake! + + // verify the mistake hurts classification + var before = classifier.categorize('happy celebration') + + // fix the mistake + classifier.unlearn('happy celebration party', 'negative') + classifier.learn('happy celebration party', 'positive') + + // verify correction helped + var after = classifier.categorize('happy celebration') + assert.equal(after.predictedCategory, 'positive') + }) +}) + +describe('[E2E] fitPrior impact on imbalanced datasets', function () { + it('fitPrior=true favors the majority class on ambiguous input', function () { + var withPrior = bayes({ fitPrior: true }) + var withoutPrior = bayes({ fitPrior: false }) + + // heavily imbalanced: 10 positive, 1 negative + for (var i = 0; i < 10; i++) { + withPrior.learn('word' + i, 'majority') + withoutPrior.learn('word' + i, 'majority') + } + withPrior.learn('rare', 'minority') + withoutPrior.learn('rare', 'minority') + + // ambiguous text (unknown word) + var resultPrior = withPrior.categorize('unknown') + var resultNoPrior = withoutPrior.categorize('unknown') + + // with prior, majority class should be strongly favored + assert.equal(resultPrior.predictedCategory, 'majority') + assert.ok(resultPrior.likelihoods[0].proba > 0.8) + + // without prior, both classes should be closer to equal + assert.equal(resultNoPrior.likelihoods[0].proba, 0.5) + }) +})