Skip to main content
8 of 14
edited title
Watchduck
  • 230
  • 2
  • 10

Find the join of two set partitions

Two partitions of a set have a greatest lower bound (meet) and a least upper bound (join).
See here: Meets and joins in the lattice of partitions (Math SE)
The meet is easy to calculate. I am trying to improve the calculation of the join.

As part of a library I have written a class for set partitions.
Its join method can be found here. It uses the method pairs, which can be found here.

Creating all the 2-subsets of a block is not a great idea for big blocks.
I intended this approach only as a prototype.
But to my surprise I was still faster than my Python interpretation of the answer on Math SE.

The SetPart is essentially the list of blocks wrapped in a class.

This is the property pairs:    (file)

from itertools import combinations


def pairs(self):

    result = set()

    for block in self.blocks:
        for pair in combinations(block, 2):
            result.add(pair)

    return result

The method join_pairs is the essential part of the algorithm.:    (file)
The method merge_pair merges two elements into the same block.

def join_pairs(self, other):

    from discretehelpers.set_part import SetPart

    result = SetPart([], self.domain)

    for pair in self.pairs.union(other.pairs):
        result.merge_pair(*pair)

    return result

The method join uses join_pairs. It removes redundant elements (called trash) from the blocks of both partitions. This trash is then added to the blocks of the result again. This way only the necessary amount of 2-element subsets is used in the calculation.

def join(self, other):

    from itertools import chain
    from discretehelpers.set_part import SetPart

    meet_part = self.meet(other)

    trash = set()
    rep_to_trash = dict()
    for block in meet_part.blocks:
        block_rep = min(block)
        block_trash = set(block) - {block_rep}
        trash |= block_trash
        rep_to_trash[block_rep] = block_trash

    clean_s_blocks = [sorted(set(block) - trash) for block in self.blocks]
    clean_o_blocks = [sorted(set(block) - trash) for block in other.blocks]

    clean_s_part = SetPart(clean_s_blocks)
    clean_o_part = SetPart(clean_o_blocks)

    clean_join_part = clean_s_part.join_pairs(clean_o_part)

    s_elements = set(chain.from_iterable(self.blocks))
    o_elements = set(chain.from_iterable(other.blocks))
    dirty_domain = s_elements | o_elements
    clean_domain = dirty_domain - trash

    dirty_join_blocks = []
    clean_blocks_with_singletons = clean_join_part.blocks_with_singletons(elements=clean_domain)
    for clean_block in clean_blocks_with_singletons:
        dirty_block = set(clean_block)
        for element in clean_block:
            if element in rep_to_trash:
                dirty_block |= rep_to_trash[element]
        dirty_join_blocks.append(sorted(dirty_block))

    return SetPart(dirty_join_blocks, domain=self.domain)

My standard example are these two partitions:

a = SetPart([[0, 1, 2, 4], [5, 6, 9], [7, 8]])
b = SetPart([[0, 1], [2, 3, 4], [6, 8, 9]])

To test performance I have increased their size by factor 10, e.g. by replacing 5 with 50...59:

import timeit

from discretehelpers.set_part import SetPart


a = SetPart([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89]])
b = SetPart([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]])

print(timeit.timeit(lambda: a.meet(b), number=1000))        # 0.03588864533230662
print(timeit.timeit(lambda: a.meet_pairs(b), number=1000))  # 1.766225082334131
print(timeit.timeit(lambda: a.join(b), number=1000))        # 0.7479327972978354
print(timeit.timeit(lambda: a.join_pairs(b), number=1000))  # 4.01305090310052

My question is mainly about improving the algorithm. But fancy Python tricks are also welcome.

I have edited the question. By removing redundant elements before the calculation, I have avoided choking the process with factorial growth. But this is probably still far from the optimal solution.

A simplified version of the SetPart class can be found here.
It can be run in a single Python file, and requires only bidict.

Watchduck
  • 230
  • 2
  • 10