Skip to content

Commit

Permalink
WIP: Fixing merge
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelma committed Sep 11, 2023
1 parent df03f00 commit 169aa77
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 121 deletions.
77 changes: 20 additions & 57 deletions spinetoolbox/spine_db_editor/graphics_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,18 @@ def __init__(self, spine_db_editor, x, y, extent, db_map_ids):
self.setToolTip(self._make_tool_tip())
self._highlight_color = Qt.transparent
self._extent = None
self.label_item = EntityLabelItem(self)

def _make_tool_tip(self):
raise NotImplementedError()

def default_parameter_data(self):
raise NotImplementedError()

@property
def has_dimensions(self):
raise NotImplementedError()

@property
def entity_type(self):
raise NotImplementedError()
Expand Down Expand Up @@ -243,7 +248,7 @@ def _get_color(self):

def _get_arc_width(self):
for db_map, id_ in self.db_map_ids:
arc_width = self._spine_db_editor.get_arc_width(db_map, id_)
arc_width = self._spine_db_editor.get_arc_width(db_map, self.entity_type, id_)
if arc_width is not None:
min_val, val, max_val = arc_width
range_ = max_val - min_val
Expand All @@ -253,9 +258,6 @@ def _get_arc_width(self):
# val == max_val => result = 5
return 1 + 9 * (val - min_val) / range_

def _has_name(self):
return True

def _set_up(self):
if self._has_name():
name = self._get_name()
Expand All @@ -280,7 +282,7 @@ def _set_up(self):

def polish(self):
self._renderer = self.db_mngr.entity_class_renderer(
self.first_db_map, self.first_entity_class_id, color_code=self._get_color()
self.first_db_map, self.entity_class_type, self.first_entity_class_id, color_code=self._get_color()
)
self._svg_item.setSharedRenderer(self._renderer)
self._update_arcs_width()
Expand All @@ -301,7 +303,7 @@ def _init_bg(self):
def refresh_icon(self):
"""Refreshes the icon."""
renderer = self.db_mngr.entity_class_renderer(
self.first_db_map, self.first_entity_class_id, color_code=self._get_color()
self.first_db_map, self.entity_class_type, self.first_entity_class_id, color_code=self._get_color()
)
self._set_renderer(renderer)

Expand Down Expand Up @@ -483,7 +485,11 @@ def __init__(self, spine_db_editor, x, y, extent, db_map_ids):
db_map_ids (tuple): tuple of (db_map, id) tuples
"""
super().__init__(spine_db_editor, x, y, extent, db_map_ids=db_map_ids)
self._update_all()
self._set_up()

@property
def has_dimensions(self):
return True

def default_parameter_data(self):
"""Return data to put as default in a parameter table when this item is selected."""
Expand Down Expand Up @@ -530,18 +536,6 @@ def _make_tool_tip(self):
f"""@{self.display_database}</p></html>"""
)

def _update_all(self):
self._extent = 0.5 * self._given_extent
self.setRect(-0.5 * self._extent, -0.5 * self._extent, self._extent, self._extent)
self.refresh_icon()
self._init_bg()

def _init_bg(self):
extent = self._extent
self._bg = QGraphicsEllipseItem(-0.5 * extent, -0.5 * extent, extent, extent, self)
self._bg.setPen(Qt.NoPen)
self._bg_brush = QGuiApplication.palette().button()

def add_arc_item(self, arc_item):
super().add_arc_item(arc_item)
self._rotate_svg_item()
Expand Down Expand Up @@ -595,10 +589,12 @@ def __init__(self, spine_db_editor, x, y, extent, db_map_ids):
"""
super().__init__(spine_db_editor, x, y, extent, db_map_ids=db_map_ids)
self._db_map_relationship_class_lists = {}
self.label_item = ObjectLabelItem(self)
self.setZValue(0.5)
self.update_name()
self._update_all()
self._set_up()

@property
def has_dimensions(self):
return False

def default_parameter_data(self):
"""Return data to put as default in a parameter table when this item is selected."""
Expand All @@ -623,44 +619,11 @@ def shape(self):
def _has_name(self):
return bool(self.label_item.toPlainText())

def _update_all(self):
if not self._has_name():
self.label_item.hide()
self._extent = 0.2 * self._given_extent
else:
self.label_item.show()
self._extent = self._given_extent
self.setRect(-0.5 * self._extent, -0.5 * self._extent, self._extent, self._extent)
self.refresh_icon()
self._init_bg()

def update_name(self):
"""Refreshes the name."""
db_map_ids_by_name = dict()
for db_map, id_ in self.db_map_ids:
name = self._spine_db_editor.get_item_name(db_map, self.entity_type, id_)
db_map_ids_by_name.setdefault(name, list()).append((db_map, id_))
if len(db_map_ids_by_name) == 1:
name = next(iter(db_map_ids_by_name))
self.label_item.setPlainText(name)
return True
current_name = self.label_item.toPlainText()
self._db_map_ids = tuple(db_map_ids_by_name.get(current_name, ()))
return False

def _make_tool_tip(self):
if not self.first_id:
return None
return f"<html><p style='text-align:center;'>{self.entity_name}<br>@{self.display_database}</html>"

def _init_bg(self):
bg_rect = QRectF(-0.5 * self._extent, -0.5 * self._extent, self._extent, self._extent)
self._bg = QGraphicsRectItem(bg_rect, self)
self._bg.setFlag(QGraphicsItem.ItemStacksBehindParent, enabled=True)
pen = self._bg.pen()
pen.setColor(Qt.transparent)
self._bg.setPen(pen)

def mouseDoubleClickEvent(self, e):
add_relationships_menu = QMenu(self._spine_db_editor)
title = TitleWidgetAction("Add relationships", self._spine_db_editor)
Expand Down Expand Up @@ -968,8 +931,8 @@ def _make_pen(self):
return pen


class ObjectLabelItem(QGraphicsTextItem):
"""Provides a label for ObjectItem's."""
class EntityLabelItem(QGraphicsTextItem):
"""Provides a label for EntityItem's."""

entity_name_edited = Signal(str)

Expand Down
138 changes: 74 additions & 64 deletions spinetoolbox/spine_db_editor/widgets/graph_view_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def __init__(self, *args, **kwargs):
self.src_inds = list()
self.dst_inds = list()
self._pvs_by_pname = {}
self._pv_cache = {}
self._val_ranges_by_pname = {}
self._possible_colors = {}
self._adding_relationships = False
self._pos_for_added_objects = None
self.added_db_map_relationship_ids = set()
Expand Down Expand Up @@ -231,7 +231,7 @@ def _handle_objects_updated(self, db_map_data):
"""
updated_ids = {(db_map, x["id"]) for db_map, objs in db_map_data.items() for x in objs}
for item in self.ui.graphicsView.items():
if isinstance(item, ObjectItem) and set(item.db_map_ids).intersection(updated_ids)
if isinstance(item, ObjectItem) and set(item.db_map_ids).intersection(updated_ids):
if not item.has_unique_key():
self.build_graph(persistent=True)
break
Expand Down Expand Up @@ -322,6 +322,7 @@ def hide_removed_entities(self, db_map_data, type_):
self.scene = scene

def _graph_handle_parameter_values_added(self, db_map_data):
# FIXME: not in use at the moment, but maybe handy
pnames = {x["parameter_definition_name"] for db_map in self.db_maps for x in db_map_data.get(db_map, ())}
position_pnames = {self.ui.graphicsView.pos_x_parameter, self.ui.graphicsView.pos_y_parameter}
property_pnames = {
Expand All @@ -345,8 +346,11 @@ def polish_items(self):
def _update_property_pvs(self):
self._pvs_by_pname = {
pname: {
(db_map, ent_id): self._get_pv(db_map, ent_id, pname)
for db_map_ent_ids in self.db_map_entity_id_sets
(db_map, ent_id): self._get_pv(db_map, item_type, ent_id, pname)
for item_type, db_map_ent_id_sets in zip(
("object", "relationship"), (self.db_map_object_id_sets, self.db_map_relationship_id_sets)
)
for db_map_ent_ids in db_map_ent_id_sets
for db_map, ent_id in db_map_ent_ids
}
for pname in (self.ui.graphicsView.color_parameter, self.ui.graphicsView.arc_width_parameter)
Expand Down Expand Up @@ -417,8 +421,8 @@ def build_graph(self, persistent=False):
self._stop_layout_generators()
self._layout_gen_id = monotonic()
self.layout_gens[self._layout_gen_id] = layout_gen = self._make_layout_generator()
self._progress_bar_widget.set_layout_generator(layout_gen)
self._progress_bar_widget.show()
self.ui.progress_bar_widget.set_layout_generator(layout_gen)
self.ui.progress_bar_widget.show()
layout_gen.layout_available.connect(self._complete_graph)
layout_gen.finished.connect(lambda id_: self.layout_gens.pop(id_, None)) # Lambda to avoid issues in Python 3.7
self._thread_pool.start(layout_gen)
Expand Down Expand Up @@ -546,6 +550,7 @@ def _update_graph_data(self):
self.db_map_object_id_sets = new_db_map_object_id_sets
self.db_map_relationship_id_sets = new_db_map_relationship_id_sets
self._update_src_dst_inds(db_map_object_id_lists)
self._update_pv_cache()
self._update_property_pvs()
return True

Expand Down Expand Up @@ -588,63 +593,57 @@ def _update_src_dst_inds(self, db_map_object_id_lists):
self.src_inds.append(src)
self.dst_inds.append(dst)

def _get_pv(self, db_map, item_type, entity_id, pname):
if not pname:
return None
alternative = next(iter(self.db_mngr.get_items(db_map, "alternative", only_visible=False)), None)
if not alternative:
return None
pv = self._pv_cache.get((db_map, entity_id, pname, alternative["id"]))
if not pv:
return None
return from_database(pv["value"], pv["type"])

@busy_effect
def _update_pv_cache(self):
db_map_ent_ids = list(
(db_map, ent_id)
for item_type, db_map_ent_id_sets in zip(
("object", "relationship"), (self.db_map_object_id_sets, self.db_map_relationship_id_sets)
)
for db_map_ent_ids in db_map_ent_id_sets
for db_map, ent_id in db_map_ent_ids
)
pnames = (
self.ui.graphicsView.name_parameter,
self.ui.graphicsView.pos_x_parameter,
self.ui.graphicsView.pos_y_parameter,
self.ui.graphicsView.color_parameter,
self.ui.graphicsView.arc_width_parameter,
)
self._pv_cache.clear()
for db_map in self.db_maps:
for pv in self.db_mngr.get_items(db_map, "parameter_value", only_visible=False):
entity_id = pv["entity_id"]
pname = pv["parameter_name"]
if (db_map, entity_id) not in db_map_ent_ids or pname not in pnames:
continue
alt_id = pv["alternative_id"]
self._pv_cache[db_map, entity_id, pname, alt_id] = pv

def get_item_name(self, db_map, item_type, entity_id):
entity = self.db_mngr.get_item(db_map, item_type, entity_id, only_visible=False)
if not entity:
return ""
if not self.ui.graphicsView.name_parameter:
entity = self.db_mngr.get_item(db_map, item_type, entity_id, only_visible=False)
return entity["name"]
if self._name_by_db_map_entity_id is None:
self._name_by_db_map_entity_id = {
(db_map, pv["entity_id"]): pv
for db_map in self.db_maps
for pv in self.db_mngr.get_items_by_field(
db_map, "parameter_value", "parameter_name", self.ui.graphicsView.name_parameter, only_visible=False
)
}
name_pv = self._name_by_db_map_entity_id.get((db_map, entity_id))
if not name_pv:
return ""
name = from_database(name_pv["value"], name_pv["type"])
if isinstance(name, str):
return name
return ""

def _get_item_color(self, db_map, item_type, entity_id):
entity = self.db_mngr.get_item(db_map, item_type, entity_id, only_visible=False)
if not entity:
return None
if self._color_by_db_map_entity_id is None:
self._color_by_db_map_entity_id = {
(db_map, pv["entity_id"]): pv
for db_map in self.db_maps
for pv in self.db_mngr.get_items_by_field(
db_map,
"parameter_value",
"parameter_name",
self.ui.graphicsView.color_parameter,
only_visible=False,
)
}
color_pv = self._color_by_db_map_entity_id.get((db_map, entity_id))
if not color_pv:
return None
return from_database(color_pv["value"], color_pv["type"])
return self._get_pv(db_map, item_type, entity_id, self.ui.graphicsView.name_parameter)

def get_item_color(self, db_map, item_type, entity_id):
if len(self._possible_colors) == 1:
return None
color = self._get_item_color(db_map, item_type, entity_id)
k = self._possible_colors.get(color)
return k, len(self._possible_colors)

def get_item_color(self, db_map, entity_id):
return self._get_item_property(db_map, entity_id, self.ui.graphicsView.color_parameter)
return self._get_item_property(db_map, item_type, entity_id, self.ui.graphicsView.color_parameter)

def get_arc_width(self, db_map, entity_id):
return self._get_item_property(db_map, entity_id, self.ui.graphicsView.arc_width_parameter)
def get_arc_width(self, db_map, item_type, entity_id):
return self._get_item_property(db_map, item_type, entity_id, self.ui.graphicsView.arc_width_parameter)

def _get_item_property(self, db_map, entity_id, pname):
def _get_item_property(self, db_map, item_type, entity_id, pname):
"""Returns a tuple of (min_value, value, max_value) for given entity and property.
Returns (0, 0, 0) if the property is not defined for the entity.
Returns None if the property is not defined for *any* entity.
Expand All @@ -666,6 +665,15 @@ def _get_item_property(self, db_map, entity_id, pname):
min_val, max_val = val_range
return min_val, val, max_val

def _get_fixed_pos(self, db_map, item_type, entity_id):
pos_x, pos_y = [
self._get_pv(db_map, item_type, entity_id, pname)
for pname in (self.ui.graphicsView.pos_x_parameter, self.ui.graphicsView.pos_y_parameter)
]
if isinstance(pos_x, float) and isinstance(pos_y, float):
return {"x": pos_x, "y": pos_y}
return None

def _get_parameter_positions(self, parameter_name):
if not parameter_name:
yield from []
Expand All @@ -688,23 +696,25 @@ def _make_layout_generator(self):
for item in self.ui.graphicsView.items():
if isinstance(item, EntityItem):
fixed_positions[item.first_db_map, item.first_id] = {"x": item.pos().x(), "y": item.pos().y()}
param_pos_x = dict(self._get_parameter_positions(self.ui.graphicsView.pos_x_parameter))
param_pos_y = dict(self._get_parameter_positions(self.ui.graphicsView.pos_y_parameter))
for db_map_entity_id in param_pos_x.keys() & param_pos_y.keys():
fixed_positions[db_map_entity_id] = {"x": param_pos_x[db_map_entity_id], "y": param_pos_y[db_map_entity_id]}
db_map_entity_ids = self.db_map_object_id_sets + self.db_map_relationship_id_sets
for item_type, db_map_entity_id_sets in zip(
("object", "relationship"), (self.db_map_object_id_sets, self.db_map_relationship_id_sets)
):
for db_map_entity_ids in db_map_entity_id_sets:
for db_map, entity_id in db_map_entity_ids:
fixed_positions[db_map, entity_id] = self._get_fixed_pos(db_map, item_type, entity_id)
db_map_entity_id_sets = self.db_map_object_id_sets + self.db_map_relationship_id_sets
heavy_positions = {
ind: fixed_positions[db_map_entity_id]
for ind, db_map_entity_ids in enumerate(db_map_entity_ids)
for ind, db_map_entity_ids in enumerate(db_map_entity_id_sets)
for db_map_entity_id in db_map_entity_ids
if db_map_entity_id in fixed_positions
if fixed_positions.get(db_map_entity_id)
}
spread_factor = int(self.qsettings.value("appSettings/layoutAlgoSpreadFactor", defaultValue="100")) / 100
neg_weight_exp = int(self.qsettings.value("appSettings/layoutAlgoNegWeightExp", defaultValue="2"))
max_iters = int(self.qsettings.value("appSettings/layoutAlgoMaxIterations", defaultValue="12"))
return GraphLayoutGeneratorRunnable(
self._layout_gen_id,
len(db_map_entity_ids),
len(db_map_entity_id_sets),
self.src_inds,
self.dst_inds,
spread=spread_factor * self._ARC_LENGTH_HINT,
Expand Down

0 comments on commit 169aa77

Please sign in to comment.