With recursive functions, you should think recursively! Here's how I would think of this function:
I start writing the signature of the function, that is
int countNodes( TreeNode *root )
So first, the cases that are not recursive. For example, if the given tree is NULL, then there are no nodes, so I return 0.
- Then, I observe that the number of nodes in my tree are the number of nodes of the left sub-tree plus the number of nodes of the right sub-tree plus 1 (the root node). Therefore, I basically call the function for the left and right nodes and add the values adding 1 also.
- Note that I assume the function already works correctly!
Why did I do this? Simple, the function is supposed to work on any binary tree right? Well, the left sub-tree of the root node, is in fact a binary tree! The right sub-tree also is a binary tree. So, I can safely assume with the same countNodes functions I can count the nodes of those trees. Once I have them, I just add left+right+1 and I get my result.
How does the recursive function really work? You could use a pen and paper to follow the algorithm, but in short it is something like this:
Let's say you call the function with this tree:
a
/ \
b c
/ \
d e
You see the root is not null, so you call the function for the left sub-tree:
b
and later the right sub-tree
c
/ \
d e
Before calling the right sub-tree though, the left sub-tree needs to be evaluated.
So, you are in the call of the function with input:
b
You see that the root is not null, so you call the function for the left sub-tree:
NULL
which returns 0, and the right sub-tree:
NULL
which also returns 0. You compute the number of nodes of the tree and it is 0+0+1 = 1.
Now, you got 1 for the left sub-tree of the original tree which was
b
and the function gets called for
c
/ \
d e
Here, you call the function again for the left sub-tree
d
which similar to the case of b returns 1, and then the right sub-tree
e
which also returns 1 and you evaluate the number of nodes in the tree as 1+1+1 = 3.
Now, you return the first call of the function and you evaluate the number of nodes in the tree as 1+3+1 = 5.
So as you can see, for each left and right, you call the function again, and if they had left or right children, the function gets called again and again and each time it goes deeper in the tree. Therefore, root->left->left or root->right->left->left get evaluated not directly, but after subsequent calls.