From 3f95b1b40c258cba13e1bb77667e67d588df8bb7 Mon Sep 17 00:00:00 2001 From: Lucas Carlson Date: Sun, 28 Dec 2025 18:08:39 -0800 Subject: [PATCH 1/4] feat(knn): add k-Nearest Neighbors classifier MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements a kNN classifier that leverages existing LSI infrastructure for similarity computations. Unlike Bayes which requires training, kNN uses instance-based learning—store examples and classify by finding the most similar ones. Key features: - Hash-style API consistent with Bayes and LSI (add, classify) - classify_with_neighbors() returns interpretable results with neighbor details, vote tallies, and confidence scores - Distance-weighted voting option for more nuanced classification - Full JSON/Marshal serialization and storage backend support - Handles edge cases like single-item classifiers gracefully This completes Issue #103 and provides a third classification algorithm suited for small datasets where interpretability matters. Closes #103 --- README.md | 88 ++++++- classifier.gemspec | 4 +- lib/classifier.rb | 1 + lib/classifier/knn.rb | 363 ++++++++++++++++++++++++++++ test/knn/knn_test.rb | 541 ++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 993 insertions(+), 4 deletions(-) create mode 100644 lib/classifier/knn.rb create mode 100644 test/knn/knn_test.rb diff --git a/README.md b/README.md index ccf69a0..c571169 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![CI](https://github.com/cardmagic/classifier/actions/workflows/ruby.yml/badge.svg)](https://github.com/cardmagic/classifier/actions/workflows/ruby.yml) [![License: LGPL](https://img.shields.io/badge/License-LGPL_2.1-blue.svg)](https://opensource.org/licenses/LGPL-2.1) -A Ruby library for text classification using Bayesian and Latent Semantic Indexing (LSI) algorithms. +A Ruby library for text classification using Bayesian, LSI (Latent Semantic Indexing), and k-Nearest Neighbors (kNN) algorithms. **[Documentation](https://rubyclassifier.com/docs)** · **[Tutorials](https://rubyclassifier.com/docs/tutorials)** · **[Guides](https://rubyclassifier.com/docs/guides)** @@ -13,6 +13,7 @@ A Ruby library for text classification using Bayesian and Latent Semantic Indexi - [Installation](#installation) - [Bayesian Classifier](#bayesian-classifier) - [LSI (Latent Semantic Indexing)](#lsi-latent-semantic-indexing) +- [k-Nearest Neighbors (kNN)](#k-nearest-neighbors-knn) - [Persistence](#persistence) - [Performance](#performance) - [Development](#development) @@ -170,9 +171,92 @@ gem 'pragmatic_segmenter' - [LSI Basics Guide](https://rubyclassifier.com/docs/guides/lsi/basics) - In-depth documentation - [Wikipedia: Latent Semantic Analysis](http://en.wikipedia.org/wiki/Latent_semantic_analysis) +## k-Nearest Neighbors (kNN) + +Instance-based classification that stores examples and classifies by finding the most similar ones. No training phase required—just add examples and classify. + +### Key Features + +- **No Training Required**: Uses instance-based learning—store examples and classify by similarity +- **Interpretable Results**: Returns neighbors that contributed to the decision +- **Incremental Updates**: Easy to add or remove examples without retraining +- **Distance-Weighted Voting**: Optional weighting by similarity score +- **Built on LSI**: Leverages LSI's semantic similarity for better matching + +### Quick Start + +```ruby +require 'classifier' + +knn = Classifier::KNN.new(k: 3) + +# Add labeled examples +knn.add(spam: ["Buy now! Limited offer!", "You've won a million dollars!"]) +knn.add(ham: ["Meeting at 3pm tomorrow", "Please review the document"]) + +# Classify new text +knn.classify "Congratulations! Claim your prize!" +# => "spam" +``` + +### Detailed Classification + +Get neighbor information for interpretable results: + +```ruby +result = knn.classify_with_neighbors "Free money offer" + +result[:category] # => "spam" +result[:confidence] # => 0.85 +result[:neighbors] # => [{item: "Buy now!...", category: "spam", similarity: 0.92}, ...] +result[:votes] # => {"spam" => 2.0, "ham" => 1.0} +``` + +### Distance-Weighted Voting + +Weight votes by similarity score for more accurate classification: + +```ruby +knn = Classifier::KNN.new(k: 5, weighted: true) + +knn.add( + positive: ["Great product!", "Loved it!", "Excellent service"], + negative: ["Terrible experience", "Would not recommend"] +) + +# Closer neighbors have more influence on the result +knn.classify "This was amazing!" +# => "positive" +``` + +### Updating the Classifier + +```ruby +# Add more examples anytime +knn.add(neutral: "It was okay, nothing special") + +# Remove examples +knn.remove_item "Buy now! Limited offer!" + +# Change k value +knn.k = 7 + +# List all categories +knn.categories +# => ["spam", "ham", "neutral"] +``` + +### When to Use kNN vs Bayes vs LSI + +| Classifier | Best For | +|------------|----------| +| **Bayes** | Large training sets, fast classification, spam filtering | +| **LSI** | Semantic similarity, document clustering, search | +| **kNN** | Small datasets, interpretable results, incremental learning | + ## Persistence -Save and load trained classifiers with pluggable storage backends. Works with both Bayes and LSI classifiers. +Save and load classifiers with pluggable storage backends. Works with Bayes, LSI, and kNN classifiers. ### File Storage diff --git a/classifier.gemspec b/classifier.gemspec index 8ee836d..a4cfdee 100644 --- a/classifier.gemspec +++ b/classifier.gemspec @@ -1,8 +1,8 @@ Gem::Specification.new do |s| s.name = 'classifier' s.version = '2.1.0' - s.summary = 'A general classifier module to allow Bayesian and other types of classifications.' - s.description = 'A general classifier module to allow Bayesian and other types of classifications.' + s.summary = 'Text classification with Bayesian, LSI, and k-Nearest Neighbors algorithms.' + s.description = 'A Ruby library for text classification using Bayesian, LSI (Latent Semantic Indexing), and k-Nearest Neighbors (kNN) algorithms. Includes native C extension for fast LSI operations.' s.author = 'Lucas Carlson' s.email = 'lucas@rufy.com' s.homepage = 'https://rubyclassifier.com' diff --git a/lib/classifier.rb b/lib/classifier.rb index 07ac499..81c9c90 100644 --- a/lib/classifier.rb +++ b/lib/classifier.rb @@ -31,3 +31,4 @@ require 'classifier/extensions/vector' require 'classifier/bayes' require 'classifier/lsi' +require 'classifier/knn' diff --git a/lib/classifier/knn.rb b/lib/classifier/knn.rb new file mode 100644 index 0000000..c62e6eb --- /dev/null +++ b/lib/classifier/knn.rb @@ -0,0 +1,363 @@ +# rbs_inline: enabled + +# Author:: Lucas Carlson (mailto:lucas@rufy.com) +# Copyright:: Copyright (c) 2024 Lucas Carlson +# License:: LGPL + +require 'json' +require 'mutex_m' +require 'classifier/lsi' + +module Classifier + # This class implements a k-Nearest Neighbors classifier that leverages + # the existing LSI infrastructure for similarity computations. + # + # Unlike traditional classifiers that require training, kNN uses instance-based + # learning - it stores examples and classifies by finding the most similar ones. + # + # Example usage: + # knn = Classifier::KNN.new(k: 3) + # knn.add("spam" => ["Buy now!", "Limited offer!"]) + # knn.add("ham" => ["Meeting tomorrow", "Project update"]) + # knn.classify("Special discount!") # => "spam" + # + class KNN + include Mutex_m + + # @rbs @k: Integer + # @rbs @weighted: bool + # @rbs @lsi: LSI + # @rbs @dirty: bool + # @rbs @storage: Storage::Base? + + attr_reader :k + attr_accessor :weighted, :storage + + # Creates a new kNN classifier. + # + # @param k [Integer] Number of neighbors to consider (default: 5) + # @param weighted [Boolean] Use distance-weighted voting (default: false) + # + # @rbs (?k: Integer, ?weighted: bool) -> void + def initialize(k: 5, weighted: false) # rubocop:disable Naming/MethodParameterName + super() + validate_k!(k) + @k = k + @weighted = weighted + @lsi = LSI.new(auto_rebuild: true) + @dirty = false + @storage = nil + end + + # Adds labeled examples to the classifier using hash-style syntax. + # Keys are categories, values are items (or arrays of items). + # + # @example Single item per category + # knn.add("spam" => "Buy now!") + # knn.add("ham" => "Meeting tomorrow") + # + # @example Multiple items per category + # knn.add("spam" => ["Buy now!", "Limited offer!"]) + # + # @example Batch operations + # knn.add( + # "spam" => ["Buy now!", "Limited offer!"], + # "ham" => ["Meeting tomorrow", "Project update"] + # ) + # + # @rbs (**untyped items) -> void + def add(**items) + synchronize { @dirty = true } + @lsi.add(**items) + end + + # Adds a single labeled example to the classifier. + # + # @deprecated Use {#add} instead for clearer hash-style syntax. + # + # @param item [String] The text content to add + # @param category [String, Symbol] The category/label for this item + # + # @rbs (String, String | Symbol) -> void + def add_item(item, category) + synchronize { @dirty = true } + @lsi.add_item(item, category) + end + + # Classifies the given text by finding the k nearest neighbors + # and using majority voting. + # + # @param text [String] The text to classify + # @return [String, Symbol, nil] The predicted category, or nil if no examples exist + # + # @rbs (String) -> (String | Symbol)? + def classify(text) + result = classify_with_neighbors(text) + result[:category] + end + + # Classifies the given text and returns detailed information about + # the neighbors that contributed to the decision. + # + # @param text [String] The text to classify + # @return [Hash] A hash containing: + # - :category - The predicted category + # - :neighbors - Array of neighbor details (item, category, similarity) + # - :votes - Hash of category => vote count/weight + # - :confidence - Confidence score (winning vote share) + # + # @rbs (String) -> Hash[Symbol, untyped] + def classify_with_neighbors(text) + synchronize do + return empty_result if @lsi.items.empty? + + neighbors = find_neighbors(text) + return empty_result if neighbors.empty? + + votes = tally_votes(neighbors) + winner = votes.max_by { |_, v| v }&.first + total_votes = votes.values.sum + confidence = winner && total_votes.positive? ? votes[winner] / total_votes.to_f : 0.0 + + { + category: winner, + neighbors: neighbors, + votes: votes, + confidence: confidence + } + end + end + + # @rbs (String) -> Array[String | Symbol] + def categories_for(item) + @lsi.categories_for(item) + end + + # @rbs (String) -> void + def remove_item(item) + synchronize { @dirty = true } + @lsi.remove_item(item) + end + + # @rbs () -> Array[untyped] + def items + @lsi.items + end + + # @rbs () -> Array[String | Symbol] + def categories + synchronize do + @lsi.items.flat_map { |item| @lsi.categories_for(item) }.uniq + end + end + + # @rbs (Integer) -> void + def k=(value) + validate_k!(value) + @k = value + end + + # @rbs (?untyped) -> untyped + def as_json(_options = nil) + { + version: 1, + type: 'knn', + k: @k, + weighted: @weighted, + lsi: @lsi.as_json + } + end + + # @rbs (?untyped) -> String + def to_json(_options = nil) + as_json.to_json + end + + # Loads a classifier from a JSON string or Hash. + # + # @param json [String, Hash] JSON string or parsed hash + # @return [KNN] A new KNN instance with restored state + # + # @rbs (String | Hash[String, untyped]) -> KNN + def self.from_json(json) + data = json.is_a?(String) ? JSON.parse(json) : json + raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'knn' + + # Restore the LSI from its nested data + lsi_data = data['lsi'] + lsi_data['type'] = 'lsi' # Ensure type is set for LSI.from_json + + instance = new(k: data['k'], weighted: data['weighted']) + instance.instance_variable_set(:@lsi, LSI.from_json(lsi_data)) + instance.instance_variable_set(:@dirty, false) + instance + end + + # Saves the classifier to the configured storage. + # + # @rbs () -> void + def save + raise ArgumentError, 'No storage configured. Use save_to_file(path) or set storage=' unless storage + + storage.write(to_json) + @dirty = false + end + + # Saves the classifier to a file. + # + # @param path [String] The file path + # @return [Integer] Number of bytes written + # + # @rbs (String) -> Integer + def save_to_file(path) + result = File.write(path, to_json) + @dirty = false + result + end + + # Reloads the classifier from configured storage. + # + # @rbs () -> self + def reload + raise ArgumentError, 'No storage configured' unless storage + raise UnsavedChangesError, 'Unsaved changes would be lost. Call save first or use reload!' if @dirty + + data = storage.read + raise StorageError, 'No saved state found' unless data + + restore_from_json(data) + @dirty = false + self + end + + # Force reloads the classifier from storage. + # + # @rbs () -> self + def reload! + raise ArgumentError, 'No storage configured' unless storage + + data = storage.read + raise StorageError, 'No saved state found' unless data + + restore_from_json(data) + @dirty = false + self + end + + # @rbs () -> bool + def dirty? + @dirty + end + + # Loads a classifier from configured storage. + # + # @param storage [Storage::Base] The storage to load from + # @return [KNN] The loaded classifier + # + # @rbs (storage: Storage::Base) -> KNN + def self.load(storage:) + data = storage.read + raise StorageError, 'No saved state found' unless data + + instance = from_json(data) + instance.storage = storage + instance + end + + # Loads a classifier from a file. + # + # @param path [String] The file path + # @return [KNN] The loaded classifier + # + # @rbs (String) -> KNN + def self.load_from_file(path) + from_json(File.read(path)) + end + + # @rbs () -> Array[untyped] + def marshal_dump + [@k, @weighted, @lsi, @dirty] + end + + # @rbs (Array[untyped]) -> void + def marshal_load(data) + mu_initialize + @k, @weighted, @lsi, @dirty = data + @storage = nil + end + + private + + # Finds the k nearest neighbors for the given text. + # + # @rbs (String) -> Array[Hash[Symbol, untyped]] + def find_neighbors(text) + # LSI requires at least 2 items to build an index + # For single item, return it directly with a default similarity + if @lsi.items.size == 1 + item = @lsi.items.first + return [{ + item: item, + category: @lsi.categories_for(item).first, + similarity: 1.0 + }] + end + + proximity = @lsi.proximity_array_for_content(text) + neighbors = proximity.reject { |item, _| item == text }.first(@k) + + neighbors.map do |item, similarity| + { + item: item, + category: @lsi.categories_for(item).first, + similarity: similarity + } + end + end + + # Tallies votes from neighbors, optionally weighted by similarity. + # + # @rbs (Array[Hash[Symbol, untyped]]) -> Hash[String | Symbol, Float] + def tally_votes(neighbors) + votes = Hash.new(0.0) + + neighbors.each do |neighbor| + category = neighbor[:category] + next unless category + + weight = @weighted ? [neighbor[:similarity], 0.0].max : 1.0 + votes[category] += weight + end + + votes + end + + # @rbs () -> Hash[Symbol, untyped] + def empty_result + { category: nil, neighbors: [], votes: {}, confidence: 0.0 } + end + + # @rbs (Integer) -> void + def validate_k!(val) + raise ArgumentError, "k must be a positive integer, got #{val}" unless val.is_a?(Integer) && val.positive? + end + + # Restores state from JSON (used by reload). + # + # @rbs (String) -> void + def restore_from_json(json) + data = JSON.parse(json) + raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'knn' + + synchronize do + @k = data['k'] + @weighted = data['weighted'] + + lsi_data = data['lsi'] + lsi_data['type'] = 'lsi' + @lsi = LSI.from_json(lsi_data) + @dirty = false + end + end + end +end diff --git a/test/knn/knn_test.rb b/test/knn/knn_test.rb new file mode 100644 index 0000000..7c932dc --- /dev/null +++ b/test/knn/knn_test.rb @@ -0,0 +1,541 @@ +require_relative '../test_helper' + +class KNNTest < Minitest::Test + def setup + @str1 = 'This text deals with dogs. Dogs.' + @str2 = 'This text involves dogs too. Dogs!' + @str3 = 'This text revolves around cats. Cats.' + @str4 = 'This text also involves cats. Cats!' + @str5 = 'This text involves birds. Birds.' + end + + # Initialization tests + + def test_default_initialization + knn = Classifier::KNN.new + + assert_equal 5, knn.k + refute knn.weighted + assert_empty knn.items + end + + def test_custom_k_initialization + knn = Classifier::KNN.new(k: 3) + + assert_equal 3, knn.k + end + + def test_weighted_initialization + knn = Classifier::KNN.new(weighted: true) + + assert knn.weighted + end + + def test_invalid_k_raises_error + assert_raises(ArgumentError) { Classifier::KNN.new(k: 0) } + assert_raises(ArgumentError) { Classifier::KNN.new(k: -1) } + assert_raises(ArgumentError) { Classifier::KNN.new(k: 1.5) } + end + + def test_k_setter + knn = Classifier::KNN.new(k: 5) + knn.k = 3 + + assert_equal 3, knn.k + end + + def test_k_setter_validation + knn = Classifier::KNN.new + + assert_raises(ArgumentError) { knn.k = 0 } + assert_raises(ArgumentError) { knn.k = -1 } + end + + # Adding items tests + + def test_add_with_hash_syntax + knn = Classifier::KNN.new + knn.add('Dog' => 'Dogs are loyal pets') + knn.add('Cat' => 'Cats are independent') + + assert_equal 2, knn.items.size + assert_includes knn.items, 'Dogs are loyal pets' + assert_includes knn.items, 'Cats are independent' + end + + def test_add_with_symbol_keys + knn = Classifier::KNN.new + knn.add(Dog: 'Dogs are loyal', Cat: 'Cats are independent') + + assert_equal 2, knn.items.size + assert_equal ['Dog'], knn.categories_for('Dogs are loyal') + assert_equal ['Cat'], knn.categories_for('Cats are independent') + end + + def test_add_multiple_items_same_category + knn = Classifier::KNN.new + knn.add('Dog' => ['Dogs are loyal', 'Puppies are cute', 'Canines are friendly']) + + assert_equal 3, knn.items.size + assert_equal ['Dog'], knn.categories_for('Dogs are loyal') + assert_equal ['Dog'], knn.categories_for('Puppies are cute') + assert_equal ['Dog'], knn.categories_for('Canines are friendly') + end + + def test_add_batch_operations + knn = Classifier::KNN.new + knn.add( + 'Dog' => ['Dogs are loyal', 'Puppies are cute'], + 'Cat' => ['Cats are independent', 'Kittens are playful'] + ) + + assert_equal 4, knn.items.size + assert_equal ['Dog'], knn.categories_for('Dogs are loyal') + assert_equal ['Cat'], knn.categories_for('Cats are independent') + end + + def test_add_item_legacy_api + knn = Classifier::KNN.new + knn.add_item 'Dogs are loyal pets', 'Dog' + knn.add_item 'Cats are independent', 'Cat' + + assert_equal 2, knn.items.size + assert_equal ['Dog'], knn.categories_for('Dogs are loyal pets') + end + + # Classification tests + + def test_basic_classification + knn = Classifier::KNN.new(k: 3) + knn.add( + 'Dog' => [@str1, @str2], + 'Cat' => [@str3, @str4], + 'Bird' => @str5 + ) + + assert_equal 'Dog', knn.classify('This is about dogs') + assert_equal 'Cat', knn.classify('This is about cats') + assert_equal 'Bird', knn.classify('This is about birds') + end + + def test_classify_empty_classifier + knn = Classifier::KNN.new + + assert_nil knn.classify('Some text') + end + + def test_classify_with_k_larger_than_items + knn = Classifier::KNN.new(k: 10) + knn.add('Dog' => 'Dogs are pets') + knn.add('Cat' => 'Cats are pets') + + # Should still work with fewer items than k + result = knn.classify('Dogs are great') + + refute_nil result + end + + # classify_with_neighbors tests + + def test_classify_with_neighbors_structure + knn = Classifier::KNN.new(k: 3) + knn.add( + 'Dog' => [@str1, @str2], + 'Cat' => [@str3, @str4] + ) + + result = knn.classify_with_neighbors('Dogs are great pets') + + assert_instance_of Hash, result + assert result.key?(:category) + assert result.key?(:neighbors) + assert result.key?(:votes) + assert result.key?(:confidence) + end + + def test_classify_with_neighbors_returns_neighbors + knn = Classifier::KNN.new(k: 2) + knn.add( + 'Dog' => [@str1, @str2], + 'Cat' => @str3 + ) + + result = knn.classify_with_neighbors('Dogs are great') + + assert_equal 2, result[:neighbors].size + result[:neighbors].each do |neighbor| + assert neighbor.key?(:item) + assert neighbor.key?(:category) + assert neighbor.key?(:similarity) + end + end + + def test_classify_with_neighbors_empty_classifier + knn = Classifier::KNN.new + + result = knn.classify_with_neighbors('Some text') + + assert_nil result[:category] + assert_empty result[:neighbors] + assert_empty result[:votes] + assert_in_delta(0.0, result[:confidence]) + end + + def test_classify_with_neighbors_confidence + knn = Classifier::KNN.new(k: 3) + knn.add( + 'Dog' => [@str1, @str2], + 'Cat' => @str3 + ) + + result = knn.classify_with_neighbors('Dogs are wonderful') + + assert_kind_of Float, result[:confidence] + assert_operator result[:confidence], :>=, 0.0 + assert_operator result[:confidence], :<=, 1.0 + end + + # Weighted voting tests + + def test_weighted_voting + knn = Classifier::KNN.new(k: 3, weighted: true) + knn.add( + 'Dog' => [@str1, @str2], + 'Cat' => [@str3, @str4] + ) + + result = knn.classify_with_neighbors('Dogs are great') + + # Votes should be weighted by similarity + assert knn.weighted + # Weighted votes should have non-integer values + assert(result[:votes].values.any? { |v| v != v.to_i }) + end + + def test_unweighted_voting + knn = Classifier::KNN.new(k: 3, weighted: false) + knn.add( + 'Dog' => [@str1, @str2], + 'Cat' => @str3 + ) + + result = knn.classify_with_neighbors('Dogs are great') + + # Unweighted votes should be integers (counts) + result[:votes].each_value do |vote| + assert_equal vote.to_i.to_f, vote + end + end + + # Categories tests + + def test_categories + knn = Classifier::KNN.new + knn.add( + 'Dog' => 'Dogs are loyal', + 'Cat' => 'Cats are independent', + 'Bird' => 'Birds can fly' + ) + + cats = knn.categories + + assert_equal 3, cats.size + assert_includes cats, 'Dog' + assert_includes cats, 'Cat' + assert_includes cats, 'Bird' + end + + def test_categories_empty + knn = Classifier::KNN.new + + assert_empty knn.categories + end + + # Remove item tests + + def test_remove_item + knn = Classifier::KNN.new + knn.add('Dog' => [@str1, @str2]) + + assert_equal 2, knn.items.size + + knn.remove_item(@str1) + + assert_equal 1, knn.items.size + refute_includes knn.items, @str1 + end + + def test_remove_nonexistent_item + knn = Classifier::KNN.new + knn.add('Dog' => @str1) + + knn.remove_item('nonexistent') + + assert_equal 1, knn.items.size + end + + # Serialization tests + + def test_as_json + knn = Classifier::KNN.new(k: 3, weighted: true) + knn.add('Dog' => @str1, 'Cat' => @str2) + + data = knn.as_json + + assert_instance_of Hash, data + assert_equal 1, data[:version] + assert_equal 'knn', data[:type] + assert_equal 3, data[:k] + assert data[:weighted] + assert data.key?(:lsi) + end + + def test_to_json + knn = Classifier::KNN.new(k: 3) + knn.add('Dog' => @str1) + + json = knn.to_json + data = JSON.parse(json) + + assert_equal 'knn', data['type'] + assert_equal 3, data['k'] + end + + def test_from_json_with_string + knn = Classifier::KNN.new(k: 3, weighted: true) + knn.add( + 'Dog' => [@str1, @str2], + 'Cat' => @str3 + ) + + json = knn.to_json + loaded = Classifier::KNN.from_json(json) + + assert_equal knn.k, loaded.k + assert_equal knn.weighted, loaded.weighted + assert_equal knn.items.sort, loaded.items.sort + assert_equal knn.classify('Dogs are great'), loaded.classify('Dogs are great') + end + + def test_from_json_with_hash + knn = Classifier::KNN.new(k: 5) + knn.add('Dog' => @str1, 'Cat' => @str2) + + hash = JSON.parse(knn.to_json) + loaded = Classifier::KNN.from_json(hash) + + assert_equal knn.k, loaded.k + assert_equal knn.items.sort, loaded.items.sort + end + + def test_from_json_invalid_type + invalid_json = { version: 1, type: 'invalid' }.to_json + + assert_raises(ArgumentError) { Classifier::KNN.from_json(invalid_json) } + end + + def test_save_and_load_from_file + knn = Classifier::KNN.new(k: 3, weighted: true) + knn.add( + 'Dog' => [@str1, @str2], + 'Cat' => [@str3, @str4] + ) + + Dir.mktmpdir do |dir| + path = File.join(dir, 'knn.json') + knn.save_to_file(path) + + assert_path_exists path + + loaded = Classifier::KNN.load_from_file(path) + + assert_equal knn.k, loaded.k + assert_equal knn.weighted, loaded.weighted + assert_equal knn.classify('Dogs are great'), loaded.classify('Dogs are great') + end + end + + def test_save_load_preserves_classification + knn = Classifier::KNN.new(k: 3) + knn.add( + 'Dog' => [@str1, @str2], + 'Cat' => [@str3, @str4], + 'Bird' => @str5 + ) + + Dir.mktmpdir do |dir| + path = File.join(dir, 'knn.json') + knn.save_to_file(path) + loaded = Classifier::KNN.load_from_file(path) + + assert_equal knn.classify(@str1), loaded.classify(@str1) + assert_equal knn.classify('Dogs are nice'), loaded.classify('Dogs are nice') + assert_equal knn.classify('Cats are cute'), loaded.classify('Cats are cute') + end + end + + # Marshal tests + + def test_marshal_dump_load + knn = Classifier::KNN.new(k: 3, weighted: true) + knn.add('Dog' => [@str1, @str2], 'Cat' => @str3) + + dumped = Marshal.dump(knn) + loaded = Marshal.load(dumped) + + assert_equal knn.k, loaded.k + assert_equal knn.weighted, loaded.weighted + assert_equal knn.items.sort, loaded.items.sort + assert_equal knn.classify('Dogs are great'), loaded.classify('Dogs are great') + end + + # Dirty tracking tests + + def test_dirty_after_add + knn = Classifier::KNN.new + + refute_predicate knn, :dirty? + + knn.add('Dog' => 'Dogs are great') + + assert_predicate knn, :dirty? + end + + def test_dirty_after_remove + knn = Classifier::KNN.new + knn.add('Dog' => 'Dogs are great') + knn.instance_variable_set(:@dirty, false) + + knn.remove_item('Dogs are great') + + assert_predicate knn, :dirty? + end + + def test_save_clears_dirty + knn = Classifier::KNN.new + knn.add('Dog' => 'Dogs are great') + + assert_predicate knn, :dirty? + + Dir.mktmpdir do |dir| + path = File.join(dir, 'knn.json') + knn.save_to_file(path) + + refute_predicate knn, :dirty? + end + end + + # Storage tests + + def test_save_without_storage_raises + knn = Classifier::KNN.new + + assert_raises(ArgumentError) { knn.save } + end + + def test_reload_without_storage_raises + knn = Classifier::KNN.new + + assert_raises(ArgumentError) { knn.reload } + end + + def test_storage_save_and_load + knn = Classifier::KNN.new(k: 3) + knn.add('Dog' => @str1, 'Cat' => @str2) + + storage = Classifier::Storage::Memory.new + knn.storage = storage + knn.save + + loaded = Classifier::KNN.load(storage: storage) + + assert_equal knn.k, loaded.k + assert_equal knn.items.sort, loaded.items.sort + end + + def test_reload + storage = Classifier::Storage::Memory.new + + knn = Classifier::KNN.new(k: 3) + knn.add('Dog' => @str1) + knn.storage = storage + knn.save + + # Modify after save + knn.add('Cat' => @str2) + + assert_equal 2, knn.items.size + + # Reload should restore to saved state + knn.reload! + + assert_equal 1, knn.items.size + assert_includes knn.items, @str1 + end + + def test_reload_with_unsaved_changes + storage = Classifier::Storage::Memory.new + + knn = Classifier::KNN.new + knn.add('Dog' => @str1) + knn.storage = storage + knn.save + + knn.add('Cat' => @str2) + + assert_raises(Classifier::UnsavedChangesError) { knn.reload } + end + + def test_reload_success + storage = Classifier::Storage::Memory.new + + knn = Classifier::KNN.new(k: 3) + knn.add('Dog' => @str1) + knn.storage = storage + knn.save + + # Modify but don't mark as dirty (simulate external change) + knn.instance_variable_set(:@dirty, false) + + result = knn.reload + + assert_same knn, result + refute_predicate knn, :dirty? + end + + # Edge cases + + def test_single_item_classification + knn = Classifier::KNN.new(k: 5) + knn.add('Dog' => 'Dogs are great') + + result = knn.classify('Something about dogs') + + assert_equal 'Dog', result + end + + def test_classification_with_very_different_text + knn = Classifier::KNN.new(k: 3) + knn.add( + 'Dog' => [@str1, @str2], + 'Cat' => [@str3, @str4] + ) + + # Even very different text should return some classification + result = knn.classify('Completely unrelated computer programming text') + + refute_nil result + end + + def test_items_returns_copy + knn = Classifier::KNN.new + knn.add('Dog' => 'Dogs are great') + + items = knn.items + + # Modifying returned array shouldn't affect internal state + items.clear + + assert_equal 1, knn.items.size + end +end From 0f4c7b5c6815e1b503edde6b19b6ecf3ba6c669f Mon Sep 17 00:00:00 2001 From: Lucas Carlson Date: Sun, 28 Dec 2025 18:16:47 -0800 Subject: [PATCH 2/4] docs: clarify kNN vs Bayes size guidance --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c571169..b4e0eaa 100644 --- a/README.md +++ b/README.md @@ -250,9 +250,11 @@ knn.categories | Classifier | Best For | |------------|----------| -| **Bayes** | Large training sets, fast classification, spam filtering | +| **Bayes** | Fast classification, any training size (stores only word counts) | | **LSI** | Semantic similarity, document clustering, search | -| **kNN** | Small datasets, interpretable results, incremental learning | +| **kNN** | <1000 examples, interpretable results, incremental updates | + +**Why the size difference?** Bayes stores aggregate statistics—adding 10,000 documents just increments counters. kNN stores every example and compares against all of them during classification, so performance degrades with size. ## Persistence From 55f205c252b6058dbff3c3cfbd73f7080ebb995e Mon Sep 17 00:00:00 2001 From: Lucas Carlson Date: Sun, 28 Dec 2025 18:23:58 -0800 Subject: [PATCH 3/4] refactor(knn): remove YARD tags, keep descriptions Remove @param/@return/@example tags that duplicate @rbs type info. Keep method descriptions that explain behavior. Remove deprecated add_item method (new class doesn't need legacy API). --- lib/classifier/knn.rb | 79 +++---------------------------------------- test/knn/knn_test.rb | 9 ----- 2 files changed, 4 insertions(+), 84 deletions(-) diff --git a/lib/classifier/knn.rb b/lib/classifier/knn.rb index c62e6eb..092ac34 100644 --- a/lib/classifier/knn.rb +++ b/lib/classifier/knn.rb @@ -34,10 +34,6 @@ class KNN attr_accessor :weighted, :storage # Creates a new kNN classifier. - # - # @param k [Integer] Number of neighbors to consider (default: 5) - # @param weighted [Boolean] Use distance-weighted voting (default: false) - # # @rbs (?k: Integer, ?weighted: bool) -> void def initialize(k: 5, weighted: false) # rubocop:disable Naming/MethodParameterName super() @@ -49,63 +45,21 @@ def initialize(k: 5, weighted: false) # rubocop:disable Naming/MethodParameterNa @storage = nil end - # Adds labeled examples to the classifier using hash-style syntax. - # Keys are categories, values are items (or arrays of items). - # - # @example Single item per category - # knn.add("spam" => "Buy now!") - # knn.add("ham" => "Meeting tomorrow") - # - # @example Multiple items per category - # knn.add("spam" => ["Buy now!", "Limited offer!"]) - # - # @example Batch operations - # knn.add( - # "spam" => ["Buy now!", "Limited offer!"], - # "ham" => ["Meeting tomorrow", "Project update"] - # ) - # + # Adds labeled examples. Keys are categories, values are items or arrays. # @rbs (**untyped items) -> void def add(**items) synchronize { @dirty = true } @lsi.add(**items) end - # Adds a single labeled example to the classifier. - # - # @deprecated Use {#add} instead for clearer hash-style syntax. - # - # @param item [String] The text content to add - # @param category [String, Symbol] The category/label for this item - # - # @rbs (String, String | Symbol) -> void - def add_item(item, category) - synchronize { @dirty = true } - @lsi.add_item(item, category) - end - - # Classifies the given text by finding the k nearest neighbors - # and using majority voting. - # - # @param text [String] The text to classify - # @return [String, Symbol, nil] The predicted category, or nil if no examples exist - # + # Classifies text using k nearest neighbors with majority voting. # @rbs (String) -> (String | Symbol)? def classify(text) result = classify_with_neighbors(text) result[:category] end - # Classifies the given text and returns detailed information about - # the neighbors that contributed to the decision. - # - # @param text [String] The text to classify - # @return [Hash] A hash containing: - # - :category - The predicted category - # - :neighbors - Array of neighbor details (item, category, similarity) - # - :votes - Hash of category => vote count/weight - # - :confidence - Confidence score (winning vote share) - # + # Classifies and returns {category:, neighbors:, votes:, confidence:}. # @rbs (String) -> Hash[Symbol, untyped] def classify_with_neighbors(text) synchronize do @@ -174,10 +128,6 @@ def to_json(_options = nil) end # Loads a classifier from a JSON string or Hash. - # - # @param json [String, Hash] JSON string or parsed hash - # @return [KNN] A new KNN instance with restored state - # # @rbs (String | Hash[String, untyped]) -> KNN def self.from_json(json) data = json.is_a?(String) ? JSON.parse(json) : json @@ -194,7 +144,6 @@ def self.from_json(json) end # Saves the classifier to the configured storage. - # # @rbs () -> void def save raise ArgumentError, 'No storage configured. Use save_to_file(path) or set storage=' unless storage @@ -204,10 +153,6 @@ def save end # Saves the classifier to a file. - # - # @param path [String] The file path - # @return [Integer] Number of bytes written - # # @rbs (String) -> Integer def save_to_file(path) result = File.write(path, to_json) @@ -216,7 +161,6 @@ def save_to_file(path) end # Reloads the classifier from configured storage. - # # @rbs () -> self def reload raise ArgumentError, 'No storage configured' unless storage @@ -230,8 +174,7 @@ def reload self end - # Force reloads the classifier from storage. - # + # Force reloads, discarding unsaved changes. # @rbs () -> self def reload! raise ArgumentError, 'No storage configured' unless storage @@ -250,10 +193,6 @@ def dirty? end # Loads a classifier from configured storage. - # - # @param storage [Storage::Base] The storage to load from - # @return [KNN] The loaded classifier - # # @rbs (storage: Storage::Base) -> KNN def self.load(storage:) data = storage.read @@ -265,10 +204,6 @@ def self.load(storage:) end # Loads a classifier from a file. - # - # @param path [String] The file path - # @return [KNN] The loaded classifier - # # @rbs (String) -> KNN def self.load_from_file(path) from_json(File.read(path)) @@ -288,8 +223,6 @@ def marshal_load(data) private - # Finds the k nearest neighbors for the given text. - # # @rbs (String) -> Array[Hash[Symbol, untyped]] def find_neighbors(text) # LSI requires at least 2 items to build an index @@ -315,8 +248,6 @@ def find_neighbors(text) end end - # Tallies votes from neighbors, optionally weighted by similarity. - # # @rbs (Array[Hash[Symbol, untyped]]) -> Hash[String | Symbol, Float] def tally_votes(neighbors) votes = Hash.new(0.0) @@ -342,8 +273,6 @@ def validate_k!(val) raise ArgumentError, "k must be a positive integer, got #{val}" unless val.is_a?(Integer) && val.positive? end - # Restores state from JSON (used by reload). - # # @rbs (String) -> void def restore_from_json(json) data = JSON.parse(json) diff --git a/test/knn/knn_test.rb b/test/knn/knn_test.rb index 7c932dc..9eed7db 100644 --- a/test/knn/knn_test.rb +++ b/test/knn/knn_test.rb @@ -94,15 +94,6 @@ def test_add_batch_operations assert_equal ['Cat'], knn.categories_for('Cats are independent') end - def test_add_item_legacy_api - knn = Classifier::KNN.new - knn.add_item 'Dogs are loyal pets', 'Dog' - knn.add_item 'Cats are independent', 'Cat' - - assert_equal 2, knn.items.size - assert_equal ['Dog'], knn.categories_for('Dogs are loyal pets') - end - # Classification tests def test_basic_classification From fd02858a28dea0e1144a2be358c210e73eaf5499 Mon Sep 17 00:00:00 2001 From: Lucas Carlson Date: Sun, 28 Dec 2025 18:44:28 -0800 Subject: [PATCH 4/4] refactor: address code review feedback - Fix LSI single-item bug instead of working around in KNN - Remove defensive max(0) check on similarity scores - Use early return for nil winner - Simplify guard clause with 'or next' - Condense class docstring - Avoid mutating input hash with .dup --- lib/classifier/knn.rb | 43 +++++++++++-------------------------------- lib/classifier/lsi.rb | 1 + 2 files changed, 12 insertions(+), 32 deletions(-) diff --git a/lib/classifier/knn.rb b/lib/classifier/knn.rb index 092ac34..d473abc 100644 --- a/lib/classifier/knn.rb +++ b/lib/classifier/knn.rb @@ -9,13 +9,9 @@ require 'classifier/lsi' module Classifier - # This class implements a k-Nearest Neighbors classifier that leverages - # the existing LSI infrastructure for similarity computations. + # Instance-based classification: stores examples and classifies by similarity. # - # Unlike traditional classifiers that require training, kNN uses instance-based - # learning - it stores examples and classifies by finding the most similar ones. - # - # Example usage: + # Example: # knn = Classifier::KNN.new(k: 3) # knn.add("spam" => ["Buy now!", "Limited offer!"]) # knn.add("ham" => ["Meeting tomorrow", "Project update"]) @@ -70,15 +66,12 @@ def classify_with_neighbors(text) votes = tally_votes(neighbors) winner = votes.max_by { |_, v| v }&.first + return empty_result unless winner + total_votes = votes.values.sum - confidence = winner && total_votes.positive? ? votes[winner] / total_votes.to_f : 0.0 + confidence = total_votes.positive? ? votes[winner] / total_votes.to_f : 0.0 - { - category: winner, - neighbors: neighbors, - votes: votes, - confidence: confidence - } + { category: winner, neighbors: neighbors, votes: votes, confidence: confidence } end end @@ -133,9 +126,8 @@ def self.from_json(json) data = json.is_a?(String) ? JSON.parse(json) : json raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'knn' - # Restore the LSI from its nested data - lsi_data = data['lsi'] - lsi_data['type'] = 'lsi' # Ensure type is set for LSI.from_json + lsi_data = data['lsi'].dup + lsi_data['type'] = 'lsi' instance = new(k: data['k'], weighted: data['weighted']) instance.instance_variable_set(:@lsi, LSI.from_json(lsi_data)) @@ -225,17 +217,6 @@ def marshal_load(data) # @rbs (String) -> Array[Hash[Symbol, untyped]] def find_neighbors(text) - # LSI requires at least 2 items to build an index - # For single item, return it directly with a default similarity - if @lsi.items.size == 1 - item = @lsi.items.first - return [{ - item: item, - category: @lsi.categories_for(item).first, - similarity: 1.0 - }] - end - proximity = @lsi.proximity_array_for_content(text) neighbors = proximity.reject { |item, _| item == text }.first(@k) @@ -253,10 +234,8 @@ def tally_votes(neighbors) votes = Hash.new(0.0) neighbors.each do |neighbor| - category = neighbor[:category] - next unless category - - weight = @weighted ? [neighbor[:similarity], 0.0].max : 1.0 + category = neighbor[:category] or next + weight = @weighted ? neighbor[:similarity] : 1.0 votes[category] += weight end @@ -282,7 +261,7 @@ def restore_from_json(json) @k = data['k'] @weighted = data['weighted'] - lsi_data = data['lsi'] + lsi_data = data['lsi'].dup lsi_data['type'] = 'lsi' @lsi = LSI.from_json(lsi_data) @dirty = false diff --git a/lib/classifier/lsi.rb b/lib/classifier/lsi.rb index d67f188..416a5cf 100644 --- a/lib/classifier/lsi.rb +++ b/lib/classifier/lsi.rb @@ -629,6 +629,7 @@ def needs_rebuild_unlocked? # @rbs (String) ?{ (String) -> String } -> Array[[String, Float]] def proximity_array_for_content_unlocked(doc, &) return [] if needs_rebuild_unlocked? + return @items.keys.map { |item| [item, 1.0] } if @items.size == 1 content_node = node_for_content_unlocked(doc, &) result =