[SPARK-34591][ML] Add decision tree pruning as a parameter#55728
Open
WeichenXu123 wants to merge 18 commits intomasterfrom
Open
[SPARK-34591][ML] Add decision tree pruning as a parameter#55728WeichenXu123 wants to merge 18 commits intomasterfrom
WeichenXu123 wants to merge 18 commits intomasterfrom
Conversation
### What changes were proposed in this pull request?
This PR disables a feature created in SPARK-3159 where LearningNodes are
merged after a RF model is trained.
### Why are the changes needed?
2 Reasons:
1. In addition to basic classification, another use case for decision trees are the
probabilities associated with predictions. Once pruned, these predictions are lost
and it makes the trees/predictions challenging to work with if not unusable.
2. It is not in line with the default behavior in sklearn. In sklearn, the trees
are left unpruned by default.
### Does this PR introduce _any_ user-facing change?
No, it's dev-only.
### How was this patch tested?
Locally ran `./build/mvn -pl mllib package` and verified tests passed
Additionally, running through git workflow as described here:
https://spark.apache.org/developer-tools.html#github-workflow-tests
This PR disables a feature created in SPARK-3159 where LearningNodes are merged after a RF model is trained. 2 Reasons: 1. In addition to basic classification, another use case for decision trees are the probabilities associated with predictions. Once pruned, these predictions are lost and it makes the trees/predictions challenging to work with if not unusable. 2. It is not in line with the default behavior in sklearn. In sklearn, the trees are left unpruned by default. Please see Jira ticket for more explanation. No, it's dev-only. I modified the two tests introduced with this change to verify postive/negative use of feature. I also added assertions for default behavior Locally ran `./build/mvn -pl mllib package` and verified tests passed Additionally, running through git workflow as described here: https://spark.apache.org/developer-tools.html#github-workflow-tests
…are merged after a RF model is trained.
2 Reasons:
1. In addition to basic classification, another use case for decision trees are the probabilities associated with predictions.
Once pruned, these predictions are lost and it makes the trees/predictions challenging to work with if not unusable.
2. It is not in line with the default behavior in sklearn. In sklearn, the trees are left unpruned by default.
Please see Jira ticket for more explanation.
No, it's dev-only.
I modified the two tests introduced with this change to verify postive/negative use of feature. I also added assertions for default behavior
Locally ran `./build/mvn -pl mllib package` and verified tests passed
Locally ran `./dev/scalafmt` which resulted in some minor cosmetic changes
Additionally, running through git workflow as described here:
https://spark.apache.org/developer-tools.html#github-workflow-tests
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
There was a problem hiding this comment.
Pull request overview
Adds a configurable switch to control post-training decision tree “pruning” (merging redundant leaf nodes) and wires it through Spark ML (Scala + PySpark) APIs down to the underlying training implementation.
Changes:
- Introduce a new
pruneTreeparameter on Spark ML tree-based classifiers (Scala + PySpark) and propagate it into the oldStrategyused by the training code. - Modify
ml/tree/impl/RandomForestto usestrategy.pruneTreewhen convertingLearningNodeto finalNodetrees (affecting pruning behavior). - Update/extend RandomForest implementation tests and reformat a large portion of the suite.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 15 comments.
Show a summary per file
| File | Description |
|---|---|
| python/pyspark/ml/tree.py | Adds pruneTree param + getter to the shared Python tree classifier params. |
| python/pyspark/ml/classification.py | Exposes pruneTree in Python DecisionTreeClassifier / RandomForestClassifier constructors, setters, and docstrings. |
| mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala | Adds pruneTree to Scala ML TreeClassifierParams with defaults and docs. |
| mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala | Adds pruneTree to the old mllib Strategy so training code can read it. |
| mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala | Uses strategy.pruneTree to decide whether to prune when finalizing trees (and for early-stop size estimation). |
| mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala | Adds setPruneTree and sets strategy.pruneTree during training; logs the param. |
| mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala | Same as above for RF classifier. |
| mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala | Reformats tests and adds/updates pruning-related expectations (but currently contains compilation-breaking calls). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
zhengruifeng
reviewed
May 7, 2026
| def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("5.0.0") |
zhengruifeng
reviewed
May 7, 2026
| featureSubsetStrategy: String, | ||
| seed: Long, | ||
| instr: Option[Instrumentation], | ||
| prune: Boolean = true, // exposed for testing only, real trees are always pruned |
Contributor
There was a problem hiding this comment.
should the default value be true to align with previous impl?
Contributor
Author
There was a problem hiding this comment.
"default prune = false" is proposed in the jira: https://issues.apache.org/jira/browse/SPARK-34591
but to keep API compatibility, keeping it to true might be safer.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What changes were proposed in this pull request?
This PR adds a parameter to enable/disable a featuer where LearningNodes are merged after a RF model is trained.
This PR takes over #32813
Why are the changes needed?
2 Reasons:
In addition to basic classification, another use case for decision trees are the probabilities associated with predictions.
Once pruned, these predictions are lost and it makes the trees/predictions challenging to work with if not unusable.
It is not in line with the default behavior in sklearn. In sklearn, the trees are left unpruned by default.
Please see Jira ticket for more explanation.
Does this PR introduce any user-facing change?
Behavior change:
Default pruning behavior flips from always-on to always-off, making all existing decision tree/random forest/GBT callers produce larger, unpruned trees by default.
New params:
adds a parameter that is exposed to the Tree based classifiers. Will add tests here to ensure parameter is exposed correctly.
How was this patch tested?
I modified the two tests introduced with this change to verify postive/negative use of feature. I also added assertions for default behavior
Will add tests that ensure user exposed API is validated.
Locally ran
./build/mvn -pl mllib packageand verified tests passedAdditionally, running through git workflow as described here:
https://spark.apache.org/developer-tools.html#github-workflow-tests