Java 数据结构与算法-堆

堆的基础知识

堆是一种特殊的树形数据结构。根据根节点的值与子节点的值的大小关系,堆又分为最大堆和最小堆。在最大堆中,每个节点的值总是大于或等于其任意子节点的值,因此最大堆的根节点就是整个堆的最大值。在最小堆中,每个节点的值总是小于或等于其任意子节点的值,因此最小堆的根节点就是整个堆的最小值

堆通常用完全二叉树实现。在完全二叉树中,除最底层之外,其他层都被节点填满,最底层尽可能从左到右插入节点……

完全二叉树又可以用数组实现,因此堆也可以用数组实现……

为了在最大堆中添加新的节点,应该先从上到下、从左到右找出第 1 个空缺的位置,并将新节点添加到该空缺位置。如果新节点的值比它的父节点的值大,那么交换它和它的父节点。重复这个过程,直到新节点的值小于或等于它的父节点,或者它已经到达堆的顶部位置。在最小堆中添加新节点的过程与此类似,唯一的不同是要确保新节点的值要大于或等于它的父节点……

通常只删除位于堆顶部的元素。如果删除最大堆的顶部节点,则将堆最底层最右边的节点移到堆的顶部。如果此时它的左子节点或右子节点的值大于它,那么它和左右子节点中值较大的节点交换。如果交换之后节点的值仍然小于它的子节点的值,则再次交换,直到该节点的值大于或等于它的左右子节点的值,或者到达最底层为止。删除最小堆的顶部节点的过程与此类似,唯一的不同是要确保节点的值要小于它的左右子节点的值……

堆的插入、删除操作都可能需要交换节点,以便把节点放到合适的位置,交换的次数最多为二叉树的深度,因此如果堆中有 n 个节点那么它的插入和删除操作的时间复杂度都是 O(logn)

Java 提供了类型 PriorityQueue 实现数据结构堆。PriorityQueue 在默认情况下是一个最小堆,如果使用最大堆调用构造函数就需要传入 Comparator 改变比较排序的规则。PriorityQueue 实现了接口 Queue,它常用的函数如下表所示:

操作抛异常不抛异常
插入新的元素add(e)offer(e)
删除堆顶元素removepoll
返回堆顶元素elementpeek

PriorityQueue 和其他实现接口 Queue 的类型一样,在某些时候调用函数 add、remove 和 element 时可能会抛出异常,但调用函数 offer、poll 和 peek 不会抛异常。例如,如果调用函数 remove 从一个空堆中删除堆顶元素,就会抛出异常。但如果调用函数 poll 从一个空堆中删除堆顶元素,则会返回 null

值得强调的是,虽然 Java 中的 PriorityQueue 实现了 Queue 接口,但它并不是一个队列,也不是按照 “先入先出” 的顺序删除元素的……PriorityQueue 的删除顺序与元素添加的顺序无关

同理,PriorityQueue 的函数 element 和 peek 都返回位于堆顶的元素,即根据堆的类型返回值最大或最小的元素,这与元素添加的顺序无关

堆的应用

堆最大的特点是最大值或最小值位于堆的顶部,只需要 O(1) 的时间就可以求出一个数据集合中的最大值或最小值,同时在堆中添加或删除元素的时间复杂度都是 O(logn),因此综合来看堆是一个比较高效的数据结构。如果面试题需要求出一个动态数据集合中的最大值或最小值,那么可以考虑使用堆来解决问题

堆经常用来求取一个数据集合中值最大或最小的 k 个元素。通常,最小堆用来求取数据集合中 k 个值最大的元素,最大堆用来求取数据集合中 k 个值最小的元素

接下来使用最小堆或最大堆解决几道典型的算法面试题

面试题 59:数据流的第 k 大数字

题目:请设计一个类型 KthLargest,它每次从一个数据流中读取一个数字,并得出数据流已经读取的数字中第 k (k≥1) 大的数字。该类型的构造函数有两个参数:一个是整数 k,另一个是包含数据流中最开始数字的整数数组 nums(假设数组 nums 的长度大于 k)。该类型还有一个函数 add,用来添加数据流中的新数字并返回数据流中已经读取的数字的第 k 大数字

与数据流相关的题目的特点是输入的数据是动态添加的,也就是说,可以不断地从数据流中读取新的数据,数据流的数据量是无限的。在这个题目中,类型 KthLargest 的函数 add 用来添加从数据流中读出的新数据

解决这个问题的关键在于选择合适的数据结构。如果数据存储在排序的数组中,那么只需要 O(1) 的时间就能找出第 k 大的数字。但这个直观的方法有两个缺点。首先,需要把从数据流中读取的所有数据都存到排序数组中,如果从数据流中读出 n 个数字,那么动态数组的大小为 O(n)。随着不断从数据流中读出新的数据,O(n) 的空间复杂度可能会耗尽所有的内存。其次,在排序数组中添加新的数字的时间复杂度也是 O(n)

……

基于最小堆的参考代码如下所示:

class KthLargest {
    private PriorityQueue<Integer> minHeap;
    private int size;

    public KthLargest(int k, int[] nums) {
        size = k;
        minHeap = new PriorityQueue<>();
        for (int num : nums) {
            add(num);
        }
    }

    public int add(int val) {
        if (minHeap.size() < size) {
            minHeap.offer(val);
        } else if (val > minHeap.peek()) {
            minHeap.poll();
            minHeap.offer(val);
        }
        return minHeap.peek();
    }
}

在上述代码中,minHeap 是一个最小堆。由于 minHeap 中最多保存 k 个数字,因此它的空间复杂度是 O(k)。在函数 add 中,需要在最小堆中添加、删除一个元素,并返回它的堆顶元素,因此每次调用函数 add 的时间复杂度是 O(logk)

假设数据流中总共有 n 个数字。这种解法特别适合 n 远大于 k 的场景。当 n 非常大时,内存可能不能容纳数据流中的所有数字。但使用最小堆之后,内存中只需要保存 k 个数字,空间效率非常高

面试题 60:出现频率最高的 k 个数字

题目:请找出数组中出现频率最高的 k 个数字

……首先要想到的是解决这个题目需要用到哈希表……哈希表的键是数组中出现的数字,而值是数字出现的频率

接下来找出现频率最高的 k 个数字。可以用一个最小堆存储频率最高的 k 个数字,堆中的每个元素是数组中的数字及其在数组中出现的次数……在用哈希表统计完数组中每个数字的频率之后,再逐一扫描哈希表中每个从数字到频率的映射,以便找出出现频率最高的 k 个数字……

public static List<Integer> topKFrequent(int[] nums, int k) {
    Map<Integer, Integer> numToCount = new HashMap<>();
    for (int num : nums) {
        numToCount.put(num, numToCount.getOrDefault(num, 0) + 1);
    }
    Queue<Map.Entry<Integer, Integer>> minHeap = new PriorityQueue<>(Comparator.comparingInt(Map.Entry::getValue));
    for (Map.Entry<Integer, Integer> entry : numToCount.entrySet()) {
        if (minHeap.size() < k) {
            minHeap.offer(entry);
        } else {
            if (entry.getValue() > minHeap.peek().getValue()) {
                minHeap.poll();
                minHeap.offer(entry);
            }
        }
    }
    List<Integer> result = new LinkedList<>();
    for (Map.Entry<Integer, Integer> entry : minHeap) {
        result.add(entry.getKey());
    }
    return result;
}

在上述代码中,哈希表 numToCount 用来统计数字出现的频率,它的键是数组中的数字,值是数字在数组中出现的次数。最小堆 minHeap 中的每个元素是哈希表中从数字到频率的映射。由于最小堆比较的是数字的频率,因此调用构造函数创建 minHeap 设置的比较规则是比较哈希表中映射的值,也就是数字的频率

假设输入数组的长度为 n。上述代码需要一个大小为 O(n) 的哈希表,以及一个大小为 O(k) 的最小堆,因此总的空间复杂度是 O(n)。在大小为 k 的堆中进行添加或删除操作的时间复杂度是 O(logk),因此上述代码的时间复杂度是 O(nlogk)

面试题 61:和最小的 k 个数对

题目:给定两个递增排序的整数数组,从两个数组中各取一个数字 u 和 v 组成一个数对 (u, v),请找出和最小的 k 个数对

使用最大堆(输出是和递减的最小的 k 个数对)

这个题目要求找出和最小的 k 个数对,可以用最大堆来存储这 k 个和最小的数对。逐一将 m*n 个数对添加到最大堆中。当堆中的数对的数目小于 k 时,直接将数对添加到堆中。如果堆中已经有 k 个数对,那么先要比较待添加的数对之和及堆顶的数对之和(也是堆中最大的数对之和)……

接下来考虑如何优化。题目给出的条件是输入的两个数组都是递增排序的,这个特性我们还没有用到。如果从第 1 个数组中选出第 k+1 个数字和第 2 个数组中的某个数字组成数对 p,那么该数对之和一定不是和最小的 k 个数对中的一个,这是因为第 1 个数组的前 k 个数字和第 2 个数组中的同一个数字组成的 k 个数对之和都要小于数对 p 之和。因此,不管输入的数组 nums1 有多长,最多只需要考虑前 k 个数字。同理,不管输入的数组 nums2 有多长,最多也只需要考虑前 k 个数字。优化后的代码如下所示:

public static List<List<Integer>> kSmallestPairs1(int[] nums1, int[] nums2, int k) {
    Queue<int[]> maxHeap = new PriorityQueue<>((p1, p2)
            -> p2[0] + p2[1] - p1[0] - p1[1]);
    for (int i = 0; i < Math.min(k, nums1.length); i++) {
        for (int j = 0; j < Math.min(k, nums2.length); j++) {
            if (maxHeap.size() >= k) {
                int[] root = maxHeap.peek();
                if (root[0] + root[1] > nums1[i] + nums2[j]) {
                    maxHeap.poll();
                    maxHeap.offer(new int[]{nums1[i], nums2[j]});
                }
            } else {
                maxHeap.offer(new int[]{nums1[i], nums2[j]});
            }
        }
    }
    List<List<Integer>> result = new LinkedList<>();
    while (!maxHeap.isEmpty()) {
        int[] vals = maxHeap.poll();
        result.add(Arrays.asList(vals[0], vals[1]));
    }
    return result;
}

在上述代码中,maxHeap 是一个最大堆,它的每个元素都是一个长度为 2 的数组,表示一个数对。每个数对的第 1 个数字来自数组 nums1,第 2 个数字来自数组 nums2。由于希望和最大的数对位于堆的顶部,因此在 PriorityQueue 的构造函数中传入的比较规则比较的是两个数对之和……

上述代码中有两个相互嵌套的 for 循环,每个循环最多执行 k 次。在循环体内可能在最大堆中进行添加和删除操作,由于最大堆中最多包含 k 个元素,因此添加、删除操作的时间复杂度是 O(logk)。这两个 for 循环的时间复杂度是 O(k^2logk)。另外,上述代码还有一个 while 循环,它逐一从最大堆中删除元素并将对应的数对添加到链表中,这个 while 循环的时间复杂度是 O(klogk)。因此,上述代码总的时间复杂度是 O(k^2logk)

使用最小堆(输出是和递增的最小的 k 个数对)

最大堆的方法筛选出最小的数对(通过移出更大的数对),最小堆的方法生产出最小的数对

public static List<List<Integer>> kSmallestPairs2(int[] nums1, int[] nums2, int k) {
    Queue<int[]> minHeap = new PriorityQueue<>((p1, p2)
            -> nums1[p1[0]] + nums2[p1[1]] - nums1[p2[0]] - nums2[p2[1]]);
    if (nums2.length > 0) {
        for (int i = 0; i < Math.min(k, nums1.length); i++) {
            minHeap.offer(new int[]{i, 0});
        }
    }
    List<List<Integer>> result = new ArrayList<>();
    while (k-- > 0 && !minHeap.isEmpty()) {
        int[] ids = minHeap.poll();
        result.add(Arrays.asList(nums1[ids[0]], nums2[ids[1]]));
        if (ids[1] < nums2.length - 1) {
            minHeap.offer(new int[]{ids[0], ids[1] + 1});
        }
    }
    return result;
}

在上述代码中,minHeap 是一个最小堆,它的每个元素都是一个长度为 2 的数组,数组的第 1 个数字表示数对的第 1 个数字在数组 nums1 中的下标,第 2 个数字表示数对的第 2 个数字在数组 nums2 中的下标。由于使用最小堆的目的是找出和最小的数对,因此在创建 minHeap 时在构造函数传入的 lambda 表达式中分别根据数对的两个数字在数组 nums1 和 nums2 的下标读取对应的数字并比较数对之和

上述代码先用一个 for 循环构建一个大小为 k 的最小堆,该循环的时间复杂度是 O(klogk)。接下来是一个执行 k 次的 while 循环,每次对大小为 k 的最小堆进行添加或删除操作,因此这个 while 循环的时间复杂度也是 O(klogk)。上述代码总的时间复杂度为 O(klogk)

本章小结

本章介绍了堆这种数据结构。堆又可以分成最大堆和最小堆。在最大堆中最大值总是位于堆顶,在最小堆中最小值总是位于堆顶。因此,在堆中只需要 O(1) 的时间就能得到堆中的最大值或最小值

堆经常用来解决在数据集合中找出 k 个最大值或最小值相关的问题。通常用最大堆找出数据集合中的 k 个最小值,用最小堆找出数据集合中的 k个最大值

Java 的库中提供了类型 PriorityQueue,虽然该类型实现了接口 Queue,但是它是堆而不是队列。PriorityQueue 的构造函数能传入不同的比较规则,从而创建最大堆或最小堆