Given a matrix in which each row and each column is sorted, write a method to find an element in it.

**My initial thoughts:**

We are kind-of doing a 2D binary search here. Each time we find the center of the matrix. If that is the element we are looking for, then return it. Notice that the center divide the matrix into four parts: upper-left, upper-right, bottom-left and bottom-right. If the element we are looking for is less than the center, then it cannot be in the bottom-right part, where every elements are even greater than the center. If the element we are searching is greater than the center, then it cannot be in the upper-left part. Hence, each time we approximately eliminate of the entire matrix. Therefore we have the recursion: , whose solution is .

**My initial codes:**

public static Pair<Integer> findInMatrix(int[][] data, int rowLow, int rowHigh, int colLow, int colHigh, int element) { if (rowLow > rowHigh || colLow > colHigh) return null; new LinkedList<Pair<Integer>>(); int rowMid = (rowLow + rowHigh) / 2; int colMid = (colLow + colHigh) / 2; int mid = data[rowMid][colMid]; if (mid == element) return new Pair<Integer>(rowMid, colMid); else if (element < mid) { Pair<Integer> topLeft = findInMatrix(data, rowLow, rowMid - 1, colLow, colMid - 1, element); if (topLeft == null) { Pair<Integer> topRight = findInMatrix(data, rowLow, rowMid - 1, colMid, colHigh, element); if (topRight == null) { Pair<Integer> bottomLeft = findInMatrix(data, rowMid, rowHigh, colLow, colMid - 1, element); return bottomLeft; } else return topRight; } else return topLeft; } else { // element > mid Pair<Integer> topRight = findInMatrix(data, rowLow, rowMid - 1, colMid + 1, colHigh, element); if (topRight == null) { Pair<Integer> bottomLeft = findInMatrix(data, rowMid + 1, rowHigh, colLow, colMid, element); if (bottomLeft == null) { Pair<Integer> bottomRight = findInMatrix(data, rowMid, rowHigh, colMid + 1, colHigh, element); return bottomRight; } else return bottomLeft; } else return topRight; } }

Solution:

This algorithm works by elimination. Every move to the left (–col) eliminates all the elements below the current cell in that column. Likewise, every move down eliminates all the elements to the left of the cell in that row.public static Pair<Integer> FindElem(int[][] mat, int elem, int M, int N) { int row = 0; int col = N - 1; while (row < M && col >= 0) { if (mat[row][col] == elem) { return new Pair<Integer>(row, col); } else if (mat[row][col] > elem) { col--; } else { row++; } } return null; }Comments: Brilliant. Time complexity .