Using LightGBM from Haskell
Summary
- Easily run LightGBM from Haskell
- Improved parameter value safety thanks to refined types
- Improved parameter grouping safety thanks to algebraic data types
Introduction
One of the most popular recent approaches to tree boosting as a machine learning approach is LightGBM. It's fast, performs well on various sorts of regression and classification problems, and has an open-source implementation under active development which includes APIs in Python and R as well as a community-provided Julia wrapper. It does not, however, do anything to make life easy for the average Haskell programmer. Let's change that.
I've created a Haskell wrapper around LightGBM which makes it easy to call the library from a Haskell program. Beyond simply making it possible to use LightGBM from Haskell, however, I've tried to make this interface to the library relatively safe. In partcular:
- Many LightGBM parameters are supposed to be limited to defined sets of values (e.g. fractions in the range \([0, 1]\), numbers in the range \([1, 2)\), or strictly positive integers). In this Haskell interface I use refinement types to ensure that these limitations are honored - at compile time when possible, at run time when necessary.
- Many LightGBM parameters are only supposed to be used in a subset of cases (e.g. for a particular objective or with a particular metric function). In this Haskell interface I use algebraic data types to try to ensure that combinations of parameters that don't make sense cannot be represented.
I hope that these decisions make the library safer and easier to use than it would otherwise be.
So let's take a look at how to put the library to use.
How to use the library
We'll apply LightGBM to the venerable Kaggle problem of predicting who will survive the sinking of the Titanic. The code and data for this exercise can be found in the Titanic example subdirectory of the GitHub repo.1
Data Preparation
We start by taking the Kaggle training data and splitting it into two
parts - one for training and one for validation.2 In the repo,
you'll find these files named train_part.csv and
validate_part.csv.
Here's a sample of the first few rows of the training data:
| PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 | 0 | 3 | Braund, Mr. Owen Harris | male | 22 | 1 | 0 | A/5 21171 | 7.25 | S | |
| 2 | 1 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Thayer) | female | 38 | 1 | 0 | PC 17599 | 71.2833 | C85 | C |
| 3 | 1 | 3 | Heikkinen, Miss. Laina | female | 26 | 0 | 0 | STON/O2. 3101282 | 7.925 | S |
The next step is to note that LightGBM requires that categorical
features be encoded as integer values, so we'll need to convert some
of the columns (e.g. Sex, Embarked) into integers. In my case, I
wrote some code using Haskell's Cassava CSV library to do the
conversion,3 but any approach you want to use is fine.
We should also get rid of any columns which will induce spurious
correlations or which have so little data that we can't use them. In
our case, that would include Name and Ticket for probable
irrelevance, and Cabin for lack of a reasonable amount of
data.4 We end up with a filtered dataset that looks like this:
| Survived | PassengerId | Pclass | Sex | Age | SibSp | Parch | Fare | Embarked |
|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 3 | 1 | 22.0 | 1 | 0 | 7.25 | 2 |
| 1 | 2 | 1 | 0 | 38.0 | 1 | 0 | 71.2833 | 0 |
| 1 | 3 | 3 | 0 | 26.0 | 0 | 0 | 7.925 | 2 |
Note that our dependent variable (a.k.a. what we're trying to predict) is in the 0th column - this is the default data format expected by LightGBM.
Also note that we still have the PassengerId field, even though
that's in no way relevant to the problem. We keep it around since the
final submission file format for Kaggle asks for it. To keep it from
influencing the learning algorithm, we'll use a parameter to ask
LightGBM to ignore it.
Training the booster
Training the booster is pretty straightforward - feed the training data into the training algorithm, give a few hyperparameters, and let it rip. The gory details are in the example code, but here are the highlights.
Imports
First, some imports to get the LightGBM library and CSV filtering logic:
import qualified LightGBM as LGBM
import qualified LightGBM.DataSet as DS
import qualified LightGBM.Parameters as P
import ConvertData (csvFilter, testFilter)
Training Parameters
Next, the parameters that control the training of the booster:
trainParams :: [P.Param]
trainParams =
[ P.Objective P.BinaryClassification
, P.Metric [P.BinaryLogloss, P.AUC]
, P.TrainingMetric True
, P.LearningRate $$(refineTH 0.1)
, P.NumLeaves $$(refineTH 63)
, P.FeatureFraction $$(refineTH 0.8)
, P.BaggingFreq $$(refineTH 5)
, P.BaggingFraction $$(refineTH 0.8)
, P.MinDataInLeaf 50
, P.MinSumHessianInLeaf $$(refineTH 5.0)
, P.IsSparse True
, P.LabelColumn $ P.ColName "Survived"
, P.IgnoreColumns [P.ColName "PassengerId"]
, P.CategoricalFeatures
[P.ColName "Pclass", P.ColName "Sex", P.ColName "Embarked"]
]
There are several things to note here:
- I use the refined library to limit some of the parameters to their
proper domains. For example, the
FeatureFractionshould be in the range \((0, 1]\), and by using a refinement of theDoubletype I can ensure that it's so at compile time (with that$$(refineTH ...)Template Haskell stuff).5 - LightGBM multi-parameters are converted into lists (e.g. the
Metricparameter). - LightGBM enumerated parameters are turned into equivalent sum
types (e.g. the
Objectiveparameter). - Column selection is based on a sum type rather than a string
prefix as is done in the standard LightGBM parameters (e.g. in the
LabelColumnparameter). - We can select which column contains the "labels" (the dependent
quantity being predicted) with the
LabelColumnparameter. - We can ignore some columns that we might be carrying along just
for reporting purposes using the
IgnoreColumnsparameter. - Categorical features are encoded as integers, so we have to signal
explicitly to LightGBM whether a feature is categorical (i.e. it's
just an enum of a finite set of values) or not (i.e. it's a
numerical value of some sort). We do this with the
CategoricalFeaturesparemeter.
More generally, note that the parameters module does some parameter
bundling to ensure that nonsensical combinations of parameters don't
occur. For instance, the NumClasses parameter can only be set with
the MultiClass application. This is a break from the flat parameter
space of the underlying LightGBM library where ensuring parameter
coherence is up to the user.
Loading Data
The library provides a simple interface to load data from a CSV file
with an optional header into a DataSet for use with the algorithm.
In our case, all of the files have headers so a simple helper function
is in order.
loadData :: FilePath -> DS.DataSet
loadData = DS.readCsvFile (DS.HasHeader True)
Training
I create a couple of temporary files to hold the filtered data (I'm doing the filtering inline - I could also have filtered the data out-of-band, saved them, and then fed them in directly).
trainModel :: IO LGBM.Model
trainModel =
TMP.withSystemTempFile "filtered_train" $ \trainFile trainHandle -> do
_ <- csvFilter "train_part.csv" trainHandle
hClose trainHandle
TMP.withSystemTempFile "filtered_val" $ \valFile valHandle -> do
_ <- csvFilter "validate_part.csv" valHandle
hClose valHandle
let trainingData = loadData trainFile
validationData = loadData valFile
predictionFile = "LightGBM_predict_result.txt"
modelName = "LightGBM_model.txt"
model <-
LGBM.trainNewModel modelName trainParams trainingData validationData 100
case model of
Left e -> error $ "Error training model: " ++ show e
Right m -> do
_ <- LGBM.writeModelFile modelFile m
-- [... a bit of self validation code elided here ...]
return m
Note how we use the training data and the validation data in the the training cycle.
The effect of this code is to train a model, write the model out to
the modelName file for future use, and return the model for
immediate use (or return an error-log in case there was an error
during training).
Predicting
Now that we have the model, we can use it to predict the fate of other passengers. Here we go:
main :: IO ()
main = do
cwd <- SD.getCurrentDirectory
SD.withCurrentDirectory
(cwd </> "examples" </> "titanic")
(do
m <- trainModel
TMP.withSystemTempFile "filtered_test" $ \testFile testHandle -> do
_ <- testFilter "test.csv" testHandle
hClose testHandle
TMP.withSystemTempFile "predictions" $ \predFile predHandle -> do
hClose predHandle
predResults <- LGBM.predict m [] (loadData testFile)
case predResults of
Left e -> error $ "Error predicting final results: " ++ show e
Right predValues -> do
LGBM.writeCsvFile predFile predValues
-- [... some code to report the output in Kaggle format elided ...]
)
The model output will go to the predFile where it can be used for
further processing (e.g. massaging into the proper format for
submitting to Kaggle.).
Caveats
This interface to the LightGBM library is fundamentally a wrapper
around the command-line interface to LightGBM, which makes it rather
heavily embedded in the IO type and heavily dependent on the file
system. The file system dependence is not particularly bad - data
sets and models in the machine learning space are typically large
enough that you'd want to have them persisted to disk anyway - but it
perhaps gives an odd feel to the wrapper API. Most wrappers around
LightGBM use foreign function calls to the C API and pass data
structures in directly (e.g. as Pandas or R data frames); I might do
something like that in the future if it looks like it would help
matters.
Future directions?
The wrapper presented here is still very rudimentary, and many improvements should be made. For example:
- Add the library to Hackage
- Grid search for parameter tuning
- Cross-validation support
- Better validation metrics
- Using the C API via the Haskell FFI rather than wrapping the command line interface
- DONE Provide a data frame interface to DataSet allowing us to use a data frame directly as input and extract a dataframe as output.
Footnotes:
Please note that this article is only meant to introduce you to the Haskell wrapper for LightGBM in the simplest way possible! The approach taken in this article to the actual machine learning problem at hand is intentionally naive to make sure that the exposition of the library doesn't get lost in the complexity of handling a machine learning problem properly. Do not take this approach as a tutorial in how to approach machine learning in general! A few of the most egregious simplifications will be called out in the footnotes.
This "holdout" approach is not a particularly good validation method, but it's simple to implement. Some high-level interfaces to LightGBM provide support for cross-validation, and I might supply that too eventually.
I leave open the possibility of engineering features on the basis of these columns (e.g. using titles from names as a proxy for social class, or correlating cabin numbers to particular locations on the ship); I'm just saying that leaving these columns as they are doesn't give us any useful information for a feature.
Note that to use the Template Haskell approach we need to
enable the TemplateHaskell extension using a file pragma (as is done
in the example files) or similar approach. You always have the option
of using refine instead of refineTH in which case Template Haskell
is not needed and the bounds checks will happen at runtime rather than
compile time.