• Coding
  • [Exercise] Find the sum of 3 numbers

Given an array A of size n and a number S, find in the most efficient way if there exists 3 numbers that sums up to S

No need for code, pseudo-code is enough.

I want to make sure about my solution, if you want to solve them do not continue reading:

sort A using merge sort in O(nlogn)
use 2 iterators i and j, initailzed to i=0 and j=n-1
while (i < j)
    s2=A[i]+A[j]
    if(s2<S)
        perform a binary search on A to find S-s2 // the third number needed to add up to 100 and check that its index is different that i and j
        if we found it then we return the numbers found
        else i++
    else
        j--
@Ra8 I thought a bit about your algorithm. It looks good at the first look but it doesn't. It may sound reasonable to piggy back on the optimal nlogn solution of the pair sum problem but that may not guarantee it will work on all problem sets. Here is an example:

Assume the set (3,4,5,6,7) and the target sum of a,b,c is 12. A subset for the solution is 3,4,5 but if you trace your algorithm, it wont find it. here is a little trace:

(3,7), binary search for 2, not found i++
(4,7) binary search for 1, not found i++
(5,7) sum > 12 j--
(5,6) binary search for 1 not found i++
(6,6) sum > 12

As far as I can see, there isn't a way to have this more efficient than O(n^2). One approach would be to find all pairs with a 2 nested for loops => (n^2) and put them into a hash table: (sum as key, pair of indices as value). Then loop over the array and check for every element check if the difference between the target an it is a key in the hash, checking a key in the hash is O(1) and the loop is O(n) so you end up with O(n^2)

That would still be more efficient than the naive O(n^3) solution which would constitute 3 nested loops.
Ayman wrote@Ra8 I thought a bit about your algorithm. It looks good at the first look but it doesn't. It may sound reasonable to piggy back on the optimal nlogn solution of the pair sum problem but that may not guarantee it will work on all problem sets. Here is an example:

Assume the set (3,4,5,6,7) and the target sum of a,b,c is 12. A subset for the solution is 3,4,5 but if you trace your algorithm, it wont find it. here is a little trace:

(3,7), binary search for 2, not found i++
(4,7) binary search for 1, not found i++
(5,7) sum > 12 j--
(5,6) binary search for 1 not found i++
(6,6) sum > 12

As far as I can see, there isn't a way to have this more efficient than O(n^2). One approach would be to find all pairs with a 2 nested for loops => (n^2) and put them into a hash table: (sum as key, pair of indices as value). Then loop over the array and check for every element check if the difference between the target an it is a key in the hash, checking a key in the hash is O(1) and the loop is O(n) so you end up with O(n^2)

That would still be more efficient than the naive O(n^3) solution which would constitute 3 nested loops.
Thank you Ayman, I didnt notice my mistake, i thought i could do it in O(nlogn).
You can always fix a number and search for its complement as a pair.
n^2 log n
Have you considered counting the items first? I found it's a lot faster to execute.
Here's the basic idea in Python:
from collections import Counter
import random

asum = random.randrange(200) + 100

# A Counter is a key:value mapping
# a = { i: ocurrence(i) for i in alist}
a = Counter (random.randrange(100) for _ in range(int(1e7)))


def two_sum_gen(counts, total):

    for i in (c for c in counts if c <= total/2):
        if total-i in counts:
            yield (i, total-i)

def three_sum_gen(counts, total):

    for i in (c for c in counts if c <= total/2):
        for a, b in two_sum_gen(counts, total-i):
            yield (i, a, b)
This code is simplified for the sake of clarity. For instance, it only returns the set of valid groups, and purposefully omits the count of each occurrence of the grouping. It also consider that each element occurs an infinite amount of time. So for instance:

list = [1, 5, 3]
total = 9

answer: ((3, 3, 3), (1, 5, 3)) # first one is wrong.

The goal is really to show what I have in mind and not clobber it with stupid details. (I have previously worked on these issues in the original exercise).

Quick explanation
  • Counter is a hashmap that holds each element of a collection as keys and the number of occurrence of each. 2sum becomes easy to implement: for each element n in the collection, retrieve how many times (total-n) has occurred. Since dictionary access is constant in time O(1), the whole function executes in O(n). (Actually, you can easily divide that by 2, by iterating on elements lesser than (or equal to) total/2).
  • 3sum is not very hard to calculate either: For each element n in the collection, retrieve all the pairs in the collections whose sum equals total. (just like @arithma suggests).
The complete answer is a little more complicated. Not included in the code:
  • Each time n looks for (total-n) in the dictionary, it should decrease its own count by 1, to avoid infinite count bug.
  • Each time a match is found, we should count the number of times it occurs by multiplying the count of each element
But for now this should do.

Profiling
I am going to test against this stupid bruteforce test:
asum = random.randrange(200) + 100
a = [random.randrange(300) for _ in range(1e2)]

result = [(i, j, k) for i,j,k in three_sum_gen(a, asum)]
This one clearly executes in O(n3). For an initial list of 300 elements, it takes approximately 3 seconds on my PC (recent i5) to execute. Here's the profiling of the execution:
156946 function calls in 3.328 seconds

Ordered by: standard name

ncalls tottime percall cumtime percall filename:lineno(function)
156943 3.296 0.000 3.296 0.000 <string>:1(<genexpr>)
1 0.032 0.032 3.328 3.328 <string>:1(<module>)
1 0.000 0.000 3.328 3.328 {built-in method exec}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
Here's the execution time of my method using Counters (emphasis mine):
7494 function calls in 0.004 seconds

Ordered by: standard name

ncalls tottime percall cumtime percall filename:lineno(function)
1873 0.001 0.000 0.002 0.000 3sum.py:12(two_sum_gen)
1917 0.001 0.000 0.001 0.000 3sum.py:14(<genexpr>)
1824 0.001 0.000 0.003 0.000 3sum.py:18(three_sum_gen)
51 0.000 0.000 0.000 0.000 3sum.py:20(<genexpr>)
1824 0.001 0.000 0.003 0.000 <string>:1(<genexpr>)
1 0.001 0.001 0.004 0.004 <string>:1(<module>)
1 0.000 0.000 0.004 0.004 {built-in method exec}
1 0.000 0.000 0.000 0.000 {built-in method len}
1 0.000 0.000 0.000 0.000 {built-in method print}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
I tried going to higher values: For 1000 elements, the Counter-based approach takes a split second as well, while the slow approach doesn't finish (at least after a few minutes).

The highest I could get the Counter-based approach was to run it on a 10,000,000 (1e7) element Counter in 13.2 seconds. After that I got bored.

The point is: it can operate on fairly large arrays. If you need more, than probably some optimizations have to be done.

What next
I will try to generalize to groups of n elements whose sum is total.
Here is another approach without using any storage or search. It is based on the optimal solution for the 2 Sum problem with some additions, here are the two algorithms.

2 Sum O(nlogn):
Sort the list then loop through it while keeping two pointers j and k of the left most and right most of the set. On each iteration if the sums of j(left) and k(right) are equal j++ and k--, if sum is greater than target k--, if less j++.
function two_sum(target,set)
{
  set.sort();

  var j = 0;
  var k = set.length-1;

  while(j < k)
  {
      sum = set[j] + set[k];
      if(sum == target)
      {
        console.log(set[j]+", "+set[k]);
        j++;
        k--;
      }
      else if(sum < target)
        j++;
      
       else 
        k--;
  }
}
3 Sum O(n^2):
Sort the list. Introduce a loop wrapping the while loop above and with a new index i denoting the position of the 3rd element. Moves through the whole list from 0 to length and on each iteration while holding the position at i run the while loop on the sorted sublist of the list from i+1 to N to find the second and 3rd numbers just as was done in the 2 Sum problem. Takes O(n^2) time.
function three_sum(target,set)
{
  set.sort();

  for(i = 0; i < set.length; i++)
  {
    var j = i + 1;
    var k = set.length - 1;

    while(j < k)
    {
      var sum = set[i] + set[j] + set[k];
      
      if(j == i || sum < target)
        j++;
      
      else if(k == i || sum > target)
        k--;
      
      else
      {
        console.log(set[i]+", "+set[j]+", "+set[k]);
        j++;
        k--;
      } 
    }
  }
Usage:
var set = [1,8,7,5,9,3,2]
var target = 12;

two_sum(target,set);
three_sum(target,set);
Output:
3, 9
5, 7
1, 2, 9
1, 3, 9
2, 3, 7
I did some reading about it, having an algorithm faster than O(n^2) is within the list of unsolved problems in computer science.
Since dictionary access is constant in time O(1), the whole function executes in O(n). (Actually, you can easily divide that by 2, by iterating on elements lesser than (or equal to) total/2).
Hash access is constant on average and linear at worst. IIRC, that does not qualify as O(1).
@Ayman: If 3sum is using 2sum as a sub program, where is the log n going to? Am not sure, either 2sum is linear, or 3sum is n^2 log n.
@arithma the 2 sum program is linearthmic (nlogn) due to sorting(merge sort), when we are doing nlogn work (sorting) + n work (while loop) it results in nlogn as an upper bound (big oh) since nlogn is greater than n so we can neglect the lower order term because efficiency is usually expressed in terms of the higher order term.

For the 3 sum problem only the while loop is being reused which itself it takes n time of work, since it is within a for loop which also takes n work time then n x n => the're nested which results in O(n^2). The nlogn work is done separately by the sorting before and not within the loop so it doesnt get multiplied but added. So the total work is nlogn(sorting) + n^2(nested loops) which results in a O(n^2) upper bound because n^2 is definitely larger than nlogn.