Skip to main content
modernize code using lru_cache
Source Link
Gareth Rees
  • 50.1k
  • 3
  • 130
  • 211

1. Comments on your codeReview

  1. Your function knapsack lacks a docstring that would explain what arguments the function takes (what kind of things are in items? must items be a sequence, or can it be an iterable?) and what it returns.

Also, this kind of function is ideal for doctests.

    """
    Solve the knapsack problem by finding the most valuable
    subsequence of `items` that weighs no more than `maxweight`.

    `items` is a sequence of pairs `(value, weight)`, where `value` is
    a number and `weight` is a non-negative integer.

    `maxweight` is a non-negative integer.

    Return a pair whose first element is the sum of values in the most
    valuable subsequence, and whose second element is the subsequence.

    >>> items = [(4, 12), (2, 1), (6, 4), (1, 1), (2, 2)]
    >>> knapsack(items, 15)
    (11, [(2, 1), (6, 4), (1, 1), (2, 2)])
    """
  1. YourThe function knapsack lacks a docstring that would explain what arguments the function takes (what kind of things are in items? must items be a sequence, or can it be an iterable?) and what it returns.

  2. This kind of function is ideal for doctests.

  3. The comments say things like "Create an (N+1) by (W+1) 2-d list". But what is N and what is W? Presumably N is len(items) and W is maxweight, but this seems needlessly unclear. Better to putso it would be a couple of lines like thisgood idea to use the same names in the comments and the code:

     N = len(items)
     W = maxweight
    

so that the comments match the code (and then use N and W in the remainder of the code).

  1. The comment above bestvalues fails to explain what the values in this table actually meanmean. I would write something like this instead:

     # bestvalues[i][j] is the best sum of values for any
     # subsequence of the first i items, whose weights sum
     # to no more than j.
    

(This This makes it obvious why 0 ≤ i ≤ N\$0 ≤ i ≤ N\$ and why 0 ≤ j ≤ W\$0 ≤ j ≤ W\$.)

  1. In a loop like:

     bestvalues = [[0] * (maxweight + 1)
                   for i in xrange(len(items) + 1)]
    
  1. You can simplify the code by omitting theseThese lines could be omitted:

     # Increment i, because the first row (0) is the case where no items
     # are chosen, and is already initialized as 0, so we're skipping it
     i += 1
    

and then usingin the rest of the code, use i + 1 instead of i and i instead of i - 1.

  1. YourThe reconstruction loop:

     i = N
     while i > 0:
         # code
         i -= 1
    
    for i in xrangereversed(Nrange(1, 0,N -+ 1)):
        # code
  1. You printThe code prints an error message like this:

     print('usage: knapsack.py [file]')
    

Error messages ought to go to standard error (not standard output). And you can't know that yourit's a good idea not to hard-code the name of the program is called "knapsack.py":, because it might have been renamedwould be easy to rename the program but forget to update the code. So write instead:

    sys.stderr.write('usage"usage: {0} [file]\n'[file]\n".format(sys.argv[0]))
  1. YourThe block of code that reads the problem description and prints the result only runs when __name__ == '__main__'. This makes it hard to test, for example from the interactive interpreter. It's usually best to put this kind of code in its own function, like this:

     def main(filename):
         with open(filename) as f:
             # etc.
    
     if __name__ == '__main__':
         if len(sys.argv) != 2:
             print('usage: knapsack.py [file]')
             sys.exit(1)
         main(sys.argv[1])
    
  1. You readThe code reads the whole of the file into memory as a list of lines:

     lines = f.readlines()
    

this is harmless here because the file is small, but it's a bad habit to get into. It's usually bestbetter to process a file one line at a time if you can, like this:

    with open(filename) as f:
        maxweight = int(next(f))
        items = [map[[int(int,word) for word in line.split())] for line in f]

(Note that this results in slightly simpler code too.)

2. A more Pythonic solution?Recursive approach

Any dynamic programming algorithm can be implemented in two ways: by building a table of partial results from the bottom up (as in yourthe code in the post), or by recursively computing the result from the top down, using memoization to avoid computing any partial result more than once.

There are two advantages of theThe top-down approach: first, it often results in slightly simpler and clearer code, and second, it only computes the partial results that are needed for the particular problem instance (whereas in the bottom-up approach computesit's often hard to avoid computing all partial results even if some of them go unused).

So we could use the @memoized decorator from the Python Decorator Library@functools.lru_cache to implement a top-down solution, as shown below. (The Python wiki seems to be down at the moment, but you can find the code on archive.org.)like this:

from functools import lru_cache

def knapsack(items, maxweight):
    """
    Solve"""Solve the knapsack problem by finding the most valuable
  subsequence
   subsequence of `items` subjectitems that weighs no more than
    `maxweight`maxweight.

    `items`items ismust be a sequence of pairs `(value, weight)`, where `value`value is
  a
   a number and `weight`weight is a non-negative integer.

    `maxweight`maxweight is a non-negative integer.

    Return a pair whose first element is the sum of values in the most
    valuable subsequence, and whose second element is the subsequence.

    >>> items = [(4, 12), (2, 1), (6, 4), (1, 1), (2, 2)]
    >>> knapsack(items, 15)
    (11, [(2, 1), (6, 4), (1, 1), (2, 2)]) 

    """
    @lru_cache(maxsize=None)
    def bestvalue(i, j):
        # Return the value of the most valuable subsequence of the first i
        # i elements in items whose weights sum to no more than j.
    @memoized
    def bestvalue(i, j):
        if ij ==< 0: return 0
        value, weight = items[i -return 1]float('-inf')
        if weighti >== j0:
            return bestvalue(i - 1, j)0
        else:
value, weight = items[i - 1]
        return max(bestvalue(i - 1, j),
                       bestvalue(i - 1, j - weight) + value)

    j = maxweight
    result = []
    for i in xrangereversed(range(len(items), 0, -1)):
        if bestvalue(i + 1, j) != bestvalue(i - 1, j):
            result.append(items[i - 1]items[i])
            j -= items[i - 1][1]items[i][1]
    result.reverse()
    return bestvalue(len(items), maxweight), result

To see how many partial solutions this code needs to compute, print len(bestvalue.cachecache_info() just before returning the result. When solving the example problem in the docstring, I find that this computes 37 partial solutions (compared tooutputs:

CacheInfo(hits=17, misses=42, maxsize=None, currsize=42)

The 42 entries in the cache are fewer than the 96 partial solutions computed by the bottom up approach).

1. Comments on your code

  1. Your function knapsack lacks a docstring that would explain what arguments the function takes (what kind of things are in items? must items be a sequence, or can it be an iterable?) and what it returns.

Also, this kind of function is ideal for doctests.

    """
    Solve the knapsack problem by finding the most valuable
    subsequence of `items` that weighs no more than `maxweight`.

    `items` is a sequence of pairs `(value, weight)`, where `value` is
    a number and `weight` is a non-negative integer.

    `maxweight` is a non-negative integer.

    Return a pair whose first element is the sum of values in the most
    valuable subsequence, and whose second element is the subsequence.

    >>> items = [(4, 12), (2, 1), (6, 4), (1, 1), (2, 2)]
    >>> knapsack(items, 15)
    (11, [(2, 1), (6, 4), (1, 1), (2, 2)])
    """
  1. Your comments say things like "Create an (N+1) by (W+1) 2-d list". But what is N and what is W? Presumably N is len(items) and W is maxweight, but this seems needlessly unclear. Better to put a couple of lines like this:

     N = len(items)
     W = maxweight
    

so that the comments match the code (and then use N and W in the remainder of the code).

  1. The comment above bestvalues fails to explain what the values in this table actually mean. I would write something like this instead:

     # bestvalues[i][j] is the best sum of values for any
     # subsequence of the first i items, whose weights sum
     # to no more than j.
    

(This makes it obvious why 0 ≤ i ≤ N and why 0 ≤ j ≤ W.)

  1. In a loop like

     bestvalues = [[0] * (maxweight + 1)
                   for i in xrange(len(items) + 1)]
    
  1. You can simplify the code by omitting these lines:

     # Increment i, because the first row (0) is the case where no items
     # are chosen, and is already initialized as 0, so we're skipping it
     i += 1
    

and then using i + 1 instead of i and i instead of i - 1.

  1. Your reconstruction loop:

     i = N
     while i > 0:
         # code
         i -= 1
    
    for i in xrange(N, 0, -1):
        # code
  1. You print an error message like this:

     print('usage: knapsack.py [file]')
    

Error messages ought to go to standard error (not standard output). And you can't know that your program is called "knapsack.py": it might have been renamed. So write instead:

    sys.stderr.write('usage: {0} [file]\n'.format(sys.argv[0]))
  1. Your block of code that reads the problem description and prints the result only runs when __name__ == '__main__'. This makes it hard to test, for example from the interactive interpreter. It's usually best to put this kind of code in its own function, like this:

     def main(filename):
         with open(filename) as f:
             # etc.
    
     if __name__ == '__main__':
         if len(sys.argv) != 2:
             print('usage: knapsack.py [file]')
             sys.exit(1)
         main(sys.argv[1])
    
  1. You read the whole of the file into memory as a list of lines:

     lines = f.readlines()
    

this is harmless here because the file is small, but it's a bad habit to get into. It's usually best to process a file one line at a time if you can, like this:

    with open(filename) as f:
        maxweight = int(next(f))
        items = [map(int, line.split()) for line in f]

(Note that this results in slightly simpler code too.)

2. A more Pythonic solution?

Any dynamic programming algorithm can be implemented in two ways: by building a table of partial results from the bottom up (as in your code), or by recursively computing the result from the top down, using memoization to avoid computing any partial result more than once.

There are two advantages of the top-down approach: first, it often results in slightly simpler and clearer code, and second, it only computes the partial results that are needed for the particular problem instance (whereas the bottom-up approach computes all partial results even if some of them go unused).

So we could use the @memoized decorator from the Python Decorator Library to implement a top-down solution, as shown below. (The Python wiki seems to be down at the moment, but you can find the code on archive.org.)

def knapsack(items, maxweight):
    """
    Solve the knapsack problem by finding the most valuable
     subsequence of `items` subject that weighs no more than
    `maxweight`.

    `items` is a sequence of pairs `(value, weight)`, where `value` is
     a number and `weight` is a non-negative integer.

    `maxweight` is a non-negative integer.

    Return a pair whose first element is the sum of values in the most
    valuable subsequence, and whose second element is the subsequence.

    >>> items = [(4, 12), (2, 1), (6, 4), (1, 1), (2, 2)]
    >>> knapsack(items, 15)
    (11, [(2, 1), (6, 4), (1, 1), (2, 2)])
    """

    # Return the value of the most valuable subsequence of the first i
    # elements in items whose weights sum to no more than j.
    @memoized
    def bestvalue(i, j):
        if i == 0: return 0
        value, weight = items[i - 1]
        if weight > j:
            return bestvalue(i - 1, j)
        else:
            return max(bestvalue(i - 1, j),
                       bestvalue(i - 1, j - weight) + value)

    j = maxweight
    result = []
    for i in xrange(len(items), 0, -1):
        if bestvalue(i, j) != bestvalue(i - 1, j):
            result.append(items[i - 1])
            j -= items[i - 1][1]
    result.reverse()
    return bestvalue(len(items), maxweight), result

To see how many partial solutions this code needs to compute, print len(bestvalue.cache) just before returning the result. When solving the example problem in the docstring, I find that this computes 37 partial solutions (compared to the 96 partial solutions computed by the bottom up approach).

1. Review

  1. The function knapsack lacks a docstring that would explain what arguments the function takes (what kind of things are in items? must items be a sequence, or can it be an iterable?) and what it returns.

  2. This kind of function is ideal for doctests.

  3. The comments say things like "Create an (N+1) by (W+1) 2-d list". But what is N and what is W? Presumably N is len(items) and W is maxweight, so it would be a good idea to use the same names in the comments and the code:

     N = len(items)
     W = maxweight
    
  4. The comment above bestvalues fails to explain what the values in this table mean. I would write something like this:

     # bestvalues[i][j] is the best sum of values for any
     # subsequence of the first i items, whose weights sum
     # to no more than j.
    

This makes it obvious why \$0 ≤ i ≤ N\$ and why \$0 ≤ j ≤ W\$.

  1. In a loop like:

     bestvalues = [[0] * (maxweight + 1)
                   for i in xrange(len(items) + 1)]
    
  1. These lines could be omitted:

     # Increment i, because the first row (0) is the case where no items
     # are chosen, and is already initialized as 0, so we're skipping it
     i += 1
    

and then in the rest of the code, use i + 1 instead of i and i instead of i - 1.

  1. The reconstruction loop:

     i = N
     while i > 0:
         # code
         i -= 1
    
    for i in reversed(range(1, N + 1)):
        # code
  1. The code prints an error message like this:

     print('usage: knapsack.py [file]')
    

Error messages ought to go to standard error (not standard output). And it's a good idea not to hard-code the name of the program, because it would be easy to rename the program but forget to update the code. So write instead:

    sys.stderr.write("usage: {0} [file]\n".format(sys.argv[0]))
  1. The block of code that reads the problem description and prints the result only runs when __name__ == '__main__'. This makes it hard to test, for example from the interactive interpreter. It's usually best to put this kind of code in its own function, like this:

     def main(filename):
         with open(filename) as f:
             # etc.
    
     if __name__ == '__main__':
         if len(sys.argv) != 2:
             print('usage: knapsack.py [file]')
             sys.exit(1)
         main(sys.argv[1])
    
  1. The code reads the whole of the file into memory as a list of lines:

     lines = f.readlines()
    

this is harmless here because the file is small, but it's usually better to process a file one line at a time, like this:

    with open(filename) as f:
        maxweight = int(next(f))
        items = [[int(word) for word in line.split()] for line in f]

2. Recursive approach

Any dynamic programming algorithm can be implemented in two ways: by building a table of partial results from the bottom up (as in the code in the post), or by recursively computing the result from the top down, using memoization to avoid computing any partial result more than once.

The top-down approach often results in slightly simpler and clearer code, and it only computes the partial results that are needed for the particular problem instance (whereas in the bottom-up approach it's often hard to avoid computing all partial results even if some of them go unused).

So we could use @functools.lru_cache to implement a top-down solution, like this:

from functools import lru_cache

def knapsack(items, maxweight):
    """Solve the knapsack problem by finding the most valuable subsequence
    of items that weighs no more than maxweight.

    items must be a sequence of pairs (value, weight), where value is a
    number and weight is a non-negative integer.

    maxweight is a non-negative integer.

    Return a pair whose first element is the sum of values in the most
    valuable subsequence, and whose second element is the subsequence.

    >>> items = [(4, 12), (2, 1), (6, 4), (1, 1), (2, 2)]
    >>> knapsack(items, 15)
    (11, [(2, 1), (6, 4), (1, 1), (2, 2)]) 

    """
    @lru_cache(maxsize=None)
    def bestvalue(i, j):
        # Return the value of the most valuable subsequence of the first
        # i elements in items whose weights sum to no more than j.
        if j < 0:
            return float('-inf')
        if i == 0:
            return 0
        value, weight = items[i - 1]
        return max(bestvalue(i - 1, j), bestvalue(i - 1, j - weight) + value)

    j = maxweight
    result = []
    for i in reversed(range(len(items))):
        if bestvalue(i + 1, j) != bestvalue(i, j):
            result.append(items[i])
            j -= items[i][1]
    result.reverse()
    return bestvalue(len(items), maxweight), result

To see how many partial solutions this needs to compute, print bestvalue.cache_info() just before returning the result. When solving the example problem in the docstring, this outputs:

CacheInfo(hits=17, misses=42, maxsize=None, currsize=42)

The 42 entries in the cache are fewer than the 96 partial solutions computed by the bottom up approach.

oops, 96
Source Link
Gareth Rees
  • 50.1k
  • 3
  • 130
  • 211

To see how many partial solutions this code needs to compute, print len(bestvalue.cache) just before returning the result. When solving the example problem in the docstring, I find that this computes 37 partial solutions (compared to the 8096 partial solutions computed by the bottom up approach).

To see how many partial solutions this code needs to compute, print len(bestvalue.cache) just before returning the result. When solving the example problem in the docstring, I find that this computes 37 partial solutions (compared to the 80 partial solutions computed by the bottom up approach).

To see how many partial solutions this code needs to compute, print len(bestvalue.cache) just before returning the result. When solving the example problem in the docstring, I find that this computes 37 partial solutions (compared to the 96 partial solutions computed by the bottom up approach).

show how to compute the solution from the top down with memoization
Source Link
Gareth Rees
  • 50.1k
  • 3
  • 130
  • 211

To see how many partial solutions this code needs to compute, print len(bestvalue.cache) just before returning the result. When solving the example problem in the docstring, I find that this computes 37 partial solutions (compared to the 80 partial solutions computed by the bottom up approach).

To see how many partial solutions this code needs to compute, print len(bestvalue.cache) just before returning the result. When solving the example problem in the docstring, I find that this computes 37 partial solutions (compared to the 80 partial solutions computed by the bottom up approach).

show how to compute the solution from the top down with memoization
Source Link
Gareth Rees
  • 50.1k
  • 3
  • 130
  • 211
Loading
added 24 characters in body
Source Link
Gareth Rees
  • 50.1k
  • 3
  • 130
  • 211
Loading
Source Link
Gareth Rees
  • 50.1k
  • 3
  • 130
  • 211
Loading