First, let k be the smallest power of 2 greater than or equal to sqrt(n). k is still O(sqrt(n)) so this won't change the complexity.
To construct the full k by k table, we construct it one row at a time.
We start with the 0th row: this is easy, because 0 xor j = j.
for i in xrange(k):
result[0][i] = i
Next, we go over the rows in gray-code order. The gray-code is a way of counting every number from 0 to one-less than a power of 2 by changing one bit at a time.
Because of the gray-code property, we're changing the row number by 1 bit, so we have an easy job computing the new row from the old since the xors will only change by 1 bit.
last = 0
for row in graycount(k):
if row == 0: continue
bit_to_change = find_changed_bit(last, row)
for i in xrange(k):
result[row][i] = flip_bit(result[last][i], bit_to_change))
last = row
We need some functions to help us here. First a function that finds the first bit that's different.
def find_changed_bit(a, b):
i = 1
while True:
if a % 2 != b % 2: return i
i *= 2
a //= 2
b //= 2
We need a function that changes a bit in O(1) time.
def flip_bit(a, bit):
thebit = (a // bit) % 2
if thebit:
return a - bit
else:
return a + bit
Finally, the tricky bit: counting in gray codes. From wikipedia, we can read that an easy gray code can be obtained by computing xor(a, a // 2).
def graycount(a):
for i in xrange(a):
yield slow_xor(a, a // 2)
def slow_xor(a, b):
result = 0
k = 1
while a or b:
result += k * (a % 2 == b % 2)
a //= 2
b //= 2
k *= 2
return result
Note that the slow_xor is O(number of bits in a and b), but that's ok here since we're not using it in the inner loop of the main function.