Name: Towards AI Legal Name: Towards AI, Inc. Description: Towards AI is the world's leading artificial intelligence (AI) and technology publication. Read by thought-leaders and decision-makers around the world. Phone Number: +1-650-246-9381 Email: pub@towardsai.net
228 Park Avenue South New York, NY 10003 United States
Website: Publisher: https://towardsai.net/#publisher Diversity Policy: https://towardsai.net/about Ethics Policy: https://towardsai.net/about Masthead: https://towardsai.net/about
Name: Towards AI Legal Name: Towards AI, Inc. Description: Towards AI is the world's leading artificial intelligence (AI) and technology publication. Founders: Roberto Iriondo, , Job Title: Co-founder and Advisor Works for: Towards AI, Inc. Follow Roberto: X, LinkedIn, GitHub, Google Scholar, Towards AI Profile, Medium, ML@CMU, FreeCodeCamp, Crunchbase, Bloomberg, Roberto Iriondo, Generative AI Lab, Generative AI Lab VeloxTrend Ultrarix Capital Partners Denis Piffaretti, Job Title: Co-founder Works for: Towards AI, Inc. Louie Peters, Job Title: Co-founder Works for: Towards AI, Inc. Louis-François Bouchard, Job Title: Co-founder Works for: Towards AI, Inc. Cover:
Towards AI Cover
Logo:
Towards AI Logo
Areas Served: Worldwide Alternate Name: Towards AI, Inc. Alternate Name: Towards AI Co. Alternate Name: towards ai Alternate Name: towardsai Alternate Name: towards.ai Alternate Name: tai Alternate Name: toward ai Alternate Name: toward.ai Alternate Name: Towards AI, Inc. Alternate Name: towardsai.net Alternate Name: pub.towardsai.net
5 stars – based on 497 reviews

Frequently Used, Contextual References

TODO: Remember to copy unique IDs whenever it needs used. i.e., URL: 304b2e42315e

Resources

Free: 6-day Agentic AI Engineering Email Guide.
Learnings from Towards AI's hands-on work with real clients.
Linear Trees: What If Every Decision-Tree Leaf Had Its Own Linear Model?
Data Science   Latest   Machine Learning

Linear Trees: What If Every Decision-Tree Leaf Had Its Own Linear Model?

Last Updated on June 25, 2026 by Editorial Team

Author(s): Fern

Originally published on Towards AI.

Linear Trees: What If Every Decision-Tree Leaf Had Its Own Linear Model?

Linear Trees: What If Every Decision-Tree Leaf Had Its Own Linear Model?
Link: https://unsplash.com/photos/a-computer-screen-with-a-bunch-of-code-on-it-ieic5Tq8YMk

As data scientists, our machine learning toolkits often force us into a frustrating ultimatum, making us choose between two entirely different worlds of modeling.

On one side of the ring, we have standard linear regression. It is the reliable workhorse of the industry — fast, incredibly simple, and highly interpretable. Every feature receives a clean coefficient, allowing us to walk into a meeting and clearly explain exactly how a prediction changes when a specific feature changes. But linear regression carries one massive, glaring weakness: it arrogantly assumes that a single, rigid equation can accurately describe your entire dataset.

On the other side of the ring, we have decision trees. They are wonderfully flexible, capable of capturing nonlinear relationships, complex thresholds, and feature interactions without requiring us to manually engineer them. However, standard regression trees have their own fatal flaw: they make constant predictions inside each leaf. They treat continuous data like flat platforms, making their predictions look like a clunky staircase rather than a smooth, realistic curve.

So, let’s ask the obvious question: what happens when we combine the hierarchical structure of a decision tree with the smooth predictive behaviour of linear regression?

We get a Linear Tree.

A Linear Tree divides your data into different mathematical regions using decision-tree rules, and then fits a separate, distinct linear model inside each of those regions. It is an algorithm that acknowledges a simple truth: one straight line cannot describe the entire world but several local straight lines absolutely can.

Problem with Linear Regression

To understand why this is so powerful, let’s look at real estate. Suppose we want to build a model to predict apartment rental prices in a bustling tech corridor like Mahadevapura, using only property size.

A simple linear regression model might learn a global rule that looks like this: Rent = ₹15,000 + (₹30 × Square Footage)

This model assumes that every single additional square foot contributes the exact same amount of value to the rent, no matter what. But anyone who has ever hunted for an apartment knows real markets do not behave that way.

For a tiny studio apartment, an additional 100 square feet is life-changing and highly valuable. For a massive luxury Colive space, a few extra square feet barely register. The real relationship in the data is piecewise. It actually looks more like this:

  • If Size < 500 sq ft: Rent = ₹10,000 + (₹40 × Size)
  • If Size 500 to 1,500 sq ft: Rent = ₹20,000 + (₹25 × Size)
  • If Size > 1,500 sq ft: Rent = ₹45,000 + (₹15 × Size)

Instead of forcing one stubborn line through every single observation, we need a different line for each distinct market segment. A standard linear model completely chokes on this unless we spend hours manually creating data transformations, interaction variables, and arbitrary threshold indicators.

A Linear Tree learns these specific regions automatically.

Problem with Decision Trees

A standard regression tree attempts to solve this nonlinearity problem in a completely different way. It repeatedly slices the data using rigid, binary rules.

It might ask: Is the apartment size below 500 square feet?

  • Yes: Predict a flat ₹25,000.
  • No: Is the size below 1,500 square feet?
  • Yes: Predict a flat ₹45,000.
  • No: Predict a flat ₹70,000.

This allows the model to capture those market thresholds easily. However, look at what happens inside the leaf. Every single apartment that reaches the same leaf receives the exact same prediction. A 510-square-foot flat and a 1,490-square-foot flat will receive the identical ₹45,000 prediction because they landed in the same bucket.

If we plot its predictions on a graph, we get a chunky staircase. We could try to fix this by building a massively deep tree to create smaller steps, but that introduces a host of other headaches: massive complexity, zero interpretability, terrible behavior outside the training range, and a massive risk of overfitting.

Become a Medium member

Instead of creating dozens of tiny, constant, flat regions, a Linear Tree creates a few meaningful, broad regions — and then fits a mathematical trendline inside each one.

So what exactly does a linear tree do?

A Linear Tree is simply a decision tree where the terminal leaves contain linear models instead of flat constants.

A standard tree looks like this:

Feature A < 10?

  • Yes: Predict 25
  • No: Predict 68

A Linear Tree looks like this:

Feature A < 10?

  • Yes: y = 4 + (2 × Feature X) - (0.5 × Feature Z)
  • No: y = 30 + (0.8 × Feature X) + (2 × Feature Z)

The tree structure still acts as the traffic cop, deciding which specific region an observation belongs in. But once the observation reaches its final leaf, the final prediction is calculated dynamically using that leaf’s unique linear equation.

How does a linear tree learn?

Step 1: Fit a global linear model At the root of the tree, the algorithm looks at all available training data and fits a standard linear model. This is the baseline.

Step 2: Evaluate candidate splits The algorithm then looks for ways to slice the data. It considers possible splits like age < 35 or income < ₹70,000. For every single candidate split, it divides the observations into two distinct groups, and fits a brand-new, separate linear model to each group.

Step 3: Measure the mathematical improvement The algorithm calculates the error of the single parent model, and compares it to the combined error of the two new child models. If splitting the data into two linear regimes significantly reduces the overall error, the split is locked in.

Step 4: Repeat recursively This process continues deeper into the child branches. The algorithm keeps splitting and fitting lines until it hits a stopping condition: reaching a maximum depth, running out of data points in a leaf, or failing to find a split that improves the error.

Step 5: Predict When a new data point arrives, it trickles down the decision rules until it lands in a terminal leaf, where the local linear model calculates the final output.

Implementing Linear Trees in Python

The Python ecosystem has a package called linear-tree that provides scikit-learn-style estimators. Let’s look at how clean this is to implement. We can even pass regularized models (like Ridge or Lasso) into the leaves to prevent our coefficients from going crazy.

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.tree import DecisionTreeRegressor
from lineartree import LinearTreeRegressor

# ---------------------------------------------------------
# 1. Generate the Piecewise "Real Estate" Dataset
# ---------------------------------------------------------
np.random.seed(42)
X = np.linspace(0, 20, 300).reshape(-1, 1)

# Create a behavioral shift at X = 8 (The Mathematical Breaking Point)
# If X < 8, steep slope. If X >= 8, shallower slope.
y = np.where(
X < 8,
10 + 4.5 * X, # Regime 1: Steep linear trend
46 + 1.2 * (X - 8) # Regime 2: Flatter linear trend
)

# Add some real-world noise
y += np.random.normal(0, 2.5, size=y.shape)

# ---------------------------------------------------------
# 2. Initialize and Train the Models
# ---------------------------------------------------------
# Linear Regression (The mediocre middle line)
linear_model = LinearRegression()

# Regression Tree (Max depth 2 creates a 4-step staircase)
tree_model = DecisionTreeRegressor(max_depth=2, random_state=42)

# Linear Tree (Max depth 1 creates a single split with two lines)
# We pass Ridge to keep the coefficients stable inside the leaves
linear_tree_model = LinearTreeRegressor(base_estimator=Ridge(), max_depth=1)

# Fit all models
linear_model.fit(X, y)
tree_model.fit(X, y)
linear_tree_model.fit(X, y)

# ---------------------------------------------------------
# 3. Generate Predictions for the Plot
# ---------------------------------------------------------
X_plot = np.linspace(0, 20, 500).reshape(-1, 1)

y_lr = linear_model.predict(X_plot)
y_tree = tree_model.predict(X_plot)
y_lt = linear_tree_model.predict(X_plot)

# ---------------------------------------------------------
# 4. Create the Visualization
# ---------------------------------------------------------
plt.figure(figsize=(12, 7))

# Plot the raw training data
plt.scatter(X, y, color='gray', alpha=0.4, label='Actual Data', edgecolors='k')

# Plot the competing models
plt.plot(X_plot, y_lr, color='#3498db', linestyle='--', linewidth=3, label='Linear Regression (Global)')
plt.plot(X_plot, y_tree, color='#e74c3c', linestyle='-.', linewidth=3, label='Regression Tree (Staircase)')
plt.plot(X_plot, y_lt, color='#2ecc71', linewidth=4, label='Linear Tree (Piecewise)')

# Formatting for a clean, Medium-ready aesthetic
plt.title('Algorithm Showdown: Forcing One Line vs. Finding Two', fontsize=16, fontweight='bold', pad=15)
plt.xlabel('Feature (e.g., Square Footage)', fontsize=12, labelpad=10)
plt.ylabel('Target (e.g., Rent Price)', fontsize=12, labelpad=10)

# Add a vertical line to explicitly show the breaking point
plt.axvline(x=8, color='black', linestyle=':', alpha=0.5, label='Hidden Threshold (X=8)')

plt.legend(fontsize=11, loc='lower right', framealpha=0.9)
plt.grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()

# Show the plot
plt.show()
Output

Limitations of Linear Trees

1. Discontinuous Boundaries: Because each leaf is an isolated mathematical island, predictions can jump abruptly at the boundaries. If the tree splits at an income of ₹50,000, two users making ₹49,999 and ₹50,001 might receive drastically different predictions because they are evaluated by entirely different leaf equations.

2. Coefficient Instability in Small Leaves: If you let your tree grow too deep, you might end up with a leaf containing only 15 data points, trying to fit a linear model with 10 features. The math will completely break down. This is why using regularized models (like Ridge or Lasso) inside the leaves is critical — it shrinks the coefficients and prevents the local models from overfitting to noise.

Conclusion

Linear Trees are built on a beautifully pragmatic observation: complex, messy, global data is usually just a collection of simpler, local trends.

A single linear model is far too restrictive for the real world. A standard decision tree is far too crude. A Linear Tree strikes the perfect balance, giving you the threshold-finding power of tree splits, combined with the smooth, trend-capturing elegance of linear math.

Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.

Published via Towards AI


Towards AI Academy

We Build Enterprise-Grade AI. We'll Teach You to Master It Too.

15 engineers. 100,000+ students. Towards AI Academy teaches what actually survives production.

Start free — no commitment:

6-Day Agentic AI Engineering Email Guide — one practical lesson per day

Agents Architecture Cheatsheet — 3 years of architecture decisions in 6 pages

Our courses:

AI Engineering Certification — 90+ lessons from project selection to deployed product. The most comprehensive practical LLM course out there.

Agent Engineering Course — Hands on with production agent architectures, memory, routing, and eval frameworks — built from real enterprise engagements.

AI for Work — Understand, evaluate, and apply AI for complex work tasks.

Note: Article content contains the views of the contributing authors and not Towards AI.