1

Problem Statement

I am trying to solve a variation of the Maximum Path Sum in a Binary Tree problem where some nodes in the tree are colored red. The path sum is only valid if:

  1. The path starts and ends at a red node.
  2. The path can contain zero or more additional red nodes in between.
  3. The path can include non-red nodes as long as it starts and ends at red nodes.
  4. The path follows parent-child connections (no jumps).

Given this constraint, how do I compute the maximum sum path in the binary tree?

Example

Consider this tree where (R) represents red nodes:

        10(R)
       /     \
    -2       7(R)
    /  \       \
   8(R)  -4     6
         /
       -1(R)
  • The original Maximum Path Sum (ignoring red constraints) is: 8 → -2 → 10 → 7 → 6 = 29
  • But with the red node constraint, the best valid path must start and end at a red node: 8 → -2 → 10 → 7 = 23

What I Have Tried

The standard approach for Maximum Path Sum uses DFS with recursion while maintaining a global max. I modified it to only update the global max when encountering a red-to-red path, but I am struggling to properly track valid paths and backtrack correctly.

class Node:
    def __init__(self, val):
        self.val = val 
        self.left = None
        self.right = None 
        self.red = False 

class Solution:
    def solve(self, root):
        ans = float("-inf")
    
        def dfs(node):
            if not node:
                return [0, False]
        
            left, left_red = dfs(node.left)
            right, right_red = dfs(node.right)
            
            # update the global ans variable based on wether current node is red or not
            nonlocal ans 
            if node.red:
                if left_red:
                    ans = max(ans, node.val + left)
                if right_red:
                    ans = max(ans, node.val + right)
                if left_red and right_red:
                    ans = max(ans, node.val + left + right)
            else:
                if left_red and right_red:
                    ans = max(ans, node.val + left + right)
            
            # return the single best rising path from this node 
            if node.red:
                local_max = float("-inf")
                if left_red:
                    local_max = max(local_max, node.val + left, node.val)
                if right_red:
                    local_max = max(local_max, node.val + right, node.val)
                
                return [max(local_max, node.val), node.red]
            else:
                local_max = float("-inf")
                if left_red:
                    local_max = max(local_max, node.val + left, node.val)
                if right_red:
                    local_max = max(local_max, node.val + right, node.val)
                
                return [local_max, left_red or right_red]
                
        dfs(root)
        return ans

soln = Solution()

root = Node(10)

root.left = Node(-5)

root.right = Node(20)

root.left.left = Node(4)
root.left.right = Node(3)

root.right.left = Node(1)
root.right.right = Node(6)
root.right.right.red = True

root.right.left.left = Node(-10)
root.right.left.left.red = True

print(soln.solve(root))

This fails for test cases where a red node can be a leaf.

        10 
       /   \
     -5     20  
     / \    / \ 
    4  3   1   6(R)
          /
        -10(R)

The actual answer to this should be

-10 -> 1 -> 20 -> 6 = 17

But the output is 27


Is this (using dfs) even the right approach? Or do I have to reframe the tree as a graph and then do a BFS from each red node to compute the distance from red to red node?

I think there might be a small bug in how I am returning the rising path but I am not able to point it out correctly.

Problem Constraints

  • The tree has at least two red nodes.
  • The values in the tree can be positive, negative, or zero.
  • The number of nodes is at most 10⁵.
4
  • 1
    Can you share your attempt, so we can see (1) how the trees are encoded (2) what the problem is with your attempt? Commented Feb 10 at 21:01
  • thanks @trincot I have updated the question to include the code of my attempt. Commented Feb 17 at 10:20
  • @btilly - what do you mean by top 0, 1, or 2 rising paths starting at red? Are you suggesting returning a list of top 3 paths from each node to it's parent? Commented Feb 17 at 10:22
  • @PodGen4 Just return 1. But every node has three potential directions leading there. From a red node on the left, from a red node on the right, or starting at the node itself. If 2+ of those exist, we have a candidate path. Only the best of which needs to go up. Commented Feb 18 at 1:54

1 Answer 1

2

The code given above has a subtle bug in the way the rising path is being returned from a non red node.

While returning the best rising path from a node that is not red, we can't compare just the current node. So that needs to be removed from the comparison.

Here is the updated code that works--

class Node:
    def __init__(self, val):
        self.val = val 
        self.left = None
        self.right = None 
        self.red = False 

class Solution:
    def solve(self, root):
        ans = float("-inf")
    
        def dfs(node):
            if not node:
                return [0, False]
        
            left, left_red = dfs(node.left)
            right, right_red = dfs(node.right)
            
            # update the global ans variable based on wether current node is red or not
            nonlocal ans 
            if node.red:
                if left_red:
                    ans = max(ans, node.val + left)
                if right_red:
                    ans = max(ans, node.val + right)
                if left_red and right_red:
                    ans = max(ans, node.val + left + right)
            else:
                if left_red and right_red:
                    ans = max(ans, node.val + left + right)
            
            # return the single best rising path from this node 
            if node.red:
                local_max = float("-inf")
                if left_red:
                    local_max = max(local_max, node.val + left, node.val)
                if right_red:
                    local_max = max(local_max, node.val + right, node.val)
                
                return [max(local_max, node.val), node.red]
            else:
                local_max = float("-inf")
                if left_red:
                    local_max = max(local_max, node.val + left)
                if right_red:
                    local_max = max(local_max, node.val + right)
                
                return [local_max, left_red or right_red]
                
        dfs(root)
        return ans

soln = Solution()

root = Node(10)

root.left = Node(-5)

root.right = Node(20)

root.left.left = Node(4)
root.left.right = Node(3)

root.right.left = Node(1)
root.right.right = Node(6)
root.right.right.red = True

root.right.left.left = Node(-10)
root.right.left.left.red = True

print(soln.solve(root))
Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.