Delete Nodes and Return Forest

Posted on: 27/2/2025

Today, I followed Keith Galli’s video on his LeetCode solving thought process and tried to solve the problems on my own first before looking at his solutions. Most of them, I was able to solve after a few tries. But for the last problem, “Delete Nodes and Return Forest”, I tried and kept getting more confused—just like with the Trim Binary Search Tree problem.

At first glance, I thought this was similar to the one I got stuck on previously, which led me to create a video explaining it to myself. But this problem was slightly different. My first thought was, “Oh no, I think recursion is needed here, but I have no idea how to do it. Maybe I’ll figure it out along the way.”

Nope. I tried, but I had no idea how to use recursion. So I just started coding my logic instead.

Where I Got Stuck

When the current node’s value is in to_delete, I need to take its branches and append them to my list of trees. But the problem is—how do I remove the previous link to the current node? There’s no node.previous, only node.left and node.right.

I overcomplicated things by checking node.left and node.right for any nodes to delete, then appending node.left.left and node.left.right to the tree if node.left was the node to be deleted. Getting more confused? Definitely!

Breaking It Down

I like Keith Galli’s thought process—whenever we see a tree problem, we should first think about traversal. He started by writing a simple breadth-first search (BFS) and modified it from there.

class Solution:
    def delNodes(self, root: Optional[TreeNode], to_delete: List[int]) -> List[TreeNode]:
        queue = [root]
        while queue:
            node = queue.pop(0)  # First in, first out
            print(node)
            queue.append(node.left)
            queue.append(node.right)

Step-by-Step Approach

1. Start by simply deleting nodes during traversal: If a node is in to_delete, remove it.

class Solution:
    def delNodes(self, root: Optional[TreeNode], to_delete: List[int]) -> List[TreeNode]:
        queue = [root]
        while queue:
            node = queue.pop(0)
            if node.left:
                queue.append(node.left)
                if node.left.val in to_delete:
                    node.left = None
            if node.right:
                queue.append(node.right)
                if node.right.val in to_delete:
                    node.right = None

2. Handle disjoint nodes: If a deleted node had children, they become separate trees. So, we recursively pass the left and right nodes into the function.

class Solution:
    def __init__(self):
        self.trees = [] # List to store the trees

    def delNodes(self, root: Optional[TreeNode], to_delete: List[int]) -> List[TreeNode]:
        to_delete = set(to_delete)  # Convert to set for faster lookup

        if root:
            if root.val not in to_delete:
                self.trees.append(root) # Add the root to the list of trees

        queue = [root]
        while queue:
            node = queue.pop(0)

            if node.val in to_delete:
                if node.left:
                    self.delNodes(node.left, to_delete)
                if node.right:
                    self.delNodes(node.right, to_delete)
            else:
                if node.left:
                    queue.append(node.left)
                    if node.left.val in to_delete:
                        node.left = None
                if node.right:
                    queue.append(node.right)
                    if node.right.val in to_delete:
                        node.right = None       

        return self.trees

This was an insightful video, and I learned more about the thought process behind solving tree problems. Don’t forget to check out Keith Galli’s video, highly recommended!

My own explanation of how this code works.

Until next time!

Prev Next