Skip to content

[FR] Classifier function #142

@KronosTheLate

Description

@KronosTheLate

I am pretty new to KNN, but currently, I only know to use it for classifying. So it would be most natural for this package to include a simply classifier function. It should take indices of nearest neighbours given by knn, the train classes, evaluating each set of nearest neighbors to a class. A threshold value of minimum number of neighbours of the same class should be optional, and default to 1.

Possible sources of inspiration

  • kNN.jl's classifier.jl file.
  • The following function I have ended up having to define:
"""
    classify(neighbor_inds::Vector{Int}, train_classes::Vector{Int}; tiebreaker=rand, possible_classes=unique(train_classes))
    classify(neighbor_inds::Vector{Vector{Int}}, args...; kwargs...)

kwargs:
`tiebreaker` is 
1) a function that takes a tuple of candidates and returns a value, or 
2) a value that is returned upon a tie.
"""
function classify(neighbor_inds::AbstractVector{Int}, train_classes::AbstractVector{Int}; tiebreaker=rand, l::Int=1)
    possible_classes = unique(train_classes)
    neighbor_classes = train_classes[neighbor_inds]
    my_counts = [count(==(psbl_cls), neighbor_classes) for psbl_cls in possible_classes]
    A = [possible_classes my_counts]
    sorted_counts = sortslices(A, dims=1, by=x->x[2], rev=true)
    if sorted_counts[begin, end] < l
        return missing
    elseif sorted_counts[1, 2] == sorted_counts[2, 2]
        inds = [sorted_counts[i, 2] == sorted_counts[1, 2] for i in 1:size(sorted_counts, 1)]
        candidates_of_equal_count = sorted_counts[inds, :][:, 1]
        if tiebreaker isa Function
            return candidates_of_equal_count |> tiebreaker
        else
            f = (args...)->tiebreaker
            return candidates_of_equal_count |> f
        end
    else
        return sorted_counts[1, 1]
    end
end
function classify(neighbor_inds::Vector{Vector{Int}}, train_classes::AbstractVector{Int}; kwargs...)
    [classify(neighbor_inds[i], train_classes; kwargs...) for i in eachindex(neighbor_inds)]
end
nearest_neighbour_inds = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
train_classes = [1, 1, 2, 4, 5, 6, 8, 8, 9]
classify(nearest_neighbour_inds, train_classes)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions