Author(s): Miguel Cardona Polo Originally published on Towards AI. “Trees playing Baseball” by author using DALL·E 3. Decision trees form the backbone of some of the most popular machine learning models in industry today, such as Random Forests, Gradient Boosted Trees, and XGBoost. Large Language Models (LLMs) are an exciting and very useful tool, but most real-world industry are not solved using LLMs. Instead, the majority of machine learning applications deal with structured, tabular data, such as large CSVs, Excel files, and databases. It is estimated that 70-80% of these tabular data tasks are solved using gradient boosting techniques like XGBoost, which rely on simple yet incredibly powerful decision trees. One of the biggest advantages of decision trees is their interpretability. Unlike modern black-box models, decision trees provide clear, step-by-step reasoning behind predictions. This transparency helps businesses understand their data better, make smarter decisions, and move beyond just predictions. In this article, you’ll gain a deep understanding of how decision trees work, including: The math behind decision trees (optional for those interested). Python code to build your own decision tree from scratch. Two hands-on examples (regression & classification) with step-by-step calculations, showing exactly how a decision tree learns. Don’t miss these detailed walkthroughs to solidify your understanding! Concept of Decision Trees A decision tree is like a flowchart used to make decisions. It starts at a single point (called the root node) and splits into branches based on questions about the data. At each step, the tree asks a question like “is the value greater than X?” or “does it belong to category Y”. Based on the answer, it moves down a branch to the next question (called the decision nodes). This process continues until the data reaches a final point (called a leaf) which gives the decision or prediction — this could be a “Yes/No”, a specific class, or a continuous number. Take a look at this decision tree used to predict which students will pass an exam, it's based on the number of hours the student studied, their hours of sleep the day before the exam, and their previous grade. Flow chart of decision tree used to predict which students will pass an exam. Image by author. Each leaf node represents a group of data points that have similar characteristics and therefore are given the same prediction (Pass or Fail). For example, students who have studied between 2 to 6 hours, and have slept more than 6, are a similar group of students (from what was seen in the training data), and therefore the decision tree predicts they’ll pass the exam. Note that decisions can be made both on numerical data, like the hours slept, and on categorical data, like the previous grade achieved by the student. This is why decision trees are so popular in tabular data, such as spreadsheets and databases, as these can contain both types. If you are wondering how this flowchart is translated into data, we can plot it into a graph. The hours of sleep and study are represented as axes, and the previous grade as a cross (failed previous exam) or a circle (passed previous exam). We can place in the graph some example students for which we want to predict their next exam grade, the position of the crosses and circles in the graph indicate the hours of sleep and study for that student. Partitioned graph of decision tree to predict which students will pass an exam. Image by author. You can check that the following graph represents the same decision tree as the flowchart does, where the blue dashed lines are the decision boundaries (thresholds) and each highlighted section represents a leaf node from the decision tree. There’s an area left unhighlighted as the prediction under those conditions is based on the student’s last exam, so only those students who passed their last exam are predicted to pass. Now let’s look into how decision trees choose the questions and the numbers (thresholds on features) that make it an accurate prediction model. How Decision Trees Learn As mentioned earlier, the goal is to split the data into smaller groups, so that similar data points are grouped together. Decision trees do this by asking questions and using thresholds (numbers or categories) on the training data. A split in a decision tree is a point where the data is divided based on a specific feature and threshold, creating branches. For example, in the case discussed earlier, one feature was the number of hours a student slept, with a threshold of ‘less than 2 hours.’ This split created a branch grouping students who slept less than two hours. These are predicted to fail their next exam. To choose the best split decision trees attempt all possible splits (features and thresholds) and pick the one with the lowest impurity, a value that indicates how mixed or diverse the data in a group is. Lower impurity means the group has similar data, which is the aim of the learning process. Impurity Measure It’s named impurity measure because it captures the diversity in a group. For example, if you have a basket with fruits, and it only contains apples, there is no diversity, the basket is pure — therefore the impurity is low. On the other hand, if the basket has a mix of apples, oranges, and bananas, it has a high diversity and therefore a high impurity. There are impurity measures specific for regression tasks, where we predict a continuous number, and for classification tasks, where the target is a class. Here is one example for each. Formula for the classification impurity measure Entropy, and for the regression impurity measure Variance. Image by author. If you are interested in the intuition behind these formulas and fancy some example calculations, the section below is for you, otherwise feel free to skip it. Impurity Measure — A Deeper Look First, let’s build intuition on the Entropy formula by understanding how different splits yield higher or lower […]
↧