Machine Learning - Classification And Regression Trees (CART)

From Q
Jump to navigation Jump to search

Create a CART predictive model which explains how an outcome variable's values can be predicted based on other values

A Classification And Regression Tree (CART), is a predictive model, which explains how an outcome variable's values can be predicted based on other values. A CART output is a decision tree where each fork is a split in a predictor variable and each end node contains a prediction for the outcome variable.

Example

To run a CART model in Displayr, select Insert > More > Machine Learning > Classification and Regression Trees (CART).
In Q, select Automate > Browse Online Library > Machine Learning > Classification and Regression Trees (CART).

An interactive tree created using the Sankey output option using 'Preferred Cola' as the Outcome variable and 'Age', 'Gender' and 'Exercise Frequency' as the Predictor variables.

Options

Outcome Variable to be predicted.

Predictors Variables which will be considered as predictors of the Outcome. Predictors that are considered to be uninformative will be automatically excluded from the model.

Algorithm The machine learning algorithm. Defaults to CART but may be changed to other machine learning methods.

Output How the tree should be displayed. The choices are:

  • Sankey: An interactive tree. This is the default.
  • Tree: A greyscale tree plot.
  • Text: A text representation of the tree.
  • Prediction-Accuracy Table: Produces a table relating the observed and predicted outcome. Also known as a confusion matrix.
  • Cross Validation: A plot of the cross-validation accuracy versus the size of the tree in terms of the number of leaves.

Missing data Method for dealing with missing data. See Missing Data Options.

Pruning The type of post-pruning applied to the tree. Choices are:

  • Minimum error: Prune back leaf nodes to create the tree with the smallest cross validation error.
  • Smallest tree: Prune to create the smallest tree with cross validation error at most 1 standard error greater than the minimum error.
  • None: Retain the tree as it has been built. Note that choosing this option without Early stopping is prone to overfitting.

Early stopping Whether to stop splitting nodes before the fit stops improving. Setting this may decrease the time to build the tree, potentially at the cost of not finding the tree with the best accuracy. See here for more detail.

Variable names Displays Variable Names in the output.

Predictor category labels Whether to shorten category labels from categorical predictor variables. The choices are:

  • Full labels: The complete labels.
  • Abbreviated labels: Labels that have been shortened by taking the first few letters from each word.
  • Letters: Letters from the alphabet where "a" corresponds to the first category, "b" to the second category, and so on.

Outcome category labels Same as above but for the outcome variable.

Allow long-running calculations Predictors with m categories require evaluation of 2^(m - 1) split points. This may cause calculations to run for a long time. Checking this box allows categorical variables with more than 30 categories to be included in Predictors.

DIAGNOSTICS

Prediction-Accuracy Table Creates a table showing the observed and predicted values, as a heatmap.

SAVE VARIABLE(S)

Predicted Values Creates a new variable containing predicted values for each case in the data.

Probabilities of Each Response Creates new variables containing predicted probabilities of each response.

Acknowledgements

Uses the R packages rpart and partykit.

More information

For an introduction to decision trees, see this blog post.
This blog post discusses the Pruning and Early stopping options.
This post describes how trees are built.

Code

var controls = [];

// ALGORITHM
var algorithm = form.comboBox({label: "Algorithm",
                               alternatives: ["CART", "Deep Learning", "Gradient Boosting", "Linear Discriminant Analysis",
                                              "Random Forest", "Regression", "Support Vector Machine"],
                               name: "formAlgorithm", default_value: "CART",
                               prompt: "Machine learning or regression algorithm for fitting the model"});

controls.push(algorithm);
algorithm = algorithm.getValue();

var regressionType = "";
if (algorithm == "Regression")
{
    regressionTypeControl = form.comboBox({label: "Regression type", 
                                           alternatives: ["Linear", "Binary Logit", "Ordered Logit", "Multinomial Logit", "Poisson",
                                                          "Quasi-Poisson", "NBD"], 
                                           name: "formRegressionType", default_value: "Linear",
                                           prompt: "Select type according to outcome variable type"});
    regressionType = regressionTypeControl.getValue();
    controls.push(regressionTypeControl);
}

// DEFAULT CONTROLS
missing_data_options = ["Error if missing data", "Exclude cases with missing data", "Imputation (replace missing values with estimates)"];

// AMEND DEFAULT CONTROLS PER ALGORITHM
if (algorithm == "Support Vector Machine")
    output_options = ["Accuracy", "Prediction-Accuracy Table", "Detail"];
if (algorithm == "Gradient Boosting") 
    output_options = ["Accuracy", "Importance", "Prediction-Accuracy Table", "Detail"];
if (algorithm == "Random Forest")
    output_options = ["Importance", "Prediction-Accuracy Table", "Detail"];
if (algorithm == "Deep Learning")
    output_options = ["Accuracy", "Prediction-Accuracy Table", "Cross Validation", "Network Layers"];
if (algorithm == "Linear Discriminant Analysis")
    output_options = ["Means", "Detail", "Prediction-Accuracy Table", "Scatterplot", "Moonplot"];

if (algorithm == "CART") {
    output_options = ["Sankey", "Tree", "Text", "Prediction-Accuracy Table", "Cross Validation"];
    missing_data_options = ["Error if missing data", "Exclude cases with missing data",
                             "Use partial data", "Imputation (replace missing values with estimates)"]
}
if (algorithm == "Regression") {
    if (regressionType == "Multinomial Logit")
        output_options = ["Summary", "Detail", "ANOVA"];
    else if (regressionType == "Linear")
        output_options = ["Summary", "Detail", "ANOVA", "Relative Importance Analysis", "Shapley Regression", "Jaccard Coefficient", "Correlation", "Effects Plot"];
    else
        output_options = ["Summary", "Detail", "ANOVA", "Relative Importance Analysis", "Effects Plot"];
}

// COMMON CONTROLS FOR ALL ALGORITHMS
var outputControl = form.comboBox({label: "Output", prompt: "The type of output used to show the results",
                                   alternatives: output_options, name: "formOutput",
                                   default_value: output_options[0]});
controls.push(outputControl);
var output = outputControl.getValue();

if (algorithm == "Regression") {
    if (regressionType == "Linear") {
        if (output == "Jaccard Coefficient" || output == "Correlation")
            missing_data_options = ["Error if missing data", "Exclude cases with missing data", "Use partial data (pairwise correlations)"];
        else
            missing_data_options = ["Error if missing data", "Exclude cases with missing data", "Dummy variable adjustment", "Use partial data (pairwise correlations)", "Multiple imputation"];
    }        
    else
        missing_data_options = ["Error if missing data", "Exclude cases with missing data", "Dummy variable adjustment", "Multiple imputation"];
}

var missingControl = form.comboBox({label: "Missing data", 
                                    alternatives: missing_data_options, name: "formMissing", default_value: "Exclude cases with missing data",
                                    prompt: "Options for handling cases with missing data"});
var missing = missingControl.getValue();
controls.push(missingControl);
controls.push(form.checkBox({label: "Variable names", name: "formNames", default_value: false, prompt: "Display names instead of labels"}));

// CONTROLS FOR SPECIFIC ALGORITHMS

if (algorithm == "Support Vector Machine")
    controls.push(form.textBox({label: "Cost", name: "formCost", default_value: 1, type: "number",
                                prompt: "High cost produces a complex model with risk of overfitting, low cost produces a simpler mode with risk of underfitting"}));

if (algorithm == "Gradient Boosting") {
    controls.push(form.comboBox({label: "Booster", 
                                 alternatives: ["gbtree", "gblinear"], name: "formBooster", default_value: "gbtree",
                                 prompt: "Boost tree or linear underlying models"}));
    controls.push(form.checkBox({label: "Grid search", name: "formSearch", default_value: false,
                                 prompt: "Search for optimal hyperparameters"}));
}

if (algorithm == "Random Forest")
    if (output == "Importance")
        controls.push(form.checkBox({label: "Sort by importance", name: "formImportance", default_value: true}));

if (algorithm == "Deep Learning") {
    controls.push(form.numericUpDown({name:"formEpochs", label:"Maximum epochs", default_value: 10, minimum: 1, maximum: Number.MAX_SAFE_INTEGER,
                                      prompt: "Number of rounds of training"}));
    controls.push(form.textBox({name: "formHiddenLayers", label: "Hidden layers", prompt: "Comma delimited list of the number of nodes in each hidden layer", required: true}));
    controls.push(form.checkBox({label: "Normalize predictors", name: "formNormalize", default_value: true,
                                 prompt: "Normalize to zero mean and unit variance"}));
}

if (algorithm == "Linear Discriminant Analysis") {
    if (output == "Scatterplot")
    {
        controls.push(form.colorPicker({label: "Outcome color", name: "formOutColor", default_value:"#5B9BD5"}));
        controls.push(form.colorPicker({label: "Predictors color", name: "formPredColor", default_value:"#ED7D31"}));
    }
    controls.push(form.comboBox({label: "Prior", alternatives: ["Equal", "Observed",], name: "formPrior", default_value: "Observed",
                                 prompt: "Probabilities of group membership"}));
}

if (algorithm == "CART") {
    controls.push(form.comboBox({label: "Pruning", alternatives: ["Minimum error", "Smallest tree", "None"], 
                                 name: "formPruning", default_value: "Minimum error",
                                 prompt: "Remove nodes after tree has been built"}));
    controls.push(form.checkBox({label: "Early stopping", name: "formStopping", default_value: false,
                                 prompt: "Stop building tree when fit does not improve"}));
    controls.push(form.comboBox({label: "Predictor category labels", alternatives: ["Full labels", "Abbreviated labels", "Letters"],
                                 name: "formPredictorCategoryLabels", default_value: "Abbreviated labels",
                                 prompt: "Labelling of predictor categories in the tree"}));
    controls.push(form.comboBox({label: "Outcome category labels", alternatives: ["Full labels", "Abbreviated labels", "Letters"],
                                 name: "formOutcomeCategoryLabels", default_value: "Full labels",
                                 prompt: "Labelling of outcome categories in the tree"}));
    controls.push(form.checkBox({label: "Allow long-running calculations", name: "formLongRunningCalculations", default_value: false,
                                 prompt: "Allow predictors with more than 30 categories"}));
}

var stacked_check = false;
if (algorithm == "Regression") {
    if (missing == "Multiple imputation")
        controls.push(form.dropBox({label: "Auxiliary variables",
                                    types:["Variable: Numeric, Date, Money, Categorical, OrderedCategorical"], 
                                    name: "formAuxiliaryVariables", required: false, multi:true,
                                    prompt: "Additional variables to use when imputing missing values"}));
    controls.push(form.comboBox({label: "Correction", alternatives: ["None", "False Discovery Rate", "Bonferroni"], name: "formCorrection",
                                 default_value: "None", prompt: "Multiple comparisons correction applied when computing p-values of post-hoc comparisons"}));
    var is_RIA_or_shapley = output == "Relative Importance Analysis" || output == "Shapley Regression";
    var is_Jaccard_or_Correlation = output == "Jaccard Coefficient" || output == "Correlation";
    if (regressionType == "Linear" && missing != "Use partial data (pairwise correlations)" && missing != "Multiple imputation")
        controls.push(form.checkBox({label: "Robust standard errors", name: "formRobustSE", default_value: false,
                                     prompt: "Standard errors are robust to violations of assumption of constant variance"}));
    if (is_RIA_or_shapley)
        controls.push(form.checkBox({label: "Absolute importance scores", name: "formAbsoluteImportance", default_value: false,
                                     prompt: "Show absolute instead of signed importances"}));
    if (regressionType != "Multinomial Logit" && (is_RIA_or_shapley || is_Jaccard_or_Correlation || output == "Summary"))
        controls.push(form.dropBox({label: "Crosstab interaction", name: "formInteraction", types:["Variable: Numeric, Date, Money, Categorical, OrderedCategorical"],
                                    required: false, prompt: "Categorical variable to test for interaction with other variables"}));
    if (regressionType !== "Multinomial Logit")
        controls.push(form.numericUpDown({name : "formOutlierProportion", label:"Automated outlier removal percentage", default_value: 0, 
                                          minimum:0, maximum:49.9, increment:0.1,
                                          prompt: "Data points removed and model refitted based on the residual values in the model using the full dataset"}));
    stacked_check_box = form.checkBox({label: "Stack data", name: "formStackedData", default_value: false,
                                       prompt: "Allow input into the Outcome control to be a single multi variable and Predictors to be a single grid variable"})
    stacked_check = stacked_check_box.getValue();
    controls.push(stacked_check_box);
}

controls.push(form.numericUpDown({name:"formSeed", label:"Random seed", default_value: 12321, minimum: 1, maximum: Number.MAX_SAFE_INTEGER,
                                  prompt: "Initializes randomization for imputation and certain algorithms"}));

var outcome = form.dropBox({label: "Outcome", 
                            types: [ stacked_check ? "VariableSet: BinaryMulti, NominalMulti, OrdinalMulti, NumericMulti" : "Variable: Numeric, Date, Money, Categorical, OrderedCategorical"], 
                            multi: false,
                            name: "formOutcomeVariable",
                            prompt: "Independent target variable to be predicted"});
var predictors = form.dropBox({label: "Predictor(s)",
                               types:[ stacked_check ? "VariableSet: BinaryGrid, NumericGrid" : "Variable: Numeric, Date, Money, Categorical, OrderedCategorical"], 
                               name: "formPredictorVariables", multi: stacked_check ? false : true,
                               prompt: "Dependent input variables"});

controls.unshift(predictors);
controls.unshift(outcome);

form.setInputControls(controls);
form.setHeading((regressionType == "" ? "" : (regressionType + " ")) + algorithm);
library(flipMultivariates)

model <- MachineLearning(formula = if (isTRUE(get0("formStackedData"))) as.formula(NULL) else QFormula(formOutcomeVariable ~ formPredictorVariables),
                         algorithm = formAlgorithm,
                         weights = QPopulationWeight, subset = QFilter,
                         missing = formMissing,
                         output = if (formOutput == "Shapley Regression") "Shapley regression" else formOutput,
                         show.labels = !formNames,
                         seed = get0("formSeed"),
                         cost = get0("formCost"),
                         booster = get0("formBooster"),
                         grid.search = get0("formSearch"),
                         sort.by.importance = get0("formImportance"),
                         hidden.nodes = get0("formHiddenLayers"),
                         max.epochs = get0("formEpochs"),
                         normalize = get0("formNormalize"),
                         outcome.color = get0("formOutColor"),
                         predictors.color = get0("formPredColor"),
                         prior = get0("formPrior"),
                         prune = get0("formPruning"),
                         early.stopping = get0("formStopping"),
                         predictor.level.treatment = get0("formPredictorCategoryLabels"),
                         outcome.level.treatment = get0("formOutcomeCategoryLabels"),
                         long.running.calculations = get0("formLongRunningCalculations"),
                         type = get0("formRegressionType"),
                         auxiliary.data = get0("formAuxiliaryVariables"),
                         correction = get0("formCorrection"),
                         robust.se = get0("formRobustSE", ifnotfound = FALSE),
                         importance.absolute = get0("formAbsoluteImportance"),
                         interaction = get0("formInteraction"),
                         outlier.prop.to.remove = if (get0("formRegressionType", ifnotfound = "") != "Multinomial Logit") get0("formOutlierProportion")/100 else NULL,
                         stacked.data.check = get0("formStackedData"),
                         unstacked.data = if (isTRUE(get0("formStackedData"))) list(Y = get0("formOutcomeVariable"), X = get0("formPredictorVariables")) else NULL)