Skip to content

Commit d72aa3c

Browse files
hbmartincoderabbitai[bot]sourcery-ai[bot]
authored
Support for gradients (#94)
* Support for gradients * typing fixes * Update graphviz2drawio/models/SvgParser.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * Update graphviz2drawio/mx/Node.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
1 parent 3986517 commit d72aa3c

7 files changed

Lines changed: 126 additions & 19 deletions

File tree

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ python -m graphviz2drawio test/directed/hello.gv.txt
108108
## Roadmap
109109

110110
* Migrate to uv/hatch for packaging and dep mgmt
111-
* Support for fill gradient
112111
* Support compatible [arrows](https://graphviz.org/docs/attr-types/arrowType/)
113112
* Support [multiple edges](https://graphviz.org/Gallery/directed/switch.html)
114113
* Support [edges with links](https://graphviz.org/Gallery/directed/pprof.html)

graphviz2drawio/models/SVG.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ def get_first(g: Element, tag: str) -> Element | None:
1111
return g.find(f"./{NS_SVG}{tag}")
1212

1313

14-
def count_tags(g: Element, tag: str) -> int:
15-
return len(g.findall(f"./{NS_SVG}{tag}"))
14+
def findall(g: Element, tag: str) -> list[Element]:
15+
return g.findall(f"./{NS_SVG}{tag}")
1616

1717

1818
def get_title(g: Element) -> str | None:

graphviz2drawio/models/SvgParser.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1+
import re
12
from collections import OrderedDict
3+
from collections.abc import Iterable
4+
from math import isclose
25
from xml.etree import ElementTree
36

47
from graphviz2drawio.mx.Edge import Edge
58
from graphviz2drawio.mx.EdgeFactory import EdgeFactory
6-
from graphviz2drawio.mx.Node import Node
9+
from graphviz2drawio.mx.Node import Gradient, Node
710
from graphviz2drawio.mx.NodeFactory import NodeFactory
811

12+
from ..mx.Curve import LINE_TOLERANCE
13+
from ..mx.utils import adjust_color_opacity
914
from . import SVG
1015
from .commented_tree_builder import COMMENT, CommentedTreeBuilder
1116
from .CoordsTranslate import CoordsTranslate
@@ -29,17 +34,28 @@ def parse_nodes_edges_clusters(
2934
nodes: OrderedDict[str, Node] = OrderedDict()
3035
edges: OrderedDict[str, Edge] = OrderedDict()
3136
clusters: OrderedDict[str, Node] = OrderedDict()
37+
gradients = dict[str, Gradient]()
3238

3339
prev_comment = None
3440
for g in root:
3541
if g.tag == COMMENT:
3642
prev_comment = g.text
43+
elif SVG.is_tag(g, "defs"):
44+
for gradient in _extract_gradients(g):
45+
gradients[gradient[0]] = gradient[1:]
3746
elif SVG.is_tag(g, "g"):
3847
title = prev_comment or SVG.get_title(g)
3948
if title is None:
4049
raise MissingTitleError(g)
50+
if (defs := SVG.get_first(g, "defs")) is not None:
51+
for gradient in _extract_gradients(defs):
52+
gradients[gradient[0]] = gradient[1:]
4153
if g.attrib["class"] == "node":
42-
nodes[title] = node_factory.from_svg(g, labelloc="c")
54+
nodes[title] = node_factory.from_svg(
55+
g,
56+
labelloc="c",
57+
gradients=gradients,
58+
)
4359
elif g.attrib["class"] == "edge":
4460
# We need to merge edges with the same source and target
4561
# GV represents multiple labels with multiple edges
@@ -52,6 +68,64 @@ def parse_nodes_edges_clusters(
5268
else:
5369
edges[edge.key_for_label] = edge
5470
elif g.attrib["class"] == "cluster":
55-
clusters[title] = node_factory.from_svg(g, labelloc="t")
71+
clusters[title] = node_factory.from_svg(
72+
g,
73+
labelloc="t",
74+
gradients=gradients,
75+
)
5676

5777
return nodes, list(edges.values()), clusters
78+
79+
80+
_stop_color_re = re.compile(r"stop-color:([^;]+);")
81+
_stop_opacity_re = re.compile(r"stop-opacity:([^;]+);")
82+
83+
84+
def _extract_stop_color(stop: ElementTree.Element) -> str | None:
85+
style = stop.attrib.get("style", "")
86+
if (color := _stop_color_re.search(style)) is not None:
87+
if (opacity := _stop_opacity_re.search(style)) is not None:
88+
return adjust_color_opacity(color.group(1), float(opacity.group(1)))
89+
return None
90+
91+
92+
def _extract_gradients(
93+
defs: ElementTree.Element,
94+
) -> Iterable[tuple[str, str, str, str]]:
95+
for radial_gradient in SVG.findall(defs, "radialGradient"):
96+
stops = SVG.findall(radial_gradient, "stop")
97+
start_color = _extract_stop_color(stops[0])
98+
end_color = _extract_stop_color(stops[-1])
99+
if start_color is None or end_color is None:
100+
continue
101+
yield (
102+
radial_gradient.attrib["id"],
103+
start_color,
104+
end_color,
105+
"radial",
106+
)
107+
for linear_gradient in SVG.findall(defs, "linearGradient"):
108+
stops = SVG.findall(linear_gradient, "stop")
109+
110+
start_color = _extract_stop_color(stops[0])
111+
end_color = _extract_stop_color(stops[-1])
112+
if start_color is None or end_color is None:
113+
continue
114+
115+
y1 = float(linear_gradient.attrib["y1"])
116+
y2 = float(linear_gradient.attrib["y2"])
117+
118+
gradient_direction = "north"
119+
if isclose(y1, y2, rel_tol=LINE_TOLERANCE):
120+
x1 = float(linear_gradient.attrib["y1"])
121+
x2 = float(linear_gradient.attrib["y2"])
122+
gradient_direction = "east" if x1 < x2 else "west"
123+
elif y1 < y2:
124+
gradient_direction = "south"
125+
126+
yield (
127+
linear_gradient.attrib["id"],
128+
start_color,
129+
end_color,
130+
gradient_direction,
131+
)

graphviz2drawio/mx/Node.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
from typing import TypeAlias
2+
13
from ..models.Rect import Rect
24
from .GraphObj import GraphObj
35
from .MxConst import VERTICAL_ALIGN
46
from .Styles import Styles
57
from .Text import Text
68

9+
Gradient: TypeAlias = tuple[str, str | None, str]
10+
711

812
class Node(GraphObj):
913
def __init__(
@@ -12,7 +16,7 @@ def __init__(
1216
gid: str,
1317
rect: Rect | None,
1418
texts: list[Text],
15-
fill: str,
19+
fill: str | Gradient,
1620
stroke: str,
1721
shape: str,
1822
labelloc: str,
@@ -46,21 +50,29 @@ def texts_to_mx_value(self) -> str:
4650
def get_node_style(self) -> str:
4751
style_for_shape = Styles.get_for_shape(self.shape)
4852
dashed = 1 if self.dashed else 0
53+
additional_styling = ""
4954

5055
attributes = {
51-
"fill": self.fill,
5256
"stroke": self.stroke,
5357
"stroke_width": self.stroke_width,
5458
"dashed": dashed,
5559
}
60+
if isinstance(self.fill, str):
61+
attributes["fill"] = self.fill
62+
elif type(self.fill) is tuple:
63+
attributes["fill"] = self.fill[0]
64+
additional_styling += (
65+
f"gradientColor={self.fill[1]};gradientDirection={self.fill[2]};"
66+
)
67+
5668
if (rect := self.rect) is not None and (image_path := rect.image) is not None:
5769
from graphviz2drawio.mx.image import image_data_for_path
5870

5971
attributes["image"] = image_data_for_path(image_path)
6072

6173
attributes["vertical_align"] = VERTICAL_ALIGN.get(self.labelloc, "middle")
6274

63-
return style_for_shape.format(**attributes)
75+
return style_for_shape.format(**attributes) + additional_styling
6476

6577
def __repr__(self) -> str:
6678
return (

graphviz2drawio/mx/NodeFactory.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from xml.etree.ElementTree import Element
23

34
from graphviz2drawio.models import SVG
@@ -6,7 +7,7 @@
67
from ..models.Errors import MissingIdentifiersError
78
from . import MxConst, Shape
89
from .MxConst import DEFAULT_STROKE_WIDTH
9-
from .Node import Node
10+
from .Node import Gradient, Node
1011
from .RectFactory import rect_from_ellipse_svg, rect_from_image, rect_from_svg_points
1112
from .Text import Text
1213
from .utils import adjust_color_opacity
@@ -17,11 +18,16 @@ def __init__(self, coords: CoordsTranslate) -> None:
1718
super().__init__()
1819
self.coords = coords
1920

20-
def from_svg(self, g: Element, labelloc: str) -> Node:
21+
def from_svg(
22+
self,
23+
g: Element,
24+
labelloc: str,
25+
gradients: dict[str, Gradient],
26+
) -> Node:
2127
sid = g.attrib["id"]
2228
gid = SVG.get_title(g)
2329
rect = None
24-
fill = MxConst.NONE
30+
fill: str | Gradient = MxConst.NONE
2531
stroke = MxConst.NONE
2632
stroke_width = DEFAULT_STROKE_WIDTH
2733
dashed = False
@@ -36,7 +42,8 @@ def from_svg(self, g: Element, labelloc: str) -> Node:
3642
if (polygon := SVG.get_first(g, "polygon")) is not None:
3743
rect = rect_from_svg_points(self.coords, polygon.attrib["points"])
3844
shape = Shape.RECT
39-
fill, stroke = self._extract_fill_and_stroke(polygon)
45+
fill = self._extract_fill(polygon, gradients)
46+
stroke = self._extract_stroke(polygon)
4047
stroke_width = polygon.attrib.get("stroke-width", DEFAULT_STROKE_WIDTH)
4148
if "stroke-dasharray" in polygon.attrib:
4249
dashed = True
@@ -50,10 +57,11 @@ def from_svg(self, g: Element, labelloc: str) -> Node:
5057
)
5158
shape = (
5259
Shape.ELLIPSE
53-
if SVG.count_tags(g, "ellipse") == 1
60+
if len(SVG.findall(g, "ellipse")) == 1
5461
else Shape.DOUBLE_CIRCLE
5562
)
56-
fill, stroke = self._extract_fill_and_stroke(ellipse)
63+
fill = self._extract_fill(ellipse, gradients)
64+
stroke = self._extract_stroke(ellipse)
5765
stroke_width = ellipse.attrib.get("stroke-width", DEFAULT_STROKE_WIDTH)
5866
if "stroke-dasharray" in ellipse.attrib:
5967
dashed = True
@@ -76,15 +84,25 @@ def from_svg(self, g: Element, labelloc: str) -> Node:
7684
dashed=dashed,
7785
)
7886

87+
_fill_url_re = re.compile(r"url\(#([^)]+)\)")
88+
7989
@staticmethod
80-
def _extract_fill_and_stroke(g: Element) -> tuple[str, str]:
90+
def _extract_fill(g: Element, gradients: dict[str, Gradient]) -> str | Gradient:
8191
fill = g.attrib.get("fill", MxConst.NONE)
82-
stroke = g.attrib.get("stroke", MxConst.NONE)
92+
if fill.startswith("url"):
93+
match = NodeFactory._fill_url_re.search(fill)
94+
if match is not None:
95+
return gradients[match.group(1)]
8396
if "fill-opacity" in g.attrib and fill != MxConst.NONE:
8497
fill = adjust_color_opacity(fill, float(g.attrib["fill-opacity"]))
98+
return fill
99+
100+
@staticmethod
101+
def _extract_stroke(g: Element) -> str:
102+
stroke = g.attrib.get("stroke", MxConst.NONE)
85103
if "stroke-opacity" in g.attrib and stroke != MxConst.NONE:
86104
stroke = adjust_color_opacity(stroke, float(g.attrib["stroke-opacity"]))
87-
return fill, stroke
105+
return stroke
88106

89107
def _extract_texts(self, g: Element) -> tuple[list[Text], complex | None]:
90108
texts = []

graphviz2drawio/mx/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ def adjust_color_opacity(hex_color: str, opacity: float) -> str:
44
hex_color = hex_color.lstrip("#")
55

66
# Convert hex to RGB
7-
r, g, b = tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4))
7+
try:
8+
r, g, b = tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4))
9+
except ValueError:
10+
return hex_color
811

912
# Apply opacity over white background
1013
r = int(r * opacity + 255 * (1 - opacity))

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ line-ending = "auto"
2929
"graphviz2drawio/version.py" = ["T201"]
3030
"doc/source/conf.py" = ["A001", "ERA001", "INP001"]
3131
"graphviz2drawio/models/commented_tree_builder.py" = ["ANN001", "ANN201", "ANN204"]
32+
"graphviz2drawio/models/SvgParser.py" = ["C901", "PLR0912"]
3233

3334
[tool.pytest.ini_options]
3435
pythonpath = ". venv/lib/python3.12/site-packages"

0 commit comments

Comments
 (0)