GBTClassifier (Spark 4.2.0 JavaDoc)
- All Implemented Interfaces:
Serializable,org.apache.spark.internal.Logging,ClassifierParams,ProbabilisticClassifierParams,Params,HasCheckpointInterval,HasFeaturesCol,HasLabelCol,HasMaxIter,HasPredictionCol,HasProbabilityCol,HasRawPredictionCol,HasSeed,HasStepSize,HasThresholds,HasValidationIndicatorCol,HasWeightCol,PredictorParams,DecisionTreeParams,GBTClassifierParams,GBTParams,HasVarianceImpurity,TreeEnsembleClassifierParams,TreeEnsembleParams,DefaultParamsWritable,Identifiable,MLWritable
Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting) learning algorithm for classification. It supports binary labels, as well as both continuous and categorical features.
The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999.
Notes on Gradient Boosting vs. TreeBoost: - This implementation is for Stochastic Gradient Boosting, not for TreeBoost. - Both algorithms learn tree ensembles by minimizing loss functions. - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes based on the loss function, whereas the original gradient boosting method does not. - We expect to implement TreeBoost in the future: [https://issues.apache.org/jira/browse/SPARK-4240]
- See Also:
- Note:
- Multiclass labels are not currently supported.
-
Nested Class Summary
Nested classes/interfaces inherited from interface org.apache.spark.internal.Logging
org.apache.spark.internal.Logging.LogStringContext, org.apache.spark.internal.Logging.SparkShellLoggingFilter -
Constructor Summary
Constructors
-
Method Summary
If false, the algorithm will pass trees to executors to match instances with nodes.
Param for set checkpoint interval (>= 1) or disable checkpoint (-1).
Creates a copy of this instance with the same UID and some extra params.
The number of features to consider for splits at each tree node.
impurity()Criterion used for information gain calculation (case-insensitive).
leafCol()Leaf indices column name.
lossType()Loss function which GBT tries to minimize.
maxBins()Maximum number of bins used for discretizing continuous features and for choosing how to split on features at each node.
maxDepth()Maximum depth of the tree (nonnegative).
maxIter()Param for maximum number of iterations (>= 0).
Maximum memory in MB allocated to histogram aggregation.
Minimum information gain for a split to be considered at a tree node.
Minimum number of instances each child must have after split.
Minimum fraction of the weighted sample count that each child must have after split.
read()seed()setCacheNodeIds(boolean value) setCheckpointInterval(int value) Specifies how often to checkpoint the cached node IDs.
The impurity setting is ignored for GBT models.
setMaxBins(int value) setMaxDepth(int value) setMaxIter(int value) setMaxMemoryInMB(int value) setMinInfoGain(double value) setMinInstancesPerNode(int value) setMinWeightFractionPerNode(double value) setSeed(long value) setStepSize(double value) setSubsamplingRate(double value) Sets the value of param
weightCol().stepSize()Param for Step size (a.k.a.
Fraction of the training data used for learning each decision tree, in range (0, 1].
Accessor for supported loss settings: logistic
uid()An immutable unique ID for the object and its derivatives.
Param for name of the column that indicates whether each row is for training or for validation.
Threshold for stopping early when fit with validation is used.
Param for weight column name.
Methods inherited from interface org.apache.spark.internal.Logging
initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, isTraceEnabled, log, logBasedOnLevel, logDebug, logDebug, logDebug, logDebug, logError, logError, logError, logError, logInfo, logInfo, logInfo, logInfo, logName, LogStringContext, logTrace, logTrace, logTrace, logTrace, logWarning, logWarning, logWarning, logWarning, MDC, org$apache$spark$internal$Logging$$log_, org$apache$spark$internal$Logging$$log__$eq, withLogContextMethods inherited from interface org.apache.spark.ml.util.MLWritable
Methods inherited from interface org.apache.spark.ml.param.Params
clear, copyValues, defaultCopy, defaultParamMap, estimateMatadataSize, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, onParamChange, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn
-
Constructor Details
-
GBTClassifier
public GBTClassifier
(String uid) -
GBTClassifier
public GBTClassifier()
-
-
Method Details
-
supportedLossTypes
public static final String[] supportedLossTypes()
Accessor for supported loss settings: logistic
-
load
-
read
-
lossType
Loss function which GBT tries to minimize. (case-insensitive) Supported: "logistic" (default = logistic)
- Specified by:
lossTypein interfaceGBTClassifierParams- Returns:
- (undocumented)
-
impurity
Criterion used for information gain calculation (case-insensitive). This impurity type is used in DecisionTreeRegressor, RandomForestRegressor, GBTRegressor and GBTClassifier (since GBTClassificationModel is internally composed of DecisionTreeRegressionModels). Supported: "variance". (default = variance)
- Specified by:
impurityin interfaceHasVarianceImpurity- Returns:
- (undocumented)
-
validationTol
Description copied from interface:
GBTParamsThreshold for stopping early when fit with validation is used. (This parameter is ignored when fit without validation is used.) The decision to stop early is decided based on this logic: If the current loss on the validation set is greater than 0.01, the diff of validation error is compared to relative tolerance which is validationTol * (current loss on the validation set). If the current loss on the validation set is less than or equal to 0.01, the diff of validation error is compared to absolute tolerance which is validationTol * 0.01.
- Specified by:
validationTolin interfaceGBTParams- Returns:
- (undocumented)
- See Also:
-
stepSize
Description copied from interface:
GBTParamsParam for Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator. (default = 0.1)
- Specified by:
stepSizein interfaceGBTParams- Specified by:
stepSizein interfaceHasStepSize- Returns:
- (undocumented)
-
validationIndicatorCol
public final Param<String> validationIndicatorCol()
Param for name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation..
- Specified by:
validationIndicatorColin interfaceHasValidationIndicatorCol- Returns:
- (undocumented)
-
maxIter
Description copied from interface:
HasMaxIterParam for maximum number of iterations (>= 0).
- Specified by:
maxIterin interfaceHasMaxIter- Returns:
- (undocumented)
-
subsamplingRate
Fraction of the training data used for learning each decision tree, in range (0, 1]. (default = 1.0)
- Specified by:
subsamplingRatein interfaceTreeEnsembleParams- Returns:
- (undocumented)
-
featureSubsetStrategy
public final Param<String> featureSubsetStrategy()
The number of features to consider for splits at each tree node. Supported options: - "auto": Choose automatically for task: If numTrees == 1, set to "all." If numTrees greater than 1 (forest), set to "sqrt" for classification and to "onethird" for regression. - "all": use all features - "onethird": use 1/3 of the features - "sqrt": use sqrt(number of features) - "log2": use log2(number of features) - "n": when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features. (default = "auto")
These various settings are based on the following references: - log2: tested in Breiman (2001) - sqrt: recommended by Breiman manual for random forests - The defaults of sqrt (classification) and onethird (regression) match the R randomForest package.
- Specified by:
featureSubsetStrategyin interfaceTreeEnsembleParams- Returns:
- (undocumented)
- See Also:
-
leafCol
Leaf indices column name. Predicted leaf index of each instance in each tree by preorder. (default = "")
- Specified by:
leafColin interfaceDecisionTreeParams- Returns:
- (undocumented)
-
maxDepth
Maximum depth of the tree (nonnegative). E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (default = 5)
- Specified by:
maxDepthin interfaceDecisionTreeParams- Returns:
- (undocumented)
-
maxBins
Maximum number of bins used for discretizing continuous features and for choosing how to split on features at each node. More bins give higher granularity. Must be at least 2 and at least number of categories in any categorical feature. (default = 32)
- Specified by:
maxBinsin interfaceDecisionTreeParams- Returns:
- (undocumented)
-
minInstancesPerNode
public final IntParam minInstancesPerNode()
Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Must be at least 1. (default = 1)
- Specified by:
minInstancesPerNodein interfaceDecisionTreeParams- Returns:
- (undocumented)
-
minWeightFractionPerNode
public final DoubleParam minWeightFractionPerNode()
Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in the interval [0.0, 0.5). (default = 0.0)
- Specified by:
minWeightFractionPerNodein interfaceDecisionTreeParams- Returns:
- (undocumented)
-
minInfoGain
Minimum information gain for a split to be considered at a tree node. Should be at least 0.0. (default = 0.0)
- Specified by:
minInfoGainin interfaceDecisionTreeParams- Returns:
- (undocumented)
-
maxMemoryInMB
public final IntParam maxMemoryInMB()
Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size. (default = 256 MB)
- Specified by:
maxMemoryInMBin interfaceDecisionTreeParams- Returns:
- (undocumented)
-
cacheNodeIds
If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval. (default = false)
- Specified by:
cacheNodeIdsin interfaceDecisionTreeParams- Returns:
- (undocumented)
-
weightCol
Param for weight column name. If this is not set or empty, we treat all instance weights as 1.0.
- Specified by:
weightColin interfaceHasWeightCol- Returns:
- (undocumented)
-
seed
Description copied from interface:
HasSeedParam for random seed.
-
checkpointInterval
public final IntParam checkpointInterval()
Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.
- Specified by:
checkpointIntervalin interfaceHasCheckpointInterval- Returns:
- (undocumented)
-
uid
An immutable unique ID for the object and its derivatives.
- Specified by:
uidin interfaceIdentifiable- Returns:
- (undocumented)
-
setMaxDepth
-
setMaxBins
-
setMinInstancesPerNode
public GBTClassifier setMinInstancesPerNode
(int value) -
setMinWeightFractionPerNode
public GBTClassifier setMinWeightFractionPerNode
(double value) -
setMinInfoGain
-
setMaxMemoryInMB
-
setCacheNodeIds
-
setCheckpointInterval
public GBTClassifier setCheckpointInterval
(int value) Specifies how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the checkpoint directory is set in
SparkContext. Must be at least 1. (default = 10)- Parameters:
value- (undocumented)- Returns:
- (undocumented)
-
setImpurity
The impurity setting is ignored for GBT models. Individual trees are built using impurity "Variance."
- Parameters:
value- (undocumented)- Returns:
- (undocumented)
-
setSubsamplingRate
public GBTClassifier setSubsamplingRate
(double value) -
setSeed
-
setMaxIter
-
setStepSize
-
setFeatureSubsetStrategy
-
setLossType
-
setValidationIndicatorCol
-
setWeightCol
Sets the value of param
weightCol(). If this is not set or empty, we treat all instance weights as 1.0. By default the weightCol is not set, so all instances have weight 1.0.- Parameters:
value- (undocumented)- Returns:
- (undocumented)
-
copy
Description copied from interface:
ParamsCreates a copy of this instance with the same UID and some extra params. Subclasses should implement this method and set the return type properly. See
defaultCopy().- Specified by:
copyin interfaceParams- Specified by:
copyin classPredictor<Vector,GBTClassifier, GBTClassificationModel> - Parameters:
extra- (undocumented)- Returns:
- (undocumented)
-