Skip to main content
14 of 14
external link has changed
Watchduck
  • 230
  • 2
  • 10

Find the join of two set partitions

Two partitions of a set have a greatest lower bound (meet) and a least upper bound (join).
See here: Meets and joins in the lattice of partitions (Math SE)
The meet is easy to calculate. I am trying to improve the calculation of the join.

As part of a library I have written a class for set partitions.
Its join method can be found here. It uses the method pairs, which can be found here.

Creating all the 2-subsets of a block is not a great idea for big blocks.
I intended this approach only as a prototype.
But to my surprise I was still faster than my Python interpretation of the answer on Math SE.

The SetPart is essentially the list of blocks wrapped in a class.

This is the property pairs:    (file)

from itertools import combinations


def pairs(self):

    result = set()

    for block in self.blocks:
        for pair in combinations(block, 2):
            result.add(pair)

    return result

The method join_pairs is the essential part of the algorithm.:    (file)
The method merge_pair merges two elements into the same block.

def join_pairs(self, other):

    from discretehelpers.set_part import SetPart

    result = SetPart([], self.domain)

    for pair in self.pairs.union(other.pairs):
        result.merge_pair(*pair)

    return result

The method join uses join_pairs. It removes redundant elements (called trash) from the blocks of both partitions. This trash is then added to the blocks of the result again. This way only the necessary amount of 2-element subsets is used in the calculation.

def join(self, other):

    from itertools import chain
    from discretehelpers.set_part import SetPart

    meet_part = self.meet(other)

    trash = set()
    rep_to_trash = dict()
    for block in meet_part.blocks:
        block_rep = min(block)
        block_trash = set(block) - {block_rep}
        trash |= block_trash
        rep_to_trash[block_rep] = block_trash

    clean_s_blocks = [sorted(set(block) - trash) for block in self.blocks]
    clean_o_blocks = [sorted(set(block) - trash) for block in other.blocks]

    clean_s_part = SetPart(clean_s_blocks)
    clean_o_part = SetPart(clean_o_blocks)

    clean_join_part = clean_s_part.join_pairs(clean_o_part)

    s_elements = set(chain.from_iterable(self.blocks))
    o_elements = set(chain.from_iterable(other.blocks))
    dirty_domain = s_elements | o_elements
    clean_domain = dirty_domain - trash

    dirty_join_blocks = []
    clean_blocks_with_singletons = clean_join_part.blocks_with_singletons(elements=clean_domain)
    for clean_block in clean_blocks_with_singletons:
        dirty_block = set(clean_block)
        for element in clean_block:
            if element in rep_to_trash:
                dirty_block |= rep_to_trash[element]
        dirty_join_blocks.append(sorted(dirty_block))

    return SetPart(dirty_join_blocks, domain=self.domain)

My standard example are these two partitions:

a = SetPart([[0, 1, 2, 4], [5, 6, 9], [7, 8]])
b = SetPart([[0, 1], [2, 3, 4], [6, 8, 9]])

To test performance I have increased their size by factor 10, e.g. by replacing 5 with 50...59:

import timeit

from discretehelpers.set_part import SetPart


a = SetPart([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89]])
b = SetPart([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]])

print(timeit.timeit(lambda: a.meet(b), number=1000))        # 0.03588864533230662
print(timeit.timeit(lambda: a.meet_pairs(b), number=1000))  # 1.766225082334131
print(timeit.timeit(lambda: a.join(b), number=1000))        # 0.7479327972978354
print(timeit.timeit(lambda: a.join_pairs(b), number=1000))  # 4.01305090310052

My question is mainly about improving the algorithm. But fancy Python tricks are also welcome.

I have edited the question. By removing redundant elements before the calculation, I have avoided choking the process with factorial growth. But this is probably still far from the optimal solution.

A simplified version of the SetPart class can be found here.
It can be run in a single Python file, and has no dependencies.

As requested, here is the whole class, including the methods shown above:

from functools import cached_property
from itertools import combinations, product, chain
 
 
def have(val):
    return val is not None
 
 
########################################################################################################################
 
 
class SetPart(object):
 
    def __init__(self, blocks=None, domain='N'):
 
        """
        :param blocks: List of lists. Each block is a list with at least two elements. The blocks do not intersect.
        :param domain: Set of allowed elements of blocks. Can be a finite set. Elements are usually integers.
                       By default, the domain is the set of non-negative integers.
        """
 
        if domain not in ['N', 'Z']:
            assert type(domain) in [set, list, tuple, range]
            self.domain = set(domain)
        else:
            self.domain = domain  # keep the letters
 
        if blocks is None:
            self.set_trivial()
            return
 
        blocks = sorted(sorted(block) for block in blocks if len(block) > 1)
        if not blocks:
            self.set_trivial()
            return
 
        self.blocks = blocks
 
        _ = dict()
        for block_index, block in enumerate(self.blocks):
            for element in block:
                _[element] = block_index
        self.non_singleton_to_block_index = _
 
        self.non_singletons = set(self.non_singleton_to_block_index.keys())
        assert len(self.non_singletons) == sum([len(block) for block in self.blocks])
 
        if self.domain == 'N':
            self.length = max(self.non_singletons) + 1
 
        self.trivial = False
 
    ##############################################################################################
 
    def set_trivial(self):
        self.trivial = True
        self.blocks = []
        self.non_singleton_to_block_index = dict()
        self.non_singletons = set()
 
        if self.domain == 'N':
            self.length = 0
 
    ###############################################
 
    def __eq__(self, other):
        return self.blocks == other.blocks
 
    ###############################################
 
    def element_in_domain(self, element):
        if self.domain == 'N':
            return type(element) == int and element >= 0
        elif self.domain == 'Z':
            return type(element) == int
        else:
            return element in self.domain
 
    ###############################################
 
    def merge_pair(self, a, b):
 
        """
        When elements `a` and `b` are in different blocks, both blocks will be merged.
        Changes the partition. Returns nothing.
        """
 
        if a == b:
            return  # nothing to do
 
        assert self.element_in_domain(a) and self.element_in_domain(b)
 
        a_found = a in self.non_singletons
        b_found = b in self.non_singletons
 
        if a_found and b_found:
 
            block_index_a = self.non_singleton_to_block_index[a]
            block_index_b = self.non_singleton_to_block_index[b]
 
            if block_index_a == block_index_b:
                return  # nothing to do
 
            block_a = self.blocks[block_index_a]
            block_b = self.blocks[block_index_b]
            merged_block = sorted(block_a + block_b)
 
            self.blocks.remove(block_a)
            self.blocks.remove(block_b)
            self.blocks.append(merged_block)
 
        elif not a_found and not b_found:
            self.blocks.append([a, b])
 
        else:  # a_found and not b_found
            if b_found and not a_found:
                a, b = b, a
            block_index_a = self.non_singleton_to_block_index[a]
            self.blocks[block_index_a].append(b)
 
        self.__init__(self.blocks, self.domain)  # reinitialize
 
    ###############################################
 
    def blocks_with_singletons(self, elements=None):
        """
        :param elements: Any subset of the domain.
        :return: Blocks with added singleton-blocks for each element in `elements` that is not in an actual block.
        """
        assert type(elements) in [set, list, range]
        assert self.non_singletons.issubset(set(elements))
        singletons = set(elements).difference(self.non_singletons)
        singleton_blocks = [[_] for _ in singletons]
        return sorted(self.blocks + singleton_blocks)
 
    ##############################################################################################
 
    @cached_property
    def pairs(self):
        """
        :return: For each block the set of all is 2-element subsets. All those in one set.
        Slow for big blocks, because of factorial growth.
        """
        result = set()
        for block in self.blocks:
            for pair in combinations(block, 2):
                result.add(pair)
        return result
 
    ###############################################
 
    def join_pairs(self, other):
        """
        :param other: another set partition
        :return: The join of the two set partitions.
        This method uses the property `pairs`, so it is also slow for big blocks.
        """
        assert self.domain == other.domain
        result = SetPart([], self.domain)
        for pair in self.pairs.union(other.pairs):
            result.merge_pair(*pair)
        return result
 
    ###############################################
 
    def meet(self, other):
        """
        :param other: another set partition
        :return: The meet of the two set partitions.
        Let M be the meet of partitions A and B. The blocks of M are the intersections of the blocks of A and B.
        """
        meet_blocks = []
        for s_block, o_block in product(self.blocks, other.blocks):
            intersection = set(s_block) & set(o_block)
            if intersection:
                meet_blocks.append(sorted(intersection))
        return SetPart(meet_blocks, self.domain)
 
    ###############################################
 
    def join(self, other):
        """
        :param other: another set partition
        :return: The join of the two set partitions.
        This method uses the method `join_pairs`.
        The danger of factorial growth is reduced, by making the input partitions smaller.
        """
        meet_part = self.meet(other)
 
        trash = set()
        rep_to_trash = dict()
        for block in meet_part.blocks:
            block_rep = min(block)
            block_trash = set(block) - {block_rep}
            trash |= block_trash
            rep_to_trash[block_rep] = block_trash
 
        clean_s_blocks = [sorted(set(block) - trash) for block in self.blocks]
        clean_o_blocks = [sorted(set(block) - trash) for block in other.blocks]
 
        clean_s_part = SetPart(clean_s_blocks)
        clean_o_part = SetPart(clean_o_blocks)
 
        clean_join_part = clean_s_part.join_pairs(clean_o_part)
 
        s_elements = set(chain.from_iterable(self.blocks))
        o_elements = set(chain.from_iterable(other.blocks))
        dirty_domain = s_elements | o_elements
        clean_domain = dirty_domain - trash
 
        dirty_join_blocks = []
        clean_blocks_with_singletons = clean_join_part.blocks_with_singletons(elements=clean_domain)
        for clean_block in clean_blocks_with_singletons:
            dirty_block = set(clean_block)
            for element in clean_block:
                if element in rep_to_trash:
                    dirty_block |= rep_to_trash[element]
            dirty_join_blocks.append(sorted(dirty_block))
 
        return SetPart(dirty_join_blocks, domain=self.domain)
Watchduck
  • 230
  • 2
  • 10