ctci/10. Sorting and searching/10.9. sorted_matrix_search.md

7.5 KiB

10.9. Sorted matrix search

Given an M x N matrix in which each row and each column is sorted in ascending order, write a method to find an element

example

A = [[1,2,4,5,7],
     [2,3,5,6,8],
     [3,6,7,8,10],
     [4,7,8,9,11]]

num = 3

First idea

If A[0][3] < num and A[3][0] < num, if A[3][3] < num, return -1. Otherwise it must be in the matrix.

If A[0][3] => num or A[3][0] => num, Iterate through the last column, find first row where A[row][3] > num. Then iterate that row from the back. If not found, iterate through the last row, do the same. This should take O(N+M) in the worst case.

def search_sorted_matrix(A, num):
    if num < A[0][0] or num > A[M][N]:
        return 0
    if num == A[0][0] or num == A[M][N]:
        return 1
    if num < A[0][N]: # case for small numbers, first row
        initidx = 0
        maxidx = N
        while True:
            i = (maxidx + initidx) // 2
            if A[0][i] == num:
                return 1
            elif A[0][i] > num:
                maxidx = i
            elif A[0][i] < num:
                initidx = i

            if A[0][initidx] < num and A[0][initidx+1] > num:
                break

        # we need to look in column initidx
        for i in range(initidx+1, 0, -1):
            for j in range(N):
                if A[i][j] == num:
                    return 1
                if A[i][j] > num:
                    break

    elif num < A[M][0]:  # case for small numbers, first column
        initidx = 0
        maxidx = M
        while True:
            i = (maxidx + initidx) // 2
            if A[i][0] == num:
                return 1
            elif A[i][0] > num:
                maxidx = i
            elif A[i][0] < num:
                initidx = i

            if A[initidx][0] < num and A[initidx+1][0] > num:
                break

        # we need to look in column initidx
        for i in range(initidx+1, 0, -1):
            for j in range(M):
                if A[j][i] == num:
                    return 1
                if A[j][i] > num:
                    break

    elif num > A[M][0]:
        initidx = 0
        maxidx = N
        while True:
            i = (maxidx + initidx) // 2
            if A[M][i] == num:
                return 1
            elif A[M][i] > num:
                maxidx = i
            elif A[M][i] < num:
                initidx = i

            if A[M][initidx] < num and A[M][initidx+1] > num:
                initidx += 1
                break

        # we need to look in column initidx
        for j in range(initidx, N):
            for i in range(M-1, 1, -1):
                if A[i][j] == num:
                    return 1
                if A[i][j] < num:
                    break

    return 0

Complexity is O(logN + ) or O(logM + ), bffI don't know exactly. O(1) space.

We only want to check rows where num > A[i][0] and num < A[i][M-1]. If it's equal to the top or botton fow, return 1. For example, only rows 1,2 qualify. We check the columns, and maybe 3,4 qualify. Then we only have to check the mini-matrix of the rows and cols.

example:

data = [
        ([
            [0, 1, 3, 5, 6],
            [1, 2, 4, 9, 10],
            [3, 12, 13, 14, 15],
            [6, 17, 18, 25, 26],
            [8, 22, 23, 28, 30]
        ], 4, 0)
    ]
minimatrix = [[0, 1, 3],
              [1, 2, 4],
              [3, 12, 13]]

This is recursive, until we either find it or end up with a minimatrix of size 1 and A[i][j] != num, return 0

def search_sorted_matrix(A, num):
    if num < A[0][0] or num > A[M-1][N-1]:
        return 0
    if num == A[0][0] or num == A[M-1][N-1]:
        return 1

    rows = [0, M-1]
    cols = [0, N-1]

    while True:
        new_rows = []
        new_cols = []
        flag = False

        # check bounded rows
        for i in range(rows[0], rows[1]+1):
            if A[i][cols[0]] == num or A[i][cols[1]] == num:
                return 1
            if A[i][cols[0]] > num:
                new_rows.append(i - 1)
                flag = False
                break
            if i == rows[1]:
                new_rows.append(i)
                flag = False
                break
            if A[i][cols[0]] < num and A[i][cols[1]] > num:
                if not flag:
                    new_rows.append(i)
                    flag = True

        # check bounded cols
        for j in range(cols[0], cols[1]+1):
            if A[rows[0]][j] == num or A[rows[1]][j] == num:
                return 1
            if A[rows[0]][j] > num:
                new_cols.append(j - 1)
                flag = False
                break
            if j == cols[1]:
                new_cols.append(j)
                flag = False
                break
            if A[rows[0]][j] < num and A[rows[1]][j] > num:
                if not flag:
                    new_cols.append(j)
                    flag = True

        rows = new_rows
        cols = new_cols

    return 0

The matrix is recursively getting smaller, need to check for false test cases:

def search_sorted_matrix(A, num):
    if num < A[0][0] or num > A[M-1][N-1]:
        return 0
    if num == A[0][0] or num == A[M-1][N-1]:
        return 1

    rows = [0, M-1]
    cols = [0, N-1]

    while True:
        new_rows = []
        new_cols = []
        flag = False

        if rows[0] == rows[1]:
            for j in range(cols[0], cols[1]+1):
                if A[rows[0]][j] == num:
                    return 1
            return 0

        if cols[0] == cols[1]:
            for i in range(rows[0], rows[1]+1):
                if A[i][cols[0]] == num:
                    return 1
            return 0

        # check bounded rows
        for i in range(rows[0], rows[1]+1):
            if A[i][cols[0]] == num or A[i][cols[1]] == num:
                return 1
            if A[i][cols[0]] > num:
                new_rows.append(i - 1)
                flag = False
                break
            if i == rows[1]:
                new_rows.append(i)
                flag = False
                break
            if A[i][cols[0]] < num and A[i][cols[1]] > num:
                if not flag:
                    new_rows.append(i)
                    flag = True

        # check bounded cols
        for j in range(cols[0], cols[1]+1):
            if A[rows[0]][j] == num or A[rows[1]][j] == num:
                return 1
            if A[rows[0]][j] > num:
                new_cols.append(j - 1)
                flag = False
                break
            if j == cols[1]:
                new_cols.append(j)
                flag = False
                break
            if A[rows[0]][j] < num and A[rows[1]][j] > num:
                if not flag:
                    new_cols.append(j)
                    flag = True

        if len(new_rows) == len(new_cols) == 1 or new_rows[0] == new_rows[1] and new_cols[0] == new_cols[1]:
            return 0

        rows = new_rows
        cols = new_cols

    return 0

This works for occurring numbers and missing numbers also.

Hints and solution

We can do binary search in each column, but that takes O(Mlog(N)). one binary search per row.

Keep track of possible rows and cols, where num > first elem, and num < last elem, and keep track of this using arrays, move up, down, left and right around the rows and the columns.

If we compare num to the center number in the matrix, we can eliminate roughly one quarter. hm.

p410, I don't really understand it. We can discard quarters of the matrix by checking the middle values in the matrix and recursively check the lower left quadrant and the upper right quadrant.