Why your bisection search is wrong

Cosmin
Cosmin Negruseri
19 mai 2014

What is bisection search? The bisection method or bisection search is a numerical algorithm for finding a value x such that f(x) = 0, for a given continuous function f. It works by repeatedly bisecting an interval and choosing a subinterval that contains x. It's pretty simple and robust, but it has few gotchas.

Let's solve the following problem:

For a given number c find it's cubic root using the +, -, *, / operations.

Try solving the problem on your own, before reading below.

Let's choose f(x) = x3 - c. f is continuous and x is the cubic root of c, when f(x) = 0. Thus, we can apply the bisection method.

def cubicRoot(c):
  lo, hi = 0.0, c
  while lo * lo * lo != c:
    mid = (lo + hi) / 2
    if mid * mid * mid < c:
       lo = mid
    else:
       hi = mid

  return lo

Any bugs? Well, quite a few. Try to spot as many as you can, before reading on.

You may notice the precision issue right from the start. We'll discuss it a bit later.

What else? The code doesn’t work for negative values of c. This is easily fixable:

def cubicRoot(c):
  lo, hi = 0.0, c
  while lo * lo * lo != c:
    mid = (lo + hi) / 2
    if mid * mid * mid < c:
       lo = mid
    else:
       hi = mid

  return lo

def cubicRoot(c):
   if c < 0:
       return - _cubicRoot(-c)
   else:
       return _cubicRoot(c)
}

What else? This code doesn't work for c = 1/1000. Why is that? We’re setting the cubic root upper limit to c (line 3). But, the cubic root of c in (0, 1) is larger than c, meaning the upper bound we’re setting is wrong.

Let's fix it:

def _cubicRoot(c):
    lo, hi = 0.0, max(1, c)

    while lo * lo * lo != c:
        mid = (lo + hi) / 2
        if (mid * mid * mid < c):
            lo = mid
        else:
            hi = mid

    return lo

def cubicRoot(c):
    if c < 0:
        return -_cubicRoot(-c)
    else:
        return _cubicRootABS(c)

Going back back to the precision issue:

If you've read nearly all binary searches and merge sorts are wrong you know that mid = (lo + hi) / 2 has an overflow problem. So we change that line to mid = lo + (hi - lo) / 2.

Testing for equality doesn't work with floating point numbers. The first idea that comes to mind is to use an absolute error comparison (|a - b| < eps).

Instead of:

while lo * lo * lo != c:

switch to

while abs(c - lo * lo * lo) > 1e-3:

This doesn't work. For large numbers like 1e37 the code goes into an infinite loop. Try to figure out why. Discuss it in the comment section. Let’s try using the relative error (|(a - b) / b| < eps). There are some weird cases when a and b are close or equal to 0. Can the code be cleaner?

After each while loop iteration we learn something new about x’s range. A double has only 64 bits of precision. So instead of a tricky floating point stopping criteria we can run the loop a fixed number of times so that the interval is as small as the precision of our floating point numbers allows.

Tomek Czajka(of topcoder fame) pointed out that my final version was buggy as well. I chose the number of iterations to be 120 but that’s way too small. It doesn't work for c = 1e60.

A double is represented by the mantissa which consists of 52 bits and the exponent which contains 11 bits(signed). One loop iteration either decreases the exponent of our interval by 1 or we find out a new bit of the mantissa. The maximum value of the exponent is 210 and the mantissa has 52 bits. Thus we need about 1100 steps to figure out the answer.

def _cubicRoot(c):
    lo, hi = 0.0, max(1, c)

    for iter in range(0, 1100):
        mid = lo + (hi - lo) / 2
        if (mid * mid * mid < c):
            lo = mid
        else:
            hi = mid

    return lo

def cubicRoot(c):
    if c < 0:
        return -_cubicRoot(-c)
    else:
        return _cubicRoot(c)

No more epsilon! But now, because of cases with large exponents, the code runs pretty slow on all cases. An idea is to stop as soon as we don't decrease the lo..hi interval, instead of doing a constant number of iterations. So here’s this faster version:

def _cubicRoot(c):
    lo, hi = 0.0, max(1, c)
    prev_range_size = hi - lo
    while True:
        mid = lo + (hi - lo) / 2
        if (mid * mid * mid < c):
            lo = mid
        else:
            hi = mid
        if hi - lo == prev_range_size:
            break
        prev_range_size = hi - lo

    return lo

def cubicRoot(c):
    if c < 0:
        return -_cubicRoot(-c)
    else:
        return _cubicRoot(c)

This is still slow when we’re dealing with large numbers. To address it one can binary search first on the exponent, or get close to the real exponent by dividing the original exponent by 3. Try it out in the comment section.

I'm curious, did your solution have any of these problems?

notes:

  • Thanks Tomek for pointing out the iteration problem and possible efficient solutions.
  • When c is large, mid * mid * mid will overflow but the algorithm still works.
  • We’ve addressed negative numbers, numbers in (0, 1), overflow problems, absolute and relative error problems.
  • Some tests for your own code -1, 0, 1, 2, 8, 0.001, 1e30, 1e60
  • In practice faster methods like Newton Rapson method are used.
Categorii:
remote content