diff --git a/gerrychain/proposals/tree_proposals.py b/gerrychain/proposals/tree_proposals.py index fe4525fe..bba9e3ff 100644 --- a/gerrychain/proposals/tree_proposals.py +++ b/gerrychain/proposals/tree_proposals.py @@ -45,16 +45,17 @@ def recom( partition.parts[parts_to_merge[0]] | partition.parts[parts_to_merge[1]] ) - flips = recursive_tree_part( + flips_left, flips_right = bipartition_tree_retworkx( subgraph, - parts_to_merge, pop_col=pop_col, pop_target=pop_target, epsilon=epsilon, node_repeats=node_repeats, - method=method, ) + flips = {node: parts_to_merge[0] for node in flips_left} + flips |= {node: parts_to_merge[1] for node in flips_right} + return partition.flip(flips) diff --git a/gerrychain/tree.py b/gerrychain/tree.py index 519b1d77..863cc1e9 100644 --- a/gerrychain/tree.py +++ b/gerrychain/tree.py @@ -182,8 +182,9 @@ def bipartition_tree_retworkx( ): pops = graph.pygraph_pop_lookup(pop_col) - balanced_nodes = retworkx.bipartition_tree(graph.pygraph, lambda x: random.random(), pops, pop_target, epsilon) - return {graph.retworkx_networkx_mapping[x] for x in choice(balanced_nodes)[1]} + balanced_node_choices = retworkx.bipartition_tree(graph.pygraph, lambda x: random.random(), pops, float(pop_target), float(epsilon)) + balanced_nodes = {graph.retworkx_networkx_mapping[x] for x in choice(balanced_node_choices)[1]} + return (balanced_nodes, graph.node_indicies - balanced_nodes) def bipartition_tree(