Trim a Binary Search Tree

Posted on: 2025-01-24

Today, I was stucked at this coding challenge: given a Binary Search Tree, a low and a high value, remove all nodes which is not withing the range from low to high.

Looking at the hints, it seemed like an easy task, but the more lines of code I wrote, the more confused I became. I tried adding more conditions with more lines of code just to resolve the failed test but it will fail the next one or the earlier ones until i have to give up and look at the solution.

I’m writing this as a record for my future self and also to clarify what is going on in my mind.

The answer was just 9 lines of code! Crazy! While I wrote almost 30 lines and still unable to solve this. I noticed recursion is very useful in many cases, but I always hesitated to use it because I wasn’t able to understand it. So, I took some time to break down the code slowly with an example.

I recorded a video to explain it step by step:

Understanding the Concept

  1. A binary search tree is built such that all numbers on the left are smaller, and all numbers on the right are larger.
  2. For this challenge, given a range, e.g., low = 7 and high = 10, we need to trim all numbers that are out of this boundary. This means we only retain numbers between 7 and 10, removing anything smaller than 7 or larger than 10.
  3. The smart move is passing the left or right branch recursively into the function.
  4. If the root is smaller than the lower bound, then the answer lies somewhere in the right branch, so we pass the right branch as the new root:
  5. if root.val < low: 
        return self.trimBST(root.right, low, high)
  6. This "trims" the left branch by just moving the root pointer. All the nodes in the left branch still exists in memory. Just that the root pointer now points at the right branch.
  7. Similarly, if the root is larger than the upper bound, move the root pointer to the left node:
  8. if root.val > high:
        return self.trimBST(root.left, low, high)
  9. Finally, when the root is within the range, it proceeds to trim left and right branches recursively and return the root:
  10. root.left = self.trimBST(root.left, low, high)
    root.right = self.trimBST(root.right, low, high)
    return root
  11. The full code:
  12. class Solution:
        def trimBST(self, root: TreeNode, low: int, high: int) -> TreeNode:
            if not root:
                return None
            if root.val < low: 
                return self.trimBST(root.right, low, high) 
            if root.val> high:
                return self.trimBST(root.left, low, high)
            root.left = self.trimBST(root.left, low, high)
            root.right = self.trimBST(root.right, low, high)
            return root

Hope this explanation helps! And hope in the future . Watch the video for more details. See you in the next post!

Next