import math
import random
import matplotlib.pyplot as plt
%matplotlib inline
PART 1: What Is Impurity?¶
Exercise 1.1 — Think About Mixing¶
Imagine a bag of coloured balls — some Red, some Black.
We want a number that captures how mixed the bag is:
- A bag of all Red → not mixed at all → impurity = 0
- A bag of all Black → not mixed at all → impurity = 0
- A bag of half Red, half Black → maximally mixed → impurity = 1
Wait, that's backwards from what you might expect. Re-read the examples below:
impurity(0.5, 0.5) → 1 # maximally mixed: impurity is 1
impurity(0, 1) → 0 # all one colour: impurity is 0
impurity(1, 0) → 0
Important: Here impurity = 0 means completely pure (easy to classify), and impurity = 1 means maximally mixed (hard to classify). The naming feels backwards — but stick with it.
Before coding, answer these questions in the markdown cell below:
- If a bag has 80% Red and 20% Black, would its impurity be closer to 0 or 1? Why?
- What is
impurity(0.9, 0.1)? Higher or lower thanimpurity(0.8, 0.2)? - What mathematical property must the function
impurity(p, q)have whenp = 1 - q? - Is
impurity(p, q)the same asimpurity(q, p)? Why should it be?
Your answers:
- ...
- ...
- ...
- ...
Exercise 1.2 — Invent the Formula¶
We're going to invent the formula for impurity(p, q) where:
p= fraction of Red balls (between 0 and 1)q= fraction of Black balls (between 0 and 1)p + q = 1always
Step-by-step discovery:
Step A: Consider the function $f(p) = p \times (1 - p)$. Fill in the table below by computing it:
| p | 1-p | p × (1-p) |
|---|---|---|
| 0.0 | 1.0 | ? |
| 0.2 | ? | |
| 0.4 | ? | |
| 0.5 | ? | |
| 0.6 | ? | |
| 0.8 | ? | |
| 1.0 | ? |
Tasks:
- Compute the table values in Python and print them.
- What is the maximum value of $p \times (1 - p)$? At what
pdoes it occur? - What are the minimum values? When do they occur?
- Plot $f(p) = p \times (1 - p)$ for $p \in [0, 1]$ using the helper below.
- Does this function behave like our desired impurity? Where are its max and min?
Plotting helper:
def plot_impurity_candidate(f, title="Candidate impurity function"):
ps = [i/100 for i in range(101)]
vals = [f(p) for p in ps]
plt.figure(figsize=(7, 4))
plt.plot(ps, vals, 'b-', linewidth=2)
plt.xlabel('p (fraction of Red)')
plt.ylabel('impurity')
plt.title(title)
plt.grid(True)
plt.show()
# PLOTTING HELPER — run this cell as-is
def plot_impurity_candidate(f, title="Candidate impurity function"):
ps = [i/100 for i in range(101)]
vals = [f(p) for p in ps]
plt.figure(figsize=(7, 4))
plt.plot(ps, vals, 'b-', linewidth=2)
plt.xlabel('p (fraction of Red)')
plt.ylabel('value')
plt.title(title)
plt.grid(True)
plt.show()
# YOUR CODE HERE
# Step 1: Print the table
print(f"{'p':<6} {'1-p':<6} {'p*(1-p)':<10}")
print("-" * 24)
for p in [0.0, 0.2, 0.4, 0.5, 0.6, 0.8, 1.0]:
pass # fill in
# Step 2: Plot it
# plot_impurity_candidate(lambda p: p * (1 - p), title="f(p) = p * (1-p)")
Exercise 1.3 — Scale It¶
From Exercise 1.2, you found that $f(p) = p \times (1-p)$ peaks at $p = 0.5$ with a value of $0.25$.
But we want impurity(0.5, 0.5) = 1 (max impurity = 1, not 0.25).
Wait — let's re-read the spec:
impurity(0.5, 0.5) → 1 ← pure (hard to classify = 1 certainty)
impurity(0, 1) → 0 ← certain
impurity(1, 0) → 0 ← certain
Recall: impurity = 1 means maximally mixed/uncertain, and impurity = 0 means completely pure/certain.
So we actually want a function that:
- Returns 1 when
p = 0.5(most uncertain) - Returns 0 when
p = 0orp = 1(most certain)
Your task:
Take $f(p) = p \times (1 - p)$ and apply a simple transformation to it so that:
- Its maximum (at p=0.5) becomes 1
- Its minimum (at p=0 or p=1) becomes 0
- Find the transformation mathematically (algebra only, no code yet).
- Write the resulting formula.
- Verify by plugging in: $p = 0.5$, $p = 0$, $p = 1$.
- Write it in Python and plot it.
Your derivation:
- Maximum of f(p) = p*(1-p) is ... at p = ...
- Transformation: ...
- Final formula:
impurity(p) = ... - Check p=0.5: ...
- Check p=0: ...
- Check p=1: ...
# YOUR CODE HERE
def impurity(p, q):
"""
Returns the impurity of a set with fraction p of Red and q of Black.
impurity(0.5, 0.5) should return 1 (most mixed = least certain)
impurity(1, 0) should return 0 (all same = most certain)
impurity(0, 1) should return 0
"""
pass # replace this
# Verify the spec:
print(impurity(0.5, 0.5)) # should be 1
print(impurity(0, 1)) # should be 0
print(impurity(1, 0)) # should be 0
print(impurity(0.8, 0.2)) # should be between 0 and 1
# Verify ordering:
print(impurity(0.8, 0.2) < impurity(0.7, 0.3)) # should be True
print(impurity(0.7, 0.3) < impurity(0.6, 0.4)) # should be True
# Plot your impurity function
plot_impurity_candidate(lambda p: impurity(p, 1 - p), title="My impurity function")
Exercise 1.4 — Explore the Function¶
Now that you have impurity(p, q), explore it.
Tasks:
- Compute and print the impurity for the following bags. Do the results match your intuition?
| Bag | Contents | p | q |
|---|---|---|---|
| A | 9 Red, 1 Black | 0.9 | 0.1 |
| B | 7 Red, 3 Black | 0.7 | 0.3 |
| C | 6 Red, 4 Black | 0.6 | 0.4 |
| D | 5 Red, 5 Black | 0.5 | 0.5 |
| E | 3 Red, 7 Black | 0.3 | 0.7 |
| F | 10 Red, 0 Black | 1.0 | 0.0 |
- Is bag A more or less certain than bag B? Does
impurityagree? - Is
impurity(0.3, 0.7)equal toimpurity(0.7, 0.3)? Why should it be? - What is the impurity of a bag with 50 Red, 50 Black, and 50 Green (3 classes)?
- Hint: Can you extend your formula? Think about what $p \times (1-p)$ means for each pair of classes.
- (This is a bonus — don't worry if you can't figure it out now.)
# YOUR CODE HERE
bags = [
("A", 0.9, 0.1),
("B", 0.7, 0.3),
("C", 0.6, 0.4),
("D", 0.5, 0.5),
("E", 0.3, 0.7),
("F", 1.0, 0.0),
]
print(f"{'Bag':<5} {'p':<6} {'q':<6} {'impurity':>10}")
print("-" * 30)
for name, p, q in bags:
imp = impurity(p, q)
print(f"{name:<5} {p:<6.1f} {q:<6.1f} {imp:>10.4f}")
PART 2: Impurity of a List¶
Exercise 2.1 — From Fractions to Lists¶
So far, impurity(p, q) takes fractions directly. But in practice, you have a list like:
["R", "B", "B", "R", "R"]
You need to compute p and q from the list yourself.
Tasks:
For the list
["R", "B", "B", "R", "R"], manually compute:- Count of "R"
- Count of "B"
- Total count
p(fraction of R)q(fraction of B)
Write a function
list_impurity(items)that:- Takes a list of "R" and "B" strings
- Computes
pandq - Returns
impurity(p, q) - Edge case: What should it return for an empty list
[]? Decide and document your choice.
Test your function:
list_impurity(["R", "B", "B", "R", "R"]) # what do you expect? list_impurity(["R", "R", "R", "R"]) # all same → what value? list_impurity(["B", "B", "B", "B"]) # all same → what value? list_impurity(["R", "B"]) # exactly half → what value? list_impurity([]) # empty → your decision
Try these additional lists and rank them by impurity (most mixed to least mixed):
list_A = ["R", "R", "R", "R", "R", "R", "R", "R", "B", "B"] list_B = ["R", "R", "R", "R", "R", "B", "B", "B", "B", "B"] list_C = ["R", "R", "R", "R", "R", "R", "R", "B", "B", "B"]
# YOUR CODE HERE
def list_impurity(items):
"""
Computes the impurity of a list of "R" and "B" elements.
Returns a value between 0 (pure/certain) and 1 (maximally mixed).
"""
pass # replace this
# Test cases
test_cases = [
["R", "B", "B", "R", "R"],
["R", "R", "R", "R"],
["B", "B", "B", "B"],
["R", "B"],
[],
]
for lst in test_cases:
print(f"{str(lst):<40} impurity = {list_impurity(lst):.4f}")
# Ranking exercise
list_A = ["R", "R", "R", "R", "R", "R", "R", "R", "B", "B"]
list_B = ["R", "R", "R", "R", "R", "B", "B", "B", "B", "B"]
list_C = ["R", "R", "R", "R", "R", "R", "R", "B", "B", "B"]
for name, lst in [("A", list_A), ("B", list_B), ("C", list_C)]:
print(f"List {name}: impurity = {list_impurity(lst):.4f}")
# Which is most mixed? Least mixed?
Exercise 2.2 — Weighted Impurity of Two Groups¶
When we split a list into two parts, we get two separate impurity values — one for each part.
But how do we combine them into one number?
Think about it: If one group has 100 items and another has 2 items, should they count equally?
Your task:
Invent a formula for the total impurity of two groups,
leftandright, that accounts for their sizes.- What should happen if one group is empty?
- What should happen if both groups are identical?
- Hint: think about weighted average.
Write a function
total_impurity(left, right)that takes two lists and returns a single number.Test it:
total_impurity(["R", "R", "R"], ["B", "B", "B"]) # perfectly separated total_impurity(["R", "B", "R"], ["B", "R", "B"]) # each group is mixed total_impurity(["R", "R", "R", "R", "R", "R", "R", "R", "R", "B"], ["B"]) # unequal sizes
Reflection: For the last test case, the right group has 1 item (
["B"]) so it's pure (impurity = 0). But it only has 1 item out of 11 total. Does your formula give it less weight? It should.
# YOUR CODE HERE
def total_impurity(left, right):
"""
Returns the weighted total impurity of two groups.
Each group's impurity is weighted by its size relative to the combined total.
"""
pass # replace this
# Test cases
print(total_impurity(["R", "R", "R"], ["B", "B", "B"])) # perfectly separated → expect high certainty
print(total_impurity(["R", "B", "R"], ["B", "R", "B"])) # both mixed
print(total_impurity(["R"]*9 + ["B"], ["B"])) # unequal sizes
PART 3: Finding the Best Split¶
Exercise 3.1 — Splitting a List at a Position¶
You are given an ordered list:
items = ["R", "B", "B", "R", "R"]
You can split it at position k (0-indexed), meaning:
- Left part:
items[:k]— the firstkitems - Right part:
items[k:]— the remaining items
For example, splitting at position k=2:
- Left:
["R", "B"] - Right:
["B", "R", "R"]
Tasks:
For
items = ["R", "B", "B", "R", "R"], print the left and right parts for every possible split positionk = 1, 2, 3, 4.- (Why not
k=0ork=5? What would those give?)
- (Why not
For each split position, compute
total_impurity(left, right). Print all values.Which split position gives the minimum total impurity?
- What are the left and right groups at that position?
- Does that split make intuitive sense?
# YOUR CODE HERE
items = ["R", "B", "B", "R", "R"]
print(f"{'k':<5} {'left':<25} {'right':<25} {'total impurity':>15}")
print("-" * 70)
for k in range(1, len(items)):
left = items[:k]
right = items[k:]
imp = total_impurity(left, right)
print(f"{k:<5} {str(left):<25} {str(right):<25} {imp:>15.4f}")
Exercise 3.2 — The Best Split Function¶
Now automate it: write a function find_best_split_position(items) that:
- Tries every possible split position
k = 1tolen(items) - 1 - Computes
total_impurityat each position - Returns the position
kthat gives the minimum total impurity
Tasks:
Implement
find_best_split_position(items). It should return the bestkand the minimum total impurity.Test it on these lists:
["R", "B", "B", "R", "R"] ["R", "R", "R", "B", "B", "B"] ["B", "R", "B", "R", "B", "R"] ["R", "R", "R", "R", "B"]
For
["R", "R", "R", "B", "B", "B"], the answer should be obvious from visual inspection. Does your function agree?For
["B", "R", "B", "R", "B", "R"](perfectly alternating), what split position does it find? Are there multiple "tied" positions? How does your function handle ties?
# YOUR CODE HERE
def find_best_split_position(items):
"""
Finds the split position k that minimises total_impurity(items[:k], items[k:]).
Returns (best_k, min_impurity).
"""
pass # replace this
# Test
test_lists = [
["R", "B", "B", "R", "R"],
["R", "R", "R", "B", "B", "B"],
["B", "R", "B", "R", "B", "R"],
["R", "R", "R", "R", "B"],
]
for lst in test_lists:
best_k, min_imp = find_best_split_position(lst)
print(f"List: {lst}")
print(f" Best k={best_k}: left={lst[:best_k]}, right={lst[best_k:]}")
print(f" Min impurity: {min_imp:.4f}")
print()
Exercise 3.3 — Visualise the Split Scores¶
Use the plotting helper below to visualise how total impurity changes at each split position.
def plot_split_scores(items, title=None):
"""
Plots total_impurity vs split position k for a given list.
Marks the best (minimum impurity) position.
"""
ks = list(range(1, len(items)))
scores = [total_impurity(items[:k], items[k:]) for k in ks]
best_k = ks[scores.index(min(scores))]
plt.figure(figsize=(8, 4))
plt.plot(ks, scores, 'b-o')
plt.axvline(x=best_k, color='red', linestyle='--', label=f'best k={best_k}')
plt.xlabel('Split position k')
plt.ylabel('Total impurity')
plt.title(title or f'Split scores for {items}')
plt.legend()
plt.grid(True)
plt.show()
Tasks:
- Run
plot_split_scoreson all four test lists from Exercise 3.2. - For
["R", "R", "R", "B", "B", "B"]: is the minimum sharp or flat? Why? - For
["B", "R", "B", "R", "B", "R"]: what does the plot look like? Is there a clear winner?
# PLOTTING HELPER — run this cell as-is
def plot_split_scores(items, title=None):
ks = list(range(1, len(items)))
scores = [total_impurity(items[:k], items[k:]) for k in ks]
best_k = ks[scores.index(min(scores))]
plt.figure(figsize=(8, 4))
plt.plot(ks, scores, 'b-o')
plt.axvline(x=best_k, color='red', linestyle='--', label=f'best k={best_k}')
plt.xlabel('Split position k')
plt.ylabel('Total impurity')
plt.title(title or f'Split scores for {items}')
plt.legend()
plt.grid(True)
plt.show()
# YOUR CODE HERE — plot all four lists
for lst in test_lists:
plot_split_scores(lst)
# DATASET — run this cell as-is
random.seed(42)
heights = [158, 162, 165, 167, 168, 170, 171, 172, 174, 175,
176, 178, 179, 180, 181, 182, 183, 185, 187, 190]
genders = ["F", "F", "F", "F", "M", "F", "M", "M", "F", "M",
"M", "M", "M", "M", "F", "M", "M", "M", "M", "M"]
print("Height Gender")
print("-" * 16)
for h, g in zip(heights, genders):
print(f"{h:<8} {g}")
Before coding, answer these questions:
- Just by looking at the data, at roughly what height does gender start being predominantly Male?
- If you had to draw a single horizontal line through the data to separate M from F as cleanly as possible, where would you draw it?
- Is a perfect separation possible? Why or why not?
Your answers:
- ...
- ...
- ...
Exercise 4.2 — Visualise the Data¶
Run the plotting helper below to see the data.
# PLOTTING HELPER — run this cell as-is
def plot_height_gender(heights, genders, split_height=None, title="Height vs Gender"):
"""
Scatter plot of height vs gender.
Optionally draws a vertical split line at split_height.
"""
colors = {"M": "blue", "F": "red"}
y_jitter = {"M": 1, "F": 0}
plt.figure(figsize=(10, 3))
for h, g in zip(heights, genders):
plt.scatter(h, y_jitter[g], color=colors[g], s=80, zorder=5)
plt.text(h, y_jitter[g] + 0.05, g, ha='center', fontsize=8)
if split_height is not None:
plt.axvline(x=split_height, color='green', linestyle='--', linewidth=2,
label=f'split at {split_height}')
plt.legend()
plt.yticks([0, 1], ["F", "M"])
plt.xlabel("Height (cm)")
plt.title(title)
plt.grid(axis='x')
plt.tight_layout()
plt.show()
# Plot the data
plot_height_gender(heights, genders)
Exercise 4.3 — Split by Height Threshold¶
Now the key idea: instead of splitting at a position in a list, we split at a threshold value in a feature.
For a given threshold t:
- Left group: all rows where
height <= t - Right group: all rows where
height > t
We compute the genders in each group, then calculate total_impurity.
Tasks:
Write a function
split_by_threshold(heights, genders, t)that:- Splits the data at threshold
t - Returns
(left_genders, right_genders)where each is a list of gender labels
- Splits the data at threshold
Test it:
split_by_threshold(heights, genders, 172)should return:- left: genders of everyone with height ≤ 172
- right: genders of everyone with height > 172
For the thresholds below, compute and print
total_impurity(left, right). Notice the pattern.t = 160, 165, 170, 172, 175, 178, 182, 185, 188Which threshold gives the lowest total impurity? Verify visually using
plot_height_genderwithsplit_height=t.
# YOUR CODE HERE
def split_by_threshold(heights, genders, t):
"""
Splits the dataset at threshold t on height.
Returns (left_genders, right_genders) where:
left_genders = genders where height <= t
right_genders = genders where height > t
"""
pass # replace this
# Test manually
left, right = split_by_threshold(heights, genders, 172)
print("Left (height <= 172):", left)
print("Right (height > 172):", right)
print("Total impurity:", total_impurity(left, right))
# Test multiple thresholds
thresholds = [160, 165, 170, 172, 175, 178, 182, 185, 188]
print(f"{'Threshold':<12} {'Left size':<12} {'Right size':<12} {'Total impurity':>15}")
print("-" * 55)
for t in thresholds:
left, right = split_by_threshold(heights, genders, t)
imp = total_impurity(left, right)
print(f"{t:<12} {len(left):<12} {len(right):<12} {imp:>15.4f}")
Exercise 4.4 — Find the Best Threshold Automatically¶
Instead of manually trying thresholds, automate it.
Key insight: You only need to try thresholds at the unique values in your data (or midpoints between consecutive values). Why? Because splitting at 173.5 vs 174 makes no difference if there's no data point between them.
Tasks:
Write
find_best_threshold(heights, genders)that:- Tries every unique height value as a threshold
- Computes
total_impurityfor each - Returns the threshold with minimum total impurity, and that minimum impurity value
Run it and print the result. Does it match your manual inspection?
Visualise the best threshold using
plot_height_gender(heights, genders, split_height=best_t).Reflection: What are the genders in the left and right groups at the best threshold? Is the split clean?
# YOUR CODE HERE
def find_best_threshold(heights, genders):
"""
Finds the height threshold that minimises total_impurity.
Returns (best_threshold, min_impurity).
"""
pass # replace this
best_t, best_imp = find_best_threshold(heights, genders)
print(f"Best threshold: height <= {best_t}")
print(f"Minimum total impurity: {best_imp:.4f}")
left, right = split_by_threshold(heights, genders, best_t)
print(f"Left group: {left}")
print(f"Right group: {right}")
plot_height_gender(heights, genders, split_height=best_t)
Exercise 4.5 — Plot Impurity vs Threshold¶
Use the plotting helper below to visualise how impurity changes as the threshold varies.
# PLOTTING HELPER — run this cell as-is
def plot_threshold_scores(heights, genders, title="Impurity vs Threshold"):
"""
Plots total impurity for each unique height threshold.
"""
unique_heights = sorted(set(heights))
scores = []
for t in unique_heights:
left = [g for h, g in zip(heights, genders) if h <= t]
right = [g for h, g in zip(heights, genders) if h > t]
if left and right:
scores.append((t, total_impurity(left, right)))
ts, imps = zip(*scores)
best_t = ts[imps.index(min(imps))]
plt.figure(figsize=(9, 4))
plt.plot(ts, imps, 'b-o')
plt.axvline(x=best_t, color='red', linestyle='--', label=f'best t={best_t}')
plt.xlabel('Threshold (height cm)')
plt.ylabel('Total impurity')
plt.title(title)
plt.legend()
plt.grid(True)
plt.show()
# YOUR CODE HERE — run the plot
plot_threshold_scores(heights, genders)
PART 5: A Decision Node — The Job Offer¶
Exercise 5.1 — What Is a Decision Tree Node?¶
Before we build a full tree from data, let's understand the structure of a decision tree by constructing one by hand.
A decision tree is made of nodes. There are two kinds:
- Boundary node: Has a rule — "if feature X ≤ value, go left; else go right." Has two children.
- Leaf node: Has a final answer — just
True(YES) orFalse(NO). No children.
Here is a tree for deciding whether to accept a job offer:
Is salary > 1000?
├── NO → Reject the offer
└── YES → Is distance <= 40?
├── YES → Accept the offer
└── NO → Reject the offer
Before coding, answer these questions:
- If a job pays 800 and is 20km away — what does the tree say?
- If a job pays 1200 and is 30km away — what does the tree say?
- If a job pays 1500 and is 60km away — what does the tree say?
- If a job pays 1000 exactly — which branch do you take? (Careful: the rule is
salary > 1000.)
Write your answers below before coding.
Your answers:
- salary=800, distance=20 → ...
- salary=1200, distance=30 → ...
- salary=1500, distance=60 → ...
- salary=1000, distance=any → ...
Exercise 5.2 — Build the DecisionTreeNode Class¶
Design a class DecisionTreeNode that can be either a boundary node or a leaf node.
Requirements:
- A leaf node stores a boolean
decision(True= YES,False= NO). - A boundary node stores:
boundary— the feature name (a string, e.g."salary")boundary_value— the threshold (a number)left— the child node for whenfeature <= boundary_valueright— the child node for whenfeature > boundary_value
- A method
check(features)that:- Takes a dict like
{"salary": 1200, "distance": 30} - Traverses the tree
- Returns the leaf's boolean
decision
- Takes a dict like
Also create two helper functions:
YES()— returns a leaf node with decisionTrueNO()— returns a leaf node with decisionFalse
Hint for the class structure:
class DecisionTreeNode:
def __init__(self, ...):
# your attributes here
pass
def check(self, features):
# if leaf: return decision
# if boundary: go left or right depending on features[self.boundary]
pass
Tasks:
- Implement
DecisionTreeNode,YES(), andNO(). - Build the job offer tree manually:
Is salary <= 1000? → NO() Else: Is distance <= 40? → YES() else NO() - Test it against your answers from Exercise 5.1.
# YOUR CODE HERE
class DecisionTreeNode:
def __init__(self, decision=None, boundary=None, boundary_value=None, left=None, right=None):
"""
A node in a decision tree.
If decision is not None: this is a leaf node.
Otherwise: this is a boundary node with boundary, boundary_value, left, right.
"""
pass # replace this
def check(self, features):
"""
Traverses the tree and returns the leaf's boolean decision.
features: dict of {feature_name: value}
Rule: go LEFT if features[boundary] <= boundary_value, else go RIGHT.
"""
pass # replace this
def YES():
"""Returns a leaf node that decides True (accept)."""
pass
def NO():
"""Returns a leaf node that decides False (reject)."""
pass
# Build the job offer tree
# Tree:
# salary <= 1000 → NO
# salary > 1000:
# distance <= 40 → YES
# distance > 40 → NO
job_tree = DecisionTreeNode(
boundary="salary",
boundary_value=1000,
left=NO(),
right=DecisionTreeNode(
boundary="distance",
boundary_value=40,
left=YES(),
right=NO()
)
)
# Test the four cases from Exercise 5.1
test_jobs = [
{"salary": 800, "distance": 20, "expected": False},
{"salary": 1200, "distance": 30, "expected": True},
{"salary": 1500, "distance": 60, "expected": False},
{"salary": 1000, "distance": 15, "expected": False},
]
print(f"{'Salary':<10} {'Distance':<12} {'Result':<10} {'Expected':<10} {'Match'}")
print("-" * 55)
for job in test_jobs:
result = job_tree.check(job)
match = "✓" if result == job["expected"] else "✗"
print(f"{job['salary']:<10} {job['distance']:<12} {str(result):<10} {str(job['expected']):<10} {match}")
Exercise 5.3 — A Bigger Tree¶
Now build a more complex tree by hand. This tree has 3 levels:
Is salary <= 1000?
├── YES → NO (reject)
└── NO (salary > 1000):
Is distance <= 40?
├── YES (close enough):
│ Is coffee == 1? ← 1 means "office has free coffee"
│ ├── YES → YES (accept)
│ └── NO → NO (reject)
└── NO (too far) → NO (reject)
Tasks:
- Build this tree using
DecisionTreeNode,YES(), andNO(). - Test with these job offers:
| salary | distance | coffee | Expected |
|---|---|---|---|
| 1200 | 30 | 1 | YES |
| 1200 | 30 | 0 | NO |
| 1200 | 60 | 1 | NO |
| 900 | 10 | 1 | NO |
- Extend the tree by adding one more rule of your own. Document it in a markdown cell.
# YOUR CODE HERE — build the 3-level tree
job_tree_v2 = None # replace with your tree
# Test cases
test_v2 = [
{"salary": 1200, "distance": 30, "coffee": 1, "expected": True},
{"salary": 1200, "distance": 30, "coffee": 0, "expected": False},
{"salary": 1200, "distance": 60, "coffee": 1, "expected": False},
{"salary": 900, "distance": 10, "coffee": 1, "expected": False},
]
for job in test_v2:
result = job_tree_v2.check(job)
match = "✓" if result == job["expected"] else "✗"
print(f"salary={job['salary']}, distance={job['distance']}, coffee={job['coffee']} → {result} {match}")
# EXTENDED DATASET — run this cell as-is
data = [
{"height": 158, "weight": 52, "gender": "F"},
{"height": 162, "weight": 55, "gender": "F"},
{"height": 165, "weight": 58, "gender": "F"},
{"height": 167, "weight": 61, "gender": "F"},
{"height": 168, "weight": 70, "gender": "M"},
{"height": 170, "weight": 60, "gender": "F"},
{"height": 171, "weight": 73, "gender": "M"},
{"height": 172, "weight": 75, "gender": "M"},
{"height": 174, "weight": 63, "gender": "F"},
{"height": 175, "weight": 78, "gender": "M"},
{"height": 176, "weight": 80, "gender": "M"},
{"height": 178, "weight": 82, "gender": "M"},
{"height": 179, "weight": 84, "gender": "M"},
{"height": 180, "weight": 85, "gender": "M"},
{"height": 181, "weight": 66, "gender": "F"},
{"height": 182, "weight": 88, "gender": "M"},
{"height": 183, "weight": 90, "gender": "M"},
{"height": 185, "weight": 92, "gender": "M"},
{"height": 187, "weight": 95, "gender": "M"},
{"height": 190, "weight": 98, "gender": "M"},
]
print(f"{'Height':<8} {'Weight':<8} {'Gender'}")
print("-" * 24)
for row in data:
print(f"{row['height']:<8} {row['weight']:<8} {row['gender']}")
Exercise 6.2 — Find the Best Feature and Threshold¶
The question: Should we split on height or weight, and at what value?
Your task:
Write a function find_decision_boundary(data, features, label) that:
- Takes a list of dicts (
data), a list of feature names to consider (features), and the label column name (label) - For each feature in
features:- Tries every unique value of that feature as a threshold
- Computes
total_impurityfor that split
- Returns the best feature name, the best threshold, and the minimum impurity
Step-by-step hints:
- Extract all unique values for a feature:
sorted(set(row[feature] for row in data)) - For each threshold
t, the left group is[row[label] for row in data if row[feature] <= t] - Track the best (feature, threshold, impurity) triple as you loop
Tasks:
- Implement
find_decision_boundary. - Run it on the dataset with
features=["height", "weight"]andlabel="gender". - Print the result. Which feature won? At what threshold?
- Does this match what you expected from looking at the data?
# YOUR CODE HERE
def find_decision_boundary(data, features, label):
"""
Finds the best (feature, threshold) split across all given features.
Returns (best_feature, best_threshold, min_impurity).
"""
pass # replace this
best_feat, best_thresh, best_imp = find_decision_boundary(data, ["height", "weight"], "gender")
print(f"Best split: {best_feat} <= {best_thresh}")
print(f"Minimum total impurity: {best_imp:.4f}")
# What are the two groups?
left_group = [row["gender"] for row in data if row[best_feat] <= best_thresh]
right_group = [row["gender"] for row in data if row[best_feat] > best_thresh]
print(f"Left ({best_feat} <= {best_thresh}): {left_group}")
print(f"Right ({best_feat} > {best_thresh}): {right_group}")
Exercise 6.3 — Visualise Both Features¶
Use the helper below to visualise both the height split and the weight split.
# PLOTTING HELPER — run this cell as-is
def plot_feature_split(data, feature, label, split_value=None, title=None):
"""
Plots a feature vs label. Optionally marks a split threshold.
"""
color_map = {"M": "blue", "F": "red"}
y_map = {"M": 1, "F": 0}
plt.figure(figsize=(10, 3))
for row in data:
x = row[feature]
y = y_map[row[label]]
c = color_map[row[label]]
plt.scatter(x, y, color=c, s=80, zorder=5)
plt.text(x, y + 0.06, row[label], ha='center', fontsize=8)
if split_value is not None:
plt.axvline(x=split_value, color='green', linestyle='--', linewidth=2,
label=f'split at {split_value}')
plt.legend()
plt.yticks([0, 1], ["F", "M"])
plt.xlabel(feature)
plt.title(title or f"{feature} vs {label}")
plt.grid(axis='x')
plt.tight_layout()
plt.show()
# YOUR CODE HERE — plot height split and weight split
# Use find_decision_boundary first on each feature separately
best_h, best_ht, imp_h = find_decision_boundary(data, ["height"], "gender")
best_w, best_wt, imp_w = find_decision_boundary(data, ["weight"], "gender")
print(f"Best height split: height <= {best_ht}, impurity = {imp_h:.4f}")
print(f"Best weight split: weight <= {best_wt}, impurity = {imp_w:.4f}")
plot_feature_split(data, "height", "gender", split_value=best_ht, title=f"Height split at {best_ht} (impurity={imp_h:.3f})")
plot_feature_split(data, "weight", "gender", split_value=best_wt, title=f"Weight split at {best_wt} (impurity={imp_w:.3f})")
PART 7: Growing the Tree — Recursive Splitting¶
Exercise 7.1 — What Happens After the First Split?¶
After the first split, you have two groups. Each group may still be impure (mixed).
The idea: repeat the splitting on each group independently, until each group is pure (impurity = 1) or has only 1 element.
Before coding, trace through manually:
Using the height-only dataset from Part 4:
heights = [158, 162, 165, 167, 168, 170, 171, 172, 174, 175,
176, 178, 179, 180, 181, 182, 183, 185, 187, 190]
genders = ["F", "F", "F", "F", "M", "F", "M", "M", "F", "M",
"M", "M", "M", "M", "F", "M", "M", "M", "M", "M"]
- What is the best first split threshold? (You found this in Part 4.)
- After splitting, what are the left and right groups?
- Is the left group pure? Is the right group pure?
- For the impure group(s), what is the best next split?
- Draw this out as a tree structure in the markdown cell below.
Your manual tree sketch:
height <= ???
├── Left: ??? (pure? yes/no)
│ └── If not pure, split again at height <= ???
│ ├── ...
│ └── ...
└── Right: ??? (pure? yes/no)
└── If not pure, split again at height <= ???
├── ...
└── ...
Exercise 7.2 — Stopping Conditions¶
Before writing the recursive function, think about when to stop.
When should we NOT split further?
- The group has only 1 element — nothing to split.
- The group is completely pure — impurity = 1 (all same label). No need to split.
- (Optional / harder): All elements have the same feature values — we can't find a meaningful threshold.
When we stop, what do we return?
A leaf node with the majority label in that group.
Tasks:
Write a helper
majority_label(labels)that takes a list of labels (like["M", "M", "F", "M"]) and returns the most common one.- What if there's a tie? Decide and document your choice.
Write a helper
is_pure(labels)that returnsTrueif all labels are the same.Test both helpers:
majority_label(["M", "M", "F", "M"]) # → "M" majority_label(["F", "M"]) # → tie: your choice is_pure(["M", "M", "M"]) # → True is_pure(["M", "M", "F"]) # → False
# YOUR CODE HERE
def majority_label(labels):
"""
Returns the most common label in the list.
Document your tie-breaking rule here.
"""
pass # replace this
def is_pure(labels):
"""
Returns True if all labels are the same.
"""
pass # replace this
# Tests
print(majority_label(["M", "M", "F", "M"])) # M
print(majority_label(["F", "M"])) # tie — what do you return?
print(is_pure(["M", "M", "M"])) # True
print(is_pure(["M", "M", "F"])) # False
print(is_pure([])) # edge case — what do you return?
Exercise 7.3 — Build the Tree Recursively¶
Now write build_tree(data, features, label) that:
- Base case: If
datais empty, or has 1 element, or is pure — return aDecisionTreeNodeleaf with the majority label. - Recursive case:
- Find the best feature and threshold using
find_decision_boundary. - Split
dataintoleft_dataandright_data. - Recursively call
build_treeon each. - Return a
DecisionTreeNodeboundary node with the two subtrees.
- Find the best feature and threshold using
Important: The leaf node now stores a label string (like "M" or "F"), not just True/False.
You may need to adapt your DecisionTreeNode slightly, or use a convention like:
decision = "M"means this leaf predicts Male.
Tasks:
- Implement
build_tree(data, features, label). - Run it on the height+weight dataset from Part 6.
- Write a helper
print_tree(node, depth=0)that prints the tree in an indented format.
Hint for print_tree:
[height <= 175]
LEFT:
[weight <= 63]
LEFT: LEAF → F
RIGHT: LEAF → M
RIGHT:
LEAF → M
# YOUR CODE HERE
def build_tree(data, features, label):
"""
Recursively builds a decision tree.
data : list of dicts
features : list of feature names to split on
label : the target column name
Returns a DecisionTreeNode (either leaf or boundary).
"""
labels = [row[label] for row in data]
# Base case 1: empty data
if len(data) == 0:
pass # return a leaf — but with what label?
# Base case 2: pure or single element
if is_pure(labels) or len(data) == 1:
pass # return a leaf
# Recursive case
best_feat, best_thresh, best_imp = find_decision_boundary(data, features, label)
left_data = [row for row in data if row[best_feat] <= best_thresh]
right_data = [row for row in data if row[best_feat] > best_thresh]
# Safety: if the split doesn't separate anything, stop
if len(left_data) == 0 or len(right_data) == 0:
pass # return a leaf
left_tree = build_tree(left_data, features, label)
right_tree = build_tree(right_data, features, label)
return DecisionTreeNode(
boundary=best_feat,
boundary_value=best_thresh,
left=left_tree,
right=right_tree
)
# Helper to print the tree
def print_tree(node, depth=0, label="ROOT"):
"""
Prints the tree in an indented format.
Implement this yourself!
"""
indent = " " * depth
if node.decision is not None: # leaf node
print(f"{indent}[{label}] LEAF → {node.decision}")
else: # boundary node
print(f"{indent}[{label}] {node.boundary} <= {node.boundary_value}?")
print_tree(node.left, depth + 1, label="YES (left)")
print_tree(node.right, depth + 1, label="NO (right)")
# Build and print the tree
tree = build_tree(data, features=["height", "weight"], label="gender")
print_tree(tree)
Exercise 7.4 — Use the Tree to Predict¶
Now use your built tree to predict the gender of new (unseen) individuals.
Tasks:
- Use
tree.check(row)to predict the gender for each row in the originaldata. - Compare the prediction to the true label.
- Count how many predictions are correct. What is the accuracy (correct / total)?
- Which rows does the tree get wrong? Can you explain why?
# YOUR CODE HERE
correct = 0
print(f"{'Height':<8} {'Weight':<8} {'True':<8} {'Predicted':<12} {'Match'}")
print("-" * 50)
for row in data:
predicted = tree.check(row)
true_label = row["gender"]
match = "✓" if predicted == true_label else "✗"
if predicted == true_label:
correct += 1
print(f"{row['height']:<8} {row['weight']:<8} {true_label:<8} {str(predicted):<12} {match}")
print(f"\nAccuracy: {correct}/{len(data)} = {correct/len(data)*100:.1f}%")
PART 8: The Full Decision Tree — fit and predict¶
Exercise 8.1 — The Interface¶
Real ML libraries use a standard interface:
fit(X, y)— train the model on dataX(features) and labelsypredict(X)— given new data, return predictions
Here, X will be a list of dicts (one per sample), and y will be a list of labels.
Before coding, think about these questions:
- In
fit(X, y): how do you know which features to try splitting on? - In
predict(X):Xis a list of rows — sopredictshould return a list of predictions. - What should a leaf return when
checkis called on it?- Currently it stores the majority label as a string.
- Should
predictreturn raw labels, or something else?
Write your answers below.
Your answers:
- Features to split on: ...
predictreturns: ...- Leaf behaviour: ...
Exercise 8.2 — Implement SimpleDecisionTree¶
Now wrap everything into a clean class.
class SimpleDecisionTree:
def fit(self, X, y):
"""
X : list of dicts (each dict is one row of features)
y : list of labels (strings or booleans)
Builds the decision tree and stores it in self.root.
Also stores the feature names in self.features.
"""
pass
def predict(self, X):
"""
X : list of dicts
Returns a list of predicted labels, one per row.
"""
pass
Tasks:
Implement
SimpleDecisionTreeusingbuild_treeinsidefit.Separate
Xandybefore fitting. Combine them insidefitto pass tobuild_tree.Test it on the height+weight dataset:
X = [{"height": row["height"], "weight": row["weight"]} for row in data] y = [row["gender"] for row in data] clf = SimpleDecisionTree() clf.fit(X, y) predictions = clf.predict(X)
Compute accuracy.
Now predict on some new unseen data:
new_people = [ {"height": 163, "weight": 54}, # your guess: F or M? {"height": 185, "weight": 91}, # your guess: F or M? {"height": 172, "weight": 65}, # your guess: F or M? ]
Does the tree's prediction match your intuition?
# YOUR CODE HERE
class SimpleDecisionTree:
def __init__(self):
self.root = None
self.features = None
def fit(self, X, y):
"""
X : list of dicts (features)
y : list of labels
"""
pass # replace this
def predict(self, X):
"""
X : list of dicts
Returns a list of predicted labels.
"""
pass # replace this
# Test the full pipeline
X = [{"height": row["height"], "weight": row["weight"]} for row in data]
y = [row["gender"] for row in data]
clf = SimpleDecisionTree()
clf.fit(X, y)
predictions = clf.predict(X)
# Accuracy
correct = sum(p == t for p, t in zip(predictions, y))
print(f"Training accuracy: {correct}/{len(y)} = {correct/len(y)*100:.1f}%")
# Predict on new data
new_people = [
{"height": 163, "weight": 54},
{"height": 185, "weight": 91},
{"height": 172, "weight": 65},
]
new_preds = clf.predict(new_people)
print("\nNew predictions:")
for person, pred in zip(new_people, new_preds):
print(f" height={person['height']}, weight={person['weight']} → {pred}")
Exercise 8.3 — Print the Learned Tree¶
Use your print_tree function to display the tree learned by SimpleDecisionTree.
Tasks:
- Call
print_tree(clf.root). - How deep is the tree? Count the levels.
- Does the tree make intuitive sense? Are the splits reasonable?
- What would happen if you trained on a dataset with more noise — would the tree be deeper or shallower? Why?
# YOUR CODE HERE
print_tree(clf.root)
PART 9: Bonus Challenges¶
If you've made it here, you've built a decision tree from scratch — impurity formula, split search, recursive tree growth, and a full ML-style class. Here are open-ended extensions.
Challenge A — Max Depth¶
Your tree keeps splitting until each leaf is pure. This can lead to overfitting — the tree memorises the training data but won't generalise to new data.
Add a max_depth parameter to SimpleDecisionTree (and build_tree):
- If the current depth equals
max_depth, stop and return a leaf regardless of purity.
- Try
max_depth=1,2,3. How does accuracy on the training set change? - What does the tree look like with
max_depth=1? This is called a decision stump. - In general: deeper tree = higher training accuracy. Is that always better? Why not?
# Challenge A — YOUR CODE
# Modify build_tree and SimpleDecisionTree to support max_depth
Challenge B — Min Samples to Split¶
Another way to prevent overfitting: don't split a group if it has fewer than min_samples rows.
Add a min_samples parameter:
- If
len(data) < min_samples, return a leaf.
Try min_samples=1 (default, current behaviour) vs min_samples=3 vs min_samples=5.
How does the tree structure change?
# Challenge B — YOUR CODE
Challenge C — A New Dataset¶
Try your tree on the Iris dataset — a classic ML benchmark.
# Load the iris dataset
from sklearn.datasets import load_iris
iris = load_iris()
# Convert to the format your tree expects
feature_names = iris.feature_names # 4 features
X_iris = [dict(zip(feature_names, row)) for row in iris.data]
y_iris = [iris.target_names[t] for t in iris.target] # "setosa", "versicolor", "virginica"
Tasks:
- Fit your
SimpleDecisionTreeon the Iris dataset. - Compute training accuracy.
- Print the tree. How deep is it?
- Compare to
sklearn.tree.DecisionTreeClassifier. Do they make the same splits? - Bonus: Split into train/test sets (80/20) and measure test accuracy.
# Challenge C — YOUR CODE
from sklearn.datasets import load_iris
iris = load_iris()
feature_names = iris.feature_names
X_iris = [dict(zip(feature_names, row)) for row in iris.data]
y_iris = [iris.target_names[t] for t in iris.target]
clf_iris = SimpleDecisionTree()
clf_iris.fit(X_iris, y_iris)
preds = clf_iris.predict(X_iris)
acc = sum(p == t for p, t in zip(preds, y_iris)) / len(y_iris)
print(f"Iris training accuracy: {acc*100:.1f}%")
print("\nTree structure:")
print_tree(clf_iris.root)
Challenge D — What Is Gini Impurity?¶
The impurity function you invented is based on $p(1-p)$. Look up Gini impurity, which is:
$$\text{Gini} = 1 - \sum_k p_k^2$$
where the sum is over all classes k, and $p_k$ is the fraction of class k.
- For two classes (R and B), write out the Gini formula in full.
- Is it the same as your formula, or different?
- Replace your
impurityfunction with the Gini formula and re-run everything. Do the results change? - Look up Entropy as another impurity measure. Try implementing that too.
# Challenge D — YOUR CODE
def gini_impurity(p, q):
"""
Gini impurity for two classes.
Formula: 1 - (p^2 + q^2)
Note: Gini=0 when pure, Gini=0.5 when maximally mixed.
"""
pass
def entropy_impurity(p, q):
"""
Entropy impurity for two classes.
Formula: -(p * log2(p) + q * log2(q)) [handle p=0 or q=0 as 0]
"""
pass
# Plot all three impurity measures on the same graph
ps = [i/100 for i in range(101)]
plt.figure(figsize=(8, 4))
plt.plot(ps, [impurity(p, 1-p) for p in ps], label='Your formula')
plt.plot(ps, [gini_impurity(p, 1-p) for p in ps], label='Gini', linestyle='--')
plt.plot(ps, [entropy_impurity(p, 1-p) for p in ps], label='Entropy', linestyle=':')
plt.xlabel('p'); plt.ylabel('impurity'); plt.title('Comparing impurity measures')
plt.legend(); plt.grid(True)
plt.show()
Reflection — What Did You Just Build?¶
Take a moment to answer these questions in your own words.
What is impurity? Why does it matter for decision trees?
What is a decision boundary? How did you find the best one?
What is recursive splitting? When does it stop?
What is overfitting? How does a decision tree overfit, and how can you prevent it?
What does
fitdo? What doespredictdo?What would a Random Forest be? (Think: what if you built many trees on random subsets of data and features, then took a vote?)
Your answers:
Impurity is...
A decision boundary is...
Recursive splitting means...
Overfitting means...
fitdoes...predictdoes...A Random Forest would be...
You've just invented a decision tree classifier from scratch — impurity, split search, recursive tree growth, and a full ML interface. That's the real thing. sklearn's DecisionTreeClassifier works on exactly these principles.