1+ import re
12from collections import OrderedDict
3+ from collections .abc import Iterable
4+ from math import isclose
25from xml .etree import ElementTree
36
47from graphviz2drawio .mx .Edge import Edge
58from graphviz2drawio .mx .EdgeFactory import EdgeFactory
6- from graphviz2drawio .mx .Node import Node
9+ from graphviz2drawio .mx .Node import Gradient , Node
710from graphviz2drawio .mx .NodeFactory import NodeFactory
811
12+ from ..mx .Curve import LINE_TOLERANCE
13+ from ..mx .utils import adjust_color_opacity
914from . import SVG
1015from .commented_tree_builder import COMMENT , CommentedTreeBuilder
1116from .CoordsTranslate import CoordsTranslate
@@ -29,17 +34,28 @@ def parse_nodes_edges_clusters(
2934 nodes : OrderedDict [str , Node ] = OrderedDict ()
3035 edges : OrderedDict [str , Edge ] = OrderedDict ()
3136 clusters : OrderedDict [str , Node ] = OrderedDict ()
37+ gradients = dict [str , Gradient ]()
3238
3339 prev_comment = None
3440 for g in root :
3541 if g .tag == COMMENT :
3642 prev_comment = g .text
43+ elif SVG .is_tag (g , "defs" ):
44+ for gradient in _extract_gradients (g ):
45+ gradients [gradient [0 ]] = gradient [1 :]
3746 elif SVG .is_tag (g , "g" ):
3847 title = prev_comment or SVG .get_title (g )
3948 if title is None :
4049 raise MissingTitleError (g )
50+ if (defs := SVG .get_first (g , "defs" )) is not None :
51+ for gradient in _extract_gradients (defs ):
52+ gradients [gradient [0 ]] = gradient [1 :]
4153 if g .attrib ["class" ] == "node" :
42- nodes [title ] = node_factory .from_svg (g , labelloc = "c" )
54+ nodes [title ] = node_factory .from_svg (
55+ g ,
56+ labelloc = "c" ,
57+ gradients = gradients ,
58+ )
4359 elif g .attrib ["class" ] == "edge" :
4460 # We need to merge edges with the same source and target
4561 # GV represents multiple labels with multiple edges
@@ -52,6 +68,64 @@ def parse_nodes_edges_clusters(
5268 else :
5369 edges [edge .key_for_label ] = edge
5470 elif g .attrib ["class" ] == "cluster" :
55- clusters [title ] = node_factory .from_svg (g , labelloc = "t" )
71+ clusters [title ] = node_factory .from_svg (
72+ g ,
73+ labelloc = "t" ,
74+ gradients = gradients ,
75+ )
5676
5777 return nodes , list (edges .values ()), clusters
78+
79+
80+ _stop_color_re = re .compile (r"stop-color:([^;]+);" )
81+ _stop_opacity_re = re .compile (r"stop-opacity:([^;]+);" )
82+
83+
84+ def _extract_stop_color (stop : ElementTree .Element ) -> str | None :
85+ style = stop .attrib .get ("style" , "" )
86+ if (color := _stop_color_re .search (style )) is not None :
87+ if (opacity := _stop_opacity_re .search (style )) is not None :
88+ return adjust_color_opacity (color .group (1 ), float (opacity .group (1 )))
89+ return None
90+
91+
92+ def _extract_gradients (
93+ defs : ElementTree .Element ,
94+ ) -> Iterable [tuple [str , str , str , str ]]:
95+ for radial_gradient in SVG .findall (defs , "radialGradient" ):
96+ stops = SVG .findall (radial_gradient , "stop" )
97+ start_color = _extract_stop_color (stops [0 ])
98+ end_color = _extract_stop_color (stops [- 1 ])
99+ if start_color is None or end_color is None :
100+ continue
101+ yield (
102+ radial_gradient .attrib ["id" ],
103+ start_color ,
104+ end_color ,
105+ "radial" ,
106+ )
107+ for linear_gradient in SVG .findall (defs , "linearGradient" ):
108+ stops = SVG .findall (linear_gradient , "stop" )
109+
110+ start_color = _extract_stop_color (stops [0 ])
111+ end_color = _extract_stop_color (stops [- 1 ])
112+ if start_color is None or end_color is None :
113+ continue
114+
115+ y1 = float (linear_gradient .attrib ["y1" ])
116+ y2 = float (linear_gradient .attrib ["y2" ])
117+
118+ gradient_direction = "north"
119+ if isclose (y1 , y2 , rel_tol = LINE_TOLERANCE ):
120+ x1 = float (linear_gradient .attrib ["y1" ])
121+ x2 = float (linear_gradient .attrib ["y2" ])
122+ gradient_direction = "east" if x1 < x2 else "west"
123+ elif y1 < y2 :
124+ gradient_direction = "south"
125+
126+ yield (
127+ linear_gradient .attrib ["id" ],
128+ start_color ,
129+ end_color ,
130+ gradient_direction ,
131+ )
0 commit comments