Find Mode in Binary Search Tree

Given the root of a binary search tree (BST) with duplicates, return all the mode(s) (i.e., the most frequently occurred element) in it.

If the tree has more than one mode, return them in any order.

Assume a BST is defined as follows:

  • The left subtree of a node contains only nodes with keys less than or equal to the node's key.
  • The right subtree of a node contains only nodes with keys greater than or equal to the node's key.
  • Both the left and right subtrees must also be binary search trees.

Example 1:

Input: root = [1,null,2,2] Output: [2]

Example 2:

Input: root = [0] Output: [0]


  • The number of nodes in the tree is in the range [1, 104].
  • -105 <= Node.val <= 105

Follow up: Could you do that without using any extra space? (Assume that the implicit stack space incurred due to recursion does not count).

Note: This problem is from LeetCode.
❌ Compiled
❌ Executed
❌ Correct
program main
    use modes_mod
    implicit none

    type(bst_node), pointer :: root
    integer :: i
    type(bst_node), allocatable :: example(:)

    example(1)%key = 1
    example(2)%key = 2
    example(3)%key = 2
    example(1)%left => null()
    example(1)%right => example(2)
    example(2)%left => example(3)
    example(2)%right => null()
    example(3)%left => null()
    example(3)%right => null()

    print '("Example 1:")'
    call solve(root=example(1))
    print '("Example 2:")'
    call solve(root=example(2))


subroutine solve(root)
    type(bst_node), pointer :: root
    type(bst_node), allocatable :: stack(:)
    integer :: top, i
    logical :: found

    top = 0
    stack(top+1) => root
    top = top + 1

    do while (top > 0)
        if (associated(stack(top)%left)) then
            stack(top+1) => stack(top)%left
            top = top + 1
        else if (associated(stack(top)%right)) then
            stack(top+1) => stack(top)%right
            top = top + 1
            do while (.not. associated(stack(top)%left))
                top = top - 1
            end do
            stack(top+1) => stack(top)%left
            top = top + 1
        end if
    end do

    do i = 1, size(stack)
        root%data(i) = stack(i)%key
    end do

end subroutine solve

end program main
❌ Compiled
❌ Executed
❌ Correct
module bst_mod

! Define the node structure
type :: node
    integer :: val
    type(node), pointer :: left
    type(node), pointer :: right
end type


! Define the function to find the mode(s) of a BST
function find_mode(root) result(modes)
    type(node), pointer :: root
    integer, allocatable :: modes(:)
    integer :: count, max_count
    integer :: i

    ! Initialize the count and max_count
    count = 0
    max_count = 0

    ! Loop through the BST and count the frequency of each value
    call count_frequency(root, count, max_count)

    ! Allocate the array to store the modes

    ! Loop through the BST again and store the modes in the array
    call store_modes(root, modes, count, max_count, 1)

end function

! Define the recursive function to count the frequency of each value in the BST
recursive subroutine count_frequency(root, count, max_count)
    type(node), pointer :: root
    integer :: count, max_count

    ! Base case: if the current node is null, return
    if (.not. associated(root)) return

    ! Increment the count of the current node's value
    count = count + 1

    ! If the current node's value is greater than the max_count, update the max_count
    if (root%val > max_count) max_count = root%val

    ! Recursively count the frequency of the left and right subtrees
    call count_frequency(root%left, count, max_count)
    call count_frequency(root%right, count, max_count)

end subroutine

! Define the recursive function to store the modes in the array
recursive subroutine store_modes(root, modes, count, max_count, index)
    type(node), pointer :: root
    integer, allocatable :: modes(:)
    integer :: count, max_count, index

    ! Base case: if the current node is null, return
    if (.not. associated(root)) return

    ! If the current node's value is equal to the max_count, store it in the array
    if (root%val == max_count) modes(index) = root%val

    ! Recursively store the modes in the left and right subtrees
    call store_modes(root%left, modes, count, max_count, index)
    call store_modes(root%right, modes, count, max_count, index + 1)

end subroutine

end module

program test_bst

use bst_mod

implicit none

! Test case 1: root = [1,null,2,2]
type(node), pointer :: root
integer :: modes(1)

root => new_node(1)
root%left => new_node(2)
root%right => new_node(2)

modes = find_mode(root)

write (*,*) "Test case 1:", modes

! Test case 2: root = [0]
root => new_node(0)

modes = find_mode(root)

write (*,*) "Test case 2:", modes

! Test case 3: root = [1,null,2,null,3,null,4]
root => new_node(1)
root%left => new_node(2)
root%right => new_node(3)
root%right%right => new_node(4)

modes = find_mode(root)

write (*,*) "Test case 3:", modes

! Test case 4: root = [1,null,2,null,3,null,4,null,null,5]
root => new_node(1)
root%left => new_node(2)
root%right => new_node(3)
root%right%right => new_node(4)
root%right%right%right => new_node(5)

modes = find_mode(root)

write (*,*) "Test case 4:", modes

! Test case 5: root
🌐 Data from online sources
def findMode(root):
    def InOrder(node):
        if not node: return
        yield from InOrder(node.left)
        yield node.val
        yield from InOrder(node.right)

    counts = collections.Counter(InOrder(root))
    max_count = max(counts.values(), default=0)
    return [k for k, v in counts.items() if v == max_count]

The algorithm is based on performing an in-order traversal of the binary search tree (BST). In-order traversal of a BST will visit nodes in ascending order. During the traversal, we count the frequency of each element using a HashMap (in C++ and Java), a Dictionary (in Python), or a Map (in JavaScript) and keep track of the maximum frequency.

Once the traversal is complete, we iterate through our frequency map/dictionary object, and if an element has a frequency equal to the maximum frequency, we add it to our result list/array.

The time complexity of this algorithm is O(n), where n is the number of nodes in the tree, since we visit each node once during the traversal.

🌐 Data from online sources
void InOrder(TreeNode* root, unordered_map<int, int>& frequencies, int& max_frequency) {
    if (!root) return;
    InOrder(root->left, frequencies, max_frequency);
    max_frequency = max(max_frequency, frequencies[root->val]);
    InOrder(root->right, frequencies, max_frequency);

vector<int> findMode(TreeNode* root) {
    unordered_map<int, int> frequencies;
    int max_frequency = 0;
    InOrder(root, frequencies, max_frequency);
    vector<int> modes;

    for (const auto& freq : frequencies) {
        if (freq.second == max_frequency) modes.push_back(freq.first);

    return modes;

