Skip to content

[SPARK-34591][ML] Add decision tree pruning as a parameter#55728

Open
WeichenXu123 wants to merge 18 commits intomasterfrom
SPARK-34591
Open

[SPARK-34591][ML] Add decision tree pruning as a parameter#55728
WeichenXu123 wants to merge 18 commits intomasterfrom
SPARK-34591

Conversation

@WeichenXu123
Copy link
Copy Markdown
Contributor

@WeichenXu123 WeichenXu123 commented May 7, 2026

What changes were proposed in this pull request?

This PR adds a parameter to enable/disable a featuer where LearningNodes are merged after a RF model is trained.

This PR takes over #32813

Why are the changes needed?

2 Reasons:

  1. In addition to basic classification, another use case for decision trees are the probabilities associated with predictions.
    Once pruned, these predictions are lost and it makes the trees/predictions challenging to work with if not unusable.

  2. It is not in line with the default behavior in sklearn. In sklearn, the trees are left unpruned by default.

Please see Jira ticket for more explanation.

Does this PR introduce any user-facing change?

Behavior change:
Default pruning behavior flips from always-on to always-off, making all existing decision tree/random forest/GBT callers produce larger, unpruned trees by default.

New params:
adds a parameter that is exposed to the Tree based classifiers. Will add tests here to ensure parameter is exposed correctly.

How was this patch tested?

I modified the two tests introduced with this change to verify postive/negative use of feature. I also added assertions for default behavior

Will add tests that ensure user exposed API is validated.

Locally ran ./build/mvn -pl mllib package and verified tests passed
Additionally, running through git workflow as described here:
https://spark.apache.org/developer-tools.html#github-workflow-tests

bribiescas-carlos and others added 15 commits June 8, 2021 12:10
    ### What changes were proposed in this pull request?

    This PR disables a feature created in SPARK-3159 where LearningNodes are
    merged after a RF model is trained.

    ### Why are the changes needed?

    2 Reasons:

    1. In addition to basic classification, another use case for decision trees are the
    probabilities associated with predictions.  Once pruned, these predictions are lost
    and it makes the trees/predictions challenging to work with if not unusable.

    2. It is not in line with the default behavior in sklearn.  In sklearn, the trees
    are left unpruned by default.

    ### Does this PR introduce _any_ user-facing change?

    No, it's dev-only.

    ### How was this patch tested?
    Locally ran `./build/mvn -pl mllib package` and verified tests passed
    Additionally, running through git workflow as described here:
    	https://spark.apache.org/developer-tools.html#github-workflow-tests
This PR disables a feature created in SPARK-3159 where LearningNodes are merged after a RF model is trained.

2 Reasons:

1. In addition to basic classification, another use case for decision trees are the probabilities associated with predictions.
 Once pruned, these predictions are lost and it makes the trees/predictions challenging to work with if not unusable.

2. It is not in line with the default behavior in sklearn.  In sklearn, the trees are left unpruned by default.

Please see Jira ticket for more explanation.

No, it's dev-only.

I modified the two tests introduced with this change to verify postive/negative use of feature.  I also added assertions for default behavior

Locally ran `./build/mvn -pl mllib package` and verified tests passed
Additionally, running through git workflow as described here:
    	https://spark.apache.org/developer-tools.html#github-workflow-tests
…are merged after a RF model is trained.

2 Reasons:

1. In addition to basic classification, another use case for decision trees are the probabilities associated with predictions.
 Once pruned, these predictions are lost and it makes the trees/predictions challenging to work with if not unusable.

2. It is not in line with the default behavior in sklearn.  In sklearn, the trees are left unpruned by default.

Please see Jira ticket for more explanation.

No, it's dev-only.

I modified the two tests introduced with this change to verify postive/negative use of feature.  I also added assertions for default behavior

Locally ran `./build/mvn -pl mllib package` and verified tests passed
Locally ran `./dev/scalafmt` which resulted in some minor cosmetic changes

Additionally, running through git workflow as described here:
    	https://spark.apache.org/developer-tools.html#github-workflow-tests
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Copilot AI review requested due to automatic review settings May 7, 2026 08:34
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a configurable switch to control post-training decision tree “pruning” (merging redundant leaf nodes) and wires it through Spark ML (Scala + PySpark) APIs down to the underlying training implementation.

Changes:

  • Introduce a new pruneTree parameter on Spark ML tree-based classifiers (Scala + PySpark) and propagate it into the old Strategy used by the training code.
  • Modify ml/tree/impl/RandomForest to use strategy.pruneTree when converting LearningNode to final Node trees (affecting pruning behavior).
  • Update/extend RandomForest implementation tests and reformat a large portion of the suite.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 15 comments.

Show a summary per file
File Description
python/pyspark/ml/tree.py Adds pruneTree param + getter to the shared Python tree classifier params.
python/pyspark/ml/classification.py Exposes pruneTree in Python DecisionTreeClassifier / RandomForestClassifier constructors, setters, and docstrings.
mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala Adds pruneTree to Scala ML TreeClassifierParams with defaults and docs.
mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala Adds pruneTree to the old mllib Strategy so training code can read it.
mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala Uses strategy.pruneTree to decide whether to prune when finalizing trees (and for early-stop size estimation).
mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala Adds setPruneTree and sets strategy.pruneTree during training; logs the param.
mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala Same as above for RF classifier.
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala Reformats tests and adds/updates pruning-related expectations (but currently contains compilation-breaking calls).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread python/pyspark/ml/classification.py Outdated
Comment thread python/pyspark/ml/classification.py Outdated
Comment thread python/pyspark/ml/tree.py Outdated
Comment thread python/pyspark/ml/tree.py Outdated
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)

/** @group setParam */
@Since("5.0.0")
Copy link
Copy Markdown
Contributor

@zhengruifeng zhengruifeng May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon do we have 4.3?

@zhengruifeng zhengruifeng changed the title [SPARK-34591] Add decision tree pruning as a parameter [SPARK-34591][ML] Add decision tree pruning as a parameter May 7, 2026
featureSubsetStrategy: String,
seed: Long,
instr: Option[Instrumentation],
prune: Boolean = true, // exposed for testing only, real trees are always pruned
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the default value be true to align with previous impl?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"default prune = false" is proposed in the jira: https://issues.apache.org/jira/browse/SPARK-34591

but to keep API compatibility, keeping it to true might be safer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants