おたくのスタジオ

Counting Bits

好久没更新博客了,今天就来诈尸一下。今天主要讲一道算法题,目前是LeetCode OJ上通过率最高的题了(我是按通过率排序来做题的==)。题目如下:

Given a non negative integer number num. For every numbers i in the range 0 ≤ i ≤ num calculate the number of 1’s in their binary representation and return them as an array.

Example:

For num = 5 you should return [0,1,1,2,1,2].

Follow up:

  • It is very easy to come up with a solution with run time O(n*sizeof(integer)). But can you do it in linear time O(n) /possibly in a single pass?
  • Space complexity should be O(n).
  • Can you do it like a boss? Do it without using any builtin function like __builtin_popcount in c++ or in any other language.

题目大意就是要依次统计所有num以下的数中二进制表示1的个数。首先,很容易想到单独求一个数的二进制表示1的个数需要O(sizeof(integer))时间,然后总共有num个数,故需要花费O(n*sizeof(integer))的时间。说到如果是单独去求一个数的二进制表示1的个数,可以使用位操作符的办法,如下:

int bitcount(unsigned x)
{
    int b;
    for(b = 0; x != 0; b++)
        x &= (x - 1);
    return b;
}

这个思路也比较容易理解。对于一个数而言,最后一位如果是1,减1之后前面的位并不受影响,该位的1已经被纳入统计;而如果是0,减1之后该位为1,1&0=0,该位并不受影响,但这带动了前面减1的过程,从而计算前面的位中1的个数。而只要统计过的位数,该位在之后运算中一定为0。所以,其实这个算法是从最低位到最高位按位统计1的个数的。

不过这道题用这种方法做并不能加速。看题目的尿性,感觉到之前算出来的值应该能重复利用(有点动态规划的感觉==)。没忍住看了下题目的提示,果然:

You should make use of what you have produced already.

然后就是一个找规律的过程。题目又给了另外一个提示:

Divide the numbers in ranges like [2-3], [4-7], [8-15] and so on. And try to generate new range from previous.

看完这个提示,我在草稿纸上画了一通,发现了如下的规律:

0 -> 0

1 -> 1

2 -> 10

3 -> 11

4 -> 100

5 -> 101

6 -> 110

7 -> 111

在计算某个数二进制表示1的个数时,例如7,其实当前最高位的1是已知的,然后剩下的11统计其实是从7-4=3来的。也就是说,根据提示的分组方法,每个组的二进制表示1的个数为上一组的对应某些数的二进制表示1的个数+1,比方说7与3对应,6与2对应,3与1对应,2与0对应,等等。

于是,按照这个思路,我们得到了如下的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class Solution {
public:
vector<int> countBits(int num) {
vector<int> count;
count.push_back(0);
int base = 1;
while(num >= base) {
int nextbase = base << 1;
for(int i = base; i < nextbase; i++) {
count.push_back(count[i - base] + 1);
if(i == num)
break;
}
base = nextbase;
}
return count;
}
};

这个代码里不太优美的地方,在于每次内循环都要判断什么时候计算完毕,就是,没有把num和分组关联起来。另外,其实还有第三个提示,但是我直接就没有看了。。。提示是这样的:

Or does the odd/even status of the number help you in calculating the number of 1s?

提示的意思是说,一个数的奇偶性在计算其二进制表示1的个数中发挥了重大作用。让我们再回头看一下那个例子。我们发现对于偶数而言,它的最后一位必定是0,也就是说它的二进制表示1的个数与将它的0抹去(右移位,其实就是除以2)之后的那个数的二进制表示1的个数应该是相同的。例如,6与3对应,4与2对应,它们的二进制表示1的个数都各自相同。如果是奇数呢,做法类似,就是比右移位后的那个数的二进制表示1的个数多1。这也是显然的。所以,这种思路的代码如下:

1
2
3
4
5
6
7
8
9
10
11
class Solution {
public:
vector<int> countBits(int num) {
vector<int> count;
count.push_back(0);
for(int i = 1; i <= num; i++) {
count.push_back(count[i >> 1] + (i & 1));
}
return count;
}
};

在discuss板块,我又发现了一个比较漂亮的解法,该解法运用了最开始所说的最低位开始统计1的个数的方法,代码如下:

1
2
3
4
5
6
7
8
9
class Solution {
public:
vector<int> countBits(int num) {
vector<int> ret(num+1, 0);
for (int i = 1; i <= num; ++i)
ret[i] = ret[i&(i-1)] + 1;
return ret;
}
};

该方法的巧妙之处在于不直接统计一个数的二进制表示1的个数,因为i&(i-1)表示消去i二进制表示最低位的1之后剩下的数,所以这个数的二进制表示1的个数一定比i小1。这就是代码中循环所表达的意思。

然而,这篇文章还没有结束。作为压轴,不得不提一下一个专门统计一个数二进制表示1的个数的算法——SWAR算法。算法的解释参见这里这里。但为了让自己理解,我还是手写一遍算法解释。首先,算法如下:

int SWAR(unsigned int i)
{
    i = i - ((i >> 1) & 0x55555555);
    i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
    return (((i + (i >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24;
}

然后来逐一分析。先看第一行:

i = i - ((i >> 1) & 0x55555555);

0x55555555的二进制展开形式是01010101010101010101010101010101,可以看到,这个掩码的作用是将一个数的二进制表示偶数位置0。i >> 1很简单,就是将i右移1位。所以第一行的意思就是,第一步,将i右移1位之后得到的数的二进制表示偶数位置0,奇数位则保持不变,得到中间结果j;第二步,把第一步得到的数ji中减去。这么说太抽象,举个简单的例子来看,假设我们的i用二进制表示后得到ba(前面均为0,这里省略),那么经过第一步操作得到j0b,然后我们i-j可以得到ba-0b的形式。这时,我们有了令人惊奇的发现:

无论ab取何值(实际上它们也只能取0或1),ba-0b的值(转为十进制)永远等于ba中1的个数。那如果i不止两位呢?容易看出,i-j的最后两位依旧符合上述规律。不仅如此,三四位,五六位,乃至更高位,均是如此。这也很容易解释:ba-0b运算时根本不可能向前一位进位,所以前一位不受任何影响;而且,前面位的运算格式也一定类似ba-0b,因为j右移了1位,而且偶数位又置为0。所以综上,这一行就是把32位数分为16个两位为单位的组,每组分别计算1出现的次数。

然后,我们来看第二行:

i = (i & 0x33333333) + ((i >> 2) & 0x33333333);

这一行要简单一些。0x33333333的二进制形式是00110011001100110011001100110011i & 0x33333333就是取了奇数组的值,(i >> 2) & 0x33333333就是取了偶数组的值,两者相加,就是归并成以4位为单位进行分组,每组分别计算1出现的次数了。注意这里是分开去与而不是(i + (i >> 2)) & 0x33333333,因为直接i + (i >> 2)是可能出现进位的。==

接下来,是第三行:

return (((i + (i >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24;

第三行的意图与第二行类似。而且,因为上一行的保证,以4位为单位,每个4位上的值最多不超过10+10=0100,而就算是0100+0100也不可能引起到上一组的进位,所以可以直接用(i + (i >> 4)) & 0x0F0F0F0F。这样,得到了一个以8位为单位进行分组,每组统计1出现次数的数k了。然后,我们来观察一下0x01010101。这个数其实是:

0x01010101 = (1 << 24) + (1 << 16) + (1 << 8) + 1

所以,我们有

k * 0x01010101 = (k << 24) + (k << 16) + (k << 8) + k

这里,我们只要关注结果的最高8位就可以了。显然,k << 24的最高8位就是k的最低8位分组,而k的最高8位当然就是k的最高8位分组,四者相加,结果的最高8位就是i的二进制表示1的个数(不用担心有进位)!我们通过右移24位,就可以得到最终想要的结果了!

还有种稍微容易理解的写法,如下:

int SWAR(unsigned int i)  
{  
    x = (x & 0x55555555) + ((x >> 1) & 0x55555555);  
    x = (x & 0x33333333) + ((x >> 2) & 0x33333333);  
    x = (x & 0x0f0f0f0f) + ((x >> 4) & 0x0f0f0f0f);  
    i = (i*(0x01010101) >> 24)  
    return i;  
}  

第一行的意思更加直接明了。就是一个数二进制表示的奇数位和偶数位相加,然后以2位为单位分组。奇数位和偶数位相加得到的值就是每组1出现次数(十进制表示)。