This package exposes the scikit-learn interface. Packages that implement this interface can be used in conjunction with ScikitLearn.jl (pipelines, cross-validation, hyperparameter tuning, ...)
This is an intentionally slim package (~100 LOC, no dependencies). That way,
ML libraries can import ScikitLearnBase
without dragging along all of
ScikitLearn
's dependencies.
The docs contain an overview of the API and a more thorough specification.
There are two implementation strategies for an existing machine learning package:
For models with simple hyperparameters, it boils down to this:
import ScikitLearnBase
type NaiveBayes
# The model hyperparameters (not learned from data)
bias::Float64
# The parameters learned from data
counts::Matrix{Int}
# A constructor that accepts the hyperparameters as keyword arguments
# with sensible defaults
NaiveBayes(; bias=0.0f0) = new(bias)
end
# This will define `clone`, `set_params!` and `get_params` for the model
ScikitLearnBase.@declare_hyperparameters(NaiveBayes, [:bias])
# NaiveBayes is a classifier
ScikitLearnBase.is_classifier(::NaiveBayes) = true # not required for transformers
function ScikitLearnBase.fit!(model::NaiveBayes, X, y)
# X should be of size (n_sample, n_feature)
.... # modify model.counts here
return model
end
function ScikitLearnBase.predict(model::NaiveBayes, X)
.... # returns a vector of predicted classes here
end
Models with more complex hyperparameter specifications should implement clone
,
get_params
and set_params!
explicitly instead of using
@declare_hyperparameters
.
More examples of PRs that implement the interface: GaussianMixtures.jl, GaussianProcesses.jl, DecisionTree.jl, LowRankModels.jl
Note: if the model performs unsupervised learning, implement transform
instead of predict
.
Once your library implements the API, file an issue/PR to add it to the list of models.
03/01/2016
5 months ago
64 commits