Problem Statement
I am trying to solve a variation of the Maximum Path Sum in a Binary Tree problem where some nodes in the tree are colored red. The path sum is only valid if:
- The path starts and ends at a red node.
- The path can contain zero or more additional red nodes in between.
- The path can include non-red nodes as long as it starts and ends at red nodes.
- The path follows parent-child connections (no jumps).
Given this constraint, how do I compute the maximum sum path in the binary tree?
Example
Consider this tree where (R) represents red nodes:
10(R)
/ \
-2 7(R)
/ \ \
8(R) -4 6
/
-1(R)
- The original Maximum Path Sum (ignoring red constraints) is: 8 → -2 → 10 → 7 → 6 = 29
- But with the red node constraint, the best valid path must start and end at a red node: 8 → -2 → 10 → 7 = 23
What I Have Tried
The standard approach for Maximum Path Sum uses DFS with recursion while maintaining a global max. I modified it to only update the global max when encountering a red-to-red path, but I am struggling to properly track valid paths and backtrack correctly.
class Node:
def __init__(self, val):
self.val = val
self.left = None
self.right = None
self.red = False
class Solution:
def solve(self, root):
ans = float("-inf")
def dfs(node):
if not node:
return [0, False]
left, left_red = dfs(node.left)
right, right_red = dfs(node.right)
# update the global ans variable based on wether current node is red or not
nonlocal ans
if node.red:
if left_red:
ans = max(ans, node.val + left)
if right_red:
ans = max(ans, node.val + right)
if left_red and right_red:
ans = max(ans, node.val + left + right)
else:
if left_red and right_red:
ans = max(ans, node.val + left + right)
# return the single best rising path from this node
if node.red:
local_max = float("-inf")
if left_red:
local_max = max(local_max, node.val + left, node.val)
if right_red:
local_max = max(local_max, node.val + right, node.val)
return [max(local_max, node.val), node.red]
else:
local_max = float("-inf")
if left_red:
local_max = max(local_max, node.val + left, node.val)
if right_red:
local_max = max(local_max, node.val + right, node.val)
return [local_max, left_red or right_red]
dfs(root)
return ans
soln = Solution()
root = Node(10)
root.left = Node(-5)
root.right = Node(20)
root.left.left = Node(4)
root.left.right = Node(3)
root.right.left = Node(1)
root.right.right = Node(6)
root.right.right.red = True
root.right.left.left = Node(-10)
root.right.left.left.red = True
print(soln.solve(root))
This fails for test cases where a red node can be a leaf.
10
/ \
-5 20
/ \ / \
4 3 1 6(R)
/
-10(R)
The actual answer to this should be
-10 -> 1 -> 20 -> 6 = 17
But the output is 27
Is this (using dfs) even the right approach? Or do I have to reframe the tree as a graph and then do a BFS from each red node to compute the distance from red to red node?
I think there might be a small bug in how I am returning the rising path but I am not able to point it out correctly.
Problem Constraints
- The tree has at least two red nodes.
- The values in the tree can be positive, negative, or zero.
- The number of nodes is at most 10⁵.