题目描述
给定两个大小分别为 m
和 n
的正序(从小到大)数组 nums1
和 nums2
。请你找出并返回这两个正序数组的 中位数 。
算法的时间复杂度应该为 O(log (m+n))
。
示例
示例 1:
- 输入
1
nums1 = [1,3], nums2 = [2]
- 输出
1
2.00000
- 解释:合并数组 =
[1,2,3]
,中位数2
示例 2:
- 输入
1
nums1 = [1,2], nums2 = [3,4]
- 输出
1
2.50000
- 解释:合并数组 =
[1,2,3,4]
,中位数(2 + 3) / 2 = 2.5
问题分析
这道题的难点在于要求时间复杂度为 $O\left(\log\left(m+n\right)\right)$。最直观的解法是合并两个数组然后找中位数,但这样的时间复杂度是 $O\left(m+n\right)$,不满足题目要求。
要达到 $O\left(\log\left(m+n\right)\right)$ 的时间复杂度,我们需要使用二分查找的思想。事实上,我们可以将问题转化为寻找两个数组中第 $k$ 小的元素的问题。
解决思路
I. 问题转化
首先将「找中位数」的问题转化为「找第 $k$ 小数」的问题:
-
如果合并后数组长度为奇数,中位数是第 $\dfrac{m+n}{2}+1$ 小的元素
-
如果合并后数组长度为偶数,中位数是第 $\dfrac{m+n}{2}$ 小和第 $\dfrac{m+n}{2}+1$ 小的元素的平均值
II. 二分策略
核心思想是在两个有序数组中找到一个分割线,使得:
1
2
3
nums1: [a[1], a[2], a[3], ..., a[i-1] | a[i], a[i+1], ..., a[m]]
nums2: [b[1], b[2], ..., b[j-1] | b[j], b[j+1], ..., b[n]]
- 分割线左边的所有元素 $\leq$ 分割线右边的所有元素
- 分割线左边的元素个数 $=$ 分割线右边的元素个数(或比右边多一个)
为了实现上述目标,我们需要:
- 确保短数组在前,长数组在后(便于处理边界情况)
- 在较短的数组上进行二分查找,寻找合适的分割位置 $\mathrm{i}$
- 根据 $\mathrm{i}$ 计算出较长数组的分割位置 $\mathrm{j}$ ,满足 $\mathrm{i} + \mathrm{j} = \dfrac{m+n+1}{2}$
III. 终止条件
当我们找到满足以下条件的分割位置时,问题解决: $$\mathrm{maxLeft1} \leq \mathrm{minRight2}$$ $$\mathrm{maxLeft2} \leq \mathrm{minRight1}$$ 这表示分割线左边的所有元素都小于或等于右边的所有元素。
代码实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
if (nums1.size() > nums2.size())
return findMedianSortedArrays(nums2, nums1);
if (nums1.empty())
{
int n = nums2.size();
if (n % 2 == 0)
return (nums2[n/2-1] + nums2[n/2]) / 2.0;
else
return nums2[n/2];
}
int m = nums1.size();
int n = nums2.size();
int left = 0, right = m;
while (left <= right)
{
int i = (left + right) / 2;
int j = (m + n + 1) / 2 - i;
int maxLeft1 = (i == 0) ? -1e6-1 : nums1[i-1];
int minRight1 = (i == m) ? 1e6+1 : nums1[i];
int maxLeft2 = (j == 0) ? -1e6-1 : nums2[j-1];
int minRight2 = (j == n) ? 1e6+1 : nums2[j];
if (maxLeft1 <= minRight2 && maxLeft2 <= minRight1)
{
if ((m + n) % 2 != 0)
return max(maxLeft1, maxLeft2);
else
return (max(maxLeft1, maxLeft2) + min(minRight1, minRight2)) / 2.0;
}
else if (maxLeft1 > minRight2)
right = i - 1;
else
left = i + 1;
}
return 0;
}
};