2044. 统计按位或能得到最大值的子集数目

题目

给你一个整数数组 nums ,请你找出 nums 子集 按位或 可能得到的 最大值 ,并返回按位或能得到最大值的 不同非空子集的数目 。

如果数组 a 可以由数组 b 删除一些元素(或不删除)得到,则认为数组 a 是数组 b 的一个 子集 。如果选中的元素下标位置不一样,则认为两个子集 不同 。

对数组 a 执行 按位或 ,结果等于 a[0] OR a[1] OR ... OR a[a.length - 1](下标从 0 开始)。

示例 1:

输入:nums = [3,1] 输出:2 解释:子集按位或能得到的最大值是 3 。有 2 个子集按位或可以得到 3 :

  • [3]
  • [3,1]

示例 2:

输入:nums = [2,2,2] 输出:7 解释:[2,2,2] 的所有非空子集的按位或都可以得到 2 。总共有 23 - 1 = 7 个子集。

示例 3:

输入:nums = [3,2,1,5] 输出:6 解释:子集按位或可能的最大值是 7 。有 6 个子集按位或可以得到 7 :

  • [3,5]
  • [3,1,5]
  • [3,2,5]
  • [3,2,1,5]
  • [2,5]
  • [2,1,5]

提示:

  • 1 <= nums.length <= 16
  • 1 <= nums[i] <= 10^5

解题思路

这是一道典型的涉及子集和位运算的问题。题目的核心要求分为两步:

  1. 找到所有非空子集的“按位或”能达到的 最大值 是多少。
  2. 统计有多少个不同的非空子集,它们的“按位或”结果等于这个最大值。

看到题目给出的约束条件 1 <= nums.length <= 16,这是一个非常关键的提示。当数组长度 n 这么小的时候(通常 n <= 20),我们基本可以断定,需要使用时间复杂度为指数级别(例如 O(2n⋅poly(n)))的算法。这通常指向了遍历所有子集的解法。

第一步:确定“最大按位或”的值

按位或 (OR) 运算有一个重要的特性:对于任意两个非负整数 aba | b >= aa | b >= b。这意味着,你往一个集合里增加更多的数,对它们进行按位或运算,结果要么保持不变,要么会变得更大(因为可能会有更多的二进制位被置为 1)。

基于这个特性,一个数组的所有元素进行按位或运算,得到的结果一定是所有子集按位或运算可能得到的最大值。

所以,第一步非常简单: 遍历整个 nums 数组,将所有元素进行按位或运算,得到的结果就是我们要找的 最大值,我们称之为 maxOr

例如: 对于 nums = [3, 2, 1, 5]

  • 二进制表示:3 = 011, 2 = 010, 1 = 001, 5 = 101
  • maxOr = 3 | 2 | 1 | 5 = (011_2) | (010_2) | (001_2) | (101_2) = 111_2 = 7
  • 所以,这道题的目标就是找到有多少个子集的按位或结果等于 7。

第二步:找出所有子集,统计满足条件的数量

现在问题转化成了:遍历 nums 的所有非空子集,计算每个子集的按位或,如果结果等于 maxOr,则计数器加一。

由于 n <= 16,总的子集数量是 2n(包括空集),最多是 216=65536 个,这个数量级是完全可以接受的。遍历所有子集主要有两种经典方法:


具体解法

解法一:回溯算法 (DFS)

这是解决子集、排列、组合问题最通用的方法。我们可以设计一个递归函数,通过深度优先搜索(DFS)来探索所有可能的子集。

我们可以定义一个递归函数 dfs(index, currentOr),其中:

  • index:表示当前正要决定 nums[index] 这个元素是选还是不选。
  • currentOr:表示从 nums[0]nums[index-1] 中,已经选择的元素的按位或结果。

递归逻辑如下:

  1. 递归终止条件:当 index 等于数组长度 nums.length 时,说明我们已经对所有元素做出了选择,形成了一个完整的子集。此时,currentOr 就是这个子集的按位或结果。我们判断 currentOr 是否等于 maxOr,如果相等,就将最终的计数器加一。
  2. 递归过程:在 dfs(index, currentOr) 中,我们面临两个选择:
    • 不选 nums[index]:直接进入下一层递归,调用 dfs(index + 1, currentOr)
    • nums[index]:将 nums[index] 加入到当前的或运算中,即 currentOr | nums[index],然后进入下一层递归,调用 dfs(index + 1, currentOr | nums[index])

主函数流程:

  1. 计算出 maxOr
  2. 初始化一个全局计数器 count = 0
  3. 调用 dfs(0, 0) 开始递归。初始的 currentOr 为 0,代表空集的按位或。
  4. 递归结束后,count 中存储的就是所有按位或结果为 maxOr 的子集数量。

注意: 题目要求非空子集。我们的 dfs 会把空集也算进去(当所有元素都不选时),其 currentOr 为 0。但因为题目约束 nums[i] >= 1,所以 maxOr 必然大于 0。因此,空集的或结果 0 绝不会等于 maxOr,所以我们不需要特殊处理空集,最终的计数就是正确的。


解法二:位掩码 (Bitmask)

n 很小的时候,位掩码是生成所有子集的一个非常高效和简洁的方法。

一个 n 位的二进制数可以与一个包含 n 个元素的数组一一对应。如果二进制数的第 i 位是 1,就代表我们选择数组中的第 i 个元素;如果是 0,则不选择。

这样,从 1到 2n−1 的所有整数,就唯一地对应了 nums 数组的所有非空子集

算法流程:

  1. 计算出 maxOr
  2. 获取数组长度 n
  3. 初始化计数器 count = 0
  4. mask = 1 循环到 (1 << n) - 1(即 2n−1)。
    • 对于每一个 mask,我们计算它所代表的子集的按位或:
      • 初始化 currentOr = 0
      • 遍历数组下标 i0n-1
      • 检查 mask 的第 i 位是否为 1(可以通过 (mask >> i) & 1 == 1 来判断)。
      • 如果为 1,则将 nums[i] 并入或运算:currentOr = currentOr | nums[i]
    • 计算完当前子集的 currentOr 后,判断它是否等于 maxOr
    • 如果 currentOr == maxOr,则 count 加一。
  5. 循环结束后,count 就是最终答案。

代码示例

DFS法

class Solution {
private:
    int maxOrValue; // 用于存储目标最大值
    int count;      // 用于计数结果为最大值的子集数量

    /**
     * @brief 回溯辅助函数
     * @param nums 原始数组
     * @param index 当前处理到的元素下标
     * @param currentOrValue 当前子集的按位或结果
     */
    void dfs(const vector<int>& nums, int index, int currentOrValue) {
        // 递归的终止条件:当 index 到达数组末尾时,
        // 说明我们已经对所有元素做出了“选”或“不选”的决定,形成了一个子集。
        if (index == nums.size()) {
            // 检查当前子集的按位或结果是否等于我们寻找的目标最大值
            if (currentOrValue == maxOrValue) {
                count++;
            }
            return;
        }

        // --- 递归过程 ---
        
        // 分支1:不选择 nums[index] 这个元素
        // 直接处理下一个元素,currentOrValue 保持不变
        dfs(nums, index + 1, currentOrValue);

        // 分支2:选择 nums[index] 这个元素
        // 处理下一个元素,并将 nums[index] 的值合并到 currentOrValue 中
        dfs(nums, index + 1, currentOrValue | nums[index]);
    }

public:
    int countMaxOrSubsets(std::vector<int>& nums) {
        // 1. 初始化成员变量
        this->maxOrValue = 0;
        this->count = 0;

        // 2. 计算整个数组的按位或最大值
        for (int num : nums) {
            maxOrValue |= num;
        }
        
        // 如果数组为空,没有任何非空子集,直接返回0
        // (题目限制 1 <= nums.length,所以这里其实不会执行)
        if (maxOrValue == 0) {
            return 0;
        }

        // 3. 从下标 0 开始,启动回溯过程。
        // 初始的按位或值为 0 (代表空集)
        dfs(nums, 0, 0);

        // 4. 返回最终统计的数量
        return count;
    }
};

位掩码法

class Solution {
public:
    int countMaxOrSubsets(std::vector<int>& nums) {
        // 1. 像之前一样,先计算出目标最大值
        int maxOr = 0;
        for (int num : nums) {
            maxOr |= num;
        }

        int n = nums.size();
        int count = 0;
        
        // 2. 计算子集的总数。1 << n 相当于 2^n
        int totalSubsets = 1 << n; // 或者 (int)pow(2, n)

        // 3. 遍历所有非空子集
        // mask 从 1 开始,到 2^n - 1 结束
        int currentOr = 0;
        for (int mask = 1; mask < totalSubsets; ++mask) {
            
            currentOr = 0;

            // 4. 根据 mask 构建子集,并计算其按位或结果
            for (int i = 0; i < n; ++i) {
                // 检查 mask 的第 i 位是否为 1
                // (mask >> i) 将第 i 位移动到最右边
                // & 1         用于判断最右边这位是否是 1
                if ((mask >> i) & 1) {
                    // 如果是 1,说明 nums[i] 在当前子集中
                    currentOr |= nums[i];
                }
            }

            // 5. 检查当前子集的按位或结果是否等于最大值
            if (currentOr == maxOr) {
                count++;
            }
        }

        return count;
    }
};

优化思路

一.剪枝

DFS会构建出每一个子集,直到最后一个元素都考虑完毕,才去检查这个子集的按位或结果是否等于 maxOr

但是按照 按位或 运算的特性:A | B >= A。这个值是只增不减的。 这就带来了一个关键的突破口: 一旦在递归的某个中间步骤,我们当前子集的按位或结果 currentOr 已经等于了最终的 maxOr,那么再往这个子集里添加任何其他数字,最终的按位或结果仍然会是 maxOr

那么加速策略如下:在DFS的任何一步,只要发现 currentOr == maxOr,我们立即停止继续向下递归。我们直接计算出剩下未处理的元素能构成多少个子集,把这个数量加到总数 count 上,然后直接返回。

剩下 k 个元素,它们能构成 $2^k$ 个子集。

二.启发式优化

我们剪枝的效率取决于多快能让 currentOr 达到 maxOrValue按位或运算的特性是,一个数越大,它往往包含的二进制位(尤其是高位)就越多。如果我们在递归时优先考虑那些较大的数,就更有可能更快地“点亮”maxOrValue的所有位,从而更早地触发剪枝条件,砍掉更多不必要的搜索分支。

那么,只需在调用 dfs 之前,对原数组 nums 进行一次降序排序,就能获得更大的剪枝概率。

优化后代码

class Solution {
private:
    int maxOrValue;
    int count;

    void dfs_optimized(const vector<int>& nums, int index, int currentOr) {

        if (currentOr == maxOrValue) { // 位或不会减小,如果已经最大则可以剪枝
            int remaining_elements = nums.size() - index;
            count += (1 << remaining_elements); 
            return;
        }

        if (index == nums.size()) { // 递归完毕,返回
            return;
        }

        // 分支1:不选择 nums[index]
        dfs_optimized(nums, index + 1, currentOr);

        // 分支2:选择 nums[index]
        dfs_optimized(nums, index + 1, currentOr | nums[index]);
    }

public:
    int countMaxOrSubsets(std::vector<int>& nums) {
        this->maxOrValue = 0;
        this->count = 0;

        sort(nums.begin(), nums.end(), greater()); // 进行一次降序排序有利于之后剪枝

        for (int num : nums) {
            maxOrValue |= num;
        }
        
        dfs_optimized(nums, 0, 0);
        return this->count;
    }

};