Skip to content

Commit

Permalink
Fix linter issues
Browse files Browse the repository at this point in the history
  • Loading branch information
InnovativeInventor committed May 16, 2022
1 parent 384b239 commit 11828f3
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
18 changes: 14 additions & 4 deletions gerrychain/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,12 @@ class FrozenGraph:
"networkx_retworkx_mapping"
]

def __init__(self, graph: Graph, pygraph: retworkx.PyGraph = None, mappings: Tuple[dict, dict] = None):
def __init__(
self,
graph: Graph,
pygraph: retworkx.PyGraph = None,
mappings: Tuple[dict, dict] = None
):
self.graph = networkx.classes.function.freeze(graph)
self.graph.join = frozen
self.graph.add_data = frozen
Expand All @@ -366,8 +371,12 @@ def __init__(self, graph: Graph, pygraph: retworkx.PyGraph = None, mappings: Tup
if mappings:
self.retworkx_networkx_mapping, self.networkx_retworkx_mapping = mappings
else:
self.retworkx_networkx_mapping = {node: self.pygraph[node]["__networkx_node__"] for node in self.pygraph.node_indexes()}
self.networkx_retworkx_mapping = {self.pygraph[node]["__networkx_node__"]: node for node in self.pygraph.node_indexes()}
self.retworkx_networkx_mapping = {
node: self.pygraph[node]["__networkx_node__"] for node in self.pygraph.node_indexes()
}
self.networkx_retworkx_mapping = {
self.pygraph[node]["__networkx_node__"]: node for node in self.pygraph.node_indexes()
}

def __len__(self):
return self.size
Expand Down Expand Up @@ -405,7 +414,8 @@ def lookup(self, node, field):
return self.graph.nodes[node][field]

def subgraph(self, nodes):
return FrozenGraph(self.graph.subgraph(nodes),
return FrozenGraph(
self.graph.subgraph(nodes),
self.pygraph.subgraph(
[self.networkx_retworkx_mapping[x] for x in nodes]
)
Expand Down
2 changes: 1 addition & 1 deletion gerrychain/proposals/tree_proposals.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import partial
from ..random import random

from ..random import random
from ..tree import (
recursive_tree_part, bipartition_tree, bipartition_tree_random,
_bipartition_tree_random_all, uniform_spanning_tree,
Expand Down
8 changes: 7 additions & 1 deletion gerrychain/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,13 @@ def bipartition_tree_retworkx(
):
pops = graph.pygraph_pop_lookup(pop_col)

balanced_node_choices = retworkx.bipartition_tree(graph.pygraph, lambda x: random.random(), pops, float(pop_target), float(epsilon))
balanced_node_choices = retworkx.bipartition_tree(
graph.pygraph,
lambda x: random.random(),
pops,
float(pop_target),
float(epsilon)
)
balanced_nodes = {graph.retworkx_networkx_mapping[x] for x in choice(balanced_node_choices)[1]}
return (balanced_nodes, graph.node_indices - balanced_nodes)

Expand Down

0 comments on commit 11828f3

Please sign in to comment.