Machine Learning - Deep Learning

From Q
Jump to navigation Jump to search

Fits a neural network for classification or regression

This method is only available in Q5.

Fits a neural network for classification or regression.

A random 30% of the data is used for cross-validation to find the optimal number of epochs according to cross-validation loss. The final network is trained on all data for the optimal number of epochs.

Usage

To run Deep Learning:

1. In Displayr, select Insert > More > Machine Learning > Deep Learning. In Q, select Automate > Browse Online Library> Machine Learning > Deep Learning.
2. Under Inputs > Deep Learning > Outcome select the outcome varaible.
3. Under Inputs > Deep Learning > Predictors select the predictor variables.
4. Change any other settings as required.

Example

The Cross Validation output of a deep learning model. Loss is the quantity minimized by the model, which is mean squared error in the case of a numeric output. mean_absolute_error is also shown.

Options

Outcome The variable to be predicted by the predictor variables. It may be either a numeric or categorical variable.

Predictors The variable(s) to predict the Outcome.

Algorithm The machine learning algorithm. Defaults to Deep Learning but may be changed to other machine learning methods.

Output

Accuracy When Outcome is categorical, produces a table of accuracy by class. Else calculates Root Mean Squared Error and R-squared if Outcome is numeric.
Prediction-Accuracy Table Produces a table relating the observed and predicted outcome. Also known as a confusion matrix.
Cross Validation Produces charts of loss (i.e. network error) and accuracy or mean absolute error vs training epoch.
Network Layers This returns a description of the layers of the network.

Missing data See Missing Data Options.

Maximum epochs The maximum number of epochs to train the network for. The actual number of epochs may be lower if the cross-validation error stops improving.

Hidden layers A comma delimited list of the number of units in the hidden layers.

Normalize predictors Whether the predictor variables are normalized to zero mean and unit variance. This is recommended if the variables differ significantly in their ranges. Note that categorical variables are also converted to dummy variables.

Variable names Displays Variable Names in the output.

Weight. Where a weight has been set for the R Output, a new data set is generated via resampling, and this new data set is used in the estimation.

Filter The data is automatically filtered using any filters prior to estimating the model.

Additional options are available by editing the code.

DIAGNOSTICS

Prediction-Accuracy Table Creates a table showing the observed and predicted values, as a heatmap.

SAVE VARIABLE(S)

Predicted Values Creates a new variable containing predicted values for each case in the data.

Probabilities of Each Response Creates new variables containing predicted probabilities of each response.

Acknowledgements

Uses the keras package, which uses TensorFlow.

More information

See this blog post for an introduction to deep learning.

Code

var controls = [];

// ALGORITHM
var algorithm = form.comboBox({label: "Algorithm",
                               alternatives: ["CART", "Deep Learning", "Gradient Boosting", "Linear Discriminant Analysis",
                                              "Random Forest", "Regression", "Support Vector Machine"],
                               name: "formAlgorithm", default_value: "Deep Learning",
                               prompt: "Machine learning or regression algorithm for fitting the model"});

controls.push(algorithm);
algorithm = algorithm.getValue();

var regressionType = "";
if (algorithm == "Regression")
{
    regressionTypeControl = form.comboBox({label: "Regression type", 
                                           alternatives: ["Linear", "Binary Logit", "Ordered Logit", "Multinomial Logit", "Poisson",
                                                          "Quasi-Poisson", "NBD"], 
                                           name: "formRegressionType", default_value: "Linear",
                                           prompt: "Select type according to outcome variable type"});
    regressionType = regressionTypeControl.getValue();
    controls.push(regressionTypeControl);
}

// DEFAULT CONTROLS
missing_data_options = ["Error if missing data", "Exclude cases with missing data", "Imputation (replace missing values with estimates)"];

// AMEND DEFAULT CONTROLS PER ALGORITHM
if (algorithm == "Support Vector Machine")
    output_options = ["Accuracy", "Prediction-Accuracy Table", "Detail"];
if (algorithm == "Gradient Boosting") 
    output_options = ["Accuracy", "Importance", "Prediction-Accuracy Table", "Detail"];
if (algorithm == "Random Forest")
    output_options = ["Importance", "Prediction-Accuracy Table", "Detail"];
if (algorithm == "Deep Learning")
    output_options = ["Accuracy", "Prediction-Accuracy Table", "Cross Validation", "Network Layers"];
if (algorithm == "Linear Discriminant Analysis")
    output_options = ["Means", "Detail", "Prediction-Accuracy Table", "Scatterplot", "Moonplot"];

if (algorithm == "CART") {
    output_options = ["Sankey", "Tree", "Text", "Prediction-Accuracy Table", "Cross Validation"];
    missing_data_options = ["Error if missing data", "Exclude cases with missing data",
                             "Use partial data", "Imputation (replace missing values with estimates)"]
}
if (algorithm == "Regression") {
    if (regressionType == "Multinomial Logit")
        output_options = ["Summary", "Detail", "ANOVA"];
    else if (regressionType == "Linear")
        output_options = ["Summary", "Detail", "ANOVA", "Relative Importance Analysis", "Shapley Regression", "Jaccard Coefficient", "Correlation", "Effects Plot"];
    else
        output_options = ["Summary", "Detail", "ANOVA", "Relative Importance Analysis", "Effects Plot"];
}

// COMMON CONTROLS FOR ALL ALGORITHMS
var outputControl = form.comboBox({label: "Output", prompt: "The type of output used to show the results",
                                   alternatives: output_options, name: "formOutput",
                                   default_value: output_options[0]});
controls.push(outputControl);
var output = outputControl.getValue();

if (algorithm == "Regression") {
    if (regressionType == "Linear") {
        if (output == "Jaccard Coefficient" || output == "Correlation")
            missing_data_options = ["Error if missing data", "Exclude cases with missing data", "Use partial data (pairwise correlations)"];
        else
            missing_data_options = ["Error if missing data", "Exclude cases with missing data", "Dummy variable adjustment", "Use partial data (pairwise correlations)", "Multiple imputation"];
    }        
    else
        missing_data_options = ["Error if missing data", "Exclude cases with missing data", "Dummy variable adjustment", "Multiple imputation"];
}

var missingControl = form.comboBox({label: "Missing data", 
                                    alternatives: missing_data_options, name: "formMissing", default_value: "Exclude cases with missing data",
                                    prompt: "Options for handling cases with missing data"});
var missing = missingControl.getValue();
controls.push(missingControl);
controls.push(form.checkBox({label: "Variable names", name: "formNames", default_value: false, prompt: "Display names instead of labels"}));

// CONTROLS FOR SPECIFIC ALGORITHMS

if (algorithm == "Support Vector Machine")
    controls.push(form.textBox({label: "Cost", name: "formCost", default_value: 1, type: "number",
                                prompt: "High cost produces a complex model with risk of overfitting, low cost produces a simpler mode with risk of underfitting"}));

if (algorithm == "Gradient Boosting") {
    controls.push(form.comboBox({label: "Booster", 
                                 alternatives: ["gbtree", "gblinear"], name: "formBooster", default_value: "gbtree",
                                 prompt: "Boost tree or linear underlying models"}));
    controls.push(form.checkBox({label: "Grid search", name: "formSearch", default_value: false,
                                 prompt: "Search for optimal hyperparameters"}));
}

if (algorithm == "Random Forest")
    if (output == "Importance")
        controls.push(form.checkBox({label: "Sort by importance", name: "formImportance", default_value: true}));

if (algorithm == "Deep Learning") {
    controls.push(form.numericUpDown({name:"formEpochs", label:"Maximum epochs", default_value: 10, minimum: 1, maximum: Number.MAX_SAFE_INTEGER,
                                      prompt: "Number of rounds of training"}));
    controls.push(form.textBox({name: "formHiddenLayers", label: "Hidden layers", prompt: "Comma delimited list of the number of nodes in each hidden layer", required: true}));
    controls.push(form.checkBox({label: "Normalize predictors", name: "formNormalize", default_value: true,
                                 prompt: "Normalize to zero mean and unit variance"}));
}

if (algorithm == "Linear Discriminant Analysis") {
    if (output == "Scatterplot")
    {
        controls.push(form.colorPicker({label: "Outcome color", name: "formOutColor", default_value:"#5B9BD5"}));
        controls.push(form.colorPicker({label: "Predictors color", name: "formPredColor", default_value:"#ED7D31"}));
    }
    controls.push(form.comboBox({label: "Prior", alternatives: ["Equal", "Observed",], name: "formPrior", default_value: "Observed",
                                 prompt: "Probabilities of group membership"}));
}

if (algorithm == "CART") {
    controls.push(form.comboBox({label: "Pruning", alternatives: ["Minimum error", "Smallest tree", "None"], 
                                 name: "formPruning", default_value: "Minimum error",
                                 prompt: "Remove nodes after tree has been built"}));
    controls.push(form.checkBox({label: "Early stopping", name: "formStopping", default_value: false,
                                 prompt: "Stop building tree when fit does not improve"}));
    controls.push(form.comboBox({label: "Predictor category labels", alternatives: ["Full labels", "Abbreviated labels", "Letters"],
                                 name: "formPredictorCategoryLabels", default_value: "Abbreviated labels",
                                 prompt: "Labelling of predictor categories in the tree"}));
    controls.push(form.comboBox({label: "Outcome category labels", alternatives: ["Full labels", "Abbreviated labels", "Letters"],
                                 name: "formOutcomeCategoryLabels", default_value: "Full labels",
                                 prompt: "Labelling of outcome categories in the tree"}));
    controls.push(form.checkBox({label: "Allow long-running calculations", name: "formLongRunningCalculations", default_value: false,
                                 prompt: "Allow predictors with more than 30 categories"}));
}

var stacked_check = false;
if (algorithm == "Regression") {
    if (missing == "Multiple imputation")
        controls.push(form.dropBox({label: "Auxiliary variables",
                                    types:["Variable: Numeric, Date, Money, Categorical, OrderedCategorical"], 
                                    name: "formAuxiliaryVariables", required: false, multi:true,
                                    prompt: "Additional variables to use when imputing missing values"}));
    controls.push(form.comboBox({label: "Correction", alternatives: ["None", "False Discovery Rate", "Bonferroni"], name: "formCorrection",
                                 default_value: "None", prompt: "Multiple comparisons correction applied when computing p-values of post-hoc comparisons"}));
    var is_RIA_or_shapley = output == "Relative Importance Analysis" || output == "Shapley Regression";
    var is_Jaccard_or_Correlation = output == "Jaccard Coefficient" || output == "Correlation";
    if (regressionType == "Linear" && missing != "Use partial data (pairwise correlations)" && missing != "Multiple imputation")
        controls.push(form.checkBox({label: "Robust standard errors", name: "formRobustSE", default_value: false,
                                     prompt: "Standard errors are robust to violations of assumption of constant variance"}));
    if (is_RIA_or_shapley)
        controls.push(form.checkBox({label: "Absolute importance scores", name: "formAbsoluteImportance", default_value: false,
                                     prompt: "Show absolute instead of signed importances"}));
    if (regressionType != "Multinomial Logit" && (is_RIA_or_shapley || is_Jaccard_or_Correlation || output == "Summary"))
        controls.push(form.dropBox({label: "Crosstab interaction", name: "formInteraction", types:["Variable: Numeric, Date, Money, Categorical, OrderedCategorical"],
                                    required: false, prompt: "Categorical variable to test for interaction with other variables"}));
    if (regressionType !== "Multinomial Logit")
        controls.push(form.numericUpDown({name : "formOutlierProportion", label:"Automated outlier removal percentage", default_value: 0, 
                                          minimum:0, maximum:49.9, increment:0.1,
                                          prompt: "Data points removed and model refitted based on the residual values in the model using the full dataset"}));
    stacked_check_box = form.checkBox({label: "Stack data", name: "formStackedData", default_value: false,
                                       prompt: "Allow input into the Outcome control to be a single multi variable and Predictors to be a single grid variable"})
    stacked_check = stacked_check_box.getValue();
    controls.push(stacked_check_box);
}

controls.push(form.numericUpDown({name:"formSeed", label:"Random seed", default_value: 12321, minimum: 1, maximum: Number.MAX_SAFE_INTEGER,
                                  prompt: "Initializes randomization for imputation and certain algorithms"}));

var outcome = form.dropBox({label: "Outcome", 
                            types: [ stacked_check ? "VariableSet: BinaryMulti, NominalMulti, OrdinalMulti, NumericMulti" : "Variable: Numeric, Date, Money, Categorical, OrderedCategorical"], 
                            multi: false,
                            name: "formOutcomeVariable",
                            prompt: "Independent target variable to be predicted"});
var predictors = form.dropBox({label: "Predictor(s)",
                               types:[ stacked_check ? "VariableSet: BinaryGrid, NumericGrid" : "Variable: Numeric, Date, Money, Categorical, OrderedCategorical"], 
                               name: "formPredictorVariables", multi: stacked_check ? false : true,
                               prompt: "Dependent input variables"});

controls.unshift(predictors);
controls.unshift(outcome);

form.setInputControls(controls);
form.setHeading((regressionType == "" ? "" : (regressionType + " ")) + algorithm);
library(flipMultivariates)

model <- MachineLearning(formula = if (isTRUE(get0("formStackedData"))) as.formula(NULL) else QFormula(formOutcomeVariable ~ formPredictorVariables),
                         algorithm = formAlgorithm,
                         weights = QPopulationWeight, subset = QFilter,
                         missing = formMissing,
                         output = if (formOutput == "Shapley Regression") "Shapley regression" else formOutput,
                         show.labels = !formNames,
                         seed = get0("formSeed"),
                         cost = get0("formCost"),
                         booster = get0("formBooster"),
                         grid.search = get0("formSearch"),
                         sort.by.importance = get0("formImportance"),
                         hidden.nodes = get0("formHiddenLayers"),
                         max.epochs = get0("formEpochs"),
                         normalize = get0("formNormalize"),
                         outcome.color = get0("formOutColor"),
                         predictors.color = get0("formPredColor"),
                         prior = get0("formPrior"),
                         prune = get0("formPruning"),
                         early.stopping = get0("formStopping"),
                         predictor.level.treatment = get0("formPredictorCategoryLabels"),
                         outcome.level.treatment = get0("formOutcomeCategoryLabels"),
                         long.running.calculations = get0("formLongRunningCalculations"),
                         type = get0("formRegressionType"),
                         auxiliary.data = get0("formAuxiliaryVariables"),
                         correction = get0("formCorrection"),
                         robust.se = get0("formRobustSE", ifnotfound = FALSE),
                         importance.absolute = get0("formAbsoluteImportance"),
                         interaction = get0("formInteraction"),
                         outlier.prop.to.remove = if (get0("formRegressionType", ifnotfound = "") != "Multinomial Logit") get0("formOutlierProportion")/100 else NULL,
                         stacked.data.check = get0("formStackedData"),
                         unstacked.data = if (isTRUE(get0("formStackedData"))) list(Y = get0("formOutcomeVariable"), X = get0("formPredictorVariables")) else NULL)