Join the Stack Overflow Community
Stack Overflow is a community of 6.6 million programmers, just like you, helping each other.
Join them; it only takes a minute:
Sign up

I was solving a programming puzzle involving combinations. It led me to wonderful itertools.combinations function and I'd like to know how it works under the hood. Documentation says that the algorithm is roughly equivalent to the following:

def combinations(iterable, r):
    # combinations('ABCD', 2) --> AB AC AD BC BD CD
    # combinations(range(4), 3) --> 012 013 023 123
    pool = tuple(iterable)
    n = len(pool)
    if r > n:
        return
    indices = list(range(r))
    yield tuple(pool[i] for i in indices)
    while True:
        for i in reversed(range(r)):
            if indices[i] != i + n - r:
                break
        else:
            return
        indices[i] += 1
        for j in range(i+1, r):
            indices[j] = indices[j-1] + 1
        yield tuple(pool[i] for i in indices)

I got the idea: we start with the most obvious combination (r first consecutive elements). Then we change one (last) item to get each subsequent combination.

The thing I'm struggling with is a conditional inside for loop.

for i in reversed(range(r)):
    if indices[i] != i + n - r:
        break

It's very terse, and I suspect that's where all the magic happens. Please, give me a hint so I could figure it out.

Update: I've got a grasp of the idea in question. Now, to fully understand this algorithm, I'm going to prove its corectness mathematically. I'd be thankful if you could help me with the proof (where to start?).

share|improve this question
    
Note that this is only part of the loop. From just that bit, it looks like it would break most of the time, but instead the break will prevent the return in the else from happening. – tobias_k 3 hours ago
up vote 3 down vote accepted

The loop has two purposes:

  1. Terminating if the last index-list has been reached
  2. Determining the right-most position in the index-list that can be legally increased. This position is then the starting point for resetting all indeces to the right.

Let us say you have an iterable over 5 elements, and want combinations of length 3. What you essentially need for this is to generate lists of indexes. The juicy part of the above algorithm generates the next such index-list from the current one:

# obvious 
index-pool:       [0,1,2,3,4]
first index-list: [0,1,2]
                  [0,1,3]
                  ...
                  [1,3,4]
last index-list:  [2,3,4]

i + n - r is the max value for index i in the index-list:

 index 0: i + n - r = 0 + 5 - 3 = 2 
 index 1: i + n - r = 1 + 5 - 3 = 3
 index 2: i + n - r = 2 + 5 - 3 = 4
 # compare last index-list above

=>

for i in reversed(range(r)):
    if indices[i] != i + n - r:
        break
else:
    break

This loops backwards through the current index-list and stops at the first position that doesn't hold its maximum index-value. If all positions hold their maximum index-value, there is no further index-list, thus return.

In the general case of [0,1,4] one can verify that the next list should be [0,2,3]. The loop stops at position 1, the subsequent code

indices[i] += 1

increments the value for indeces[i] (1 -> 2). Finally

for j in range(i+1, r):
    indices[j] = indices[j-1] + 1

resets all positions > i to the smallest legal index-values, each 1 larger than its predecessor.

share|improve this answer

This for loop does a simple thing: it checks whether the algorithm should terminate.

The algorithm start with the first r items and increases until it reaches the last r items in the iterable, which are [Sn-r+1 ... Sn-1, Sn] (if we let S be the iterable).

Now, the algorithm scans every item in the indices, and make sure they still have where to go - so it verifies the ith indice is not the index n - r + i, which by the previous paragraph is the (we ignore the 1 here because lists are 0-based indexed).

If all of these indices are equal to the last r positions - then it goes into the else, commiting the return and terminating the algorithm.


We could create the same functionality by using

if indices == list(range(n-r, n)): return

but the main reason for this "mess" (using the reverse and breaking) is that the first indice from the end that doesn't match is saved inside i and uses for the next level of the algorithm, that increment this indice and takes care of re-setting the rest.


You could check this by replacing the yields with

print('Combination: {}  Indices: {}'.format(tuple(pool[i] for i in indices), indices))
share|improve this answer
    
[Sn-r+1 ... Sn-1, Sn] in the second paragraph should be [Sn-r+i ... Sn-1, Sn], right? – Basil C. 3 hours ago
    
No, this is 1 for the representation of the values (not indices) and n-r+1 is an index in S using conventional 1-based indexing (means in python it would be [S[n-r] ... S[n-2], S[n-1]]. – Uriel Eli 3 hours ago

Source code has some additional information about what is going on.

First yeild yield tuple(pool[i] for i in indices) just return first combination of (a_1, ..., a_r) element and it prepares indexes combination for future work. e.g. we have A='ABCDE' and r=3. After first step indices will be (0, 1, 2) what points to ('A', 'B', 'C').

Looking at source code we see (this is original loop for these you are asking in Python):

2160            /* Scan indices right-to-left until finding one that is not
2161               at its maximum (i + n - r). */
2162            for (i=r-1 ; i >= 0 && indices[i] == i+n-r ; i--)
2163                ;

This loop search what element is not maximum in current state of indicess (remember that after first yield we have indices (0, 1, 2)). It is simple to see that loop will find that last element 2 is not maximum and finish with i==2.

Next part of code do increasing of i element of indices:

2170            /* Increment the current index which we know is not at its
2171               maximum.  Then move back to the right setting each index
2172               to its lowest possible value (one higher than the index
2173               to its left -- this maintains the sort order invariant). */
2174            indices[i]++;

Make indeces[i] (remember than i==2) increased by 1. As result indices will be (0, 1, 3) which points to ('A', 'B', 'D').

And rolling back index if it is to big:

2175            for (j=i+1 ; j<r ; j++)
2176                indices[j] = indices[j-1] + 1;

As conclusion we see how indices are increased step by step:

step indices

  1. (0, 1, 2)
  2. (0, 1, 3)
  3. (0, 1, 4)
  4. (0, 2, 3)
  5. (0, 2, 4)
  6. (0, 3, 4)
  7. (1, 2, 3) ...
share|improve this answer

Your Answer

 
discard

By posting your answer, you agree to the privacy policy and terms of service.

Not the answer you're looking for? Browse other questions tagged or ask your own question.