diff --git a/doc/changelog.qmd b/doc/changelog.qmd index 5d78880164..496a23afa2 100644 --- a/doc/changelog.qmd +++ b/doc/changelog.qmd @@ -44,6 +44,22 @@ title: Changelog - Added [](:class:`~plotnine.composition.inset_element`) with which you can insert plot compositions or images into another plot. +- Position scales gained a `position` parameter, so you can move an axis to the + opposite side of the panel. Use `position="top"` for the x axis and + `position="right"` for the y axis. + + ```python + ggplot(df, aes("x", "y")) + geom_point() + scale_x_continuous(position="top") + ``` + + Each axis side can be styled independently with new per-side themeables, e.g. + `axis_text_x_top`, `axis_title_y_right`, `axis_line_x_top` and + `axis_ticks_major_y_right`. + + ```python + theme(axis_text_x_top=element_text(color="blue")) + ``` + ### API Changes - Removed `geom.to_layer()`, `stat.to_layer()`, `annotate.to_layer()`, diff --git a/plotnine/_mpl/layout_manager/_layout_tree.py b/plotnine/_mpl/layout_manager/_layout_tree.py index e78d36fb16..c69d55d0ac 100644 --- a/plotnine/_mpl/layout_manager/_layout_tree.py +++ b/plotnine/_mpl/layout_manager/_layout_tree.py @@ -552,7 +552,12 @@ def align_axis_titles(self): def axis_title_clearance(s): return s.axis_title_clearance - for spaces in [self.bottom_spaces, self.left_spaces]: + for spaces in [ + self.bottom_spaces, + self.left_spaces, + self.top_spaces, + self.right_spaces, + ]: _align(spaces, axis_title_clearance, "axis_title_alignment") for tree in self.sub_compositions: diff --git a/plotnine/_mpl/layout_manager/_plot_layout_items.py b/plotnine/_mpl/layout_manager/_plot_layout_items.py index 86fab62e85..df42b2c323 100644 --- a/plotnine/_mpl/layout_manager/_plot_layout_items.py +++ b/plotnine/_mpl/layout_manager/_plot_layout_items.py @@ -6,6 +6,7 @@ from matplotlib.text import Text from plotnine._mpl.patches import StripTextPatch +from plotnine._utils import side_artists from plotnine.composition._compose import Compose from plotnine.exceptions import PlotnineError @@ -89,6 +90,10 @@ def get(name: str) -> Any: self.axis_title_x: Text | None = get("axis_title_x") self.axis_title_y: Text | None = get("axis_title_y") + self.axis_title_x_bottom: Text | None = get("axis_title_x_bottom") + self.axis_title_x_top: Text | None = get("axis_title_x_top") + self.axis_title_y_left: Text | None = get("axis_title_y_left") + self.axis_title_y_right: Text | None = get("axis_title_y_right") # # The legends references the structure that contains the # # AnchoredOffsetboxes (groups of legends) @@ -126,9 +131,9 @@ def _filter_axes(self, location: AxesLocation = "all") -> list[Axes]: if getattr(spec, pred_method)() ] - def axis_text_x(self, ax: Axes) -> Iterator[Text]: + def axis_text_x(self, ax: Axes, side: str) -> Iterator[Text]: """ - Return all x-axis labels for an axes that will be shown + Return the visible x-axis labels on one side of an axes """ major, minor = [], [] @@ -136,15 +141,16 @@ def axis_text_x(self, ax: Axes) -> Iterator[Text]: major = ax.xaxis.get_major_ticks() minor = ax.xaxis.get_minor_ticks() + label_attr = side_artists(side)[1] return ( - tick.label1 + getattr(tick, label_attr) for tick in chain(major, minor) - if _text_is_visible(tick.label1) + if _text_is_visible(getattr(tick, label_attr)) ) - def axis_text_y(self, ax: Axes) -> Iterator[Text]: + def axis_text_y(self, ax: Axes, side: str) -> Iterator[Text]: """ - Return all y-axis labels for an axes that will be shown + Return the visible y-axis labels on one side of an axes """ major, minor = [], [] @@ -152,10 +158,11 @@ def axis_text_y(self, ax: Axes) -> Iterator[Text]: major = ax.yaxis.get_major_ticks() minor = ax.yaxis.get_minor_ticks() + label_attr = side_artists(side)[1] return ( - tick.label1 + getattr(tick, label_attr) for tick in chain(major, minor) - if _text_is_visible(tick.label1) + if _text_is_visible(getattr(tick, label_attr)) ) def axis_ticks_x(self, ax: Axes) -> Iterator[Tick]: @@ -238,66 +245,114 @@ def strip_text_y_extra_width(self, position: StripPosition) -> float: return max(widths) - def axis_ticks_x_max_height_at(self, location: AxesLocation) -> float: + def axis_ticks_x_max_height_at( + self, location: AxesLocation, side: str + ) -> float: """ - Return maximum height[figure space] of x ticks + Return maximum height[figure space] of visible x ticks on a side """ + attr = side_artists(side)[0] heights = [ - self.geometry.tight_height(tick.tick1line) + self.geometry.tight_height(getattr(tick, attr)) for ax in self._filter_axes(location) for tick in self.axis_ticks_x(ax) + if getattr(tick, attr).get_visible() ] return max(heights) if len(heights) else 0 - def axis_text_x_max_height(self, ax: Axes) -> float: + def axis_text_x_max_height(self, ax: Axes, side: str) -> float: """ - Return maximum height[figure space] of x tick labels + Return maximum height[figure space] of x tick labels on a side """ heights = [ - self.geometry.tight_height(label) for label in self.axis_text_x(ax) + self.geometry.tight_height(label) + for label in self.axis_text_x(ax, side) ] return max(heights) if len(heights) else 0 - def axis_text_x_max_height_at(self, location: AxesLocation) -> float: + def axis_text_x_max_height_at( + self, location: AxesLocation, side: str + ) -> float: """ - Return maximum height[figure space] of x tick labels + Return maximum height[figure space] of x tick labels on a side """ heights = [ - self.axis_text_x_max_height(ax) + self.axis_text_x_max_height(ax, side) for ax in self._filter_axes(location) ] return max(heights) if len(heights) else 0 - def axis_ticks_y_max_width_at(self, location: AxesLocation) -> float: + def axis_ticks_y_max_width_at( + self, location: AxesLocation, side: str + ) -> float: """ - Return maximum width[figure space] of y ticks + Return maximum width[figure space] of visible y ticks on a side """ + attr = side_artists(side)[0] widths = [ - self.geometry.tight_width(tick.tick1line) + self.geometry.tight_width(getattr(tick, attr)) for ax in self._filter_axes(location) for tick in self.axis_ticks_y(ax) + if getattr(tick, attr).get_visible() ] return max(widths) if len(widths) else 0 - def axis_text_y_max_width(self, ax: Axes) -> float: + def axis_text_y_max_width(self, ax: Axes, side: str) -> float: """ - Return maximum width[figure space] of y tick labels + Return maximum width[figure space] of y tick labels on a side """ widths = [ - self.geometry.tight_width(label) for label in self.axis_text_y(ax) + self.geometry.tight_width(label) + for label in self.axis_text_y(ax, side) ] return max(widths) if len(widths) else 0 - def axis_text_y_max_width_at(self, location: AxesLocation) -> float: + def axis_text_y_max_width_at( + self, location: AxesLocation, side: str + ) -> float: """ - Return maximum width[figure space] of y tick labels + Return maximum width[figure space] of y tick labels on a side """ widths = [ - self.axis_text_y_max_width(ax) + self.axis_text_y_max_width(ax, side) for ax in self._filter_axes(location) ] return max(widths) if len(widths) else 0 + # Side-scoped extents — each names a concrete edge (side picks the + # artist, location picks the panels) and reads 0 when no axis is there. + @property + def axis_text_x_bottom(self) -> float: + return self.axis_text_x_max_height_at("last_row", "bottom") + + @property + def axis_text_x_top(self) -> float: + return self.axis_text_x_max_height_at("first_row", "top") + + @property + def axis_text_y_left(self) -> float: + return self.axis_text_y_max_width_at("first_col", "left") + + @property + def axis_text_y_right(self) -> float: + return self.axis_text_y_max_width_at("last_col", "right") + + @property + def axis_ticks_x_bottom(self) -> float: + return self.axis_ticks_x_max_height_at("last_row", "bottom") + + @property + def axis_ticks_x_top(self) -> float: + return self.axis_ticks_x_max_height_at("first_row", "top") + + @property + def axis_ticks_y_left(self) -> float: + return self.axis_ticks_y_max_width_at("first_col", "left") + + @property + def axis_ticks_y_right(self) -> float: + return self.axis_ticks_y_max_width_at("last_col", "right") + def axis_text_y_top_protrusion(self, location: AxesLocation) -> float: """ Return maximum height[figure space] above the axes of y tick labels @@ -305,9 +360,10 @@ def axis_text_y_top_protrusion(self, location: AxesLocation) -> float: extras = [] for ax in self._filter_axes(location): ax_top_y = self.geometry.top_y(ax) - for label in self.axis_text_y(ax): - label_top_y = self.geometry.top_y(label) - extras.append(max(0, label_top_y - ax_top_y)) + for side in ("left", "right"): + for label in self.axis_text_y(ax, side): + label_top_y = self.geometry.top_y(label) + extras.append(max(0, label_top_y - ax_top_y)) return max(extras) if len(extras) else 0 @@ -318,10 +374,11 @@ def axis_text_y_bottom_protrusion(self, location: AxesLocation) -> float: extras = [] for ax in self._filter_axes(location): ax_bottom_y = self.geometry.bottom_y(ax) - for label in self.axis_text_y(ax): - label_bottom_y = self.geometry.bottom_y(label) - protrusion = abs(min(label_bottom_y - ax_bottom_y, 0)) - extras.append(protrusion) + for side in ("left", "right"): + for label in self.axis_text_y(ax, side): + label_bottom_y = self.geometry.bottom_y(label) + protrusion = abs(min(label_bottom_y - ax_bottom_y, 0)) + extras.append(protrusion) return max(extras) if len(extras) else 0 @@ -332,10 +389,11 @@ def axis_text_x_left_protrusion(self, location: AxesLocation) -> float: extras = [] for ax in self._filter_axes(location): ax_left_x = self.geometry.left_x(ax) - for label in self.axis_text_x(ax): - label_left_x = self.geometry.left_x(label) - protrusion = abs(min(label_left_x - ax_left_x, 0)) - extras.append(protrusion) + for side in ("bottom", "top"): + for label in self.axis_text_x(ax, side): + label_left_x = self.geometry.left_x(label) + protrusion = abs(min(label_left_x - ax_left_x, 0)) + extras.append(protrusion) return max(extras) if len(extras) else 0 @@ -346,9 +404,10 @@ def axis_text_x_right_protrusion(self, location: AxesLocation) -> float: extras = [] for ax in self._filter_axes(location): ax_right_x = self.geometry.right_x(ax) - for label in self.axis_text_x(ax): - label_right_x = self.geometry.right_x(label) - extras.append(max(0, label_right_x - ax_right_x)) + for side in ("bottom", "top"): + for label in self.axis_text_x(ax, side): + label_right_x = self.geometry.right_x(label) + extras.append(max(0, label_right_x - ax_right_x)) return max(extras) if len(extras) else 0 @@ -364,15 +423,25 @@ def _move_artists(self, spaces: PlotSideSpaces): if self.plot_tag: set_plot_tag_position(self.plot_tag, spaces) - if self.axis_title_x: - ha = theme.getp(("axis_title_x", "ha"), "center") - self.axis_title_x.set_y(spaces.b.y1("axis_title_x")) - justify.horizontally_about(self.axis_title_x, ha, "panel") + if self.axis_title_x_bottom: + ha = theme.getp(("axis_title_x_bottom", "ha"), "center") + self.axis_title_x_bottom.set_y(spaces.b.y1("axis_title_x")) + justify.horizontally_about(self.axis_title_x_bottom, ha, "panel") + + if self.axis_title_x_top: + ha = theme.getp(("axis_title_x_top", "ha"), "center") + self.axis_title_x_top.set_y(spaces.t.y1("axis_title_x")) + justify.horizontally_about(self.axis_title_x_top, ha, "panel") - if self.axis_title_y: - va = theme.getp(("axis_title_y", "va"), "center") - self.axis_title_y.set_x(spaces.l.x1("axis_title_y")) - justify.vertically_about(self.axis_title_y, va, "panel") + if self.axis_title_y_left: + va = theme.getp(("axis_title_y_left", "va"), "center") + self.axis_title_y_left.set_x(spaces.l.x1("axis_title_y")) + justify.vertically_about(self.axis_title_y_left, va, "panel") + + if self.axis_title_y_right: + va = theme.getp(("axis_title_y_right", "va"), "center") + self.axis_title_y_right.set_x(spaces.r.x1("axis_title_y")) + justify.vertically_about(self.axis_title_y_right, va, "panel") if self.legends: set_legends_position(self.legends, spaces) @@ -398,20 +467,30 @@ def to_vertical_axis_dimensions(value: float, ax: Axes) -> float: if self._is_blank("axis_text_x"): return - va = self.plot.theme.getp(("axis_text_x", "va"), "top") - - for ax in self.plot.axs: - texts = list(self.axis_text_x(ax)) - axis_text_row_height = to_vertical_axis_dimensions( - self.axis_text_x_max_height(ax), ax + for side in ("bottom", "top"): + va_default = "top" if side == "bottom" else "bottom" + va = self.plot.theme.getp( + (f"axis_text_x_{side}", "va"), va_default ) - for text in texts: - height = to_vertical_axis_dimensions( - self.geometry.tight_height(text), ax + for ax in self.plot.axs: + texts = list(self.axis_text_x(ax, side)) + if not texts: + continue + row_height = to_vertical_axis_dimensions( + self.axis_text_x_max_height(ax, side), ax ) - justify.vertically( - text, va, -axis_text_row_height, 0, height=height + # bottom labels sit below the panel (axes y 0), top labels + # above it (axes y 1) + low, high = ( + (-row_height, 0) + if side == "bottom" + else (1, 1 + row_height) ) + for text in texts: + height = to_vertical_axis_dimensions( + self.geometry.tight_height(text), ax + ) + justify.vertically(text, va, low, high, height=height) def _adjust_axis_text_y(self, justify: TextJustifier): """ @@ -451,20 +530,28 @@ def to_horizontal_axis_dimensions(value: float, ax: Axes) -> float: if self._is_blank("axis_text_y"): return - ha = self.plot.theme.getp(("axis_text_y", "ha"), "right") - - for ax in self.plot.axs: - texts = list(self.axis_text_y(ax)) - axis_text_col_width = to_horizontal_axis_dimensions( - self.axis_text_y_max_width(ax), ax + for side in ("left", "right"): + ha_default = "right" if side == "left" else "left" + ha = self.plot.theme.getp( + (f"axis_text_y_{side}", "ha"), ha_default ) - for text in texts: - width = to_horizontal_axis_dimensions( - self.geometry.tight_width(text), ax + for ax in self.plot.axs: + texts = list(self.axis_text_y(ax, side)) + if not texts: + continue + col_width = to_horizontal_axis_dimensions( + self.axis_text_y_max_width(ax, side), ax ) - justify.horizontally( - text, ha, -axis_text_col_width, 0, width=width + # left labels sit left of the panel (axes x 0), right labels + # to the right of it (axes x 1) + low, high = ( + (-col_width, 0) if side == "left" else (1, 1 + col_width) ) + for text in texts: + width = to_horizontal_axis_dimensions( + self.geometry.tight_width(text), ax + ) + justify.horizontally(text, ha, low, high, width=width) def _strip_text_x_background_equal_heights(self): """ diff --git a/plotnine/_mpl/layout_manager/_plot_side_space.py b/plotnine/_mpl/layout_manager/_plot_side_space.py index 01048298a4..0e332ff616 100644 --- a/plotnine/_mpl/layout_manager/_plot_side_space.py +++ b/plotnine/_mpl/layout_manager/_plot_side_space.py @@ -15,6 +15,7 @@ from functools import cached_property from typing import TYPE_CHECKING +from plotnine._utils import MARGIN_SIDE from plotnine.exceptions import PlotnineError from plotnine.facets import facet_grid, facet_null, facet_wrap @@ -196,7 +197,6 @@ class left_space(_plot_side_space): """ legend: float = 0 legend_box_spacing: float = 0 - axis_title_y_margin_left: float = 0 axis_title_y: float = 0 axis_title_y_margin_right: float = 0 axis_title_alignment: float = 0 @@ -207,7 +207,6 @@ class left_space(_plot_side_space): the difference between the largest and smallest axis_title_clearance among the items in the composition. """ - axis_text_y_margin_left: float = 0 axis_text_y: float = 0 axis_text_y_margin_right: float = 0 axis_ticks_y: float = 0 @@ -229,20 +228,23 @@ def _calculate(self): self.legend = self.legend_width self.legend_box_spacing = theme.getp("legend_box_spacing") - if items.axis_title_y: - m = theme.get_margin("axis_title_y").fig - self.axis_title_y_margin_left = m.l - self.axis_title_y = geometry.width(items.axis_title_y) - self.axis_title_y_margin_right = m.r + # The text<->panel gap is the right margin of the y text/title; it + # sits on the panel-facing (right) side of the left axis. + if items.axis_title_y_left: + self.axis_title_y = geometry.width(items.axis_title_y_left) + self.axis_title_y_margin_right = getattr( + theme.get_margin("axis_title_y_left").fig, + MARGIN_SIDE["left"], + ) - # Account for the space consumed by the axis - self.axis_text_y = items.axis_text_y_max_width_at("first_col") + self.axis_text_y = items.axis_text_y_left if self.axis_text_y: - m = theme.get_margin("axis_text_y").fig - self.axis_text_y_margin_left = m.l - self.axis_text_y_margin_right = m.r + self.axis_text_y_margin_right = getattr( + theme.get_margin("axis_text_y_left").fig, + MARGIN_SIDE["left"], + ) - self.axis_ticks_y = items.axis_ticks_y_max_width_at("first_col") + self.axis_ticks_y = items.axis_ticks_y_left # Adjust plot_margin to make room for ylabels that protude well # beyond the axes @@ -329,6 +331,12 @@ class right_space(_plot_side_space): legend: float = 0 legend_box_spacing: float = 0 strip_text_y_extra_width: float = 0 + axis_title_y: float = 0 + axis_title_y_margin_left: float = 0 + axis_title_alignment: float = 0 + axis_text_y: float = 0 + axis_text_y_margin_left: float = 0 + axis_ticks_y: float = 0 def _calculate(self): items = self.items @@ -349,6 +357,24 @@ def _calculate(self): self.strip_text_y_extra_width = items.strip_text_y_extra_width("right") + # Space consumed by a y-axis on the right. The text<->panel gap is the + # left margin of the y text/title (the edge facing the panel to the + # left). + if items.axis_title_y_right: + self.axis_title_y = geometry.width(items.axis_title_y_right) + self.axis_title_y_margin_left = getattr( + theme.get_margin("axis_title_y_right").fig, + MARGIN_SIDE["right"], + ) + + self.axis_text_y = items.axis_text_y_right + if self.axis_text_y: + self.axis_text_y_margin_left = getattr( + theme.get_margin("axis_text_y_right").fig, + MARGIN_SIDE["right"], + ) + self.axis_ticks_y = items.axis_ticks_y_right + # Adjust plot_margin to make room for ylabels that protude well # beyond the axes # NOTE: This adjustment breaks down when the protrusion is large @@ -440,6 +466,12 @@ class top_space(_plot_side_space): legend: float = 0 legend_box_spacing: float = 0 strip_text_x_extra_height: float = 0 + axis_title_x: float = 0 + axis_title_x_margin_bottom: float = 0 + axis_title_alignment: float = 0 + axis_text_x: float = 0 + axis_text_x_margin_bottom: float = 0 + axis_ticks_x: float = 0 def _calculate(self): items = self.items @@ -474,6 +506,23 @@ def _calculate(self): self.strip_text_x_extra_height = items.strip_text_x_extra_height("top") + # Space consumed by an x-axis on the top. The text<->panel gap is the + # bottom margin of the x text/title (the edge facing the panel below). + if items.axis_title_x_top: + self.axis_title_x = geometry.height(items.axis_title_x_top) + self.axis_title_x_margin_bottom = getattr( + theme.get_margin("axis_title_x_top").fig, + MARGIN_SIDE["top"], + ) + + self.axis_text_x = items.axis_text_x_top + if self.axis_text_x: + self.axis_text_x_margin_bottom = getattr( + theme.get_margin("axis_text_x_top").fig, + MARGIN_SIDE["top"], + ) + self.axis_ticks_x = items.axis_ticks_x_top + # Adjust plot_margin to make room for ylabels that protude well # beyond the axes # NOTE: This adjustment breaks down when the protrusion is large @@ -567,7 +616,6 @@ class bottom_space(_plot_side_space): plot_caption_margin_top: float = 0 legend: float = 0 legend_box_spacing: float = 0 - axis_title_x_margin_bottom: float = 0 axis_title_x: float = 0 axis_title_x_margin_top: float = 0 axis_title_alignment: float = 0 @@ -579,7 +627,6 @@ class bottom_space(_plot_side_space): composition. It's amount is the difference in height between this axis text (and it's margins) and the tallest axis text (and it's margin). """ - axis_text_x_margin_bottom: float = 0 axis_text_x: float = 0 axis_text_x_margin_top: float = 0 axis_ticks_x: float = 0 @@ -615,19 +662,22 @@ def _calculate(self): self.legend = self.legend_height self.legend_box_spacing = theme.getp("legend_box_spacing") * F - if items.axis_title_x: - m = theme.get_margin("axis_title_x").fig - self.axis_title_x_margin_bottom = m.b - self.axis_title_x = geometry.height(items.axis_title_x) - self.axis_title_x_margin_top = m.t + # The text<->panel gap is the top margin of the x text/title; it + # sits on the panel-facing (top) side of the bottom axis. + if items.axis_title_x_bottom: + self.axis_title_x = geometry.height(items.axis_title_x_bottom) + self.axis_title_x_margin_top = getattr( + theme.get_margin("axis_title_x_bottom").fig, + MARGIN_SIDE["bottom"], + ) - # Account for the space consumed by the axis - self.axis_text_x = items.axis_text_x_max_height_at("last_row") + self.axis_text_x = items.axis_text_x_bottom if self.axis_text_x: - m = theme.get_margin("axis_text_x").fig - self.axis_text_x_margin_bottom = m.b - self.axis_text_x_margin_top = m.t - self.axis_ticks_x = items.axis_ticks_x_max_height_at("last_row") + self.axis_text_x_margin_top = getattr( + theme.get_margin("axis_text_x_bottom").fig, + MARGIN_SIDE["bottom"], + ) + self.axis_ticks_x = items.axis_ticks_x_bottom # Adjust plot_margin to make room for ylabels that protude well # beyond the axes @@ -1082,13 +1132,15 @@ def _calculate_panel_spacing_facet_wrap(self) -> tuple[float, float]: self.sh += self.t.strip_text_x_extra_height * (1 + strip_align_x) if facet.free["x"]: - self.sh += self.items.axis_text_x_max_height_at( - "all" - ) + self.items.axis_ticks_x_max_height_at("all") + for side in ("bottom", "top"): + self.sh += self.items.axis_text_x_max_height_at( + "all", side + ) + self.items.axis_ticks_x_max_height_at("all", side) if facet.free["y"]: - self.sw += self.items.axis_text_y_max_width_at( - "all" - ) + self.items.axis_ticks_y_max_width_at("all") + for side in ("left", "right"): + self.sw += self.items.axis_text_y_max_width_at( + "all", side + ) + self.items.axis_ticks_y_max_width_at("all", side) # width and height of axes as fraction of figure width & height self.w = (self.panel_width - self.sw * (ncol - 1)) / ncol diff --git a/plotnine/_utils/__init__.py b/plotnine/_utils/__init__.py index 4c95504afc..cc83ee5484 100644 --- a/plotnine/_utils/__init__.py +++ b/plotnine/_utils/__init__.py @@ -61,6 +61,38 @@ to_rgba = color_utils.to_rgba +def side_artists(side: str) -> tuple[str, str]: + """ + Return the `(tickline, label)` tick-attribute names for an axis side + + The bottom/left side maps to `tick1line`/`label1` and the top/right side + to `tick2line`/`label2`. + """ + if side in ("top", "right"): + return ("tick2line", "label2") + return ("tick1line", "label1") + + +# The side opposite each axis side +OPPOSITE_SIDE: dict[Side, Side] = { + "top": "bottom", + "bottom": "top", + "left": "right", + "right": "left", +} + +# The margin side that faces inward for an element on each side: for an axis +# the side facing the panel (bottom axis -> top "t", top -> "b", left -> right +# "r", right -> "l"); for a legend title/text the side facing the keys. It is +# the initial of the opposite side; cf. OPPOSITE_SIDE. +MARGIN_SIDE: dict[Side, str] = { + "bottom": "t", + "top": "b", + "left": "r", + "right": "l", +} + + def is_scalar(val): """ Return whether the given object is a scalar @@ -1132,19 +1164,6 @@ def default_field(default: T) -> T: return field(default_factory=lambda: deepcopy(default)) -def get_opposite_side(s: Side) -> Side: - """ - Return the opposite side - """ - lookup: dict[Side, Side] = { - "right": "left", - "left": "right", - "top": "bottom", - "bottom": "top", - } - return lookup[s] - - def ensure_xy_location( loc: Side | Literal["center"] | float | tuple[float, float], ) -> tuple[float, float]: diff --git a/plotnine/coords/coord.py b/plotnine/coords/coord.py index 8dcbc378fa..0e530b67d2 100644 --- a/plotnine/coords/coord.py +++ b/plotnine/coords/coord.py @@ -5,6 +5,7 @@ import numpy as np +from .._utils import OPPOSITE_SIDE from ..iapi import panel_ranges if typing.TYPE_CHECKING: @@ -12,14 +13,34 @@ import numpy.typing as npt import pandas as pd + from matplotlib.axes import Axes from plotnine import ggplot - from plotnine.iapi import labels_view, panel_view - from plotnine.scales.scale import scale + from plotnine.iapi import labels_view, layout_details, panel_view + from plotnine.scales.scale_xy import ScaleX, ScaleY from plotnine.typing import ( FloatArray, FloatArrayLike, FloatSeries, + Side, + ) + + +def _activate_axis(axis, active_side: Side, present: bool): + """ + Show ticks and labels on the active side only; hide the opposite side + + `present` is False on interior facet panels, which hides both sides. + """ + opposite = OPPOSITE_SIDE[active_side] + axis.set_tick_params( + which="both", + **{ + active_side: present, + f"label{active_side}": present, + opposite: False, + f"label{opposite}": False, + }, ) @@ -104,6 +125,77 @@ def aspect(self, panel_params: panel_view) -> float | None: """ return None + def setup_ax( + self, + ax: Axes, + panel_params: panel_view, + layout_info: layout_details, + ) -> None: + """ + Axes state for one panel: limits, breaks, labels, and active side + + Subclasses can override this or call `super().setup_ax(...)` and add + coordinate-specific behavior. Configures only mpl axes state; the + theme styles the visible artists afterwards. + """ + self._setup_ticks_labels(ax, panel_params) + self._setup_axis_sides(ax, panel_params, layout_info) + + def _setup_ticks_labels(self, ax: Axes, panel_params: panel_view) -> None: + """ + Limits, major/minor breaks, tick labels, and fixed formatter on `ax` + """ + from .._mpl.ticker import MyFixedFormatter + + def _inf_to_none( + t: tuple[float, float], + ) -> tuple[float | None, float | None]: + """ + Replace infinities with None + """ + a = t[0] if np.isfinite(t[0]) else None + b = t[1] if np.isfinite(t[1]) else None + return (a, b) + + # limits + ax.set_xlim(*_inf_to_none(panel_params.x.range)) + ax.set_ylim(*_inf_to_none(panel_params.y.range)) + + # breaks, labels + ax.set_xticks(panel_params.x.breaks, panel_params.x.labels) + ax.set_yticks(panel_params.y.breaks, panel_params.y.labels) + + # minor breaks + ax.set_xticks(panel_params.x.minor_breaks, minor=True) + ax.set_yticks(panel_params.y.minor_breaks, minor=True) + + # When you manually set the tick labels MPL changes the locator + # so that it no longer reports the x & y positions + # Fixes https://github.com/has2k1/plotnine/issues/187 + ax.xaxis.set_major_formatter(MyFixedFormatter(panel_params.x.labels)) + ax.yaxis.set_major_formatter(MyFixedFormatter(panel_params.y.labels)) + + def _setup_axis_sides( + self, + ax: Axes, + panel_params: panel_view, + layout_info: layout_details, + ) -> None: + """ + Tick visibility and spine visibility for the side each axis occupies + """ + x_pos = panel_params.x.position # "bottom" | "top" + y_pos = panel_params.y.position # "left" | "right" + _activate_axis(ax.xaxis, x_pos, layout_info.axis_x) + _activate_axis(ax.yaxis, y_pos, layout_info.axis_y) + + # Spine on each axis's active side, on every panel (edge or + # interior); the axis_line themeable styles or blanks it. + ax.spines["top"].set_visible(x_pos == "top") + ax.spines["bottom"].set_visible(x_pos == "bottom") + ax.spines["right"].set_visible(y_pos == "right") + ax.spines["left"].set_visible(y_pos == "left") + def labels(self, cur_labels: labels_view) -> labels_view: """ Modify labels @@ -131,7 +223,9 @@ def transform( """ return data - def setup_panel_params(self, scale_x: scale, scale_y: scale) -> panel_view: + def setup_panel_params( + self, scale_x: ScaleX, scale_y: ScaleY + ) -> panel_view: """ Compute the range and break information for the panel """ diff --git a/plotnine/coords/coord_cartesian.py b/plotnine/coords/coord_cartesian.py index 54730ff2f2..b0338bac5b 100644 --- a/plotnine/coords/coord_cartesian.py +++ b/plotnine/coords/coord_cartesian.py @@ -12,7 +12,7 @@ import pandas as pd - from plotnine.iapi import scale_view + from plotnine.iapi import scale_position_view from plotnine.scales.scale import scale from plotnine.typing import ( FloatArray, @@ -71,7 +71,7 @@ def setup_panel_params(self, scale_x: scale, scale_y: scale) -> panel_view: def get_scale_view( scale: scale, limits: tuple[Any, Any] - ) -> scale_view: + ) -> scale_position_view: coord_limits = ( scale.transform(limits) if limits and isinstance(scale, scale_continuous) @@ -82,7 +82,9 @@ def get_scale_view( scale.final_limits, expansion, coord_limits, identity_trans() ) sv = scale.view(limits=coord_limits, range=ranges.range) - return sv + # x/y scales are always position scales, so the view is a + # scale_position_view + return typing.cast("scale_position_view", sv) out = panel_view( x=get_scale_view(scale_x, self.limits.x), diff --git a/plotnine/coords/coord_flip.py b/plotnine/coords/coord_flip.py index ea409aaf59..4642a2f3da 100644 --- a/plotnine/coords/coord_flip.py +++ b/plotnine/coords/coord_flip.py @@ -11,12 +11,21 @@ from typing import Sequence, TypeVar from plotnine.scales.scale import scale + from plotnine.typing import Side THasLabels = TypeVar( "THasLabels", bound=pd.DataFrame | labels_view | panel_view ) +_FLIP_POSITION: dict[Side, Side] = { + "top": "right", + "bottom": "left", + "left": "bottom", + "right": "top", +} + + class coord_flip(coord_cartesian): """ Flipped cartesian coordinates @@ -47,7 +56,12 @@ def transform( def setup_panel_params(self, scale_x: scale, scale_y: scale) -> panel_view: panel_params = super().setup_panel_params(scale_x, scale_y) - return flip_labels(panel_params) + panel_params = flip_labels(panel_params) + # The axis position rotates with the flip (matches ggplot2's + # scale_flip_axis): top->right, bottom->left, left->bottom, right->top + panel_params.x.position = _FLIP_POSITION[panel_params.x.position] + panel_params.y.position = _FLIP_POSITION[panel_params.y.position] + return panel_params def setup_layout(self, layout: pd.DataFrame) -> pd.DataFrame: # switch the scales diff --git a/plotnine/coords/coord_trans.py b/plotnine/coords/coord_trans.py index b70c68d0c7..98ac8339f8 100644 --- a/plotnine/coords/coord_trans.py +++ b/plotnine/coords/coord_trans.py @@ -15,8 +15,8 @@ import pandas as pd from mizani.transforms import trans - from plotnine.iapi import scale_view - from plotnine.scales.scale import scale + from plotnine.iapi import scale_position_view + from plotnine.scales.scale_xy import ScaleX, ScaleY from plotnine.typing import ( FloatArray, FloatSeries, @@ -98,19 +98,22 @@ def backtransform_range(self, panel_params: panel_view) -> panel_ranges: y=self.trans_y.inverse(panel_params.y.range), ) - def setup_panel_params(self, scale_x: scale, scale_y: scale) -> panel_view: + def setup_panel_params(self, scale_x, scale_y) -> panel_view: """ Compute the range and break information for the panel """ def get_scale_view( - scale: scale, limits: tuple[float, float], trans: trans - ) -> scale_view: + scale: ScaleX | ScaleY, limits: tuple[float, float], trans: trans + ) -> scale_position_view: coord_limits = trans.transform(limits) if limits else limits expansion = scale.default_expansion(expand=self.expand) ranges = scale.expand_limits( - scale.final_limits, expansion, coord_limits, trans + scale.final_limits, # pyright: ignore[reportArgumentType] + expansion, + coord_limits, + trans, ) sv = scale.view( limits=coord_limits, diff --git a/plotnine/facets/facet.py b/plotnine/facets/facet.py index 7f2bb0b34e..99b107d3d4 100644 --- a/plotnine/facets/facet.py +++ b/plotnine/facets/facet.py @@ -26,7 +26,7 @@ from plotnine.coords.coord import coord from plotnine.facets.labelling import CanBeStripLabellingFunc from plotnine.facets.layout import Layout - from plotnine.iapi import layout_details, panel_view + from plotnine.iapi import layout_details from plotnine.layer import Layers from plotnine.mapping import Environment from plotnine.scales.scale import scale @@ -226,6 +226,7 @@ def map(self, data: pd.DataFrame, layout: pd.DataFrame) -> pd.DataFrame: def compute_layout( self, data: list[pd.DataFrame], + scales: Scales, ) -> pd.DataFrame: """ Compute layout @@ -234,6 +235,8 @@ def compute_layout( ---------- data : Dataframe for a each layer + scales : + The plot's scales """ msg = "{} should implement this method." raise NotImplementedError(msg.format(self.__class__.__name__)) @@ -303,62 +306,6 @@ def make_strips(self, layout_info: layout_details, ax: Axes) -> Strips: """ return Strips() - def set_limits_breaks_and_labels(self, panel_params: panel_view, ax: Axes): - """ - Add limits, breaks and labels to the axes - - Parameters - ---------- - panel_params : - range information for the axes - ax : - Axes - """ - from .._mpl.ticker import MyFixedFormatter - - def _inf_to_none( - t: tuple[float, float], - ) -> tuple[float | None, float | None]: - """ - Replace infinities with None - """ - a = t[0] if np.isfinite(t[0]) else None - b = t[1] if np.isfinite(t[1]) else None - return (a, b) - - theme = self.theme - - # limits - ax.set_xlim(*_inf_to_none(panel_params.x.range)) - ax.set_ylim(*_inf_to_none(panel_params.y.range)) - - if typing.TYPE_CHECKING: - assert callable(ax.set_xticks) - assert callable(ax.set_yticks) - - # breaks, labels - ax.set_xticks(panel_params.x.breaks, panel_params.x.labels) - ax.set_yticks(panel_params.y.breaks, panel_params.y.labels) - - # minor breaks - ax.set_xticks(panel_params.x.minor_breaks, minor=True) - ax.set_yticks(panel_params.y.minor_breaks, minor=True) - - # When you manually set the tick labels MPL changes the locator - # so that it no longer reports the x & y positions - # Fixes https://github.com/has2k1/plotnine/issues/187 - ax.xaxis.set_major_formatter(MyFixedFormatter(panel_params.x.labels)) - ax.yaxis.set_major_formatter(MyFixedFormatter(panel_params.y.labels)) - - # Blank axis text is not drawn, so its margin may be absent - # (resolves to None). Skip the tick-label padding in that case. - if not theme.T.is_blank("axis_text_x"): - pad_x = theme.get_margin("axis_text_x").pt.t - ax.tick_params(axis="x", which="major", pad=pad_x) - if not theme.T.is_blank("axis_text_y"): - pad_y = theme.get_margin("axis_text_y").pt.r - ax.tick_params(axis="y", which="major", pad=pad_y) - def __deepcopy__(self, memo: dict[Any, Any]) -> facet: """ Deep copy without copying the dataframe and environment diff --git a/plotnine/facets/facet_grid.py b/plotnine/facets/facet_grid.py index cce4c95dca..07edda6f35 100644 --- a/plotnine/facets/facet_grid.py +++ b/plotnine/facets/facet_grid.py @@ -24,6 +24,8 @@ from plotnine.iapi import layout_details from plotnine.typing import FacetSpaceRatios + from ..scales.scales import Scales + class facet_grid(facet): """ @@ -165,7 +167,11 @@ def _make_gridspec(self): **ratios, ) - def compute_layout(self, data: list[pd.DataFrame]) -> pd.DataFrame: + def compute_layout( + self, + data: list[pd.DataFrame], + scales: Scales, + ) -> pd.DataFrame: if not self.rows and not self.cols: self.nrow, self.ncol = 1, 1 return layout_null() @@ -215,8 +221,15 @@ def compute_layout(self, data: list[pd.DataFrame]) -> pd.DataFrame: # Relax constraints, if necessary layout["SCALE_X"] = layout["COL"] if self.free["x"] else 1 layout["SCALE_Y"] = layout["ROW"] if self.free["y"] else 1 - layout["AXIS_X"] = layout["ROW"] == layout["ROW"].max() - layout["AXIS_Y"] = layout["COL"] == layout["COL"].min() + x_side, y_side = scales.axis_positions + if x_side == "top": + layout["AXIS_X"] = layout["ROW"] == layout["ROW"].min() + else: + layout["AXIS_X"] = layout["ROW"] == layout["ROW"].max() + if y_side == "right": + layout["AXIS_Y"] = layout["COL"] == layout["COL"].max() + else: + layout["AXIS_Y"] = layout["COL"] == layout["COL"].min() self.nrow = layout["ROW"].max() self.ncol = layout["COL"].max() diff --git a/plotnine/facets/facet_null.py b/plotnine/facets/facet_null.py index 8620d39559..c877effd29 100644 --- a/plotnine/facets/facet_null.py +++ b/plotnine/facets/facet_null.py @@ -7,6 +7,8 @@ if typing.TYPE_CHECKING: import pandas as pd + from ..scales.scales import Scales + class facet_null(facet): """ @@ -31,5 +33,6 @@ def map(self, data: pd.DataFrame, layout: pd.DataFrame) -> pd.DataFrame: def compute_layout( self, data: list[pd.DataFrame], + scales: Scales, ) -> pd.DataFrame: return layout_null() diff --git a/plotnine/facets/facet_wrap.py b/plotnine/facets/facet_wrap.py index cdf72d1c9d..9f5b5c9d9b 100644 --- a/plotnine/facets/facet_wrap.py +++ b/plotnine/facets/facet_wrap.py @@ -25,6 +25,8 @@ from plotnine.iapi import layout_details + from ..scales.scales import Scales + class facet_wrap(facet): """ @@ -93,6 +95,7 @@ def __init__( def compute_layout( self, data: list[pd.DataFrame], + scales: Scales, ) -> pd.DataFrame: if not self.vars: self.nrow, self.ncol = 1, 1 @@ -133,10 +136,17 @@ def compute_layout( layout["SCALE_Y"] = range(1, n + 1) if self.free["y"] else 1 # Figure out where axes should go. - # The bottom-most row of each column and the left most - # column of each row - x_idx = [df["ROW"].idxmax() for _, df in layout.groupby("COL")] - y_idx = [df["COL"].idxmin() for _, df in layout.groupby("ROW")] + # The row/column of each panel that shows the axis, on the side the + # axis sits (default: bottom-most row, left-most column) + x_side, y_side = scales.axis_positions + if x_side == "top": + x_idx = [df["ROW"].idxmin() for _, df in layout.groupby("COL")] + else: + x_idx = [df["ROW"].idxmax() for _, df in layout.groupby("COL")] + if y_side == "right": + y_idx = [df["COL"].idxmax() for _, df in layout.groupby("ROW")] + else: + y_idx = [df["COL"].idxmin() for _, df in layout.groupby("ROW")] layout["AXIS_X"] = False layout["AXIS_Y"] = False _loc = layout.columns.get_loc diff --git a/plotnine/facets/layout.py b/plotnine/facets/layout.py index 58b9f3bd5f..4a46f0773c 100644 --- a/plotnine/facets/layout.py +++ b/plotnine/facets/layout.py @@ -74,7 +74,7 @@ def setup(self, layers: Layers, plot: ggplot): # Generate panel layout data = self.facet.setup_data(data) - self.layout = self.facet.compute_layout(data) + self.layout = self.facet.compute_layout(data, plot.scales) self.layout = self.coord.setup_layout(self.layout) self.check_layout() diff --git a/plotnine/ggplot.py b/plotnine/ggplot.py index 7eb70a1033..ce586b0027 100755 --- a/plotnine/ggplot.py +++ b/plotnine/ggplot.py @@ -561,22 +561,7 @@ def _draw_breaks_and_labels(self): pidx = layout_info.panel_index ax = self.axs[pidx] panel_params = self.layout.panel_params[pidx] - self.facet.set_limits_breaks_and_labels(panel_params, ax) - - # Remove unnecessary ticks and labels - if not layout_info.axis_x: - ax.xaxis.set_tick_params( - which="both", bottom=False, labelbottom=False - ) - if not layout_info.axis_y: - ax.yaxis.set_tick_params( - which="both", left=False, labelleft=False - ) - - if layout_info.axis_x: - ax.xaxis.set_tick_params(which="both", bottom=True) - if layout_info.axis_y: - ax.yaxis.set_tick_params(which="both", left=True) + self.coordinates.setup_ax(ax, panel_params, layout_info) def _draw_figure_texts(self): """ @@ -608,11 +593,19 @@ def _draw_figure_texts(self): self.layout.set_xy_labels(self.labels) ) + # The axis title is registered under a per-side target named for the + # axis position. The legacy axis_title_x/_y references point at the + # same artist so existing layout/theme code keeps working. + pp = self.layout.panel_params[0] if labels.x: - targets.axis_title_x = self.figure.add_artist(Text(text=labels.x)) + t = self.figure.add_artist(Text(text=labels.x)) + targets.axis_title_x = t + setattr(targets, f"axis_title_x_{pp.x.position}", t) if labels.y: - targets.axis_title_y = self.figure.add_artist(Text(text=labels.y)) + t = self.figure.add_artist(Text(text=labels.y)) + targets.axis_title_y = t + setattr(targets, f"axis_title_y_{pp.y.position}", t) def _draw_watermarks(self): """ diff --git a/plotnine/guides/guide.py b/plotnine/guides/guide.py index e2cc2d8e6a..d64764ba77 100644 --- a/plotnine/guides/guide.py +++ b/plotnine/guides/guide.py @@ -6,7 +6,7 @@ from types import SimpleNamespace as NS from typing import TYPE_CHECKING, cast -from .._utils import ensure_xy_location, get_opposite_side +from .._utils import MARGIN_SIDE, ensure_xy_location from .._utils.registry import Register from ..themes.theme import theme as Theme @@ -243,8 +243,7 @@ def title(self): ha = self.theme.getp(("legend_title", "ha")) va = self.theme.getp(("legend_title", "va"), "center") _margin = self.theme.getp(("legend_title", "margin")).pt - _loc = get_opposite_side(self.title_position)[0] - margin = getattr(_margin, _loc) + margin = getattr(_margin, MARGIN_SIDE[self.title_position]) top_or_bottom = self.title_position in ("top", "bottom") is_blank = self.theme.T.is_blank("legend_title") @@ -272,8 +271,7 @@ def _text_margin(self) -> Sequence[float]: _margin = self.theme.getp( (f"legend_text_{self.guide_kind}", "margin") ).pt - locs = (get_opposite_side(p)[0] for p in self.text_positions) - return [getattr(_margin, loc) for loc in locs] + return [getattr(_margin, MARGIN_SIDE[p]) for p in self.text_positions] @cached_property def title_position(self) -> Side: diff --git a/plotnine/guides/guide_colorbar.py b/plotnine/guides/guide_colorbar.py index 29a5234d7c..7940367ea5 100644 --- a/plotnine/guides/guide_colorbar.py +++ b/plotnine/guides/guide_colorbar.py @@ -13,7 +13,7 @@ from plotnine.iapi import guide_text -from .._utils import get_opposite_side +from .._utils import OPPOSITE_SIDE from ..exceptions import PlotnineError, PlotnineWarning from ..mapping.aes import rename_aesthetics from ..scales.scale_continuous import scale_continuous @@ -514,7 +514,7 @@ def text(self): centers = ("center",) * n has = (ha,) * n if isinstance(ha, str) else ha vas = (va,) * n if isinstance(va, str) else va - opposite_sides = [get_opposite_side(s) for s in self.text_positions] + opposite_sides = [OPPOSITE_SIDE[s] for s in self.text_positions] if self.is_vertical: has = has or opposite_sides vas = vas or centers diff --git a/plotnine/iapi.py b/plotnine/iapi.py index ae404bc7cf..d5a8cb1010 100644 --- a/plotnine/iapi.py +++ b/plotnine/iapi.py @@ -53,6 +53,15 @@ class scale_view: labels: Sequence[str] +@dataclass +class scale_position_view(scale_view): + """ + Trained position scale information, including the axis side + """ + + position: Side + + @dataclass class range_view: """ @@ -142,8 +151,8 @@ class panel_view: Information from the trained position scales in a panel """ - x: scale_view - y: scale_view + x: scale_position_view + y: scale_position_view @dataclass diff --git a/plotnine/scales/scale_xy.py b/plotnine/scales/scale_xy.py index 2fbdc08022..d7b044096f 100644 --- a/plotnine/scales/scale_xy.py +++ b/plotnine/scales/scale_xy.py @@ -10,7 +10,7 @@ from .._utils import array_kind, match from .._utils.registry import alias from ..exceptions import PlotnineError -from ..iapi import range_view +from ..iapi import range_view, scale_position_view from ._expand import expand_range from ._runtime_typing import TransUser # noqa: TCH001 from .range import RangeContinuous @@ -19,10 +19,43 @@ from .scale_discrete import scale_discrete if TYPE_CHECKING: - from typing import Sequence + from typing import Literal, Sequence, TypeAlias from mizani.transforms import trans + ScaleX: TypeAlias = "scale_x_continuous | scale_x_discrete" + ScaleY: TypeAlias = "scale_y_continuous | scale_y_discrete" + + +# Valid axis sides per position aesthetic +AXIS_SIDES = {"x": ("bottom", "top"), "y": ("left", "right")} + + +class scale_position: + """ + Mixin for position scales — owns the axis side behavior + + `position`, `_aesthetics` and `__post_init__` come from the concrete + position scale this is mixed into. + """ + + def __post_init__(self): + super().__post_init__() # pyright: ignore[reportAttributeAccessIssue] + aesthetic = self._aesthetics[0] # pyright: ignore[reportAttributeAccessIssue] + sides = AXIS_SIDES[aesthetic] + if self.position not in sides: # pyright: ignore[reportAttributeAccessIssue] + raise PlotnineError( + f"Invalid position {self.position!r} for the " # pyright: ignore[reportAttributeAccessIssue] + f"{aesthetic!r} axis. Expected one of {sides}." + ) + + def view(self, limits=None, range=None) -> scale_position_view: + """ + Information about the trained scale, including the axis side + """ + sv = super().view(limits=limits, range=range) # pyright: ignore[reportAttributeAccessIssue] + return scale_position_view(**vars(sv), position=self.position) # pyright: ignore[reportAttributeAccessIssue] + # positions scales have a couple of differences (quirks) that # make necessary to override some of the scale_discrete and @@ -32,7 +65,7 @@ # are intermediate base classes where the required overriding # is done @dataclass(kw_only=True) -class scale_position_discrete(scale_discrete[None]): +class scale_position_discrete(scale_position, scale_discrete[None]): """ Base class for discrete position scales """ @@ -41,7 +74,7 @@ class scale_position_discrete(scale_discrete[None]): guide: None = None def __post_init__(self): - super().__post_init__() + super().__post_init__() # scale_position validates first # Keeps two ranges, range and range_c self._range_c = RangeContinuous() if isinstance(self.limits, tuple): @@ -187,7 +220,7 @@ def expand_limits( @dataclass(kw_only=True) -class scale_position_continuous(scale_continuous[None]): +class scale_position_continuous(scale_position, scale_continuous[None]): """ Base class for continuous position scales """ @@ -214,6 +247,7 @@ class scale_x_discrete(scale_position_discrete): """ _aesthetics = ["x", "xmin", "xmax", "xend", "xintercept"] + position: Literal["bottom", "top"] = "bottom" @dataclass(kw_only=True) @@ -223,6 +257,7 @@ class scale_y_discrete(scale_position_discrete): """ _aesthetics = ["y", "ymin", "ymax", "yend", "yintercept"] + position: Literal["left", "right"] = "left" # Not part of the user API @@ -243,6 +278,7 @@ class scale_x_continuous(scale_position_continuous): """ _aesthetics = ["x", "xmin", "xmax", "xend", "xintercept"] + position: Literal["bottom", "top"] = "bottom" @dataclass(kw_only=True) @@ -263,6 +299,7 @@ class scale_y_continuous(scale_position_continuous): "middle", "upper", ] + position: Literal["left", "right"] = "left" # Transformed scales diff --git a/plotnine/scales/scales.py b/plotnine/scales/scales.py index 24dfb2b57b..fb6c9b1ba7 100644 --- a/plotnine/scales/scales.py +++ b/plotnine/scales/scales.py @@ -3,7 +3,7 @@ import itertools import typing from contextlib import suppress -from typing import List +from typing import List, cast from warnings import warn import numpy as np @@ -18,6 +18,7 @@ if typing.TYPE_CHECKING: import pandas as pd + from plotnine.scales.scale_xy import ScaleX, ScaleY from plotnine.typing import ScaledAestheticsName @@ -85,18 +86,29 @@ def get_scales( return None @property - def x(self) -> scale | None: + def x(self) -> ScaleX | None: """ Return x scale """ - return self.get_scales("x") + return cast("ScaleX | None", self.get_scales("x")) @property - def y(self) -> scale | None: + def y(self) -> ScaleY | None: """ Return y scale """ - return self.get_scales("y") + return cast("ScaleY| None", self.get_scales("y")) + + @property + def axis_positions(self) -> tuple[str, str]: + """ + The sides the x and y axes occupy, as `(x_side, y_side)` + """ + # scales.x / scales.y can be None here if "missing" scales + # have not yet been added. + x_side = "bottom" if self.x is None else self.x.position + y_side = "left" if self.y is None else self.y.position + return x_side, y_side def non_position_scales(self) -> Scales: """ diff --git a/plotnine/stats/stat_density.py b/plotnine/stats/stat_density.py index aac7b189bd..401c52509e 100644 --- a/plotnine/stats/stat_density.py +++ b/plotnine/stats/stat_density.py @@ -294,7 +294,7 @@ def nrd0(x: FloatArrayLike) -> float: "Need at least 2 data points to compute the nrd0 bandwidth." ) - std: float = np.std(x, ddof=1) # pyright: ignore + std: float = np.std(x, ddof=1) std_estimate: float = iqr(x) / 1.349 low_std = min(std, std_estimate) if low_std == 0: diff --git a/plotnine/stats/stat_ellipse.py b/plotnine/stats/stat_ellipse.py index ed0a832832..6720dc00ed 100644 --- a/plotnine/stats/stat_ellipse.py +++ b/plotnine/stats/stat_ellipse.py @@ -205,7 +205,7 @@ def scale_simp(x: FloatArray, center: FloatArray, n: int, p: int): wt = wt[wt > 0] n, _ = x.shape - wt = wt[:, np.newaxis] # pyright: ignore[reportCallIssue,reportArgumentType,reportOptionalSubscript] + wt = wt[:, np.newaxis] # loc use_loc = False diff --git a/plotnine/themes/targets.py b/plotnine/themes/targets.py index 8a34c469ca..d8410e1c02 100644 --- a/plotnine/themes/targets.py +++ b/plotnine/themes/targets.py @@ -28,6 +28,10 @@ class ThemeTargets: axis_title_x: Optional[Text] = None axis_title_y: Optional[Text] = None + axis_title_x_top: Optional[Text] = None + axis_title_x_bottom: Optional[Text] = None + axis_title_y_left: Optional[Text] = None + axis_title_y_right: Optional[Text] = None legend_frame: Optional[Rectangle] = None legend_key: list[ColoredDrawingArea] = field(default_factory=list) legends: Optional[legend_artists] = None diff --git a/plotnine/themes/theme.py b/plotnine/themes/theme.py index 9abda18395..10f968c06b 100644 --- a/plotnine/themes/theme.py +++ b/plotnine/themes/theme.py @@ -111,10 +111,16 @@ def __init__( complete=False, # Generate themeables keyword parameters with # - # from plotnine.themes.themeable import themeable - # for name in themeable.registry(): - # print(f'{name}=None,') + # python -c " + # from plotnine.themes.themeable import themeable + # for name in themeable.registry(): + # print(f'{name}=None,') + # " + axis_title_x_bottom=None, + axis_title_x_top=None, axis_title_x=None, + axis_title_y_left=None, + axis_title_y_right=None, axis_title_y=None, axis_title=None, legend_title=None, @@ -135,16 +141,32 @@ def __init__( strip_text_y=None, strip_text=None, title=None, + axis_text_x_bottom=None, + axis_text_x_top=None, axis_text_x=None, + axis_text_y_left=None, + axis_text_y_right=None, axis_text_y=None, axis_text=None, text=None, + axis_line_x_bottom=None, + axis_line_x_top=None, axis_line_x=None, + axis_line_y_left=None, + axis_line_y_right=None, axis_line_y=None, axis_line=None, + axis_ticks_minor_x_bottom=None, + axis_ticks_minor_x_top=None, axis_ticks_minor_x=None, + axis_ticks_minor_y_left=None, + axis_ticks_minor_y_right=None, axis_ticks_minor_y=None, + axis_ticks_major_x_bottom=None, + axis_ticks_major_x_top=None, axis_ticks_major_x=None, + axis_ticks_major_y_left=None, + axis_ticks_major_y_right=None, axis_ticks_major_y=None, axis_ticks_major=None, axis_ticks_minor=None, diff --git a/plotnine/themes/theme_gray.py b/plotnine/themes/theme_gray.py index 1511bb4bcf..fc150980c0 100644 --- a/plotnine/themes/theme_gray.py +++ b/plotnine/themes/theme_gray.py @@ -56,21 +56,25 @@ def __init__(self, base_size=11, base_family=None): axis_line_x=element_blank(), axis_line_y=element_blank(), axis_text=element_text(size=base_size * 0.8, color="#4D4D4D"), - axis_text_x=element_text(va="top", margin=margin(t=fifth_line)), - axis_text_y=element_text(ha="right", margin=margin(r=fifth_line)), + axis_text_x=element_text( + va="top", margin=margin(t=fifth_line, b=fifth_line) + ), + axis_text_y=element_text( + ha="right", margin=margin(r=fifth_line, l=fifth_line) + ), axis_ticks=element_line(color="#333333"), axis_ticks_length=0, axis_ticks_length_major=quarter_line, axis_ticks_length_minor=eighth_line, axis_ticks_minor=element_blank(), axis_title_x=element_text( - va="bottom", ha="center", margin=margin(t=m, unit="fig") + va="bottom", ha="center", margin=margin(t=m, b=m, unit="fig") ), axis_title_y=element_text( angle=90, va="center", ha="left", - margin=margin(r=m, unit="fig"), + margin=margin(r=m, l=m, unit="fig"), ), dpi=get_option("dpi"), figure_size=get_option("figure_size"), diff --git a/plotnine/themes/theme_matplotlib.py b/plotnine/themes/theme_matplotlib.py index 3d474bf72e..4cf3c08df6 100644 --- a/plotnine/themes/theme_matplotlib.py +++ b/plotnine/themes/theme_matplotlib.py @@ -48,17 +48,18 @@ def __init__(self, rc=None, fname=None, use_defaults=True): ), aspect_ratio=get_option("aspect_ratio"), axis_text=element_text( - size=base_size * 0.8, margin=margin(t=2.4, r=2.4, unit="pt") + size=base_size * 0.8, + margin=margin(t=2.4, b=2.4, r=2.4, l=2.4, unit="pt"), ), axis_title_x=element_text( - va="bottom", ha="center", margin=margin(t=m, unit="fig") + va="bottom", ha="center", margin=margin(t=m, b=m, unit="fig") ), axis_line=element_blank(), axis_title_y=element_text( angle=90, va="center", ha="left", - margin=margin(r=m, unit="fig"), + margin=margin(r=m, l=m, unit="fig"), ), dpi=get_option("dpi"), figure_size=get_option("figure_size"), diff --git a/plotnine/themes/theme_seaborn.py b/plotnine/themes/theme_seaborn.py index 2fe71e86dc..a617ba50d9 100644 --- a/plotnine/themes/theme_seaborn.py +++ b/plotnine/themes/theme_seaborn.py @@ -65,13 +65,13 @@ def __init__( ), ), axis_title_x=element_text( - va="bottom", ha="center", margin=margin(t=m, unit="fig") + va="bottom", ha="center", margin=margin(t=m, b=m, unit="fig") ), axis_title_y=element_text( angle=90, va="center", ha="left", - margin=margin(r=m, unit="fig"), + margin=margin(r=m, l=m, unit="fig"), ), legend_box_margin=0, legend_box_spacing=m * 3, # figure units diff --git a/plotnine/themes/themeable.py b/plotnine/themes/themeable.py index ee97f143bb..1d31f5d35a 100644 --- a/plotnine/themes/themeable.py +++ b/plotnine/themes/themeable.py @@ -17,7 +17,7 @@ import numpy as np -from .._utils import has_alpha_channel, to_rgba +from .._utils import MARGIN_SIDE, has_alpha_channel, side_artists, to_rgba from .._utils.registry import RegistryHierarchyMeta from ..exceptions import PlotnineError, deprecated_themeable_name from .elements import element_blank @@ -34,6 +34,7 @@ from plotnine import theme from plotnine.themes.targets import ThemeTargets + from plotnine.typing import Side class themeable(metaclass=RegistryHierarchyMeta): @@ -56,6 +57,13 @@ class themeable(metaclass=RegistryHierarchyMeta): `y_axis_title`. We are just using multiple inheritance to specify this composition. + A parent's effect is the combined effect of the leaves it composes: + theming `axis_text_x` styles both `axis_text_x_top` and + `axis_text_x_bottom`, and blanking it hides both. Each leaf adds its own + contribution on top of its bases — hence the `super()` call in every + `apply_*` / `blank_*` method — so a leaf that skips it applies alone and + the rest of the composition is silently lost. + When implementing a new themeable based on the ggplot2 documentation, it is important to keep this in mind and reverse the order of the "inherits from" in the documentation. @@ -521,51 +529,156 @@ def blend_alpha( return properties +def _set_axis_text_margin(themeable, ax, axis: str, side: Side): + """ + Set the gap between axis tick and axis text + """ + margin = themeable.properties.get("margin") + if margin is None: + return + pad = getattr(margin.pt, MARGIN_SIDE[side]) + ax.tick_params(axis=axis, which="major", pad=pad) + + # element_text themeables -class axis_title_x(themeable): +class axis_title_x_bottom(themeable): """ - x axis label + x axis label on the bottom Parameters ---------- theme_element : element_text + + Notes + ----- + The gap to the panel is set by the top margin (`t`), as for any + x-axis title; the other margins are ignored. """ def apply_figure(self, figure: Figure, targets: ThemeTargets): super().apply_figure(figure, targets) - if text := targets.axis_title_x: + if text := targets.axis_title_x_bottom: # ha can be a float and is handled by the layout manager text.set(**self._get_properties(omit=("margin", "ha"))) def blank_figure(self, figure: Figure, targets: ThemeTargets): super().blank_figure(figure, targets) - if text := targets.axis_title_x: + if text := targets.axis_title_x_bottom: text.set_visible(False) -class axis_title_y(themeable): +class axis_title_x_top(themeable): """ - y axis label + x axis label on the top + + Parameters + ---------- + theme_element : element_text + + Notes + ----- + The gap to the panel is set by the bottom margin (`b`) — the edge + that faces the panel below; the other margins are ignored. + """ + + def apply_figure(self, figure: Figure, targets: ThemeTargets): + super().apply_figure(figure, targets) + if text := targets.axis_title_x_top: + text.set(**self._get_properties(omit=("margin", "ha"))) + + def blank_figure(self, figure: Figure, targets: ThemeTargets): + super().blank_figure(figure, targets) + if text := targets.axis_title_x_top: + text.set_visible(False) + + +class axis_title_x(axis_title_x_top, axis_title_x_bottom): + """ + x axis label + + Parameters + ---------- + theme_element : element_text + + Notes + ----- + Only the margin on the side that faces the panel has an effect: + the top margin (`t`) when the axis is on the bottom, the bottom + margin (`b`) when it is on the top. Set both to cover either + position. + """ + + +class axis_title_y_left(themeable): + """ + y axis label on the left Parameters ---------- theme_element : element_text + + Notes + ----- + The gap to the panel is set by the right margin (`r`), as for any + y-axis title; the other margins are ignored. """ def apply_figure(self, figure: Figure, targets: ThemeTargets): super().apply_figure(figure, targets) - if text := targets.axis_title_y: + if text := targets.axis_title_y_left: # va can be a float and is handled by the layout manager text.set(**self._get_properties(omit=("margin", "va"))) def blank_figure(self, figure: Figure, targets: ThemeTargets): super().blank_figure(figure, targets) - if text := targets.axis_title_y: + if text := targets.axis_title_y_left: + text.set_visible(False) + + +class axis_title_y_right(themeable): + """ + y axis label on the right + + Parameters + ---------- + theme_element : element_text + + Notes + ----- + The gap to the panel is set by the left margin (`l`) — the edge + that faces the panel to the left; the other margins are ignored. + """ + + def apply_figure(self, figure: Figure, targets: ThemeTargets): + super().apply_figure(figure, targets) + if text := targets.axis_title_y_right: + text.set(**self._get_properties(omit=("margin", "va"))) + + def blank_figure(self, figure: Figure, targets: ThemeTargets): + super().blank_figure(figure, targets) + if text := targets.axis_title_y_right: text.set_visible(False) +class axis_title_y(axis_title_y_left, axis_title_y_right): + """ + y axis label + + Parameters + ---------- + theme_element : element_text + + Notes + ----- + Only the margin on the side that faces the panel has an effect: + the right margin (`r`) when the axis is on the left, the left + margin (`l`) when it is on the right. Set both to cover either + position. + """ + + class axis_title(axis_title_x, axis_title_y): """ Axis labels @@ -573,6 +686,14 @@ class axis_title(axis_title_x, axis_title_y): Parameters ---------- theme_element : element_text + + Notes + ----- + Only the margin on the side that faces the panel has an effect. + For the x-axis that is the top margin (`t`) on the bottom or the + bottom margin (`b`) on the top; for the y-axis the right margin + (`r`) on the left or the left margin (`l`) on the right. Set both + margins of each axis to cover either position. """ @@ -948,9 +1069,9 @@ class title( """ -class axis_text_x(MixinSequenceOfValues): +class axis_text_x_bottom(MixinSequenceOfValues): """ - x-axis tick labels + x-axis tick labels on the bottom Parameters ---------- @@ -958,48 +1079,57 @@ class axis_text_x(MixinSequenceOfValues): Notes ----- - Use the `margin` to control the gap between the ticks and the - text. e.g. - - ```python - theme(axis_text_x=element_text(margin={"t": 5, "units": "pt"})) - ``` - - creates a margin of 5 points. + The gap to the panel is set by the top margin (`t`), as for any + x-axis text; the other margins are ignored. """ def apply_ax(self, ax: Axes): super().apply_ax(ax) + if not ax.xaxis.get_tick_params(which="major").get( + "labelbottom", False + ): + return + labels = [t.label1 for t in ax.xaxis.get_major_ticks()] + self.set(labels, self._get_properties(omit=("margin", "va"))) + _set_axis_text_margin(self, ax, "x", "bottom") - # TODO: Remove this code when the minimum matplotlib >= 3.10.0, - # and use the commented one below it - import matplotlib as mpl - from packaging import version + def blank_ax(self, ax: Axes): + super().blank_ax(ax) + for t in ax.xaxis.get_major_ticks(): + t.label1.set_visible(False) - vinstalled = version.parse(mpl.__version__) - v310 = version.parse("3.10.0") - name = "labelbottom" if vinstalled >= v310 else "labelleft" - if not ax.xaxis.get_tick_params()[name]: - return - # if not ax.xaxis.get_tick_params()["labelbottom"]: - # return +class axis_text_x_top(MixinSequenceOfValues): + """ + x-axis tick labels on the top - labels = [t.label1 for t in ax.xaxis.get_major_ticks()] - self.set( - labels, - self._get_properties(omit=("margin", "va")), - ) + Parameters + ---------- + theme_element : element_text + + Notes + ----- + The gap to the panel is set by the bottom margin (`b`) — the edge + that faces the panel below; the other margins are ignored. + """ + + def apply_ax(self, ax: Axes): + super().apply_ax(ax) + if not ax.xaxis.get_tick_params(which="major").get("labeltop", False): + return + labels = [t.label2 for t in ax.xaxis.get_major_ticks()] + self.set(labels, self._get_properties(omit=("margin", "va"))) + _set_axis_text_margin(self, ax, "x", "top") def blank_ax(self, ax: Axes): super().blank_ax(ax) for t in ax.xaxis.get_major_ticks(): - t.label1.set_visible(False) + t.label2.set_visible(False) -class axis_text_y(MixinSequenceOfValues): +class axis_text_x(axis_text_x_top, axis_text_x_bottom): """ - y-axis tick labels + x-axis tick labels Parameters ---------- @@ -1007,27 +1137,40 @@ class axis_text_y(MixinSequenceOfValues): Notes ----- - Use the `margin` to control the gap between the ticks and the - text. e.g. + Only the margin on the side that faces the panel has an effect: + the top margin (`t`) when the axis is on the bottom, the bottom + margin (`b`) when it is on the top. Set both to cover either + position. e.g. ```python - theme(axis_text_y=element_text(margin={"r": 5, "units": "pt"})) + theme(axis_text_x=element_text(margin={"t": 5, "b": 5, "units": "pt"})) ``` - creates a margin of 5 points. + puts a 5 point gap between the labels and the panel on either side. + """ + + +class axis_text_y_left(MixinSequenceOfValues): + """ + y-axis tick labels on the left + + Parameters + ---------- + theme_element : element_text + + Notes + ----- + The gap to the panel is set by the right margin (`r`), as for any + y-axis text; the other margins are ignored. """ def apply_ax(self, ax: Axes): super().apply_ax(ax) - - if not ax.yaxis.get_tick_params()["labelleft"]: + if not ax.yaxis.get_tick_params(which="major").get("labelleft", False): return - labels = [t.label1 for t in ax.yaxis.get_major_ticks()] - self.set( - labels, - self._get_properties(omit=("margin", "ha")), - ) + self.set(labels, self._get_properties(omit=("margin", "ha"))) + _set_axis_text_margin(self, ax, "y", "left") def blank_ax(self, ax: Axes): super().blank_ax(ax) @@ -1035,6 +1178,59 @@ def blank_ax(self, ax: Axes): t.label1.set_visible(False) +class axis_text_y_right(MixinSequenceOfValues): + """ + y-axis tick labels on the right + + Parameters + ---------- + theme_element : element_text + + Notes + ----- + The gap to the panel is set by the left margin (`l`) — the edge + that faces the panel to the left; the other margins are ignored. + """ + + def apply_ax(self, ax: Axes): + super().apply_ax(ax) + if not ax.yaxis.get_tick_params(which="major").get( + "labelright", False + ): + return + labels = [t.label2 for t in ax.yaxis.get_major_ticks()] + self.set(labels, self._get_properties(omit=("margin", "ha"))) + _set_axis_text_margin(self, ax, "y", "right") + + def blank_ax(self, ax: Axes): + super().blank_ax(ax) + for t in ax.yaxis.get_major_ticks(): + t.label2.set_visible(False) + + +class axis_text_y(axis_text_y_left, axis_text_y_right): + """ + y-axis tick labels + + Parameters + ---------- + theme_element : element_text + + Notes + ----- + Only the margin on the side that faces the panel has an effect: + the right margin (`r`) when the axis is on the left, the left + margin (`l`) when it is on the right. Set both to cover either + position. e.g. + + ```python + theme(axis_text_y=element_text(margin={"r": 5, "l": 5, "units": "pt"})) + ``` + + puts a 5 point gap between the labels and the panel on either side. + """ + + class axis_text(axis_text_x, axis_text_y): """ Axis tick labels @@ -1045,14 +1241,19 @@ class axis_text(axis_text_x, axis_text_y): Notes ----- - Use the `margin` to control the gap between the ticks and the - text. e.g. + Only the margin on the side that faces the panel has an effect. + For the x-axis that is the top margin (`t`) on the bottom or the + bottom margin (`b`) on the top; for the y-axis the right margin + (`r`) on the left or the left margin (`l`) on the right. Set both + margins of each axis to cover either position. e.g. ```python - theme(axis_text=element_text(margin={"t": 5, "r": 5, "units": "pt"})) + theme(axis_text=element_text( + margin={"t": 5, "b": 5, "r": 5, "l": 5, "units": "pt"} + )) ``` - creates a margin of 5 points. + puts a 5 point gap between the labels and the panel on every side. """ @@ -1096,63 +1297,93 @@ def rcParams(self): # element_line themeables -class axis_line_x(themeable): +def _style_axis_line(themeable, ax, side): """ - x-axis line + Style the spine on one side, when that side carries the axis - Parameters - ---------- - theme_element : element_line + `coord.setup_ax` makes only the active side's spine visible, so a hidden + spine here is one this axis does not sit on. The spine name equals the + side (`bottom`/`top`/`left`/`right`). + """ + if not ax.spines[side].get_visible(): + return + properties = themeable._get_properties(omit=("solid_capstyle",)) + # MPL has a default zorder of 2.5 for spines, so layers 3+ would be + # drawn on top of the spines + if "zorder" not in properties: + properties["zorder"] = 10000 + ax.spines[side].set(**properties) + + +class axis_line_x_bottom(themeable): + """ + x-axis line on the bottom """ - position = "bottom" + def apply_ax(self, ax: Axes): + super().apply_ax(ax) + _style_axis_line(self, ax, "bottom") + + def blank_ax(self, ax: Axes): + super().blank_ax(ax) + ax.spines["bottom"].set_visible(False) + + +class axis_line_x_top(themeable): + """ + x-axis line on the top + """ def apply_ax(self, ax: Axes): super().apply_ax(ax) - properties = self._get_properties(omit=("solid_capstyle",)) - # MPL has a default zorder of 2.5 for spines - # so layers 3+ would be drawn on top of the spines - if "zorder" not in properties: - properties["zorder"] = 10000 - ax.spines["top"].set_visible(False) - ax.spines["bottom"].set(**properties) + _style_axis_line(self, ax, "top") def blank_ax(self, ax: Axes): super().blank_ax(ax) ax.spines["top"].set_visible(False) - ax.spines["bottom"].set_visible(False) -class axis_line_y(themeable): +class axis_line_x(axis_line_x_top, axis_line_x_bottom): """ - y-axis line + x-axis line Parameters ---------- theme_element : element_line """ - position = "left" + +class axis_line_y_left(themeable): + """ + y-axis line on the left + """ def apply_ax(self, ax: Axes): super().apply_ax(ax) - properties = self._get_properties(omit=("solid_capstyle",)) - # MPL has a default zorder of 2.5 for spines - # so layers 3+ would be drawn on top of the spines - if "zorder" not in properties: - properties["zorder"] = 10000 - ax.spines["right"].set_visible(False) - ax.spines["left"].set(**properties) + _style_axis_line(self, ax, "left") def blank_ax(self, ax: Axes): super().blank_ax(ax) ax.spines["left"].set_visible(False) + + +class axis_line_y_right(themeable): + """ + y-axis line on the right + """ + + def apply_ax(self, ax: Axes): + super().apply_ax(ax) + _style_axis_line(self, ax, "right") + + def blank_ax(self, ax: Axes): + super().blank_ax(ax) ax.spines["right"].set_visible(False) -class axis_line(axis_line_x, axis_line_y): +class axis_line_y(axis_line_y_left, axis_line_y_right): """ - x & y axis lines + y-axis line Parameters ---------- @@ -1160,164 +1391,211 @@ class axis_line(axis_line_x, axis_line_y): """ -class axis_ticks_minor_x(MixinSequenceOfValues): +class axis_line(axis_line_x, axis_line_y): """ - x-axis tick lines + x & y axis lines Parameters ---------- theme_element : element_line """ + +def _style_axis_ticks(themeable, ax, axis_name, which, side): + """ + Style the tick lines on one side of an axis + """ + axis = getattr(ax, axis_name) + # coord.setup_ax uses set_tick_params to turn off the ticks that will + # not show, setting the side key (e.g. params["bottom"]) to False and + # the artist invisible. Theming should not make them visible again. + if not axis.get_tick_params(which=which).get(side, False): + return + + # We have to use both Axis.set_tick_params() and Tick.tickline.set(). + # Splitting the properties lets set_tick_params keep a record of the + # ones it cares about so it does not undo them. GH703 + # https://github.com/matplotlib/matplotlib/issues/26008 + tick_params = {} + properties = themeable.properties + with suppress(KeyError): + tick_params["width"] = properties.pop("linewidth") + with suppress(KeyError): + tick_params["color"] = properties.pop("color") + + if tick_params: + axis.set_tick_params(which=which, **tick_params) + + attr = side_artists(side)[0] + ticks = ( + axis.get_minor_ticks() if which == "minor" else axis.get_major_ticks() + ) + themeable.set([getattr(t, attr) for t in ticks], properties) + + +def _blank_axis_ticks(ax, axis_name, which, side): + """ + Hide the tick lines on one side of an axis + """ + axis = getattr(ax, axis_name) + attr = side_artists(side)[0] + ticks = ( + axis.get_minor_ticks() if which == "minor" else axis.get_major_ticks() + ) + for tick in ticks: + getattr(tick, attr).set_visible(False) + + +class axis_ticks_minor_x_bottom(MixinSequenceOfValues): + """ + x-axis minor tick lines on the bottom + """ + def apply_ax(self, ax: Axes): super().apply_ax(ax) - # The ggplot._draw_breaks_and_labels uses set_tick_params to - # turn off the ticks that will not show. That sets the location - # key (e.g. params["bottom"]) to False. It also sets the artist - # to invisible. Theming should not change those artists to visible, - # so we return early. - params = ax.xaxis.get_tick_params(which="minor") - if not params.get("bottom", False): - return + _style_axis_ticks(self, ax, "xaxis", "minor", "bottom") - # We have to use both - # 1. Axis.set_tick_params() - # 2. Tick.tick1line.set() - # We split the properties so that set_tick_params keeps - # record of the properties it cares about so that it does - # not undo them. GH703 - # https://github.com/matplotlib/matplotlib/issues/26008 - tick_params = {} - properties = self.properties - with suppress(KeyError): - tick_params["width"] = properties.pop("linewidth") - with suppress(KeyError): - tick_params["color"] = properties.pop("color") + def blank_ax(self, ax: Axes): + super().blank_ax(ax) + _blank_axis_ticks(ax, "xaxis", "minor", "bottom") - if tick_params: - ax.xaxis.set_tick_params(which="minor", **tick_params) - lines = [t.tick1line for t in ax.xaxis.get_minor_ticks()] - self.set(lines, properties) +class axis_ticks_minor_x_top(MixinSequenceOfValues): + """ + x-axis minor tick lines on the top + """ + + def apply_ax(self, ax: Axes): + super().apply_ax(ax) + _style_axis_ticks(self, ax, "xaxis", "minor", "top") def blank_ax(self, ax: Axes): super().blank_ax(ax) - for tick in ax.xaxis.get_minor_ticks(): - tick.tick1line.set_visible(False) + _blank_axis_ticks(ax, "xaxis", "minor", "top") -class axis_ticks_minor_y(MixinSequenceOfValues): +class axis_ticks_minor_x(axis_ticks_minor_x_top, axis_ticks_minor_x_bottom): """ - y-axis minor tick lines + x-axis minor tick lines Parameters ---------- theme_element : element_line """ + +class axis_ticks_minor_y_left(MixinSequenceOfValues): + """ + y-axis minor tick lines on the left + """ + def apply_ax(self, ax: Axes): super().apply_ax(ax) - params = ax.yaxis.get_tick_params(which="minor") - if not params.get("left", False): - return + _style_axis_ticks(self, ax, "yaxis", "minor", "left") - tick_params = {} - properties = self.properties - with suppress(KeyError): - tick_params["width"] = properties.pop("linewidth") - with suppress(KeyError): - tick_params["color"] = properties.pop("color") + def blank_ax(self, ax: Axes): + super().blank_ax(ax) + _blank_axis_ticks(ax, "yaxis", "minor", "left") - if tick_params: - ax.yaxis.set_tick_params(which="minor", **tick_params) - lines = [t.tick1line for t in ax.yaxis.get_minor_ticks()] - self.set(lines, properties) +class axis_ticks_minor_y_right(MixinSequenceOfValues): + """ + y-axis minor tick lines on the right + """ + + def apply_ax(self, ax: Axes): + super().apply_ax(ax) + _style_axis_ticks(self, ax, "yaxis", "minor", "right") def blank_ax(self, ax: Axes): super().blank_ax(ax) - for tick in ax.yaxis.get_minor_ticks(): - tick.tick1line.set_visible(False) + _blank_axis_ticks(ax, "yaxis", "minor", "right") -class axis_ticks_major_x(MixinSequenceOfValues): +class axis_ticks_minor_y(axis_ticks_minor_y_left, axis_ticks_minor_y_right): """ - x-axis major tick lines + y-axis minor tick lines Parameters ---------- theme_element : element_line """ + +class axis_ticks_major_x_bottom(MixinSequenceOfValues): + """ + x-axis major tick lines on the bottom + """ + def apply_ax(self, ax: Axes): super().apply_ax(ax) - params = ax.xaxis.get_tick_params(which="major") - - # TODO: Remove this code when the minimum matplotlib >= 3.10.0, - # and use the commented one below it - import matplotlib as mpl - from packaging import version + _style_axis_ticks(self, ax, "xaxis", "major", "bottom") - vinstalled = version.parse(mpl.__version__) - v310 = version.parse("3.10.0") - name = "bottom" if vinstalled >= v310 else "left" - if not params.get(name, False): - return + def blank_ax(self, ax: Axes): + super().blank_ax(ax) + _blank_axis_ticks(ax, "xaxis", "major", "bottom") - # if not params.get("bottom", False): - # return - tick_params = {} - properties = self.properties - with suppress(KeyError): - tick_params["width"] = properties.pop("linewidth") - with suppress(KeyError): - tick_params["color"] = properties.pop("color") - - if tick_params: - ax.xaxis.set_tick_params(which="major", **tick_params) +class axis_ticks_major_x_top(MixinSequenceOfValues): + """ + x-axis major tick lines on the top + """ - lines = [t.tick1line for t in ax.xaxis.get_major_ticks()] - self.set(lines, properties) + def apply_ax(self, ax: Axes): + super().apply_ax(ax) + _style_axis_ticks(self, ax, "xaxis", "major", "top") def blank_ax(self, ax: Axes): super().blank_ax(ax) - for tick in ax.xaxis.get_major_ticks(): - tick.tick1line.set_visible(False) + _blank_axis_ticks(ax, "xaxis", "major", "top") -class axis_ticks_major_y(MixinSequenceOfValues): +class axis_ticks_major_x(axis_ticks_major_x_top, axis_ticks_major_x_bottom): """ - y-axis major tick lines + x-axis major tick lines Parameters ---------- theme_element : element_line """ + +class axis_ticks_major_y_left(MixinSequenceOfValues): + """ + y-axis major tick lines on the left + """ + def apply_ax(self, ax: Axes): super().apply_ax(ax) - params = ax.yaxis.get_tick_params(which="major") - if not params.get("left", False): - return + _style_axis_ticks(self, ax, "yaxis", "major", "left") - tick_params = {} - properties = self.properties - with suppress(KeyError): - tick_params["width"] = properties.pop("linewidth") - with suppress(KeyError): - tick_params["color"] = properties.pop("color") + def blank_ax(self, ax: Axes): + super().blank_ax(ax) + _blank_axis_ticks(ax, "yaxis", "major", "left") - if tick_params: - ax.yaxis.set_tick_params(which="major", **tick_params) - lines = [t.tick1line for t in ax.yaxis.get_major_ticks()] - self.set(lines, properties) +class axis_ticks_major_y_right(MixinSequenceOfValues): + """ + y-axis major tick lines on the right + """ + + def apply_ax(self, ax: Axes): + super().apply_ax(ax) + _style_axis_ticks(self, ax, "yaxis", "major", "right") def blank_ax(self, ax: Axes): super().blank_ax(ax) - for tick in ax.yaxis.get_major_ticks(): - tick.tick1line.set_visible(False) + _blank_axis_ticks(ax, "yaxis", "major", "right") + + +class axis_ticks_major_y(axis_ticks_major_y_left, axis_ticks_major_y_right): + """ + y-axis major tick lines + + Parameters + ---------- + theme_element : element_line + """ class axis_ticks_major(axis_ticks_major_x, axis_ticks_major_y): @@ -1838,7 +2116,10 @@ def apply_ax(self, ax: Axes): value: float | complex = self.properties["value"] try: - visible = ax.xaxis.get_major_ticks()[0].tick1line.get_visible() + tick = ax.xaxis.get_major_ticks()[0] + visible = ( + tick.tick1line.get_visible() or tick.tick2line.get_visible() + ) except IndexError: value = 0 else: @@ -1872,7 +2153,10 @@ def apply_ax(self, ax: Axes): value: float | complex = self.properties["value"] try: - visible = ax.yaxis.get_major_ticks()[0].tick1line.get_visible() + tick = ax.yaxis.get_major_ticks()[0] + visible = ( + tick.tick1line.get_visible() or tick.tick2line.get_visible() + ) except IndexError: value = 0 else: @@ -2674,7 +2958,8 @@ def apply_ax(self, ax: Axes): val = self.properties["value"] for t in ax.xaxis.get_major_ticks(): - _val = val if t.tick1line.get_visible() else 0 + visible = t.tick1line.get_visible() or t.tick2line.get_visible() + _val = val if visible else 0 t.set_pad(_val) @@ -2701,7 +2986,8 @@ def apply_ax(self, ax: Axes): val = self.properties["value"] for t in ax.yaxis.get_major_ticks(): - _val = val if t.tick1line.get_visible() else 0 + visible = t.tick1line.get_visible() or t.tick2line.get_visible() + _val = val if visible else 0 t.set_pad(_val) @@ -2744,7 +3030,8 @@ def apply_ax(self, ax: Axes): val = self.properties["value"] for t in ax.xaxis.get_minor_ticks(): - _val = val if t.tick1line.get_visible() else 0 + visible = t.tick1line.get_visible() or t.tick2line.get_visible() + _val = val if visible else 0 t.set_pad(_val) @@ -2770,7 +3057,8 @@ def apply_ax(self, ax: Axes): val = self.properties["value"] for t in ax.yaxis.get_minor_ticks(): - _val = val if t.tick1line.get_visible() else 0 + visible = t.tick1line.get_visible() or t.tick2line.get_visible() + _val = val if visible else 0 t.set_pad(_val) diff --git a/tests/baseline_images/test_axis_position/coord_flip_x_top.png b/tests/baseline_images/test_axis_position/coord_flip_x_top.png new file mode 100644 index 0000000000..29f748ffac Binary files /dev/null and b/tests/baseline_images/test_axis_position/coord_flip_x_top.png differ diff --git a/tests/baseline_images/test_axis_position/facet_wrap_y_right.png b/tests/baseline_images/test_axis_position/facet_wrap_y_right.png new file mode 100644 index 0000000000..948500ca97 Binary files /dev/null and b/tests/baseline_images/test_axis_position/facet_wrap_y_right.png differ diff --git a/tests/baseline_images/test_axis_position/x_axis_top_continuous.png b/tests/baseline_images/test_axis_position/x_axis_top_continuous.png new file mode 100644 index 0000000000..f981ddeefb Binary files /dev/null and b/tests/baseline_images/test_axis_position/x_axis_top_continuous.png differ diff --git a/tests/baseline_images/test_axis_position/x_axis_top_discrete.png b/tests/baseline_images/test_axis_position/x_axis_top_discrete.png new file mode 100644 index 0000000000..77e6fca4fd Binary files /dev/null and b/tests/baseline_images/test_axis_position/x_axis_top_discrete.png differ diff --git a/tests/baseline_images/test_axis_position/y_axis_right_continuous.png b/tests/baseline_images/test_axis_position/y_axis_right_continuous.png new file mode 100644 index 0000000000..52c9f9537d Binary files /dev/null and b/tests/baseline_images/test_axis_position/y_axis_right_continuous.png differ diff --git a/tests/baseline_images/test_theme/theme_seaborn.png b/tests/baseline_images/test_theme/theme_seaborn.png index e9acbf5446..b57f13ffff 100644 Binary files a/tests/baseline_images/test_theme/theme_seaborn.png and b/tests/baseline_images/test_theme/theme_seaborn.png differ diff --git a/tests/test_axis_position.py b/tests/test_axis_position.py new file mode 100644 index 0000000000..e20c3cfa68 --- /dev/null +++ b/tests/test_axis_position.py @@ -0,0 +1,49 @@ +from plotnine import ( + aes, + coord_flip, + facet_wrap, + geom_point, + ggplot, + scale_x_continuous, + scale_x_discrete, + scale_y_continuous, +) +from plotnine.data import mtcars + +p0 = ggplot(mtcars, aes("wt", "mpg")) + geom_point() + + +def test_x_axis_top_continuous(): + p = p0 + scale_x_continuous(position="top") + assert p == "x_axis_top_continuous" + + +def test_y_axis_right_continuous(): + p = p0 + scale_y_continuous(position="right") + assert p == "y_axis_right_continuous" + + +def test_coord_flip_x_top(): + # Before flipping, the y-axis is on the non-default side, + # after flipping the x-axis will be on the non-default side. + p = p0 + scale_y_continuous(position="right") + coord_flip() + assert p == "coord_flip_x_top" + + +def test_facet_wrap_y_right(): + p = ( + ggplot(mtcars, aes("wt", "mpg")) + + geom_point() + + facet_wrap("gear") + + scale_y_continuous(position="right") + ) + assert p == "facet_wrap_y_right" + + +def test_x_axis_top_discrete(): + p = ( + ggplot(mtcars, aes("factor(cyl)", "mpg")) + + geom_point() + + scale_x_discrete(position="top") + ) + assert p == "x_axis_top_discrete" diff --git a/tests/test_scale_internals.py b/tests/test_scale_internals.py index bd74049657..aff976fcd9 100644 --- a/tests/test_scale_internals.py +++ b/tests/test_scale_internals.py @@ -56,7 +56,7 @@ scale_y_continuous, scale_y_log10, ) -from plotnine.scales.scales import make_scale +from plotnine.scales.scales import Scales, make_scale PANDAS_LT_3 = Version(pd.__version__) < Version("3.0") @@ -915,3 +915,26 @@ def test_transform_datetime_aes_param(): + geom_point(y=yparam, color="red") ) assert p == "transform_datetime_aes_param" + + +def test_position_invalid_for_aesthetic(): + with pytest.raises(PlotnineError): + scale_x_continuous(position="left") # pyright: ignore[reportArgumentType] + with pytest.raises(PlotnineError): + scale_y_continuous(position="bottom") # pyright: ignore[reportArgumentType] + with pytest.raises(PlotnineError): + scale_x_continuous(position="middle") # pyright: ignore[reportArgumentType] + + +def test_scales_axis_positions(): + # No position scales -> defaults + assert Scales().axis_positions == ("bottom", "left") + + # Explicit sides are read from the scales + s = Scales( + [ + scale_x_continuous(position="top"), + scale_y_continuous(position="right"), + ] + ) + assert s.axis_positions == ("top", "right")