Two partitions of a set have a greatest lower bound (meet) and a least upper bound (join).
See here: Meets and joins in the lattice of partitions (Math SE)
The meet is easy to calculate. I am trying to improve the calculation of the join.
As part of a library I have written a class for set partitions.
Its join method can be found here.
It uses the method pairs, which can be found here.
Creating all the 2-subsets of a block is not a great idea for big blocks.
I intended this approach only as a prototype.
But to my surprise I was still faster than my Python interpretation of the answer on Math SE.
The SetPart is essentially the list of blocks wrapped in a class.
This is the property pairs: (file)
from itertools import combinations
def pairs(self):
result = set()
for block in self.blocks:
for pair in combinations(block, 2):
result.add(pair)
return result
The method join_pairs is the essential part of the algorithm.: (file)
The method merge_pair merges two elements into the same block.
def join_pairs(self, other):
from discretehelpers.set_part import SetPart
result = SetPart([], self.domain)
for pair in self.pairs.union(other.pairs):
result.merge_pair(*pair)
return result
The method join uses join_pairs. It removes redundant elements (called trash) from the blocks of both partitions. This trash is then added to the blocks of the result again. This way only the necessary amount of 2-element subsets is used in the calculation.
def join(self, other):
from itertools import chain
from discretehelpers.set_part import SetPart
meet_part = self.meet(other)
trash = set()
rep_to_trash = dict()
for block in meet_part.blocks:
block_rep = min(block)
block_trash = set(block) - {block_rep}
trash |= block_trash
rep_to_trash[block_rep] = block_trash
clean_s_blocks = [sorted(set(block) - trash) for block in self.blocks]
clean_o_blocks = [sorted(set(block) - trash) for block in other.blocks]
clean_s_part = SetPart(clean_s_blocks)
clean_o_part = SetPart(clean_o_blocks)
clean_join_part = clean_s_part.join_pairs(clean_o_part)
s_elements = set(chain.from_iterable(self.blocks))
o_elements = set(chain.from_iterable(other.blocks))
dirty_domain = s_elements | o_elements
clean_domain = dirty_domain - trash
dirty_join_blocks = []
clean_blocks_with_singletons = clean_join_part.blocks_with_singletons(elements=clean_domain)
for clean_block in clean_blocks_with_singletons:
dirty_block = set(clean_block)
for element in clean_block:
if element in rep_to_trash:
dirty_block |= rep_to_trash[element]
dirty_join_blocks.append(sorted(dirty_block))
return SetPart(dirty_join_blocks, domain=self.domain)
My standard example are these two partitions:
a = SetPart([[0, 1, 2, 4], [5, 6, 9], [7, 8]])
b = SetPart([[0, 1], [2, 3, 4], [6, 8, 9]])
To test performance I have increased their size by factor 10, e.g. by replacing 5 with 50...59:
import timeit
from discretehelpers.set_part import SetPart
a = SetPart([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89]])
b = SetPart([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]])
print(timeit.timeit(lambda: a.meet(b), number=1000)) # 0.03588864533230662
print(timeit.timeit(lambda: a.meet_pairs(b), number=1000)) # 1.766225082334131
print(timeit.timeit(lambda: a.join(b), number=1000)) # 0.7479327972978354
print(timeit.timeit(lambda: a.join_pairs(b), number=1000)) # 4.01305090310052
My question is mainly about improving the algorithm. But fancy Python tricks are also welcome.
I have edited the question. By removing redundant elements before the calculation, I have avoided choking the process with factorial growth. But this is probably still far from the optimal solution.
A simplified version of the SetPart class can be found here.
It can be run in a single Python file, and has no dependencies.
As requested, here is the whole class, including the methods shown above:
from functools import cached_property
from itertools import combinations, product, chain
def have(val):
return val is not None
########################################################################################################################
class SetPart(object):
def __init__(self, blocks=None, domain='N'):
"""
:param blocks: List of lists. Each block is a list with at least two elements. The blocks do not intersect.
:param domain: Set of allowed elements of blocks. Can be a finite set. Elements are usually integers.
By default, the domain is the set of non-negative integers.
"""
if domain not in ['N', 'Z']:
assert type(domain) in [set, list, tuple, range]
self.domain = set(domain)
else:
self.domain = domain # keep the letters
if blocks is None:
self.set_trivial()
return
blocks = sorted(sorted(block) for block in blocks if len(block) > 1)
if not blocks:
self.set_trivial()
return
self.blocks = blocks
_ = dict()
for block_index, block in enumerate(self.blocks):
for element in block:
_[element] = block_index
self.non_singleton_to_block_index = _
self.non_singletons = set(self.non_singleton_to_block_index.keys())
assert len(self.non_singletons) == sum([len(block) for block in self.blocks])
if self.domain == 'N':
self.length = max(self.non_singletons) + 1
self.trivial = False
##############################################################################################
def set_trivial(self):
self.trivial = True
self.blocks = []
self.non_singleton_to_block_index = dict()
self.non_singletons = set()
if self.domain == 'N':
self.length = 0
###############################################
def __eq__(self, other):
return self.blocks == other.blocks
###############################################
def element_in_domain(self, element):
if self.domain == 'N':
return type(element) == int and element >= 0
elif self.domain == 'Z':
return type(element) == int
else:
return element in self.domain
###############################################
def merge_pair(self, a, b):
"""
When elements `a` and `b` are in different blocks, both blocks will be merged.
Changes the partition. Returns nothing.
"""
if a == b:
return # nothing to do
assert self.element_in_domain(a) and self.element_in_domain(b)
a_found = a in self.non_singletons
b_found = b in self.non_singletons
if a_found and b_found:
block_index_a = self.non_singleton_to_block_index[a]
block_index_b = self.non_singleton_to_block_index[b]
if block_index_a == block_index_b:
return # nothing to do
block_a = self.blocks[block_index_a]
block_b = self.blocks[block_index_b]
merged_block = sorted(block_a + block_b)
self.blocks.remove(block_a)
self.blocks.remove(block_b)
self.blocks.append(merged_block)
elif not a_found and not b_found:
self.blocks.append([a, b])
else: # a_found and not b_found
if b_found and not a_found:
a, b = b, a
block_index_a = self.non_singleton_to_block_index[a]
self.blocks[block_index_a].append(b)
self.__init__(self.blocks, self.domain) # reinitialize
###############################################
def blocks_with_singletons(self, elements=None):
"""
:param elements: Any subset of the domain.
:return: Blocks with added singleton-blocks for each element in `elements` that is not in an actual block.
"""
assert type(elements) in [set, list, range]
assert self.non_singletons.issubset(set(elements))
singletons = set(elements).difference(self.non_singletons)
singleton_blocks = [[_] for _ in singletons]
return sorted(self.blocks + singleton_blocks)
##############################################################################################
@cached_property
def pairs(self):
"""
:return: For each block the set of all is 2-element subsets. All those in one set.
Slow for big blocks, because of factorial growth.
"""
result = set()
for block in self.blocks:
for pair in combinations(block, 2):
result.add(pair)
return result
###############################################
def join_pairs(self, other):
"""
:param other: another set partition
:return: The join of the two set partitions.
This method uses the property `pairs`, so it is also slow for big blocks.
"""
assert self.domain == other.domain
result = SetPart([], self.domain)
for pair in self.pairs.union(other.pairs):
result.merge_pair(*pair)
return result
###############################################
def meet(self, other):
"""
:param other: another set partition
:return: The meet of the two set partitions.
Let M be the meet of partitions A and B. The blocks of M are the intersections of the blocks of A and B.
"""
meet_blocks = []
for s_block, o_block in product(self.blocks, other.blocks):
intersection = set(s_block) & set(o_block)
if intersection:
meet_blocks.append(sorted(intersection))
return SetPart(meet_blocks, self.domain)
###############################################
def join(self, other):
"""
:param other: another set partition
:return: The join of the two set partitions.
This method uses the method `join_pairs`.
The danger of factorial growth is reduced, by making the input partitions smaller.
"""
meet_part = self.meet(other)
trash = set()
rep_to_trash = dict()
for block in meet_part.blocks:
block_rep = min(block)
block_trash = set(block) - {block_rep}
trash |= block_trash
rep_to_trash[block_rep] = block_trash
clean_s_blocks = [sorted(set(block) - trash) for block in self.blocks]
clean_o_blocks = [sorted(set(block) - trash) for block in other.blocks]
clean_s_part = SetPart(clean_s_blocks)
clean_o_part = SetPart(clean_o_blocks)
clean_join_part = clean_s_part.join_pairs(clean_o_part)
s_elements = set(chain.from_iterable(self.blocks))
o_elements = set(chain.from_iterable(other.blocks))
dirty_domain = s_elements | o_elements
clean_domain = dirty_domain - trash
dirty_join_blocks = []
clean_blocks_with_singletons = clean_join_part.blocks_with_singletons(elements=clean_domain)
for clean_block in clean_blocks_with_singletons:
dirty_block = set(clean_block)
for element in clean_block:
if element in rep_to_trash:
dirty_block |= rep_to_trash[element]
dirty_join_blocks.append(sorted(dirty_block))
return SetPart(dirty_join_blocks, domain=self.domain)
if len(pairs) > 0:check is redundant and can be safely removed. Also have you tried using a different interpreter than the standard CPython interpreter like pypy to see if there is any performance boost from that? \$\endgroup\$SetPart([[0, 1], [2, 3, 4], [6, 8, 9]]), I do not see 5. Replacing a single number by appending 0..9 should increase size by 9. I may have found en.wikiversity on Discrete helpers - it is miles from the source code, which does not follow PEPs 8&257 regarding documentation. \$\endgroup\$[5, 6, 9]is a block of partitiona. It does contain the element 5. I meant increase by factor 10. \$\endgroup\$joinusesmeetandjoin_pairs, and that usespairsandmerge_pair. But the rest of the class can be ignored. The question is only about thejoinmethod. \$\endgroup\$