0

Let's say we have two arrays:

array1 = [2,3,6,7,9]

array2 = [1,4,8,10]

I understood how to find the kth element of two sorted arrays in log(min(m,n)) where m is the length of array1 and n is the length of array2 as follows:

def kthelement(arr1, arr2, m, n, k):
    if m > n:
        kthelement(arr2, arr1, n, m, k) 

    low = max(0, k - m)
    high = min(k, n)

    while low <= high:
        cut1 = (low + high) >> 1 
        cut2 = k - cut1 
        l1 = MIN_VALUE if cut1 == 0 else arr1[cut1 - 1] 
        l2 = MIN_VALUE if cut2 == 0 else arr2[cut2 - 1]
        r1 = MAX_VALUE if cut1 == n else arr1[cut1]
        r2 = MAX_VALUE if cut2 == m else arr2[cut2] 
        
        if l1 <= r2 and l2 <= r1:
            print(cut1, cut2)
            return max(l1, l2)
        elif l1 > r2:
            high = cut1 - 1
        else:
            low = cut1 + 1

But I couldn't figure out how to extend this to multiple sorted arrays case. For example, given 3 arrays, I want to find the kth element of the final sorted array.

array1 = [2,3,6,7,9]

array2 = [1,4,8,10]

array3 = [2,3,5,7]

Is it possible to achieve it in log(min(m,n)) as in the two array case?

13
  • Just combine them in a new array, sort it and use the index I suppose should work. sorted(array1+array2+array3)[X] Commented Sep 7, 2022 at 14:18
  • 1
    @alanturing I have a complexity constraint. That's why your approach is not feasible for me. I need a more optimized way to do it. Commented Sep 7, 2022 at 14:20
  • Oh. Then just ignore it. XD Commented Sep 7, 2022 at 14:20
  • @alanturing That would be preferable but if it's not feasible, I am okay with anything better than O(n). Commented Sep 7, 2022 at 14:24
  • 1
    Did you try to use the same logic as with two arrays, to write the algorithm for 3 arrays? Where were you stuck when you tried that? Commented Sep 7, 2022 at 14:41

4 Answers 4

2

If k is very large, We can make binary search on the answer, which leads to a solution with time complexity O(n*logN) where N is the range of each element, and n is the number of arrays.

What we need to learn is how to check some integer x whether <= correct answer or not. We can just enumerate each array, and make binary search on it to count the number of elements less than or equal to x. accumulate them, and compare it with k.

from typing import List
import bisect

def query_k_min(vecs: List[List[int]], k: int) -> int:
    # we assume each number >=1 and <=10^9
    l, r = 0, 10**9
    while r - l > 1:
        m = (l+r)>>1
        tot = 0
        for vec in vecs:
            tot += bisect.bisect_right(vec, m)
        if tot >= k: r = m
        else: l = m
    return r

a = [[2,3,6,7,9],[1,4,8,10],[2,3,5,7]]
for x in range(1,14):
    print(query_k_min(a,x))


Sign up to request clarification or add additional context in comments.

Comments

1

The general solution is to use a min-heap. If you have n sorted arrays and you want the kth smallest number, then the solution is O(k log n).

The idea is that you insert the first number from each array into the min-heap. When inserting into the heap, you insert a tuple that contains the number, and the array that it came from.

You then remove the smallest value from the heap and add the next number from the array that value came from. You do this k times to get the kth smallest number.

See https://www.geeksforgeeks.org/find-m-th-smallest-value-in-k-sorted-arrays/ for the general idea.

8 Comments

The OP still hasn't understood the logic of the algorithm well enough to do it with 3 arrays; and you're throwing a min-heap at them to do it with a variable number of arrays!
This will be slow for finding the median element...
@btilly There's nothing in the question about the median.
The time complexity should be O(n+klogn) to include cases when k < n.
See stackoverflow.com/a/73642445/585411 for a better way to do this with multiple arrays and arbitrary k.
|
1

The following looks complicated, but if M is the sum of the logs of len(list)+2, then the average case is O(M) and the worst case is O(M^2). (The reason for the +2 is that even if the array has no elements, we need to do work, which we do by making the log to be of at least 2.) The worst case is very unlikely.

The performance is independent of k.

The idea the same as Quickselect. We are picking pivots, and splitting data around the pivot. But we do not look at each elements, we only figure out what chunk of each array that is still under consideration is before/after/landed at the pivot. The average case is because every time we look at an array, with positive probability we get rid of half of what remains. The worst case is because every time we look at the array we got a pivot from, we will get rid of half that array but may have to binary search every other array to decide we got rid of nothing else.

from collections import deque

def kth_of_sorted (k, arrays):
    # Initialize some global variables.
    known_low = 0
    known_high = 0
    total_size = 0

    # in_flight will be a double-ended queue of
    # (array, iteration, i, j, min_i, min_j)
    # Where:
    #    array is an input array
    #    iteration is which median it was compared to
    #    low is the lower bound on where kth might be
    #    high is the upper bound on where kth might be
    in_flight = deque()

    for a in arrays:
        if 0 < len(a):
            total_size += len(a)
            in_flight.append((a, 0, len(a)-1))

    # Sanity check.
    if k < 1 or total_size < k:
        return None

    while 0 < len(in_flight):
        start_a, start_low, start_high = in_flight.popleft()
        start_mid = (start_low + start_high) // 2
        pivot = start_a[start_mid]

        # If pivot is placed, how many are known?
        maybe_low = start_mid - start_low
        maybe_high = start_high - start_mid

        # This will be arrays taken from in_flight with:
        #
        #    (array, low, high, orig_low, orig_high)
        #
        # We are binary searching in these to figure out where the pivot
        # is going to go. Then we copy back to in_flight.
        to_process = deque()

        # This will be arrays taken from in_flight with:
        #
        #    (array, orig_low, mid, orig_high)
        #
        # where at mid we match the pivot.
        is_match = deque()
        # And we know an array with a pivot!
        is_match.append((start_a, start_low, start_mid, start_high))

        # This will be arrays taken from in_flight which we know do not have the pivot:
        #
        #    (array, low, high, orig_low, orig_high)
        #
        no_pivot = deque()

        while 0 < len(in_flight):
            a, low, high = in_flight.popleft()
            mid = (low + high) // 2
            if a[mid] < pivot:
                # all of low, low+1, ..., mid are below the pivot
                maybe_low += mid + 1 - low
                if mid < high:
                    to_process.append((a, mid+1, high, low, high))
                else:
                    no_pivot.append((a, mid+1, high, low, high))
            elif pivot < a[mid]:
                # all of mid, mid+1, ..., high are above the pivot.
                maybe_high += high + 1 - mid
                if low < mid:
                    to_process.append((a, low, mid-1, low, high))
                else:
                    no_pivot.append((a, low, mid-1, low, high))
            else:
                # mid is at pivot
                maybe_low += mid - low
                maybe_high += high - mid
                is_match.append((a, low, mid, high))

        # We do not yet know where the pivot_pos is.
        pivot_pos = None
        if k <= known_low + maybe_low:
            pivot_pos = 'right'
        elif total_size - known_high - maybe_high < k:
            pivot_pos = 'left'
        elif k <= known_low + maybe_low + len(is_match) and total_size < k + known_high + maybe_high + len(is_match):
            return pivot # WE FOUND IT!

        while pivot_pos is None:
            # This is very similar to how we processed in_flight.
            a, low, high, orig_low, orig_high = to_process.popleft()
            mid = (low + high) // 2
            if a[mid] < pivot:
                # all of low, low+1, ..., mid are below the pivot
                maybe_low += mid + 1 - low
                if mid < high:
                    to_process.append((a, mid+1, high, orig_low, orig_high))
                else:
                    no_pivot.append((a, mid+1, high, orig_low, orig_high))
            elif pivot < a[mid]:
                # all of mid, mid+1, ..., high are above the pivot.
                maybe_high += high + 1 - mid
                if low < mid:
                    to_process.append((a, low, mid-1, orig_low, orig_high))
                else:
                    no_pivot.append((a, low, mid-1, orig_low, orig_high))
            else:
                # mid is at pivot
                maybe_low += mid - low
                maybe_high += high - mid
                is_match.append((a, orig_low, mid, orig_high))

            if k <= known_low + maybe_low:
                pivot_pos = 'right'
            elif total_size - known_high - maybe_high < k:
                pivot_pos = 'left'
                 a, low, high = in_flight.popleft()
            mid = (low + high) // 2
            if a[mid] < pivot:
                # all of low, low+1, ..., mid are below the pivot
                maybe_low += mid + 1 - low
                if mid < high:
                    to_process.append((a, mid+1, high, low, high))
                else:
                    no_pivot.append((a, mid+1, high, low, high))
            elif pivot < a[mid]:
                # all of mid, mid+1, ..., high are above the pivot.
                maybe_high += high + 1 - mid
                if low < mid:
                    to_process.append((a, low, mid-1, low, high))
                else:
                    no_pivot.append((a, low, mid-1, low, high))
            else:
                # mid is at pivot
                maybe_low += mid - low
                maybe_high += high - mid
                is_match.append((a, low, mid, high))

        # We do not yet know where the pivot_pos is.
        pivot_pos = None
        if k <= known_low + maybe_low:
            pivot_pos = 'right'
        elif total_size - known_high - maybe_high < k:
            pivot_pos = 'left'
        elif k <= known_low + maybe_low + len(is_match) and total_size < k + known_high + maybe_high + len(is_match):
            return pivot # WE FOUND IT!

        while pivot_pos is None:
            # This is very similar to how we processed in_flight.
            a, low, high, orig_low, orig_high = to_process.popleft()
            mid = (low + high) // 2
            if a[mid] < pivot:
                # all of low, low+1, ..., mid are below the pivot
                maybe_low += mid + 1 - low
                if mid < high:
                    to_process.append((a, mid+1, high, orig_low, orig_high))
                else:
                    no_pivot.append((a, mid+1, high, orig_low, orig_high))
            elif pivot < a[mid]:
                # all of mid, mid+1, ..., high are above the pivot.
                maybe_high += high + 1 - mid
                if low < mid:
                    to_process.append((a, low, mid-1, orig_low, orig_high))
                else:
                    no_pivot.append((a, low, mid-1, orig_low, orig_high))
            else:
                # mid is at pivot
                maybe_low += mid - low
                maybe_high += high - mid
                is_match.append((a, orig_low, mid, orig_high))

            if k <= known_low + maybe_low:
                pivot_pos = 'right'
            elif total_size - known_high - maybe_high < k:
                pivot_pos = 'left'
            elif k <= known_low + maybe_low + len(is_match) and total_size < k + known_high + maybe_high + len(is_match):
                return pivot # WE FOUND IT!
       elif k <= known_low + maybe_low + len(is_match) and total_size < k + known_high + maybe_high + len(is_match):
                return pivot # WE FOUND IT!

        # And now place the pivot in the right position.
        if pivot_pos == 'right':
            known_high += maybe_high + len(is_match)
            # And put back the left side of each nonemptied array.
            for q in (to_process, no_pivot):
                while 0 < len(q):
                    a, low, high, orig_low, orig_high = q.popleft()
                    if orig_low <= high:
                        in_flight.append((a, orig_low, high))
            while 0 < len(is_match):
                a, low, mid, high = is_match.popleft()
                if low < mid:
                    in_flight.append((a, low, mid-1))
        else:
            known_low += maybe_low + len(is_match)
            # And put back the right side of each nonemptied array.
            for q in (to_process, no_pivot):
                while 0 < len(q):
                    a, low, high, orig_low, orig_high = q.popleft()
                    if low <= orig_high:
                        in_flight.append((a, low, orig_high))
            while 0 < len(is_match):
                a, low, mid, high = is_match.popleft()
                if mid < high:
                    in_flight.append((a, mid+1, high))

list1 = [2,3,6,7,9]
list2 = [1,4,8,10]
list3 = [2,3,5,7]
print(list1, list2, list3)
for i in range(1, len(list1) + len(list2) + len(list3)):
    print(i, kth_of_sorted(i,[list1, list2, list3]))

Comments

1

This question is about solving the problem using K-Way Merge Pattern.You can understand more about this pattern at Educative. However in brief what this patterns says is to merge K arrays which are sorted. For the question you asked we need to find the Kth smallest element from the array.

We simply push the first element of all arrays into heap which makes our Min heap size to M (M is the number of arrays). What problem says is that all arrays will be sorted so, if there is a small element in the heap then we expect a next smallest integer next to it in the same list. For that sake we do below steps.

For each step we pop an element from heap and then push the next element from the list which we popped to heap. Once this array is Heapified we again remove the element from top and push the next element to heap. We continue this until k elements are being popped. since we are pushing only the smallest element we had we will get the Kth Smallest element from the array.

Below is an way i wrote the code for problem.

def k_smallest_number(lists: List[List[int]], k: int) -> int:
"""
Finds the k-th smallest number among multiple sorted lists.

:param lists: A list of sorted lists containing integers.
:param k: The k-th smallest number to find.
:return: The k-th smallest number.
"""
# Replace this placeholder return statement with your code

heap = []
listIndex = 0
while listIndex < len(lists):
    if len(lists[listIndex]) > 0:
        heapq.heappush(heap, [lists[listIndex][0], 0, listIndex])
    listIndex += 1
count = 0
num = 0
while count < k and len(heap) > 0:
    num, listNumIndex, listIndex = heapq.heappop(heap)
    count += 1
    listNumIndex += 1
    if listNumIndex < len(lists[listIndex]) and count < k:
        heapq.heappush(heap, [lists[listIndex][listNumIndex], listNumIndex, listIndex])

return num

Total time will be O(mlogm+klogm)=O((m+k)logm)

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.