(See the previous iteration here.)
(See the next iteration here.)
This time, I improved the iterator such that there is no chance of numeric overflow when computing the total number of iterations. To this end, I switched from long
to BigInteger
:
package com.github.coderodde.util;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.function.Consumer;
/**
* This class implements an iterator that generates index tuples in
* lexicographic order.
*
* @version 1.0.0
* @since 1.0.0
*/
public final class IndexTupleIterator implements Iterator<List<Integer>>{
/**
* The length of the index tuples.
*/
private final int indexTupleLength;
/**
* The length of the list for which the indices are being generated.
*/
private final int lengthOfTargetList;
/**
* The actual internal state containing the current index tuple.
*/
private final List<Integer> indices;
/**
* The number of iterated tuples.
*/
private BigInteger numberOfIteratedIndexTuples = BigInteger.ZERO;
/**
* The total number of iterations.
*/
private BigInteger totalNumberOfIterations;
public IndexTupleIterator(final int indexTupleLength,
final int lengthOfTargetList) {
checkArguments(indexTupleLength,
lengthOfTargetList);
this.indexTupleLength = indexTupleLength;
this.lengthOfTargetList = lengthOfTargetList;
this.indices = new ArrayList<>(indexTupleLength);
initState();
}
@Override
public boolean hasNext() {
return numberOfIteratedIndexTuples
.compareTo(totalNumberOfIterations) < 0;
}
@Override
public List<Integer> next() {
if (!hasNext()) {
throw new NoSuchElementException("Nothing to iterate.");
}
numberOfIteratedIndexTuples =
numberOfIteratedIndexTuples.add(BigInteger.ONE);
if (indices.get(indices.size() - 1) < lengthOfTargetList - 1) {
indices.set(indices.size() - 1,
indices.get(indices.size() - 1) + 1);
return Collections.<Integer>unmodifiableList(indices);
}
for (int i = indices.size() - 2; i >= 0; --i) {
final Integer a = indices.get(i);
final Integer b = indices.get(i + 1);
if (a < b - 1) {
indices.set(i, indices.get(i) + 1);
for (int j = i + 1; j < indices.size(); j++) {
indices.set(j, indices.get(j - 1) + 1);
}
return Collections.<Integer>unmodifiableList(indices);
}
}
throw new IllegalStateException("Should not get here.");
}
@Override
public void remove() {
throw new UnsupportedOperationException(
"remove() is not supported in IndexTupleIterator.");
}
@Override
public void forEachRemaining(Consumer<? super List<Integer>> action) {
throw new UnsupportedOperationException(
"forEachRemainig() is not supported in IndexTupleIterator.");
}
public static void main(String[] args) {
int lineNumber = 1;
for (final List<Integer> list : new IndexTupleIterable(3, 6)) {
System.out.printf("%2d: %s\n", lineNumber++, list);
}
}
private void initState() {
// Create the initial index tuple <0, 1, ..., indexTupleLength - 1>:
for (int i = 0; i < indexTupleLength; ++i) {
this.indices.add(i);
}
// Compute the total number of iterations:
final BigInteger factorialN = factorial(lengthOfTargetList);
final BigInteger factorialK = factorial(indexTupleLength);
final BigInteger factorialNminusK =
factorial(lengthOfTargetList - indexTupleLength);
this.totalNumberOfIterations = factorialN.divide(factorialK)
.divide(factorialNminusK);
// We need this in order to generate the first index tuple:
this.indices.set(indexTupleLength - 1,
this.indices.get(indexTupleLength - 1) - 1);
}
/**
* Checks that the input arguments are sensible.
*
* @param indexTupleLength the length of the index tuple.
* @param lengthOfTargetList the length of the list being indexed.
*/
private static void checkArguments(final int indexTupleLength,
final int lengthOfTargetList) {
if (indexTupleLength < 1) {
final String exceptionMessage =
String.format("indexTupleLength(%d) < 1",
indexTupleLength);
throw new IllegalArgumentException(exceptionMessage);
}
if (lengthOfTargetList < 1) {
final String exceptionMessage =
String.format("lengthOfTargetList(%d) < 1",
lengthOfTargetList);
throw new IllegalArgumentException(exceptionMessage);
}
if (indexTupleLength > lengthOfTargetList) {
final String exceptionMessage =
String.format(
"indexTupleLength(%d) > lengthOfTargetList(%d)",
indexTupleLength,
lengthOfTargetList);
throw new IllegalArgumentException(exceptionMessage);
}
}
private static BigInteger factorial(final long n) {
BigInteger factorial = BigInteger.ONE;
for (long l = 1; l <= n; ++l) {
factorial = factorial.multiply(BigInteger.valueOf(l));
}
return factorial;
}
}
As always, I am eager to hear any commentary on my attempt.