2
\$\begingroup\$

Intro

This time, I have implemented the Shannon-Fano coding.

Code

io.github.coderodde.compression.ShannonFanoEncoder.java

package io.github.coderodde.compression;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;

/**
 * This class contains a static method for computing the Shannon-Fano coding.
 * 
 * @author Rodion "rodde" Efremov
 * @version 1.0.0 (Oct 27, 2025)
 * @since 1.0.0 (Oct 27, 2025)
 */
public final class ShannonFanoEncoder {
    
    public static <S extends Comparable<S>> Map<S, CodeWord> 
        encode(final Map<S, Double> probabilityDistribution) {
            
        final List<SymbolEntry<S>> letterDataList =
                new ArrayList<>(probabilityDistribution.size());
        
        for (final Map.Entry<S, Double> entry 
                : probabilityDistribution.entrySet()) {
            
            final Double weight = Objects.requireNonNull(entry.getValue());
            
            checkWeight(weight);
            
            letterDataList.add(new SymbolEntry<>(entry.getKey(), 
                                                 weight));
        }
        
        Collections.sort(letterDataList);
        
        final Map<List<SymbolEntry<S>>, Boolean> assignmentMap = new HashMap<>();
        final Map<List<SymbolEntry<S>>, List<SymbolEntry<S>>> parentsMap = 
                new HashMap<>();
        
        assignmentMap.put(letterDataList, null);
        parentsMap.put(letterDataList, null);
        
        final Deque<List<SymbolEntry<S>>> queue = new ArrayDeque<>();
        
        queue.addLast(letterDataList);
        
        final List<List<SymbolEntry<S>>> lonelySymbols =
                new ArrayList<>(probabilityDistribution.size());
        
        while (!queue.isEmpty()) {
            final List<SymbolEntry<S>> subList = queue.removeFirst();
            final Pair<List<SymbolEntry<S>>> pair = splitEvenly(subList);
            final List<SymbolEntry<S>> loSublist = pair.first;
            final List<SymbolEntry<S>> hiSublist = pair.second;
            
            parentsMap.put(loSublist, subList);
            parentsMap.put(hiSublist, subList);
            
            assignmentMap.put(loSublist, Boolean.FALSE);
            assignmentMap.put(hiSublist, Boolean.TRUE);
            
            if (loSublist.size() > 1) {
                // Split later again:
                queue.addLast(loSublist);
            } else {
                // A singleton; cannot split:
                lonelySymbols.add(loSublist);
            }
            
            if (hiSublist.size() > 1) {
                // Split later again:
                queue.addLast(hiSublist);
            } else {
                // A singleton; cannot split:
                lonelySymbols.add(hiSublist);
            }
        }
        
        return inferCode(parentsMap,
                         assignmentMap,
                         lonelySymbols);
    }
        
    static <S extends Comparable<S>> Map<S, CodeWord> 
        inferCode(final Map<List<SymbolEntry<S>>,
                            List<SymbolEntry<S>>> parentsMap,
                  final Map<List<SymbolEntry<S>>, Boolean> assignmentMap,
                  final List<List<SymbolEntry<S>>> lonelySymbols) {
     
        final Map<S, CodeWord> code = new TreeMap<>();
            
        for (final List<SymbolEntry<S>> lonelySymbol : lonelySymbols) {
            final CodeWord codeWord = computeCodeWord(code,
                                                      lonelySymbol,
                                                      parentsMap,
                                                      assignmentMap);
            
            code.put(lonelySymbol.get(0).symbol, codeWord);
        }
        
        return code;
    }
        
    static <S extends Comparable<S>> 
        CodeWord computeCodeWord(final Map<S, CodeWord> code,
                              final List<SymbolEntry<S>> symbol,
                              final Map<List<SymbolEntry<S>>,
                                        List<SymbolEntry<S>>> parentsMap,
                              final Map<List<SymbolEntry<S>>, 
                                        Boolean> assignmentMap) {
        
        final List<List<SymbolEntry<S>>> path = inferPath(symbol,
                                                          parentsMap);
        
        final int codeWordLength = path.size() - 1;
        final CodeWord codeWord = new CodeWord(codeWordLength);
        
        for (int bitIndex = 1; bitIndex < path.size(); bitIndex++) {
            final List<SymbolEntry<S>> symbols = path.get(bitIndex);
            final boolean bit = assignmentMap.get(symbols);
            
            if (bit) {
                codeWord.set(bitIndex - 1);
            }
        }
        
        return codeWord;
    }
        
    private static void checkWeight(final double weight) {
        if (Double.isNaN(weight)) {
            throw new IllegalArgumentException("weight is NaN");
        }

        if (weight <= 0.0) {
            throw new IllegalArgumentException(
                    String.format("weight(%f) <= 0.0", weight));
        }

        if (Double.isInfinite(weight)) {
            throw new IllegalArgumentException("weight is Infinity");
        }
    }
        
    static <S extends Comparable<S>> List<List<SymbolEntry<S>>>
            inferPath(final List<SymbolEntry<S>> symbols,
                      final Map<List<SymbolEntry<S>>, 
                                List<SymbolEntry<S>>> parentsMap) {
    
        final List<List<SymbolEntry<S>>> path = new ArrayList<>();
        
        for (List<SymbolEntry<S>> current = symbols;
             current != null;
             current = parentsMap.get(current)) {
            
            path.add(current);
        }
        
        Collections.reverse(path);
        return path;
    }
        
    static <S extends Comparable<S>> Pair<List<SymbolEntry<S>>> 
        splitEvenly(final List<SymbolEntry<S>> list) {
        
        final List<SymbolEntry<S>> loSublist  = new ArrayList<>();
        final List<SymbolEntry<S>> hiSublist = new ArrayList<>();
        
        int loIndex = 0;
        int hiIndex = list.size() - 1;
        
        double loProbabilitySum = 0.0;
        double hiProbabilitySum = 0.0;
        
        do {
            if (loProbabilitySum < hiProbabilitySum) {
                loSublist.addLast(list.get(loIndex));
                loProbabilitySum += list.get(loIndex).probability;
                loIndex++;
            } else if (loProbabilitySum > hiProbabilitySum) {
                hiSublist.addLast(list.get(hiIndex));
                hiProbabilitySum += list.get(hiIndex).probability;
                hiIndex--;
            } else {
                loSublist.addLast(list.get(loIndex));
                hiSublist.addLast(list.get(hiIndex));
                loProbabilitySum += list.get(loIndex).probability;
                hiProbabilitySum += list.get(hiIndex).probability;
                loIndex++;
                hiIndex--;
            }
        } while (loIndex <= hiIndex);
        
        Collections.reverse(hiSublist);
        
        return new Pair<>(loSublist,
                          hiSublist);
    }
    
    static final class Pair<E> {
        final E first;
        final E second;
        
        Pair(final E first, final E second) {
            this.first = first;
            this.second = second;
        }
    }
        
    static final class SymbolEntry<S extends Comparable<S>> 
         implements Comparable<SymbolEntry<S>> {
        
        final S symbol;
        final double probability;
        
        SymbolEntry(final S symbol, final double probability) {
            this.symbol = Objects.requireNonNull(symbol);
            this.probability = validateProbability(probability);
        }
        
        private static double validateProbability(final double probability) {
            if (Double.isNaN(probability)) {
                throw new IllegalArgumentException(
                        "The input probability is NaN");
            }
            
            if (probability <= 0.0) {
                throw new IllegalArgumentException("probability <= 0.0");
            }
            
            if (probability > 1.0) {
                throw new IllegalArgumentException("probability > 1.0");
            }
                
            return probability;
        }

        @Override
        public String toString() {
            return String.format("[%s: %f]", symbol.toString(), probability);
        }
        
        @Override
        public int compareTo(final SymbolEntry<S> o) {
            final int cmp = Double.compare(o.probability,
                                             probability);
            
            if (cmp != 0) {
                return cmp;
            }
            
            return symbol.compareTo(o.symbol);
        }
        
        @Override
        public boolean equals(final Object object) {
            final SymbolEntry<S> other = (SymbolEntry<S>) object;
            return symbol.equals(other.symbol);
        }
        
        @Override
        public int hashCode() {
            return symbol.hashCode();
        }
    }
}

io.github.coderodde.compression.CodeWord.java

package io.github.coderodde.compression;

import java.util.BitSet;

/**
 * This class implements a <b>binary</b> code word in data compression 
 * scenarios.
 * 
 * @author Rodion "rodde" Efremov
 * @version 1.0.0 (Oct 28, 2025)
 * @since 1.0.0 (Oct 28, 2025)
 */
public class CodeWord {

    private final int length;
    private final BitSet bits;
    
    public CodeWord(final int length) {
        checkLength(length);
        this.length = length;
        this.bits = new BitSet(length);
    }
    
    public int length() {
        return length;
    }
    
    public boolean get(final int index) {
        checkIndex(index);
        return bits.get(index);
    }
    
    public void set(final int index) {
        checkIndex(index);
        bits.set(index);
    }
    
    public void unset(final int index) {
        checkIndex(index);
        bits.set(index, false);
    }
    
    @Override
    public String toString() {
        final StringBuilder sb = new StringBuilder(length);
        
        for (int i = 0; i < length; ++i) {
            sb.append(get(i) ? "1" : "0");
        }
        
        return sb.toString();
    }
    
    private void checkIndex(final int index) {
        if (index < 0) {
            throw new IndexOutOfBoundsException(
                    String.format("index(%d) < 0", index));
        }
        
        if (index >= this.length) {
            throw new IndexOutOfBoundsException(
                    String.format("index(%d) >= length(%d)", 
                                  index, 
                                  length));
        }
    }
    
    private static void checkLength(final int length) {
        if (length < 1) {
            throw new IllegalArgumentException(
                    String.format("length(%d) < 1", length));
        }
    }
}

io.github.coderodde.compression.ShannonFanoEncoderTest.java

package io.github.coderodde.compression;

import io.github.coderodde.compression.ShannonFanoEncoder.Pair;
import io.github.coderodde.compression.ShannonFanoEncoder.SymbolEntry;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.junit.Test;
import static org.junit.Assert.*;

public final class ShannonFanoEncoderTest {
   
    @Test
    public void splitEvenly1() {
        final List<SymbolEntry<Character>> list = new ArrayList<>();
        
        final SymbolEntry<Character> a = new SymbolEntry<>('a', 0.4); 
        final SymbolEntry<Character> b = new SymbolEntry<>('b', 0.3); 
        final SymbolEntry<Character> c = new SymbolEntry<>('c', 0.2); 
        final SymbolEntry<Character> d = new SymbolEntry<>('d', 0.1);
        
        list.add(c);
        list.add(b);
        list.add(d);
        list.add(a);
        
        Collections.sort(list);
        
        assertEquals(list.get(0), a);
        assertEquals(list.get(1), b);
        assertEquals(list.get(2), c);
        assertEquals(list.get(3), d);
        
        final Pair<List<SymbolEntry<Character>>> pair = 
                ShannonFanoEncoder.splitEvenly(list);
        
        assertEquals(1, pair.first.size());
        assertEquals(3, pair.second.size());
        
        assertEquals(a, pair.first.get(0));
        assertEquals(b, pair.second.get(0));
        assertEquals(c, pair.second.get(1));
        assertEquals(d, pair.second.get(2));
    }
   
    @Test
    public void splitEvenly2() {
        final List<SymbolEntry<Character>> list = new ArrayList<>();
        
        final SymbolEntry<Character> a = new SymbolEntry<>('a', 0.22); 
        final SymbolEntry<Character> b = new SymbolEntry<>('b', 0.28); 
        final SymbolEntry<Character> c = new SymbolEntry<>('c', 0.15); 
        final SymbolEntry<Character> d = new SymbolEntry<>('d', 0.30);
        final SymbolEntry<Character> e = new SymbolEntry<>('e', 0.05);
        
        list.add(a);
        list.add(b);
        list.add(c);
        list.add(d);
        list.add(e);
        
        Collections.sort(list);
        
        assertEquals(list.get(0), d);
        assertEquals(list.get(1), b);
        assertEquals(list.get(2), a);
        assertEquals(list.get(3), c);
        assertEquals(list.get(4), e);
        
        final Pair<List<SymbolEntry<Character>>> pair = 
                ShannonFanoEncoder.splitEvenly(list);
        
        assertEquals(2, pair.first.size());
        assertEquals(3, pair.second.size());
        
        assertEquals(d, pair.first.get(0));
        assertEquals(b, pair.first.get(1));
        assertEquals(a, pair.second.get(0));
        assertEquals(c, pair.second.get(1));
        assertEquals(e, pair.second.get(2));
    }
    
    @Test
    public void splitEvenlyOnTwoList() {
        final List<SymbolEntry<Character>> list = new ArrayList<>();
        
        final SymbolEntry<Character> a = new SymbolEntry<>('a', 0.4); 
        final SymbolEntry<Character> b = new SymbolEntry<>('b', 0.6);
        
        list.add(a);
        list.add(b);
        
        Collections.sort(list);
        
        assertEquals(list.get(0), b);
        assertEquals(list.get(1), a);
        
        final Pair<List<SymbolEntry<Character>>> pair = 
                ShannonFanoEncoder.splitEvenly(list);
        
        assertEquals(1, pair.first.size());
        assertEquals(1, pair.second.size());
        
        assertEquals(b, pair.first.get(0));
        assertEquals(a, pair.second.get(0));
    }
    
    @Test(expected = IllegalArgumentException.class)
    public void throwsOnInvalidProbabilityDistributionLessThanOne() {
        Map<Character, Double> distribution = new HashMap<>();
        
        distribution.put('A', 0.6);
        distribution.put('B', 0.3);
        
        ShannonFanoEncoder.compress(distribution);
    }
    
    @Test(expected = IllegalArgumentException.class)
    public void throwsOnInvalidProbabilityDistributionGreaterThanOne() {
        Map<Character, Double> distribution = new HashMap<>();
        
        distribution.put('A', 0.6);
        distribution.put('B', 0.5);
        
        ShannonFanoEncoder.compress(distribution);
    }
}

io.github.coderodde.compression.demo.Demo.java

package io.github.coderodde.compression.demo;

import io.github.coderodde.compression.CodeWord;
import io.github.coderodde.compression.ShannonFanoEncoder;
import java.util.HashMap;
import java.util.Map;

/**
 * This class runs some demonstration on Shannon-Fano coding.
 * 
 * @author Rodion "rodde" Efremov
 * @version 1.0.0 (Oct 27, 2025)
 * @since 1.0.0 (Oct 27, 2025)
 */
final class Demo {
    
    public static void main(String[] args) {
        final Map<Character, Double> probabilityDistribution = new HashMap<>();
        
        // This probability distribution is from
        // https://www.geeksforgeeks.org/dsa/shannon-fano-algorithm-for-data-compression/
        probabilityDistribution.put('a', 0.22);
        probabilityDistribution.put('b', 0.28);
        probabilityDistribution.put('c', 0.15);
        probabilityDistribution.put('d', 0.30);
        probabilityDistribution.put('e', 0.05);

        final Map<Character, CodeWord> coding = 
                ShannonFanoEncoder.compress(probabilityDistribution);
        
        System.out.println(coding);
    }
}

Critique request

I am eager to hear any constructive critique on my work. Also, is there a way for tuning the performance of the Shannon-Fano encoder without compromising readability?

\$\endgroup\$
3
  • 1
    \$\begingroup\$ You may also be interested in Polar codes or Engel codes, which are typically described as approximating Huffman codes but could equally be said to approximate Shannon-Fano. They're both ways to build prefix codes with some fast arithmetic only, avoiding the heavy duty data structure manipulation. \$\endgroup\$ Commented 8 hours ago
  • \$\begingroup\$ @user555045 Do Polar and Engel codes produce codes with better average code lengths as Shannon-Fano code? \$\endgroup\$ Commented 7 hours ago
  • 1
    \$\begingroup\$ I don't know exactly off the top of my head, but their purpose is to compute prefix codes faster, not so much to be better. For "better" I'd look towards ANS. \$\endgroup\$ Commented 7 hours ago

1 Answer 1

2
\$\begingroup\$

Your algorithm is definitely more readable and has better naming than the java version in the web. However there still seems to be room for improvement.

To offer a comparison, I did the Shannon-Fano algorithm with what I would do. Mind that this code is likely faulty; order and best splitting. Also your style might even be better, and I did use modern java constructs.

But take a look:

public class ShannonFano {

    record Symbol(char name, double probability) {

    }

    static class Encoding {

        Symbol symbol;
        StringBuilder code = new StringBuilder(20); // array to store the code

        @Override
        public String toString() {
            return "%c 0.2f \"%s\"".formatted(symbol.name, symbol.probability, code);
        }
    }

    private List<Encoding> encodings;

    public ShannonFano(Symbol... symbols) {
        encodings = new ArrayList<>(symbols.length);
        double total = 0.0;
        for (int i = 0; i < symbols.length; ++i) {
            Encoding encoding = new Encoding();
            encodings.add(encoding);
            encoding.symbol = symbols[i];
            total += symbols[i].probability;
            if (total > 1.0) {
                throw new IllegalArgumentException();
            }
        }
        Encoding last = encodings.getLast();
        last.symbol = new Symbol(
                last.symbol.name,
                1.0 - total);
        encodings.sort(Comparator.comparingDouble(enc -> -enc.symbol.probability()));
    }

    public void encode() {
        encode(encodings, 1.0);
    }

    private void dump(List<Encoding> someEncodings, double total) {
        System.out.printf("(%5.2f) %s%n", total, someEncodings);
    }

    private void encode(List<Encoding> someEncodings, double total) {
        dump(someEncodings, total);
        if (someEncodings.size() < 2) {
            return;
        }
        //someEncodings.sort(Comparator.comparingDouble(enc -> enc.symbol.probability()));
        //double total = someEncodings.stream().collect(Collectors.summingDouble(enc -> enc.symbol.probability));
        List<Encoding> leftList = new ArrayList<>();
        List<Encoding> rightList = new ArrayList<>(someEncodings);
        double leftTotal = 0;
        double rightTotal = total;
        double diff = rightTotal - leftTotal;
        while (!rightList.isEmpty()) {
            Encoding enc = rightList.getFirst();
            double newDiff = diff - 2 * enc.symbol.probability;
            boolean shift = Math.abs(newDiff) < diff;
            boolean done = newDiff < 0;
            if (shift) {
                leftTotal += enc.symbol.probability;
                rightTotal -= enc.symbol.probability;
                leftList.add(rightList.removeFirst());
                diff = newDiff;
            }
            if (done || !shift) {
                break;
            }
        }
        leftList.forEach(enc -> enc.code.append('0'));
        rightList.forEach(enc -> enc.code.append('1'));
        encode(leftList, leftTotal);
        encode(rightList, rightTotal);
    }

    public void display() {
        System.out.println();
        System.out.println("Symbol Probability Code");
        for (Encoding enc : encodings.reversed()) {
            System.out.printf("%-7c%11.2f %s%n", enc.symbol.name(), enc.symbol.probability(), enc.code);
        }
    }

    public static void main(String[] args) {
        ShannonFano encoder = new ShannonFano(
                new Symbol('A', 0.22),
                new Symbol('B', 0.28),
                new Symbol('C', 0.15),
                new Symbol('D', 0.30),
                new Symbol('E', 0.05)
        );
        encoder.encode();
        encoder.display();
    }
}

So:

  • Map and Map<List> disappear. (But sublists are neat.)
  • Instead of Comparable a Comparator might be more apt here.
  • StringBuilder for the constructed code is nice (when not using functional style).
  • It seems you want to retrieve more structure, which is nice for a graphical tree view, but here a map of letter & probability to code would be the utmost required data structure.
  • Maybe I do not see it right and you want splitting a partial list unsorted into two sublists with approximate probability difference.
  • I used the newer java constructs like Stream and record (i.o. Pair).

"My" code certainly is neither better nor more presentable, but I think it is more readable, more compact. Simpler.

And that was what lead me to write a review: good code but maybe simpler is way more useful - especially for others.

\$\endgroup\$
4
  • 1
    \$\begingroup\$ Note that - due to generics - my solution is alphabet agnostic. \$\endgroup\$ Commented 2 days ago
  • \$\begingroup\$ That indeed is an improving abstraction, and simply enumeratig 'a', 'b' instead of ch++ I find better style too. I am more a bread-and-butter programmer, and I am not sure that this facet of the library is really helpfull. Though char does not suit Unicode. \$\endgroup\$ Commented 2 days ago
  • \$\begingroup\$ Whats bread-and-butter programmer? :) \$\endgroup\$ Commented 2 days ago
  • 1
    \$\begingroup\$ Your code looks to me "academic" - well worked out. I (though did university) write library modules professionally, so the code must fit snugly. Notice that in my main there is simply a "usage", and the rest is really simple. The recursive encode would be named better split (as you did). But overall smart "production quality code" is what I aspire to. No critics! \$\endgroup\$ Commented 2 days ago

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.