9
\$\begingroup\$

LRU cache is a classical design pattern frequently questioned in programming technical interviews. It is best illustrated with pictures. Here are two examples.

enter image description here

enter image description here

The following are two alternative implementations of an LRU cache class template in C++20.

The first one stores the cached nodes in a std::unordered_map and uses a manually implemented double-linked list between them.

#pragma once

#include <cassert>
#include <unordered_map>
#include <memory>

namespace vstl {

template <class Key, class Value>
class lru_cache {
public:
    explicit lru_cache(size_t capacity) : m_capacity{capacity} {
        assert(m_capacity > 0 && "Capacity should be greater than 0.");
        // Reserving the maximum capacity of the LRU cache guarantees that no
        // rehashing will occur on the insertion and the stored in the nodes
        // iterators will stay valid.
        m_cache.reserve(m_capacity);
    }

    template <class ForwardingKey, class ForwardingValue>
        requires std::convertible_to<ForwardingKey, Key> && std::convertible_to<ForwardingValue, Value>
    void update(ForwardingKey&& key, ForwardingValue&& value) {
        if (const auto it = m_cache.find(key); it != m_cache.end()) {
            it->second.data = std::forward<ForwardingValue>(value);
            update_node_position(&it->second);
            return;
        }
        const auto res = m_cache.emplace(std::forward<ForwardingKey>(key), Node{
            .data = std::forward<ForwardingValue>(value),
            .back = m_cache.end(),
            .prev = m_head,
            .next = nullptr
        });
        assert(res.second && "Emplacement must be successful.");
        m_head = &res.first->second;
        if (m_cache.size() == 1) m_tail = m_head;
        else m_head->prev->next = m_head;
        // Here we are storing an iterator to the unordered_map node that we
        // will be used for the deletion of the tail element when the maximum
        // capacity is exceeded.
        res.first->second.back = res.first;
        if (m_cache.size() > m_capacity) {
            const auto to_remove = m_tail;
            m_tail = m_tail->next;
            m_tail->prev = nullptr;
            m_cache.erase(to_remove->back);
        }
    }

    const Value* get(const Key& key) const {
        const auto it = m_cache.find(key);
        if (it == m_cache.end()) return {};
        update_node_position(&it->second);
        return &m_head->data;
    }

private:
    struct Node;

    void update_node_position(Node* node) const {
        if (m_head == node) return;
        // If the node is not at the beginning of the linked list, remove it ...
        const auto next = node->next;
        const auto prev = node->prev;
        assert(next && "After the node is different than the head of the list we must have a next node.");
        next->prev = prev;
        if (prev) prev->next = next;
        else {
            // If there is no previous element we are removing the tail of
            // the linked list and must set a new one.
            m_tail = next;
            assert(!m_tail->prev && "The tail must not have previous node.");
         }
        // ... and put it at the beginning.
        node->next = nullptr;
        node->prev = m_head;
        m_head->next = node;
        m_head = node;
    }

    using Map = std::unordered_map<Key, Node>;

    struct Node {
        Value data;
        typename Map::const_iterator back;
        Node* prev = nullptr;
        Node* next = nullptr;
    };

    mutable Map m_cache;
    mutable Node* m_head = nullptr;
    mutable Node* m_tail = nullptr;
    const size_t m_capacity;
};

}

The second one is a shorter version that stores the nodes of the cache in a std::list and uses std::unordered_map only as an index container.

#pragma once

#include <cassert>
#include <unordered_map>
#include <list>

template <class Key, class Value>
class lru_cache {
public:
    explicit lru_cache(size_t capacity) : m_capacity{capacity} {
        assert(m_capacity > 0 && "Capacity should be greater than 0.");
        // Reserving the maximum capacity of the index guarantees that no
        // rehashing will occur on the insertion and the stored in the nodes
        // iterators will stay valid.
        m_index.reserve(m_capacity);
    }

    template <class ForwardingKey, class ForwardingValue>
        requires std::convertible_to<ForwardingKey, Key> && std::convertible_to<ForwardingValue, Value>
    void update(ForwardingKey&& key, ForwardingValue&& value) {
        if (const auto it = m_index.find(key); it != m_index.end()) {
            it->second->data = std::forward<ForwardingValue>(value);
            m_cache.splice(m_cache.begin(), m_cache, it->second);
            return;
        }
        m_cache.push_front(Node{
            .data = std::forward<ForwardingValue>(value),
            .back = m_index.end()
        });
        const auto res = m_index.emplace(std::forward<ForwardingKey>(key), m_cache.begin());
        assert(res.second && "Emplacement must be successful.");
        // Here we are storing an iterator to the unordered_map node that we
        // will be used for the deletion of the tail element when the maximum
        // capacity is exceeded.
        m_cache.front().back = res.first;
        if (m_cache.size() > m_capacity) {
            m_index.erase(m_cache.back().back);
            m_cache.pop_back();
        }
    }

    const Value* get(const Key& key) const {
        const auto it = m_index.find(key);
        if (it == m_index.end()) return {};
        m_cache.splice(m_cache.begin(), m_cache, it->second);
        return &m_cache.front().data;
    }

private:
    struct Node;

    using Cache = std::list<Node>;
    using Index = std::unordered_map<Key, typename Cache::iterator>;

    struct Node {
        Value data;
        typename Index::const_iterator back;
    };

    mutable Cache m_cache;
    Index m_index;

    const size_t m_capacity;
};

Here is the test code for both variants with the doctest library:

#include "doctest/doctest.h"
#include "lru_cache.hpp"

TEST_CASE("update/get") {
    lru_cache<std::string, int> cache{3};

    // Check insertions
    cache.update("one", 1);          // ["one": 1]
    CHECK(*cache.get("one") == 1);   // ["one": 1]
    cache.update("two", 2);          // ["two": 2, "one": 1]
    CHECK(*cache.get("two") == 2);   // ["two": 2, "one": 1]
    cache.update("three", 3);        // ["three: 3, "two": 2, "one": 1]
    CHECK(*cache.get("three") == 3); // ["three: 3, "two": 2, "one": 1]

    // Check updates
    cache.update("one", 4);          // ["one": 4, "three": 3, "two": 2]
    CHECK(*cache.get("one") == 4);   // ["one": 4, "three": 3, "two": 2]
    cache.update("three", 5);        // ["three": 5, "one": 4, "two": 2]
    CHECK(*cache.get("three") == 5); // ["three": 5, "one": 4, "two": 2]
    cache.update("three", 6);        // ["three": 6, "one": 4, "two": 2]
    CHECK(*cache.get("three") == 6); // ["three": 6, "one": 4, "two": 2]

    // Check that the least recently used element is removed.
    cache.update("four", 7);         // ["four": 7, "three": 6, "one": 4]
    CHECK(*cache.get("one") == 4);   // ["one": 4, "four": 7, "three": 6]
    CHECK(!cache.get("two"));
    CHECK(*cache.get("three") == 6); // ["three": 6, "one": 4, "four": 7]
    CHECK(*cache.get("four") == 7);  // ["four": 7, "three": 6, "one": 4]

    // Check that when updating an existing element the order is updated.
    cache.update("one", 8);          // ["one": 8, "four": 7, "three": 6]
    cache.update("five", 9);         // ["five": 9, "one": 8, "four": 7]
    CHECK(!cache.get("three"));
    CHECK(*cache.get("one") == 8);   // ["one": 8, "five": 9, "four": 7]
}

class non_default_constructable {
public:
    non_default_constructable(int value) : m_value{value} {}
    operator int() const { return m_value; }
private:
    int m_value;
};

namespace std {

template <> struct hash<non_default_constructable> {
    size_t operator()(const non_default_constructable& key) const {
        return std::hash<int>{}(key);
    }
};

}

TEST_CASE("test with non-default constructable key and value") {
    vstl::lru_cache<non_default_constructable, non_default_constructable> cache{1};
    cache.update(1, 1);
    CHECK(*cache.get(1) == 1);
}

Do you see any issues or have any suggestions for improvement? Which one do you prefer?

\$\endgroup\$
1
  • 7
    \$\begingroup\$ Oh, look, it comes with an automated test suite! Excellent, thank you. \$\endgroup\$ Commented Mar 10 at 15:46

2 Answers 2

8
\$\begingroup\$

This looks like a nice, straightforward implementation of a LRU cache, leveraging STL containers to do most of the heavy lifting (and the second version all of the heavy lifting). Still, there is some room for improvement:

Make it look more like std::unordered_map

The LRU cache is basically an unordered map, with the only special feature being that it forgets nodes when it reaches capacity. It would be great if you could make the interface look as much as possible as std::unordered_map, as that has several benefits:

  • Easier for programmers that already know std::unordered_map
  • Allows STL algorithms and other generic code work with your lru_cache
  • Adds functionality that is helpful and/or improve performance

For example, your update() takes a key and value parameter, but that means those have to be created beforehand. While they are moved if possible, it's better if at least the value could be constructed in place. So, something that looks like std::unordered_map::try_emplace() would be really nice.

There is also a lot of common functionality for STL containers, like clear(), swap() and so on that would be nice for an LRU cache as well.

Allow queries that don't change a node's position

Sometimes you know that it is very unlikely you will need to access an entry again in the short term. In that case, you want to be able to query the LRU cache for an item without moving it into the most recently used position. It might be interesting to add a function to do that.

Performance

There is nothing you can improve regarding time complexity, however one big drawback of your implementation is that for each node a separate memory allocation is made. This means the nodes are spread around in memory, increasing cache pressure and increasing memory fragmentation.

Ideally, you do one memory allocation to store all capacity nodes along with the bookkeeping overhead. That means implementing your own hash table, and also manually constructing and destructing elements in that storage. That's going to be a lot more work though.

\$\endgroup\$
5
\$\begingroup\$

Overview

I don't see any good reason to implement our own linked list. The version with std::list is shorter and clearer, and my guess is that we won't measure a difference in performance. So I'm not even going to look at the first version.


It might be good to let the user choose between std::unordered_map and std::map for the Map type, since not all types are hashable and not all types are totally ordered. And for those that are both, there may be good performance reasons to prefer one or the other.


There's no declaration for size_t. I'm guessing it's a typo for std::size_t, in which case we should include one of the headers that defines it, and spell it correctly. Not doing so is non-portable.

We're also missing an include of <utility>, for std::forward().


Don't use assert() for argument checking - that's not its purpose, and we should be preventing misuse of our interface even in non-debug builds. Prefer a plain if test instead and throw a std::invalid_argument if it fails.


Consider making the capacity non-const, so that caches may be copied or (more usefully) moved, like other containers. A public setter would also allow users to change the cache size later, possibly dropping elements. If we do make the capacity modifiable, we'll need to implement or delete the copy operations so that the stored iterators continue to point to the correct right elements in the copy.

Constructors

m_index.reserve(m_capacity) is potentially quite wasteful. The way I avoid that is to keep the actual key in the list, and use std::reference_wrapper<Key> as the key type of the index. This does add some small overhead to lookup, as there's an extra indirection.

Implementation note - this does require us to have transparent comparison in the map, which is easily arranged, and usually helpful:

    struct key_equal_transparent : std::equal_to<Key>
    {
        using is_transparent = Key;
    };

    struct key_hash_transparent : std::hash<Key>
    {
        using is_transparent = Key;
    };

    struct Node {
        Key key;
        Value data;
    };

    using Cache = std::list<Node>;
    using Index = std::unordered_map<
        std::reference_wrapper<Key>,
        typename Cache::iterator,
        key_hash_transparent, key_equal_transparent
        >;

This means that we are no longer constructing std::string objects from our string-literals when we perform lookups in the test suite.

Insertion

Consider using emplace_front() instead of push_front() when adding entries.

It's probably better to prune the last entry from the cache before constructing our new entry, especially if the values could be large.

Think about what happens when the first insertion (to the list) succeeds, but the second (to the map) throws an exception. We should undo the list insertion if that occurs.

Since we always write the value into the cache (unlike the key, whose equal might be found), we could accept this as a Value with any implicit conversion occurring in the caller. This simplifies interface and implementation (std::move() instead of std::forward()).

Access

Perhaps get() should accept anything that's comparable with Key. I'd expect

    template <typename K>
    const Value* get(const K& key) const;

Modified code

#include <cassert>
#include <cstddef>
#include <list>
#include <memory>
#include <exception>
#include <unordered_map>
#include <utility>

template <class Key, class Value>
class lru_cache
{
    struct Hash : std::hash<Key> { using is_transparent = Key; };
    struct Equal : std::equal_to<Key> { using is_transparent = Key; };

    using Node = std::pair<Key, Value>;
    using Cache = std::list<Node>;
    using Ref = std::reference_wrapper<const Key>;
    using Iter = typename Cache::iterator;
    using Index = std::unordered_map<Ref, Iter, Hash, Equal>;

    std::size_t m_capacity;
    mutable Cache m_cache = {};
    Index m_index = {};

public:
    explicit lru_cache(std::size_t capacity)
        : m_capacity{capacity}
    {
        if (!capacity) {
            throw std::invalid_argument("LRU capacity == 0");
        }
    }

    lru_cache(const lru_cache& other)
        : m_capacity{other.m_capacity},
          m_cache{other.m_cache}
    {
        // rebuild the index
        m_index.reserve(m_cache.size());
        for (auto it = m_cache.begin();  it != m_cache.end();  ++it) {
            m_index.emplace(it->first, it);
        }
    }
    lru_cache(lru_cache&&) = default;

    lru_cache& operator=(lru_cache other)
    {
        swap(other);
        return *this;
    }
    lru_cache& operator=(lru_cache&&) = default;

    void swap(lru_cache& other)
    {
        using std::swap;
        swap(m_capacity, other.m_capacity);
        swap(m_cache, other.m_cache);
        swap(m_index, other.m_index);
    }

    void resize(std::size_t new_capacity)
    {
        if (!new_capacity) {
            throw std::invalid_argument("LRU capacity == 0");
        }
        m_capacity = new_capacity;
        while (m_cache.size() > m_capacity) {
            m_index.erase(m_cache.back().first);
            m_cache.pop_back();
        }
    }

    void clear()
    {
        m_index.clear();
        m_cache.clear();
    }

    template <class K>
        requires std::convertible_to<K, Key>
    void update(K&& key, Value value)
    {
        if (const auto index_it = m_index.find(key);  index_it != m_index.end()) {
            const auto cache_it = index_it->second;
            cache_it->second = std::move(value);
            m_cache.splice(m_cache.begin(), m_cache, cache_it);
            return;
        }
        if (m_cache.size() >= m_capacity) {
            m_index.erase(m_cache.back().first);
            m_cache.pop_back();
        }
        auto& added = m_cache.emplace_front(std::forward<K>(key), std::move(value));
        try {
            m_index.emplace(added.first, m_cache.begin());
        } catch (...) {
            // undo the addition
            m_cache.pop_front();
        }
    }

    template <class K>
        requires requires(const Key& k, const K& v, Index::hasher h) { k == v; h(v); }
    const Value* get(const K& key) const {
        const auto it = m_index.find(key);
        if (it == m_index.end()) return {};
        m_cache.splice(m_cache.begin(), m_cache, it->second);
        return &m_cache.front().second;
    }
};
\$\endgroup\$
2
  • 2
    \$\begingroup\$ Not sure why I got a downvote on this - if you think there's something wrong, I really would appreciate some words explaining what I could improve. \$\endgroup\$ Commented Mar 11 at 14:24
  • \$\begingroup\$ A further thought on this - when we pop_back() as part of update, perhaps we could re-use the list node for the new element? That would save the overhead of returning it to the allocator just to claim it again. I'm not able to test and show that right now; possibly around the end of this week. \$\endgroup\$ Commented Mar 17 at 11:35

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.