Using LightGBM from Haskell

Summary

  • Easily run LightGBM from Haskell
  • Improved parameter value safety thanks to refined types
  • Improved parameter grouping safety thanks to algebraic data types

Introduction

One of the most popular recent approaches to tree boosting as a machine learning approach is LightGBM. It's fast, performs well on various sorts of regression and classification problems, and has an open-source implementation under active development which includes APIs in Python and R as well as a community-provided Julia wrapper. It does not, however, do anything to make life easy for the average Haskell programmer. Let's change that.

I've created a Haskell wrapper around LightGBM which makes it easy to call the library from a Haskell program. Beyond simply making it possible to use LightGBM from Haskell, however, I've tried to make this interface to the library relatively safe. In partcular:

  • Many LightGBM parameters are supposed to be limited to defined sets of values (e.g. fractions in the range \([0, 1]\), numbers in the range \([1, 2)\), or strictly positive integers). In this Haskell interface I use refinement types to ensure that these limitations are honored - at compile time when possible, at run time when necessary.
  • Many LightGBM parameters are only supposed to be used in a subset of cases (e.g. for a particular objective or with a particular metric function). In this Haskell interface I use algebraic data types to try to ensure that combinations of parameters that don't make sense cannot be represented.

I hope that these decisions make the library safer and easier to use than it would otherwise be.

So let's take a look at how to put the library to use.

How to use the library

We'll apply LightGBM to the venerable Kaggle problem of predicting who will survive the sinking of the Titanic. The code and data for this exercise can be found in the Titanic example subdirectory of the GitHub repo.1

Data Preparation

We start by taking the Kaggle training data and splitting it into two parts - one for training and one for validation.2 In the repo, you'll find these files named train_part.csv and validate_part.csv.

Here's a sample of the first few rows of the training data:

PassengerId Survived Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked
1 0 3 Braund, Mr. Owen Harris male 22 1 0 A/5 21171 7.25   S
2 1 1 Cumings, Mrs. John Bradley (Florence Briggs Thayer) female 38 1 0 PC 17599 71.2833 C85 C
3 1 3 Heikkinen, Miss. Laina female 26 0 0 STON/O2. 3101282 7.925   S

The next step is to note that LightGBM requires that categorical features be encoded as integer values, so we'll need to convert some of the columns (e.g. Sex, Embarked) into integers. In my case, I wrote some code using Haskell's Cassava CSV library to do the conversion,3 but any approach you want to use is fine.

We should also get rid of any columns which will induce spurious correlations or which have so little data that we can't use them. In our case, that would include Name and Ticket for probable irrelevance, and Cabin for lack of a reasonable amount of data.4 We end up with a filtered dataset that looks like this:

Survived PassengerId Pclass Sex Age SibSp Parch Fare Embarked
0 1 3 1 22.0 1 0 7.25 2
1 2 1 0 38.0 1 0 71.2833 0
1 3 3 0 26.0 0 0 7.925 2

Note that our dependent variable (a.k.a. what we're trying to predict) is in the 0th column - this is the default data format expected by LightGBM.

Also note that we still have the PassengerId field, even though that's in no way relevant to the problem. We keep it around since the final submission file format for Kaggle asks for it. To keep it from influencing the learning algorithm, we'll use a parameter to ask LightGBM to ignore it.

Training the booster

Training the booster is pretty straightforward - feed the training data into the training algorithm, give a few hyperparameters, and let it rip. The gory details are in the example code, but here are the highlights.

Imports

First, some imports to get the LightGBM library and CSV filtering logic:

import qualified LightGBM as LGBM
import qualified LightGBM.DataSet as DS
import qualified LightGBM.Parameters as P
import           ConvertData (csvFilter, testFilter)

Training Parameters

Next, the parameters that control the training of the booster:

trainParams :: [P.Param]
trainParams =
  [ P.Objective P.BinaryClassification
  , P.Metric [P.BinaryLogloss, P.AUC]
  , P.TrainingMetric True
  , P.LearningRate $$(refineTH 0.1)
  , P.NumLeaves $$(refineTH 63)
  , P.FeatureFraction $$(refineTH 0.8)
  , P.BaggingFreq $$(refineTH 5)
  , P.BaggingFraction $$(refineTH 0.8)
  , P.MinDataInLeaf 50
  , P.MinSumHessianInLeaf $$(refineTH 5.0)
  , P.IsSparse True
  , P.LabelColumn $ P.ColName "Survived"
  , P.IgnoreColumns [P.ColName "PassengerId"]
  , P.CategoricalFeatures
      [P.ColName "Pclass", P.ColName "Sex", P.ColName "Embarked"]
  ]

There are several things to note here:

  • I use the refined library to limit some of the parameters to their proper domains. For example, the FeatureFraction should be in the range \((0, 1]\), and by using a refinement of the Double type I can ensure that it's so at compile time (with that $$(refineTH ...) Template Haskell stuff).5
  • LightGBM multi-parameters are converted into lists (e.g. the Metric parameter).
  • LightGBM enumerated parameters are turned into equivalent sum types (e.g. the Objective parameter).
  • Column selection is based on a sum type rather than a string prefix as is done in the standard LightGBM parameters (e.g. in the LabelColumn parameter).
  • We can select which column contains the "labels" (the dependent quantity being predicted) with the LabelColumn parameter.
  • We can ignore some columns that we might be carrying along just for reporting purposes using the IgnoreColumns parameter.
  • Categorical features are encoded as integers, so we have to signal explicitly to LightGBM whether a feature is categorical (i.e. it's just an enum of a finite set of values) or not (i.e. it's a numerical value of some sort). We do this with the CategoricalFeatures paremeter.

More generally, note that the parameters module does some parameter bundling to ensure that nonsensical combinations of parameters don't occur. For instance, the NumClasses parameter can only be set with the MultiClass application. This is a break from the flat parameter space of the underlying LightGBM library where ensuring parameter coherence is up to the user.

Loading Data

The library provides a simple interface to load data from a CSV file with an optional header into a DataSet for use with the algorithm. In our case, all of the files have headers so a simple helper function is in order.

loadData :: FilePath -> DS.DataSet
loadData = DS.readCsvFile (DS.HasHeader True)

Training

I create a couple of temporary files to hold the filtered data (I'm doing the filtering inline - I could also have filtered the data out-of-band, saved them, and then fed them in directly).

trainModel :: IO LGBM.Model
trainModel =
  TMP.withSystemTempFile "filtered_train" $ \trainFile trainHandle -> do
    _ <- csvFilter "train_part.csv" trainHandle
    hClose trainHandle
    TMP.withSystemTempFile "filtered_val" $ \valFile valHandle -> do
      _ <- csvFilter "validate_part.csv" valHandle
      hClose valHandle
      let trainingData = loadData trainFile
	  validationData = loadData valFile
	  predictionFile = "LightGBM_predict_result.txt"
	  modelName = "LightGBM_model.txt"
      model <-
	LGBM.trainNewModel modelName trainParams trainingData validationData 100
      case model of
	Left e -> error $ "Error training model:  " ++ show e
	Right m -> do
	  _ <- LGBM.writeModelFile modelFile m
	  -- [... a bit of self validation code elided here ...]
	  return m

Note how we use the training data and the validation data in the the training cycle.

The effect of this code is to train a model, write the model out to the modelName file for future use, and return the model for immediate use (or return an error-log in case there was an error during training).

Predicting

Now that we have the model, we can use it to predict the fate of other passengers. Here we go:

main :: IO ()
main = do
  cwd <- SD.getCurrentDirectory
  SD.withCurrentDirectory
    (cwd </> "examples" </> "titanic")
    (do
	m <- trainModel

	TMP.withSystemTempFile "filtered_test" $ \testFile testHandle -> do
	  _ <- testFilter "test.csv" testHandle
	  hClose testHandle
	  TMP.withSystemTempFile "predictions" $ \predFile predHandle -> do
	    hClose predHandle

	    predResults <- LGBM.predict m [] (loadData testFile)
	    case predResults of
	      Left e -> error $ "Error predicting final results:  " ++ show e
	      Right predValues -> do
		LGBM.writeCsvFile predFile predValues

	  -- [... some code to report the output in Kaggle format elided ...]
    )

The model output will go to the predFile where it can be used for further processing (e.g. massaging into the proper format for submitting to Kaggle.).

Caveats

This interface to the LightGBM library is fundamentally a wrapper around the command-line interface to LightGBM, which makes it rather heavily embedded in the IO type and heavily dependent on the file system. The file system dependence is not particularly bad - data sets and models in the machine learning space are typically large enough that you'd want to have them persisted to disk anyway - but it perhaps gives an odd feel to the wrapper API. Most wrappers around LightGBM use foreign function calls to the C API and pass data structures in directly (e.g. as Pandas or R data frames); I might do something like that in the future if it looks like it would help matters.

Future directions?

The wrapper presented here is still very rudimentary, and many improvements should be made. For example:

  • Add the library to Hackage
  • Grid search for parameter tuning
  • Cross-validation support
  • Better validation metrics
  • Using the C API via the Haskell FFI rather than wrapping the command line interface
  • DONE Provide a data frame interface to DataSet allowing us to use a data frame directly as input and extract a dataframe as output.

Footnotes:

1

Please note that this article is only meant to introduce you to the Haskell wrapper for LightGBM in the simplest way possible! The approach taken in this article to the actual machine learning problem at hand is intentionally naive to make sure that the exposition of the library doesn't get lost in the complexity of handling a machine learning problem properly. Do not take this approach as a tutorial in how to approach machine learning in general! A few of the most egregious simplifications will be called out in the footnotes.

2

This "holdout" approach is not a particularly good validation method, but it's simple to implement. Some high-level interfaces to LightGBM provide support for cross-validation, and I might supply that too eventually.

3

Take a look at ConvertData.hs in the repo if you're interested.

4

I leave open the possibility of engineering features on the basis of these columns (e.g. using titles from names as a proxy for social class, or correlating cabin numbers to particular locations on the ship); I'm just saying that leaving these columns as they are doesn't give us any useful information for a feature.

5

Note that to use the Template Haskell approach we need to enable the TemplateHaskell extension using a file pragma (as is done in the example files) or similar approach. You always have the option of using refine instead of refineTH in which case Template Haskell is not needed and the bounds checks will happen at runtime rather than compile time.

A simple beginning

It all started with a simple classification problem: Given a document, identify what company it was about. Here are 106 documents we know how to classify, and here are 105 we don't. A standard undergrad homework exercise in machine learning, they said.

Little did I know I was headed down the rabbit hole…

A year later, I feel like I'm just beginning to get my head around all this and I'm going to try to get it down in a form that makes some sense to me.


Contents © 2018 Dan Katz - Powered by Nikola