All Implemented Interfaces:
Serializable, org.apache.spark.internal.Logging, ClassifierParams, FMClassifierParams, ProbabilisticClassifierParams, Params, HasFeaturesCol, HasFitIntercept, HasLabelCol, HasMaxIter, HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, HasRegParam, HasSeed, HasSolver, HasStepSize, HasThresholds, HasTol, HasWeightCol, PredictorParams, FactorizationMachines, FactorizationMachinesParams, DefaultParamsWritable, Identifiable, MLWritable

Factorization Machines learning algorithm for classification. It supports normal gradient descent and AdamW solver.

The implementation is based on: S. Rendle. "Factorization machines" 2010.

FM is able to estimate interactions even in problems with huge sparsity (like advertising and recommendation system). FM formula is:

$$ \begin{align} y = \sigma\left( w_0 + \sum\limits^n_{i-1} w_i x_i + \sum\limits^n_{i=1} \sum\limits^n_{j=i+1} \langle v_i, v_j \rangle x_i x_j \right) \end{align} $$

First two terms denote global bias and linear term (as same as linear regression), and last term denotes pairwise interactions term. v_i describes the i-th variable with k factors.

FM classification model uses logistic loss which can be solved by gradient descent method, and regularization terms like L2 are usually added to the loss function to prevent overfitting.

See Also:
Note:
Multiclass labels are not currently supported.
  • Nested Class Summary

    Nested classes/interfaces inherited from interface org.apache.spark.internal.Logging

    org.apache.spark.internal.Logging.LogStringContext, org.apache.spark.internal.Logging.SparkShellLoggingFilter

  • Constructor Summary

    Constructors

  • Method Summary

    Creates a copy of this instance with the same UID and some extra params.

    long

    Param for dimensionality of the factors (>= 0)

    Param for whether to fit an intercept term.

    fitLinear()

    Param for whether to fit linear term (aka 1-way term)

    initStd()

    Param for standard deviation of initial coefficients

    maxIter()

    Param for maximum number of iterations (>= 0).

    Param for mini-batch fraction, must be in range (0, 1]

    read()

    regParam()

    Param for regularization parameter (>= 0).

    seed()

    setFactorSize(int value)

    Set the dimensionality of the factors.

    setFitIntercept(boolean value)

    Set whether to fit intercept term.

    setFitLinear(boolean value)

    Set whether to fit linear term.

    setInitStd(double value)

    Set the standard deviation of initial coefficients.

    setMaxIter(int value)

    Set the maximum number of iterations.

    setMiniBatchFraction(double value)

    Set the mini-batch fraction parameter.

    setRegParam(double value)

    Set the L2 regularization parameter.

    setSeed(long value)

    Set the random seed for weight initialization.

    Set the solver algorithm used for optimization.

    setStepSize(double value)

    Set the initial step size for the first step (like learning rate).

    setTol(double value)

    Set the convergence tolerance of iterations.

    solver()

    The solver algorithm for optimization.

    stepSize()

    Param for Step size to be used for each iteration of optimization (> 0).

    tol()

    Param for the convergence tolerance for iterative algorithms (>= 0).

    uid()

    An immutable unique ID for the object and its derivatives.

    weightCol()

    Param for weight column name.

    Methods inherited from interface org.apache.spark.ml.param.shared.HasSeed

    getSeed

    Methods inherited from interface org.apache.spark.ml.param.shared.HasTol

    getTol

    Methods inherited from interface org.apache.spark.internal.Logging

    initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, isTraceEnabled, log, logBasedOnLevel, logDebug, logDebug, logDebug, logDebug, logError, logError, logError, logError, logInfo, logInfo, logInfo, logInfo, logName, LogStringContext, logTrace, logTrace, logTrace, logTrace, logWarning, logWarning, logWarning, logWarning, MDC, org$apache$spark$internal$Logging$$log_, org$apache$spark$internal$Logging$$log__$eq, withLogContext

    Methods inherited from interface org.apache.spark.ml.util.MLWritable

    save

  • Constructor Details

    • FMClassifier

      public FMClassifier(String uid)

    • FMClassifier

      public FMClassifier()

  • Method Details

    • load

    • read

    • factorSize

      public final IntParam factorSize()

      Param for dimensionality of the factors (>= 0)

      Specified by:
      factorSize in interface FactorizationMachinesParams
      Returns:
      (undocumented)
    • fitLinear

      Param for whether to fit linear term (aka 1-way term)

      Specified by:
      fitLinear in interface FactorizationMachinesParams
      Returns:
      (undocumented)
    • miniBatchFraction

      Param for mini-batch fraction, must be in range (0, 1]

      Specified by:
      miniBatchFraction in interface FactorizationMachinesParams
      Returns:
      (undocumented)
    • initStd

      Param for standard deviation of initial coefficients

      Specified by:
      initStd in interface FactorizationMachinesParams
      Returns:
      (undocumented)
    • solver

      The solver algorithm for optimization. Supported options: "gd", "adamW". Default: "adamW"

      Specified by:
      solver in interface FactorizationMachinesParams
      Specified by:
      solver in interface HasSolver
      Returns:
      (undocumented)
    • weightCol

      Param for weight column name. If this is not set or empty, we treat all instance weights as 1.0.

      Specified by:
      weightCol in interface HasWeightCol
      Returns:
      (undocumented)
    • regParam

      Description copied from interface: HasRegParam

      Param for regularization parameter (>= 0).

      Specified by:
      regParam in interface HasRegParam
      Returns:
      (undocumented)
    • fitIntercept

      Param for whether to fit an intercept term.

      Specified by:
      fitIntercept in interface HasFitIntercept
      Returns:
      (undocumented)
    • seed

      Description copied from interface: HasSeed

      Param for random seed.

      Specified by:
      seed in interface HasSeed
      Returns:
      (undocumented)
    • tol

      Description copied from interface: HasTol

      Param for the convergence tolerance for iterative algorithms (>= 0).

      Specified by:
      tol in interface HasTol
      Returns:
      (undocumented)
    • stepSize

      Description copied from interface: HasStepSize

      Param for Step size to be used for each iteration of optimization (> 0).

      Specified by:
      stepSize in interface HasStepSize
      Returns:
      (undocumented)
    • maxIter

      Description copied from interface: HasMaxIter

      Param for maximum number of iterations (>= 0).

      Specified by:
      maxIter in interface HasMaxIter
      Returns:
      (undocumented)
    • uid

      An immutable unique ID for the object and its derivatives.

      Specified by:
      uid in interface Identifiable
      Returns:
      (undocumented)
    • setFactorSize

      Set the dimensionality of the factors. Default is 8.

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setFitIntercept

      public FMClassifier setFitIntercept(boolean value)

      Set whether to fit intercept term. Default is true.

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setFitLinear

      Set whether to fit linear term. Default is true.

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setRegParam

      Set the L2 regularization parameter. Default is 0.0.

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setMiniBatchFraction

      public FMClassifier setMiniBatchFraction(double value)

      Set the mini-batch fraction parameter. Default is 1.0.

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setInitStd

      Set the standard deviation of initial coefficients. Default is 0.01.

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setMaxIter

      Set the maximum number of iterations. Default is 100.

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setStepSize

      Set the initial step size for the first step (like learning rate). Default is 1.0.

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setTol

      Set the convergence tolerance of iterations. Default is 1E-6.

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setSolver

      Set the solver algorithm used for optimization. Supported options: "gd", "adamW". Default: "adamW"

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setSeed

      Set the random seed for weight initialization.

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • copy

      Description copied from interface: Params

      Creates a copy of this instance with the same UID and some extra params. Subclasses should implement this method and set the return type properly. See defaultCopy().

      Specified by:
      copy in interface Params
      Specified by:
      copy in class Predictor<Vector,FMClassifier,FMClassificationModel>
      Parameters:
      extra - (undocumented)
      Returns:
      (undocumented)
    • estimateModelSize

      public long estimateModelSize(Dataset<?> dataset)