8
\$\begingroup\$

I've took inspiration from https://www.geeksforgeeks.org/cpp/kd-trees-in-cpp/ and implemented a kdtree library and I would like to get a review of it. The main target was to have a kd-tree implementation that is drop in place and that is generic enough to cover random use-cases. I'm most interested opinions surrounding the interface.

The library:

#ifndef KDTREE_H
#define KDTREE_H

#include <cmath>
#include <functional>
#include <memory>

/*! @file
 * 
 * Main inspiration: https://www.geeksforgeeks.org/cpp/kd-trees-in-cpp/
 * 
 * KD Tree
 * =======
 * 
 * Props:
 * - Every node in the tree is a k-dimensional point.
 * - Every non-leaf node generates a splitting hyperplane that divides the space into two parts.
 * - Points to the left of the splitting hyperplane are represented by the left subtree of that node
 * and points to the right of the hyperplane are represented by the right subtree.
 * - The hyperplane direction is chosen in the following way: it is perpendicular to the axis
 * corresponding to the depth of the node (modulo k).
 * - The tree is balanced when constructed with points that are uniformly distributed.
 * 
 * */

/*! Enables very verbose logs, you should only enable this if your floating points are compatible
 * with std::to_string. You also need to set kdtree::enable_logging to true. */
#ifdef KDTREE_ENABLE_LOGGING
# ifndef KDTREE_PRINTF
#  define KDTREE_PRINTF printf
# endif
# define KDTREE_DEBUG(fmt, ...)                                                                 \
do {                                                                                            \
    if (kdtree::enable_logging)                                                                 \
        KDTREE_PRINTF("KDTREE: %30s line[%3d] " fmt "\n", __func__, __LINE__, ##__VA_ARGS__);   \
} while (0)
#else
# define KDTREE_DEBUG(fmt, ...) do {} while (0)
#endif /* KDTREE_ENABLE_LOGGING */

namespace kdtree {

/*! Enabling this will soft enable logging everywhere (only if the above is enabled) */
inline bool enable_logging = false;

/*! The type of any point, holding K coordinates. */
template <typename T, size_t K>
using vec_t = std::array<T, K>;

/*! An axis-aligned hyperbox */
template <typename T, size_t K>
struct hyperbox_t {
    vec_t<T, K> min_coords;
    vec_t<T, K> max_coords;
};

/*! A K-dimensional sphere */
template <typename T, size_t K>
struct hypersphere_t {
    vec_t<T, K> center;
    T radius;
};

/*! The generic node inside the tree. Pointers of the type node_t<T, K, D>* should be consistent
 * until a deletion occours
 * 
 * @param T the type for coordinates of a point. T must be copiable, initializable by 0 and -1 and
 * have comparators (<, ==). T must be signed.
 * @param K the number of coordinates of a point
 * @param D the type of data stored for a point. Data must be copiable and initializable. */
template <typename T, size_t K, typename D>
struct node_t {
    /*! left sub-tree root */
    node_t *left = nullptr;

    /*! right sub-tree root */
    node_t *right = nullptr;

    /*! data that is held by the node. */
    D data;

    /*! The point aferent to the data. */
    vec_t<T, K> p;
};

/*! Tree options: functions and variables that change the behaviour of the tree. */
template <typename T, size_t K, typename D>
struct tree_opts_t {
    /*! eq - equal - This is the way this tree checks if two points are exactly the same */
    std::function<bool(const vec_t<T, K>& a, const vec_t<T, K>& b)> eq;

    /*! dist2 - distance squared - The function that comutes the square of the distance between two
     * points
     * 
     * @param a the first point
     * @param b the second point */
    std::function<T(const vec_t<T, K>& a, const vec_t<T, K>& b)> dist2;

    /*! rec_intersect - the function that returns true if a axis aligned hyperbox and a hypersphere
     * intersect.
     * 
     * @param rect the hyperbox
     * @param circle the hypersphere */
    std::function<bool(const hyperbox_t<T, K>& rect, const hypersphere_t<T, K>& circle)>
            rect_intersect;

    /*! inf - a number greater than any coordinate in any point. (-inf) should also be smaller than
     * any coordinate. */
    T inf;
};

/*! Main object that contains a root and some public helpers.
 * 
 * @param root Contains the root of the tree.
 * @param eq The function that checks if two points are equal.
 * @param dist2 The function that computes the distance squared of two points.
 * @param rect_intersect Function that checks if a hypersphere and rectangle intersect.
 * @param inf A value that is greater than any coord value squared
 */
template <typename T, size_t K, typename D>
struct tree_t {
    static constexpr const T exact = 0;     /*! used in find/remove queries */
    static constexpr const T nearest = -1;  /*! used in find/remove queries */

    node_t<T, K, D> *root = nullptr;

    std::shared_ptr<tree_opts_t<T, K, D>> o;

    static std::shared_ptr<tree_t<T, K, D>> create(std::shared_ptr<tree_opts_t<T, K, D>>
            co = nullptr);
};

/*! Initializes a kd-tree structure with default helpers.
 * 
 * @param T The type of the spatial coordinates.
 * @param K The number of spatial coordinates.
 * @param D The type of the remembered data by the nodes.
 * 
 * @param custom_opts A set of custom tree options, those are defined in tree_opts_t
 * 
 * @return A shared pointer to the newly created structure. Is null on error. */
template <typename T, size_t K, typename D>
inline std::shared_ptr<tree_t<T, K, D>> create(std::shared_ptr<tree_opts_t<T, K, D>> custom_opts
        = nullptr);

/*! Inserts a new point into the kd-tree structure
 * 
 * @param tree The tree in which to insert the point.
 * @param p The point to insert.
 * @param data The data which acompanies the point.
 * 
 * @return A shared pointer to the newly created node. Is null on error. */
template <typename T, size_t K, typename D>
inline node_t<T, K, D> *insert(tree_t<T, K, D> *tree, const vec_t<T, K> &p, D&& data);

/*! same as above, but accepts shared pointer as input */
template <typename T, size_t K, typename D>
inline node_t<T, K, D> *insert(std::shared_ptr<tree_t<T, K, D>> tree, const vec_t<T, K> &p,
        D &&data)
{
    return kdtree::insert<T, K, D>(tree.get(), p, std::forward<D>(data));
}

/*! Finds a list of points inside the kd-tree, given another point and a distance to it.
 * 
 * @param tree The tree in which to search for points.
 * @param p The point to wich to compare other points.
 * @param range The distance to wich to limit the search. If the range is equal to T{0}, an exact
 * match is searched and if the range is equal to T{-1} the nearest is searched. Any other negative
 * value besides -1 consists an error.
 * 
 * #return A vector containing the nodes respective to the matching points. For exact matches an
 * empty returned vector signals an error. */
template <typename T, size_t K, typename D>
inline std::vector<node_t<T, K, D> *> find(tree_t<T, K, D> *tree, const vec_t<T, K> &p,
        const T& range);

/*! same as above, but accepts shared pointer as input */
template <typename T, size_t K, typename D>
inline std::vector<node_t<T, K, D> *> find(std::shared_ptr<tree_t<T, K, D>> tree,
        const vec_t<T, K> &p, const T& range)
{
    return kdtree::find<T, K, D>(tree.get(), p, range);
}


/*! Finds a list of points inside the kd-tree, given another point and a distance to it and removes
 * the found nodes.
 * 
 * @param tree The tree in which to search for points.
 * @param p The point to wich to compare other points.
 * @param range The distance to wich to limit the search. If the range is equal to T{0}, the nearest
 * is searched
 * 
 * @return The number of removed nodes or a negative number on error. */
template <typename T, size_t K, typename D>
inline int remove(tree_t<T, K, D> *tree, const vec_t<T, K> &p, const T& range);

/*! same as above, but accepts shared pointer as input */
template <typename T, size_t K, typename D>
inline int remove(std::shared_ptr<tree_t<T, K, D>> tree, const vec_t<T, K> &p, const T& range) {
    return kdtree::remove(tree.get(), p, range);
}

/*! Creates a string representation of a kd tree.
 * 
 * @param tree A pointer to the tree in question.
 * @param data_to_string_fn a custom function that converts Data to a string. */
template <typename T, size_t K, typename D>
inline std::string to_string(const tree_t<T, K, D> *tree,
        std::function<std::string(const D&)> data_to_string_fn = [](const D&){ return "[data]"; },
        std::function<std::string(const T&)> coord_to_string_fn = [](const T& val){
                return std::to_string(val); });

/*! same as above, but accepts shared pointer as input */
template <typename T, size_t K, typename D>
inline std::string to_string(std::shared_ptr<tree_t<T, K, D>> tree,
        std::function<std::string(const D&)> data_to_string_fn = [](const D&){ return "[data]"; },
        std::function<std::string(const T&)> coord_to_string_fn = [](const T& val){
                return std::to_string(val); })
{
    return kdtree::to_string<T, K, D>(tree.get(), data_to_string_fn, coord_to_string_fn);
}

/*! Creates a string representation of a kd tree node.
 * 
 * @param tree A pointer to the node in question.
 * @param data_to_string_fn a custom function that converts Data to a string. */
template <typename T, size_t K, typename D>
inline std::string to_string(const node_t<T, K, D> *node,
        std::function<std::string(const D&)> data_to_string_fn = [](const D&){ return "[data]"; },
        std::function<std::string(const T&)> coord_to_string_fn = [](const T& val){
                return std::to_string(val); });

/*! Creates a string representation of a kd tree point */
template <typename T, size_t K>
inline std::string to_string(const vec_t<T, K>& p, std::function<std::string(const T&)>
        coord_to_string_fn = [](const T& val){ return std::to_string(val); });

/*! Creates a string representation of a kd tree box */
template <typename T, size_t K>
inline std::string to_string(const hyperbox_t<T, K>& bb, std::function<std::string(const T&)>
        coord_to_string_fn = [](const T& val){ return std::to_string(val); });

/*! Creates a string representation of a kd tree sphere */
template <typename T, size_t K>
inline std::string to_string(const hypersphere_t<T, K>& circle, std::function<std::string(const T&)>
        coord_to_string_fn = [](const T& val){ return std::to_string(val); });

/*! Sanity check, verifies that each node splits it's subtree by a hyperplane, as intended */
template <typename T, size_t K, typename D>
inline bool is_tree_valid(tree_t<T, K, D> *tree);

/*! Same as above, but takes a shared pointer as parameter */
template <typename T, size_t K, typename D>
inline bool is_tree_valid(std::shared_ptr<tree_t<T, K, D>> tree) {
    return tree ? kdtree::is_tree_valid<T, K, D>(tree.get()) : false;
}

/* IMPLEMENTATION 
=================================================================================================
=================================================================================================
================================================================================================= */

template <typename T, size_t K, typename D>
inline std::shared_ptr<tree_t<T, K, D>> tree_t<T, K, D>::create(
        std::shared_ptr<tree_opts_t<T, K, D>> co)
{
    return kdtree::create<T, K, D>(co);
}

template <typename T, size_t K, typename D>
inline std::shared_ptr<tree_t<T, K, D>> create(std::shared_ptr<tree_opts_t<T, K, D>> custom_opts) {
    auto ret = std::make_shared<tree_t<T, K, D>>();

    if (custom_opts) {
        ret->o = custom_opts;
        return ret;
    }

    ret->o = std::make_shared<tree_opts_t<T, K, D>>();

    ret->o->eq = [](const vec_t<T, K>& a, const vec_t<T, K>& b) {
        return a == b;
    };
    ret->o->dist2 = [](const vec_t<T, K>& a, const vec_t<T, K>& b) {
        T dst_squared = 0;
        for (size_t i = 0; i < K; i++)
            dst_squared += (a[i] - b[i]) * (a[i] - b[i]);
        return dst_squared;
    };

    // https://stackoverflow.com/questions/401847/circle-rectangle-collision-detection-intersection
    ret->o->rect_intersect = [](const hyperbox_t<T, K>& rect, const hypersphere_t<T, K>& circle) {
        vec_t<T, K> circle_distance;
        for (size_t i = 0; i < K; i++)
            circle_distance[i] = abs(circle.center[i] - rect.min_coords[i]);

        vec_t<T, K> half_rect_sizes;
        for (size_t i = 0; i < K; i++) {
            half_rect_sizes[i] = (rect.max_coords[i] - rect.min_coords[i]) / 2.;
            if (circle_distance[i] > (half_rect_sizes[i] + circle.radius))
                return false;
        }

        for (size_t i = 0; i < K; i++) {
            if (circle_distance[i] <= half_rect_sizes[i])
                return true;
        }

        T dst_squared = 0;
        for (size_t i = 0; i < K; i++)
            dst_squared += (circle_distance[i] - half_rect_sizes[i]) * (circle_distance[i] -
                    half_rect_sizes[i]);

        return dst_squared <= circle.radius * circle.radius; 
    };
    ret->o->inf = std::numeric_limits<T>::max();
    return ret;
}

template <typename T, size_t K, typename D>
inline node_t<T, K, D> *insert_recursive(node_t<T, K, D> **root,
        const vec_t<T, K> &p, D&& data, size_t depth, tree_t<T, K, D> *tree)
{
    if (!(*root)) {
        KDTREE_DEBUG("Createing new node");
        return (*root) = new node_t<T, K, D>{ .data = std::forward<D>(data), .p = p };
    }
    size_t coord = depth % K;

    KDTREE_DEBUG("point: %s root: %s coord: %zu depth: %zu",
            to_string(p).c_str(), to_string(*root).c_str(), coord, depth);

    if (p[coord] < (*root)->p[coord])
        return insert_recursive<T, K, D>(&(*root)->left, p, std::forward<D>(data), depth + 1, tree);
    else
        return insert_recursive<T, K, D>(&(*root)->right, p, std::forward<D>(data), depth + 1, tree);
}

template <typename T, size_t K, typename D>
inline node_t<T, K, D> *insert(tree_t<T, K, D> *tree,
        const vec_t<T, K> &p, D&& data)
{
    return insert_recursive<T, K, D>(&(tree->root), p, std::forward<D>(data), 0, tree);
}

template <typename T, size_t K, typename D>
inline void find_in_range_recursive(node_t<T, K, D> *root,
        const vec_t<T, K> &p, const T &range, std::vector<node_t<T, K, D> *> &result,
        size_t depth, tree_t<T, K, D> *tree, const hyperbox_t<T, K>& bb)
{
    if (!root)
        return;

    hypersphere_t<T, K> zone_of_interest = { .center = p, .radius = range };
    size_t coord = depth % K;

    KDTREE_DEBUG("point: %s root: %s coord: %zu depth: %zu",
            to_string(p).c_str(), to_string(root).c_str(), coord, depth);

    hyperbox_t<T, K> left_bb = bb;
    left_bb.max_coords[coord] = root->p[coord];
    if (tree->o->rect_intersect(left_bb, zone_of_interest))
        find_in_range_recursive(root->left, p, range, result, depth + 1, tree, bb);

    hyperbox_t<T, K> right_bb = bb;
    right_bb.min_coords[coord] = root->p[coord];
    if (tree->o->rect_intersect(right_bb, zone_of_interest))
        find_in_range_recursive(root->right, p, range, result, depth + 1, tree, bb);
}

template <typename T, size_t K, typename D>
inline node_t<T, K, D> *find_exact_recursive(node_t<T, K, D> *root,
        const vec_t<T, K> &p, size_t depth, tree_t<T, K, D> *tree)
{
    size_t coord = depth % K;
    KDTREE_DEBUG("point: %s root: %s coord: %zu",
            to_string(p).c_str(), to_string(root).c_str(), coord);
    if (!root)
        return nullptr;
    if (tree->o->eq(root->p, p))
        return root;
    if (p[coord] < root->p[coord])
        return find_exact_recursive(root->left, p, depth + 1, tree);
    else
        return find_exact_recursive(root->right, p, depth + 1, tree);
}

template <typename T, size_t K, typename D>
inline node_t<T, K, D> *find_nearest_recursive(node_t<T, K, D> *root,
        const vec_t<T, K> &p, size_t depth, tree_t<T, K, D> *tree, T min_dist_squared,
        hyperbox_t<T, K> bb)
{
    if (!root)
        return nullptr;

    /* We first actualize the minimum distance to take into consideration the current node */
    node_t<T, K, D> *ret = nullptr;
    T dist_squared = tree->o->dist2(root->p, p);
    if (dist_squared < min_dist_squared) {
        min_dist_squared = dist_squared;
        ret = root;
    }

    /* We create a circle centered at p and of radius the new minimum distance */
    size_t coord = depth % K;

    KDTREE_DEBUG("point: %s root: %s coord: %zu depth: %zu",
            to_string(p).c_str(), to_string(root).c_str(), coord, depth);

    T min_dist = std::sqrt(min_dist_squared);
    hypersphere_t<T, K> zone_of_interest = { .center = p, .radius = min_dist };

    hyperbox_t<T, K> left_bb = bb;
    left_bb.max_coords[coord] = root->p[coord];

    /* If the bounding box of the left zone intersects with our zone of interest we recurse the left
    branch */
    node_t<T, K, D> *ret_left = nullptr;
    if (tree->o->rect_intersect(left_bb, zone_of_interest)) {
        ret_left = find_nearest_recursive(root->left, p, depth + 1, tree, min_dist_squared,
                left_bb);
    }

    hyperbox_t<T, K> right_bb = bb;
    right_bb.min_coords[coord] = root->p[coord];

    /* If the bounding box of the right zone intersects with our zone of interest we recurse the
    right branch */
    node_t<T, K, D> *ret_right = nullptr;
    if (tree->o->rect_intersect(right_bb, zone_of_interest)) {
        ret_right = find_nearest_recursive(root->right, p, depth + 1, tree, min_dist_squared,
                right_bb); 
    }

    /* Now we have 3 candidates: the current point, the minimum found for the left branch and
    the minimum found on the right branch. So we first compare this point to the left one and
    remember the best in ret (that is if it is not null) */
    T new_min = min_dist_squared;
    if (!ret)
        ret = ret_left;
    else if (ret_left && (new_min = tree->o->dist2(ret->p, ret_left->p)) < min_dist_squared) {
        min_dist_squared = new_min;
        ret = ret_left;
    }

    /* now if we have a point from the previous comparation we compare it with the right result and
    choose the best */
    if (ret && ret_right && tree->o->dist2(ret->p, ret_right->p) < min_dist_squared)
        ret = ret_right;
    if (!ret)
        ret = ret_right;

    return ret;
}

template <typename T, size_t K, typename D>
inline std::vector<node_t<T, K, D> *> find(tree_t<T, K, D> *tree,
        const vec_t<T, K> &p, const T& range)
{
    using ret_t = std::vector<node_t<T, K, D> *>;

    if (range == T{0}) {
        auto ret = find_exact_recursive(tree->root, p, 0, tree);
        return ret ? ret_t{ret} : ret_t{};
    }
    else if (range == T{-1}) {
        hyperbox_t<T, K> bb;
        for (size_t i = 0; i < K; i++) {
            bb.min_coords[i] = -tree->o->inf;
            bb.max_coords[i] = tree->o->inf;
        }
        auto ret = find_nearest_recursive(tree->root, p, 0, tree, tree->o->inf, bb);
        return ret ? ret_t{ret} : ret_t{};
    }
    else {
        ret_t ret;
        hyperbox_t<T, K> bb;
        for (size_t i = 0; i < K; i++) {
            bb.min_coords[i] = -tree->o->inf;
            bb.max_coords[i] = tree->o->inf;
        }
        find_in_range_recursive(tree->root, p, range, ret, 0, tree, bb);
        return ret;
    }
    return ret_t{};
}

template <typename T, size_t K, typename D>
inline node_t<T, K, D> *find_min_coord(node_t<T, K, D> *root, size_t d, size_t depth) {
    if (!root)
        return nullptr;

    size_t coord = depth % K; /* this is the current coord of the root */

    KDTREE_DEBUG("root: %s coord: %zu depth: %zu d: %zu",
            to_string(root).c_str(), coord, depth, d);

    /* if we are on a node that has cd as it's active coord then we know from kd-tree props that
    it's children are ordered by the tree props */
    if (coord == d) {
        if (root->left == nullptr)
            return root;
        return find_min_coord(root->left, d, depth+1);
    }

    /* else we must search both sub-trees for said min */
    auto res = root;
    auto l = find_min_coord(root->left, d, depth+1);
    auto r = find_min_coord(root->right, d, depth+1);
    if (l && l->p[d] < res->p[d])
        res = l;
    if (r && r->p[d] < res->p[d])
        res = r;
    return res;
}

template <typename T, size_t K, typename D>
inline node_t<T, K, D> *remove_recursive(tree_t<T, K, D> *tree, node_t<T, K, D> *root,
        const vec_t<T, K>& p, size_t depth)
{
    if (!root)
        return nullptr;

    size_t coord = depth % K;

    KDTREE_DEBUG("point: %s root: %s coord: %zu depth: %zu",
            to_string(p).c_str(), to_string(root).c_str(), coord, depth);

    if (tree->o->eq(root->p, p)) {
        if (root->right != nullptr) {
            auto min_node = find_min_coord(root->right, coord, depth+1);

            root->p = min_node->p;
            root->data = min_node->data;

            root->right = remove_recursive(tree, root->right, min_node->p, depth+1);
        }
        else if (root->left != nullptr) {
            auto min_node = find_min_coord(root->left, coord, depth+1);

            root->p = min_node->p;
            root->data = min_node->data;

            root->right = remove_recursive(tree, root->left, min_node->p, depth+1);
            root->left = nullptr;
        }
        else {
            delete root;
            return nullptr;
        }
    }

    if (p[coord] < root->p[coord])
        root->left = remove_recursive(tree, root->left, p, depth+1);
    else
        root->right = remove_recursive(tree, root->right, p, depth+1);
    return root;
}

template <typename T, size_t K, typename D>
inline int remove(tree_t<T, K, D> *tree, const vec_t<T, K> &p, const T& range) {
    auto to_delete_nodes = kdtree::find(tree, p, range);
    int ret = to_delete_nodes.size();
    if (range == T{0} && ret == 0) {
        return -1;
    }
    std::vector<vec_t<T, K>> to_delete_points;
    for (auto node : to_delete_nodes) {
        if (!node) {
            return -1;
        }
        to_delete_points.push_back(node->p);
    }
    for (auto p : to_delete_points)
        tree->root = remove_recursive(tree, tree->root, p, 0);
    return ret;
}

template <typename T, size_t K, typename D>
inline std::string to_string_recursive(const node_t<T, K, D> *root,
        std::function<std::string(const D&)> data_to_string_fn,
        std::function<std::string(const T&)> coord_to_string_fn,
        size_t depth)
{
    std::string ret = std::string(depth * 2, ' ') +
            to_string(root, data_to_string_fn, coord_to_string_fn) + "\n";
    if (root->left)
        ret += to_string_recursive(root->left, data_to_string_fn, coord_to_string_fn, depth+1);
    else
        ret += std::string((depth + 1) * 2, ' ') + "[null_left]\n";
    if (root->right)
        ret += to_string_recursive(root->right, data_to_string_fn, coord_to_string_fn, depth+1);
    else
        ret += std::string((depth + 1) * 2, ' ') + "[null_right]\n";
    return ret;
}

template <typename T, size_t K, typename D>
inline std::string to_string(const tree_t<T, K, D> *tree,
        std::function<std::string(const D&)> data_to_string_fn,
        std::function<std::string(const T&)> coord_to_string_fn)
{
    if (!tree)
        return "[invalid_null_tree]";
    if (!tree->root)
        return "[null_root]";
    return to_string_recursive(tree->root, data_to_string_fn, coord_to_string_fn, 0);
}

template <typename T, size_t K, typename D>
inline std::string to_string(const node_t<T, K, D> *node,
        std::function<std::string(const D&)> data_to_string_fn,
        std::function<std::string(const T&)> coord_to_string_fn)
{
    std::string ret = "[";
    if (!node)
        return "[null_node]";
    for (size_t i = 0; i < K; i++) {
        ret += coord_to_string_fn(node->p[i]);
        if (i != K - 1)
            ret += ", ";
    }
    ret += "]{data: " + data_to_string_fn(node->data) + "}";
    return ret;
}

template <typename T, size_t K>
inline std::string to_string(const vec_t<T, K>& v, std::function<std::string(const T&)>
        coord_to_string_fn) {
    std::string ret = "[";
    for (int i = 0; i < v.size(); i++) {
        ret += coord_to_string_fn(v[i]);
        if (i+1 != v.size())
            ret += ", ";
    }
    return ret + "]";
}

/*! Creates a string representation of a kd tree box */
template <typename T, size_t K>
inline std::string to_string(const hyperbox_t<T, K>& bb, std::function<std::string(const T&)>
        coord_to_string_fn)
{
    return "[" + kdtree::to_string(bb.min_coords) + ", " + kdtree::to_string(bb.max_coords) + "]";
}

/*! Creates a string representation of a kd tree sphere */
template <typename T, size_t K>
inline std::string to_string(const hypersphere_t<T, K>& circle, std::function<std::string(const T&)>
        coord_to_string_fn)
{
    return "[" + kdtree::to_string(circle.center) + ", " + coord_to_string_fn(circle.radius) + "]";
}

template <typename T, size_t K, typename D>
inline bool is_tree_valid_recursive(tree_t<T, K, D> *tree, node_t<T, K, D> *root, size_t depth,
        const hyperbox_t<T, K>& bb){
    if (!root)
        return true;

    for (size_t i = 0; i < K; i++)
        if (root->p[i] < bb.min_coords[i] || root->p[i] > bb.max_coords[i]) {
            KDTREE_DEBUG("FAILED at depth: %zu coord: %zu bb: %s point: %s",
                    depth, i, to_string(bb).c_str(), to_string(root->p).c_str());
            return false;
        }

    int coord = depth % K;
    hyperbox_t<T, K> left_bb = bb;
    left_bb.max_coords[coord] = root->p[coord];
    if (!is_tree_valid_recursive(tree, root->left, depth+1, left_bb)) {
        KDTREE_DEBUG("REV:LEFT");
        return false;
    }

    hyperbox_t<T, K> right_bb = bb;
    right_bb.min_coords[coord] = root->p[coord];
    if (!is_tree_valid_recursive(tree, root->right, depth+1, right_bb)) {
        KDTREE_DEBUG("REV:RIGHT");
        return false;
    }

    return true;
}

template <typename T, size_t K, typename D>
inline bool is_tree_valid(tree_t<T, K, D> *tree) {
    if (!tree)
        return false;
    hyperbox_t<T, K> bb;
    for (size_t i = 0; i < K; i++) {
        bb.min_coords[i] = -tree->o->inf;
        bb.max_coords[i] = tree->o->inf;
    }
    return is_tree_valid_recursive(tree, tree->root, 0, bb);
}

} /* namespace kd_tree */

#endif

A short test snippet, I plan to increase the coverage of the tests later on, but for now I feel this would do:

#define KDTREE_ENABLE_LOGGING

#include "kdtree.h"
// #include "debug.h"

#include <set>

// aproximatively from debug.h:
#define DBG(fmt, ...) printf(fmt "\n", ##__VA_ARGS__)
#define ASSERT_FN(fn_call) do { int x = fn_call; if (x < 0) { print("error_str"); return -1; } while (0);

using data_t = int;
using coord_t = int;
constexpr int coord_cnt = 5;

using vecN_t = kdtree::vec_t<coord_t, coord_cnt>;
using kdt_t = kdtree::tree_t<coord_t, coord_cnt, data_t>;
using kdt_p = std::shared_ptr<kdt_t>;

void print_tree(kdt_p tree) {
    auto tstr = kdtree::to_string<coord_t, coord_cnt, data_t>(tree,
            [](const data_t& data) -> std::string { return std::to_string(data); });
    DBG("TREE:\n%s", tstr.c_str());
}

std::string to_string(const vecN_t& v) {
    std::string ret = "[";
    for (int i = 0; i < v.size(); i++) {
        ret += std::to_string(v[i]);
        if (i+1 != v.size())
            ret += ", ";
    }
    return ret + "]";
}

#define ASSERT_TREE_VALIDITY \
if (!kdtree::is_tree_valid(tree) && repeated_test < 0) { \
    DBG("TREE IS INVALID"); \
    kdtree::enable_logging = true; \
    kdtree::is_tree_valid(tree); \
    print_tree(tree); \
    DBG("REPEATING TEST: %d", test); \
    repeated_test = -1; \
    test--; \
    break; \
} else if (repeated_test == 0) { if (!kdtree::is_tree_valid(tree)) break; }

int main(int argc, char const *argv[])
{
    srand(0);

    /* This variable will help signal to repeat the test only once on error */
    int repeated_test = -1000000000;
    for (int test = 0; test < 10000 && repeated_test < 0; test++) {
        if (test % 1000 == 0) {
            DBG("Passed test: %d", test);
        }
        srand(test);
        repeated_test++;
        auto tree = kdt_t::create();
        std::set<vecN_t> inserted_points;
        ASSERT_TREE_VALIDITY;
        for (int i = 0; i < test / 10; i++) {
            ASSERT_TREE_VALIDITY;
            if (inserted_points.size() && (rand() % (test % 4 + 1) == 0)) {
                if (kdtree::enable_logging) {
                    DBG("BEFORE ERASING: ");
                    print_tree(tree);
                }
                ASSERT_TREE_VALIDITY;
                auto p = *inserted_points.begin();
                KDTREE_DEBUG("finding and erasing node: %s", kdtree::to_string(p).c_str());
                auto nodes = kdtree::find(tree, p, kdt_t::exact);
                if (!nodes.size() || !nodes.front()) {
                    DBG("Failed test: node not found %s sz %zu", to_string(p).c_str(), nodes.size());
                    kdtree::enable_logging = true;
                    kdtree::find(tree, p, kdt_t::exact);
                    print_tree(tree);
                    return -1;
                }
                ASSERT_TREE_VALIDITY;
                if (!tree->o->eq(p, nodes.front()->p)) {
                    kdtree::enable_logging = true;
                    DBG("Failed test: find mismatched %s vs %s",
                            to_string(p).c_str(), to_string(nodes.front()->p).c_str());
                    kdtree::find(tree, p, kdt_t::exact);
                    print_tree(tree);
                    return -1;
                }
                inserted_points.erase(p);
                int ret;
                if ((ret = kdtree::remove(tree, p, kdt_t::exact)) != 1) {
                    DBG("Failed test: remove[%d]", ret);
                    print_tree(tree);
                    return -1;
                }
                if (kdtree::enable_logging) {
                    DBG("AFTER ERASING:");
                    print_tree(tree);
                }
                ASSERT_TREE_VALIDITY;
            }
            else {
                ASSERT_TREE_VALIDITY;
                vecN_t point;
                for (int i = 0; i < coord_cnt; i++)
                    point[i] = (rand() % 20) - 5;
                KDTREE_DEBUG("inserting point: %s", kdtree::to_string(point).c_str());
                kdtree::insert(tree, point, 1000 + i);
                inserted_points.insert(point);
                ASSERT_TREE_VALIDITY;
            }
        }
    }

    return 0;
}
\$\endgroup\$

2 Answers 2

6
\$\begingroup\$

Get rid of the macros

I strongly recommend you avoid defining macros whenever possible. They are often a source of subtle bugs, and in most cases they can be replaced by regular C++ code.

For logging, I recommend you use a library like spdlog. However, if you are using logging to help debug your code, then instead spend time to learn how to use a debugger like GDB, or whichever best matches your platform.

The ASSERT_TREE_VALIDITY macro could be replaced with a lambda function defined inside main():

…
int repeated_test = -1000000000;
for (int test = 0; test < 10000 && repeated_test < 0; test++) {
    auto validate_tree = [&]{
        if (!kdtree::is_tree_valid(tree) && repeated_test < 0) {
            …
        }
    };
    …
    validate_tree();
    …
};

Of course, you can't put a break statement inside a lambda and have it break out of the loop the lambda is called from. However, in this case I would just simplify the check to:

auto validate_tree = [&]{
    if (!kdtree::is_tree_valid(tree)) {
        throw std::runtime_error("Tree is invalid");
    }
};

Now it will automatically break out of everything if the exception is thrown. And you can still catch the exception if you want to retry anything. Also, exceptions can be caught a debugger the moment they are thrown.

Use C++ to format and print strings

You don't need C's printf() anymore to format strings; since C++17 there is std::format(), and C++23 made things even easier with std::print(). The aforementioned spdlog library uses the same way to format strings.

Type names and aliases

I see you create a lot of types or type aliases ending with _t. It's best to avoid doing this, as the suffix _t is reserved by POSIX. A more common convention in C++ is to start newly defined types with a capital. So, write struct Tree instead of struct tree_t.

Another issue is that you create type aliases that can be confusing. Consider vec_t: that sounds like an abbreviation of "vector", and we have std::vector in the standard library. However, you defined it to be a std::array. Better would have been arr_t, but even better is to not use a type alias here, and just explicitly write std::array. It's unambiguous, and a good code editor will be able to autocomplete types, so it's not even more typing for you.

Restructure your code to be more idiomatic C++

The way you structured your code it almost looks like you are writing C, and are not using any of the benefits of C++. I would expect an object of type tree_t to have member functions to insert nodes and to query it, however you made it a struct without any non-static members, and instead have stand-alone functions that take a pointer to a tree_t. I would write it like so:

template <typename CoordinateT, std::size_t K, typename DataT>
class Tree {
public:
    // Declare derived types
    using Coordinates = std::array<CoordinateT, K>;

    struct Hyperbox {
        Coordinates min;
        Coordinates max;
    };

    struct Hypersphere {
        Coordiantes center;
        CoordinateT radius;
    };

    class Node {
        std::unique_ptr<Node> left;
        std::unique_ptr<Node> right;
        Coordinates point;
        DataT data;
    };

    void insert(const Coordinates& point, DataT&& data);
    void remove(const Hypersphere& range);
    std::vector<std::pair<CoordinateT, DataT>> find(const Hypersphere& range);

private:
    std::unique_ptr<Node> root;

    bool eq(const Coordinates& lhs, const Coordinates& rhs);
    CoordinateT dist2(const Coordinates& lhs, const Coordinates& rhs);
    bool intersects(const Hyperbox& box, const Hypersphere& sphere);
};

Now you can use it like so:

int main() {
    kdtree::Tree<int, 5, int> tree;

    auto random_point = []{
        decltype(tree)::Coordinates point;
        std::ranges::generate(point, rand);
        return point;
    }

    for (…) {
        tree.insert(random_point(), rand());
        …
    }
}

Now there is more need to pass around a tree using a std::shared_ptr; it now looks more like a regular container. By nesting the related types inside Tree, we avoid having to repeat template<T, K, D> so much. Since your code effectively hardcoded the eq(), dist2() and interects() anyway, I just made them regular member functions here, but you could replace those with member variables of type std::function<…> and still be able to customize them. For the latter case, you'd add a constructor to Tree to take these functions as arguments.

Also note that I made the left and right pointers std::unique_ptr. This makes a tree node own its children, and this way you no longer have to call new and delete manually. This prevents the memory leak your tree has if you don't manually remove() all the nodes.

Use one template parameter for the type of a point

Your code uses two template parameters for the type of a point: the type of a coordinate and the number of coordinates. You then define vec_t to be a std::array holding the coordinates, but now you also need functions that operate on those arrays to check for equality, distance, and so on.

Instead of that, use a single template parameter to define the type of a point. This type should hold all the coordinates, and have functions to check for equality, distance, and so on. This way, your tree implementation does not have to deal with that at all! Basically, you pass the type of Coordinates directly. Consider:

struct vec3f {
    float x;
    float y;
    float z;

    operator<=>(const vec3f&) = default; // (in)equality checking for free!

    float dist2(const vec3f& other) {
        return (x - other.x) * (x - other.x) + …;
    }
};

kdtree::Tree<vec3f, std::string> tree;
tree.insert({1, 3.1415, 42}, "Hello, tree!");

You could then also just use the types provided by existing vector math libraries like Eigen, GLM and others.

Of course, your tree requires that you can get the individual coordinates out of a point by index. You have to ensure the point type has an operator[] then, and some way to get the number of dimensions. You can provide a separate template<typename T, std::size_t K> struct Coordinates that users can use if they don't want to create a custom class. Then they'd write code like:

kdtree::Tree<kdtree::Coordinates<int, 5>, int> tree;
tree.insert({1, 2, 3, 4, 5}, 42);
\$\endgroup\$
2
\$\begingroup\$

Missing definitions

You have failed to provide definitions for the following:

  • printf (presumably std::printf from <cstdio>)
  • size_t (std::size_t, which several headers provide, but none included here)
  • std::array
  • std::vector
  • std::string
  • std::forward

A header is much easier for users if it can be included and used without any prerequisites.

\$\endgroup\$

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.