diff --git a/graphix/pretty_print.py b/graphix/pretty_print.py index b2bd72e8f..cee5f5e5f 100644 --- a/graphix/pretty_print.py +++ b/graphix/pretty_print.py @@ -8,7 +8,7 @@ from enum import Enum from fractions import Fraction from math import pi -from typing import TYPE_CHECKING, SupportsFloat +from typing import TYPE_CHECKING, SupportsComplex, SupportsFloat # `assert_never` introduced in Python 3.11 from typing_extensions import assert_never @@ -26,6 +26,8 @@ from graphix.flow.core import PauliFlow, XZCorrections from graphix.fundamentals import Angle from graphix.pattern import Pattern + from graphix.sim.density_matrix import DensityMatrix + from graphix.sim.statevec import _ENCODING, Statevec class OutputFormat(Enum): @@ -439,3 +441,393 @@ def xzcorr_to_str(xzcorr: XZCorrections[AbstractMeasurement], output: OutputForm partial_order_to_str(xzcorr.partial_order_layers, output), ) ) + + +# --- Complex amplitude and quantum-state pretty-printing --------------------- +# +# The recognition of "nice" real numbers (fractions, square roots) relies on a +# square-then-rationalize trick: a real ``x`` is matched against ``sqrt(p / q)`` +# by approximating ``x ** 2`` with a rational ``p / q``. This single mechanism +# uniformly handles plain fractions (``1/4``), surds (``√2/2``, ``√3/2``) and, +# combined with :func:`angle_to_str`, the phase of exponentials (``e^{iπ/3}``). + +_DEFAULT_MAX_DENOMINATOR = 1000 +_DEFAULT_ATOL = 1e-9 +_DEFAULT_PRECISION = 4 + + +def _squarefree_decomposition(n: int) -> tuple[int, int]: + """Decompose a non-negative integer as ``outer ** 2 * inner`` with ``inner`` squarefree. + + Parameters + ---------- + n : int + Non-negative integer to decompose. + + Returns + ------- + tuple[int, int] + ``(outer, inner)`` such that ``outer ** 2 * inner == n`` and ``inner`` is + squarefree. ``n == 0`` returns ``(0, 1)``. + """ + if n == 0: + return 0, 1 + outer = 1 + inner = n + d = 2 + while d * d <= inner: + while inner % (d * d) == 0: + inner //= d * d + outer *= d + d += 1 + return outer, inner + + +def _recognize_sqrt(x: float, max_denominator: int, atol: float) -> tuple[int, int, int] | None: + """Recognize a real number as ``signed_num * sqrt(inner) / den``. + + The recognition approximates ``x ** 2`` by a rational ``p / q``; on success, + ``x = ±sqrt(p / q)`` is rewritten with a rationalized, fully-reduced + denominator. Pure rationals are covered as the special case ``inner == 1``. + + Parameters + ---------- + x : float + Real number to recognize. + max_denominator : int + Maximum denominator allowed when approximating ``x ** 2`` by a rational. + atol : float + Absolute tolerance for that rational approximation. + + Returns + ------- + tuple[int, int, int] or None + ``(signed_num, inner, den)`` with ``den > 0`` and ``inner`` a positive + squarefree integer, encoding ``x = signed_num * sqrt(inner) / den``. + Returns ``None`` when ``x`` is not recognized as such a value. + """ + if x == 0: + return 0, 1, 1 + square = Fraction(x * x).limit_denominator(max_denominator) + if not math.isclose(x * x, float(square), abs_tol=atol): + return None + num_outer, num_inner = _squarefree_decomposition(square.numerator) + den_outer, den_inner = _squarefree_decomposition(square.denominator) + # x = ±(num_outer √num_inner) / (den_outer √den_inner); rationalize by √den_inner. + combined_outer, inner = _squarefree_decomposition(num_inner * den_inner) + num = num_outer * combined_outer + den = den_outer * den_inner + divisor = math.gcd(num, den) + num //= divisor + den //= divisor + sign = -1 if x < 0 else 1 + return sign * num, inner, den + + +def _imaginary_unit(output: OutputFormat) -> str: + return r"\mathrm{i}" if output == OutputFormat.LaTeX else "i" + + +def _sqrt_str(inner: int, output: OutputFormat) -> str: + """Return the string for ``sqrt(inner)`` (empty when ``inner == 1``).""" + if inner == 1: + return "" + if output == OutputFormat.LaTeX: + return rf"\sqrt{{{inner}}}" + if output == OutputFormat.Unicode: + return f"√{inner}" + return f"sqrt({inner})" + + +def _fraction_str(num: str, den: str, output: OutputFormat) -> str: + if output == OutputFormat.LaTeX: + return rf"\frac{{{num}}}{{{den}}}" + return f"{num}/{den}" + + +def _render_real(signed_num: int, inner: int, den: int, output: OutputFormat) -> str: + """Render ``signed_num * sqrt(inner) / den`` produced by :func:`_recognize_sqrt`.""" + if signed_num == 0: + return "0" + sign = "-" if signed_num < 0 else "" + magnitude = abs(signed_num) + sqrt_part = _sqrt_str(inner, output) + if inner == 1: + numerator = f"{magnitude}" + elif magnitude == 1: + numerator = sqrt_part + else: + numerator = f"{magnitude}{sqrt_part}" + if den == 1: + return f"{sign}{numerator}" + return f"{sign}{_fraction_str(numerator, str(den), output)}" + + +def _real_to_str(x: float, output: OutputFormat, max_denominator: int, atol: float) -> str | None: + rec = _recognize_sqrt(x, max_denominator, atol) + if rec is None: + return None + return _render_real(*rec, output) + + +def _imaginary_to_str(x: float, output: OutputFormat, max_denominator: int, atol: float) -> str | None: + """Render a purely imaginary value ``x * i``.""" + rec = _recognize_sqrt(x, max_denominator, atol) + if rec is None: + return None + signed_num, inner, den = rec + unit = _imaginary_unit(output) + # A unit coefficient collapses to just ``±i``. + if abs(signed_num) == 1 and inner == 1 and den == 1: + return f"{'-' if signed_num < 0 else ''}{unit}" + return f"{_render_real(signed_num, inner, den, output)}{unit}" + + +def _recognize_angle_over_pi(theta: float, max_denominator: int, atol: float) -> Fraction | None: + """Return ``theta / pi`` as a simple fraction, or ``None`` if it is not one.""" + value = theta / pi + frac = Fraction(value).limit_denominator(max_denominator) + if math.isclose(value, float(frac), abs_tol=atol): + return frac + return None + + +def _exponential_to_str(z: complex, output: OutputFormat, max_denominator: int, atol: float) -> str | None: + """Render ``z`` as ``r e^{iθ}`` when both ``r`` and ``θ / π`` are recognized.""" + theta = math.atan2(z.imag, z.real) + angle_frac = _recognize_angle_over_pi(theta, max_denominator, atol) + if angle_frac is None or angle_frac == 0: + return None + radius = _real_to_str(math.hypot(z.real, z.imag), output, max_denominator, atol) + if radius is None: + return None + sign = "-" if angle_frac < 0 else "" + angle_str = angle_to_str(float(abs(angle_frac)), output) + unit = _imaginary_unit(output) + e_sym = r"\mathrm{e}" if output == OutputFormat.LaTeX else "e" + unit_sep = " " if output == OutputFormat.LaTeX else "*" if output == OutputFormat.ASCII else "" + exponent = f"{sign}{unit}{unit_sep}{angle_str}" + body = f"{e_sym}^{{{exponent}}}" if output == OutputFormat.LaTeX else f"{e_sym}^({exponent})" + if radius == "1": + return body + prefix_sep = " " if output == OutputFormat.LaTeX else "·" if output == OutputFormat.Unicode else "*" + return f"{radius}{prefix_sep}{body}" + + +def _cartesian_to_str(re: float, im: float, output: OutputFormat, max_denominator: int, atol: float) -> str | None: + """Render ``re + im i`` when both parts are recognized as nice reals.""" + re_str = _real_to_str(re, output, max_denominator, atol) + im_rec = _recognize_sqrt(im, max_denominator, atol) + if re_str is None or im_rec is None: + return None + signed_num, inner, den = im_rec + connector = " - " if signed_num < 0 else " + " + unit = _imaginary_unit(output) + if abs(signed_num) == 1 and inner == 1 and den == 1: + imag = unit + else: + imag = f"{_render_real(abs(signed_num), inner, den, output)}{unit}" + return f"{re_str}{connector}{imag}" + + +def _decimal_to_str(z: complex, output: OutputFormat, precision: int) -> str: + """Fallback formatting using rounded decimals with ``precision`` significant digits.""" + unit = _imaginary_unit(output) + if abs(z.imag) <= _DEFAULT_ATOL: + return f"{z.real:.{precision}g}" + if abs(z.real) <= _DEFAULT_ATOL: + return f"{z.imag:.{precision}g}{unit}" + return f"{z.real:.{precision}g}{z.imag:+.{precision}g}{unit}" + + +def complex_to_str( + value: object, + output: OutputFormat, + *, + max_denominator: int = _DEFAULT_MAX_DENOMINATOR, + atol: float = _DEFAULT_ATOL, + precision: int = _DEFAULT_PRECISION, +) -> str: + r"""Return a human-friendly string representation of a complex number. + + Common values are rendered exactly rather than as floating-point numbers: + fractions (``0.25`` → ``1/4``), square roots (``0.7071…`` → ``√2/2``) and + complex exponentials (``0.5 + 0.866…j`` → ``e^(iπ/3)``). Values that are not + recognized fall back to a rounded decimal representation, and inputs that + cannot be interpreted as complex numbers (e.g. symbolic parameters) are + returned via :func:`str`. + + Parameters + ---------- + value : object + The number to format. Anything supporting conversion to ``complex`` is + accepted; other objects are stringified. + output : OutputFormat + Desired formatting style: ``Unicode`` (``√``, ``π``), ``LaTeX`` + (``\sqrt``, ``\pi``) or ``ASCII`` (``sqrt``, ``pi``). + max_denominator : int, optional + Maximum denominator used when recognizing rational magnitudes and phases + (default: ``1000``). + atol : float, optional + Absolute tolerance for the recognition heuristics (default: ``1e-9``). + precision : int, optional + Number of significant digits to use for the decimal fallback when a + value is not recognized as an exact form (default: ``4``). + + Returns + ------- + str + The formatted complex number. + + Examples + -------- + >>> complex_to_str(0.25, OutputFormat.ASCII) + '1/4' + >>> complex_to_str(2**-0.5, OutputFormat.Unicode) + '√2/2' + >>> complex_to_str(0.5 + 0.8660254037844386j, OutputFormat.Unicode) + 'e^(iπ/3)' + >>> complex_to_str(0.123456 + 0.234567j, OutputFormat.ASCII, precision=2) + '0.12+0.23i' + """ + if not isinstance(value, (bool, int, float, complex, SupportsComplex)): + return str(value) + z = complex(value) + if abs(z.real) <= atol and abs(z.imag) <= atol: + return "0" + if abs(z.imag) <= atol: + return _real_to_str(z.real, output, max_denominator, atol) or _decimal_to_str(z, output, precision) + if abs(z.real) <= atol: + return _imaginary_to_str(z.imag, output, max_denominator, atol) or _decimal_to_str(z, output, precision) + exponential = _exponential_to_str(z, output, max_denominator, atol) + if exponential is not None: + return exponential + cartesian = _cartesian_to_str(z.real, z.imag, output, max_denominator, atol) + if cartesian is not None: + return cartesian + return _decimal_to_str(z, output, precision) + + +def _ket_str(ket: str, output: OutputFormat) -> str: + if output == OutputFormat.LaTeX: + return rf"\lvert {ket}\rangle" + if output == OutputFormat.Unicode: + return f"|{ket}⟩" + return f"|{ket}>" + + +def _needs_parentheses(coefficient: str) -> bool: + """Whether a coefficient is a sum and must be parenthesized before a ket.""" + return " + " in coefficient or " - " in coefficient + + +def statevec_to_str( + statevec: Statevec, + output: OutputFormat, + *, + encoding: _ENCODING = "MSB", + max_denominator: int = _DEFAULT_MAX_DENOMINATOR, + atol: float = _DEFAULT_ATOL, + rtol: float = 0.0, + precision: int = _DEFAULT_PRECISION, +) -> str: + r"""Return a ket-notation string representation of a statevector. + + Amplitudes close to zero are omitted (see :meth:`graphix.sim.statevec.Statevec.to_dict`) + and the remaining ones are pretty-printed with :func:`complex_to_str`. + + Parameters + ---------- + statevec : Statevec + The statevector to format. + output : OutputFormat + Desired formatting style (``ASCII``, ``LaTeX`` or ``Unicode``). + encoding : {"LSB", "MSB"}, optional + Bit-ordering convention for the basis kets (default: ``"MSB"``). + See :meth:`graphix.sim.statevec.Statevec.to_dict`. + max_denominator : int, optional + Maximum denominator used by the amplitude recognition (default: ``1000``). + atol : float, optional + Absolute tolerance used both to drop near-zero amplitudes and for the + recognition heuristics (default: ``1e-9``). + rtol : float, optional + Relative tolerance used to drop near-zero amplitudes (default: ``0.0``). + precision : int, optional + Number of significant digits to use for amplitudes that fall back to a + decimal representation (default: ``4``). + + Returns + ------- + str + The formatted statevector, e.g. ``√2/2|00⟩ + √2/2|01⟩``. + """ + amplitudes = statevec.to_dict(encoding, rtol=rtol, atol=atol) + if not amplitudes: + return "0" + result = "" + for index, (ket, amplitude) in enumerate(amplitudes.items()): + coefficient = complex_to_str(amplitude, output, max_denominator=max_denominator, atol=atol, precision=precision) + ket_str = _ket_str(ket, output) + if coefficient == "1": + term = ket_str + elif coefficient == "-1": + term = f"-{ket_str}" + elif _needs_parentheses(coefficient): + term = f"({coefficient}){ket_str}" + else: + term = f"{coefficient}{ket_str}" + if index == 0: + result = term + elif term.startswith("-"): + result += f" - {term[1:]}" + else: + result += f" + {term}" + return result + + +def density_matrix_to_str( + density_matrix: DensityMatrix, + output: OutputFormat, + *, + max_denominator: int = _DEFAULT_MAX_DENOMINATOR, + atol: float = _DEFAULT_ATOL, + precision: int = _DEFAULT_PRECISION, +) -> str: + r"""Return a matrix-form string representation of a density matrix. + + Each entry is pretty-printed with :func:`complex_to_str`. ``LaTeX`` output + uses a ``pmatrix`` environment; ``ASCII`` and ``Unicode`` outputs produce a + column-aligned grid. + + Parameters + ---------- + density_matrix : DensityMatrix + The density matrix to format. + output : OutputFormat + Desired formatting style (``ASCII``, ``LaTeX`` or ``Unicode``). + max_denominator : int, optional + Maximum denominator used by the entry recognition (default: ``1000``). + atol : float, optional + Absolute tolerance for the recognition heuristics (default: ``1e-9``). + precision : int, optional + Number of significant digits to use for entries that fall back to a + decimal representation (default: ``4``). + + Returns + ------- + str + The formatted density matrix. + """ + rows = [ + [ + complex_to_str(entry, output, max_denominator=max_denominator, atol=atol, precision=precision) + for entry in row + ] + for row in density_matrix.rho + ] + if output == OutputFormat.LaTeX: + body = r" \\ ".join(" & ".join(row) for row in rows) + return rf"\begin{{pmatrix}}{body}\end{{pmatrix}}" + widths = [max(len(row[col]) for row in rows) for col in range(len(rows[0]))] if rows else [] + lines = [" ".join(entry.rjust(widths[col]) for col, entry in enumerate(row)) for row in rows] + return "\n".join(f"[ {line} ]" for line in lines) diff --git a/graphix/sim/density_matrix.py b/graphix/sim/density_matrix.py index cb5570e2a..5b5a8f283 100644 --- a/graphix/sim/density_matrix.py +++ b/graphix/sim/density_matrix.py @@ -19,6 +19,7 @@ from graphix import parameter from graphix.channels import KrausChannel from graphix.parameter import Expression, ExpressionOrFloat, ExpressionOrSupportsComplex +from graphix.pretty_print import OutputFormat, density_matrix_to_str from graphix.sim.base_backend import DenseState, DenseStateBackend, Matrix, kron, matmul, outer, tensordot, vdot from graphix.sim.statevec import CNOT_TENSOR, CZ_TENSOR, SWAP_TENSOR, Statevec from graphix.states import BasicStates, State @@ -117,6 +118,39 @@ def __str__(self) -> str: """Return a string description.""" return f"DensityMatrix object, with density matrix {self.rho} and shape {self.dims()}." + def draw( + self, + output: OutputFormat = OutputFormat.Unicode, + *, + max_denominator: int = 1000, + atol: float = 1e-9, + precision: int = 4, + ) -> str: + r"""Return a pretty-printed matrix representation of the density matrix. + + Each entry is rendered with :func:`graphix.pretty_print.complex_to_str`, + so common values appear as exact expressions (e.g. ``1/2``) rather than + floating-point numbers. + + Parameters + ---------- + output : OutputFormat, optional + Desired formatting style. Defaults to :attr:`OutputFormat.Unicode`. + max_denominator : int, optional + Maximum denominator used by the entry recognition (default: ``1000``). + atol : float, optional + Absolute tolerance for the recognition heuristics (default: ``1e-9``). + precision : int, optional + Number of significant digits to use for entries that fall back to a + decimal representation (default: ``4``). + + Returns + ------- + str + The formatted density matrix. + """ + return density_matrix_to_str(self, output, max_denominator=max_denominator, atol=atol, precision=precision) + @override def add_nodes(self, nqubit: int, data: Data) -> None: r""" diff --git a/graphix/sim/statevec.py b/graphix/sim/statevec.py index 1f7dd23e6..cdfafeca5 100644 --- a/graphix/sim/statevec.py +++ b/graphix/sim/statevec.py @@ -16,6 +16,7 @@ from graphix import parameter, states from graphix.parameter import Expression, ExpressionOrSupportsComplex, check_expression_or_float +from graphix.pretty_print import OutputFormat, statevec_to_str from graphix.sim.base_backend import DenseState, DenseStateBackend, Matrix, kron, tensordot from graphix.states import BasicStates @@ -520,6 +521,64 @@ def to_prob_dict( """ return self._to_dict_map(lambda x: np.abs(x) ** 2, encoding, rtol=rtol, atol=atol) + def draw( + self, + output: OutputFormat = OutputFormat.Unicode, + *, + encoding: _ENCODING = "MSB", + max_denominator: int = 1000, + atol: float = 1e-9, + rtol: float = 0.0, + precision: int = 4, + ) -> str: + r"""Return a pretty-printed ket-notation representation of the statevector. + + Amplitudes are rendered with :func:`graphix.pretty_print.complex_to_str`, + so common values appear as exact expressions (e.g. ``√2/2``) rather than + floating-point numbers. + + Parameters + ---------- + output : OutputFormat, optional + Desired formatting style. Defaults to :attr:`OutputFormat.Unicode`. + encoding : {"LSB", "MSB"}, optional + Bit-ordering convention for the basis kets (default: ``"MSB"``). + See :meth:`to_dict`. + max_denominator : int, optional + Maximum denominator used by the amplitude recognition (default: ``1000``). + atol : float, optional + Absolute tolerance for dropping near-zero amplitudes and for the + recognition heuristics (default: ``1e-9``). + rtol : float, optional + Relative tolerance for dropping near-zero amplitudes (default: ``0.0``). + precision : int, optional + Number of significant digits to use for amplitudes that fall back to + a decimal representation (default: ``4``). + + Returns + ------- + str + The formatted statevector. + + Examples + -------- + >>> from graphix.transpiler import Circuit + >>> circuit = Circuit(2) + >>> circuit.h(0) + >>> circuit.cz(0, 1) + >>> print(circuit.simulate_statevector().statevec.draw()) + √2/2|00⟩ + √2/2|01⟩ + """ + return statevec_to_str( + self, + output, + encoding=encoding, + max_denominator=max_denominator, + atol=atol, + rtol=rtol, + precision=precision, + ) + def _to_dict_map( self, f: Callable[[npt.NDArray[np.object_ | np.complex128]], npt.NDArray[_ScalarT]], diff --git a/tests/test_pretty_print.py b/tests/test_pretty_print.py index 28dedce55..587313d9e 100644 --- a/tests/test_pretty_print.py +++ b/tests/test_pretty_print.py @@ -1,8 +1,10 @@ from __future__ import annotations +import math from typing import TYPE_CHECKING import networkx as nx +import numpy as np import pytest from numpy.random import PCG64, Generator @@ -13,8 +15,11 @@ from graphix.opengraph import OpenGraph from graphix.parameter import Placeholder from graphix.pattern import Pattern -from graphix.pretty_print import OutputFormat, pattern_to_str +from graphix.pretty_print import OutputFormat, complex_to_str, pattern_to_str from graphix.random_objects import rand_circuit +from graphix.sim.density_matrix import DensityMatrix +from graphix.sim.statevec import Statevec +from graphix.states import BasicStates from graphix.transpiler import Circuit if TYPE_CHECKING: @@ -202,3 +207,121 @@ def test_xzcorr_str() -> None: str(flow) == "x(3) = {5}, x(4) = {6}, x(1) = {3}, x(2) = {4}; z(1) = {4, 5}, z(2) = {3, 6}; {1, 2} < {3, 4} < {5, 6}" ) + + +def test_complex_to_str_issue_examples() -> None: + # The three canonical examples from the issue. + assert complex_to_str(0.25, OutputFormat.ASCII) == "1/4" + assert complex_to_str(2**-0.5, OutputFormat.Unicode) == "√2/2" + assert complex_to_str(0.5 + math.sqrt(3) / 2 * 1j, OutputFormat.Unicode) == "e^(iπ/3)" + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + (0, "0"), + (1e-12, "0"), + (1, "1"), + (-1, "-1"), + (2, "2"), + (0.5, "1/2"), + (-0.25, "-1/4"), + (2**-0.5, "√2/2"), + (math.sqrt(3) / 2, "√3/2"), + (1j, "i"), + (-1j, "-i"), + (0.5j, "1/2i"), + (-(2**-0.5) * 1j, "-√2/2i"), + ], +) +def test_complex_to_str_unicode_values(value: complex, expected: str) -> None: + assert complex_to_str(value, OutputFormat.Unicode) == expected + + +def test_complex_to_str_exponentials() -> None: + assert complex_to_str(1j, OutputFormat.Unicode) == "i" + assert complex_to_str(math.cos(math.pi / 4) + math.sin(math.pi / 4) * 1j, OutputFormat.Unicode) == "e^(iπ/4)" + # Negative phase keeps the sign inside the exponent. + assert complex_to_str(math.cos(math.pi / 3) - math.sin(math.pi / 3) * 1j, OutputFormat.Unicode) == "e^(-iπ/3)" + assert complex_to_str(0.5 + math.sqrt(3) / 2 * 1j, OutputFormat.ASCII) == "e^(i*pi/3)" + + +def test_complex_to_str_latex() -> None: + assert complex_to_str(2**-0.5, OutputFormat.LaTeX) == r"\frac{\sqrt{2}}{2}" + assert complex_to_str(0.25, OutputFormat.LaTeX) == r"\frac{1}{4}" + assert complex_to_str(0.5 + math.sqrt(3) / 2 * 1j, OutputFormat.LaTeX) == r"\mathrm{e}^{\mathrm{i} \frac{\pi}{3}}" + + +def test_complex_to_str_fallback_and_symbolic() -> None: + # An unrecognized value falls back to a rounded decimal. + assert complex_to_str(0.123456, OutputFormat.ASCII) == "0.1235" + # A non-numeric object is stringified rather than raising. + assert complex_to_str("alpha", OutputFormat.ASCII) == "alpha" + + +def test_statevec_draw() -> None: + bell = Statevec([2**-0.5, 0, 0, 2**-0.5]) + assert bell.draw(OutputFormat.Unicode) == "√2/2|00⟩ + √2/2|11⟩" + assert bell.draw(OutputFormat.ASCII) == "sqrt(2)/2|00> + sqrt(2)/2|11>" + + +def test_statevec_draw_single_basis_state() -> None: + state = Statevec(data=[BasicStates.ZERO, BasicStates.ONE]) + assert state.draw(OutputFormat.Unicode) == "|01⟩" + # LSB encoding reverses the ket label. + assert state.draw(OutputFormat.Unicode, encoding="LSB") == "|10⟩" + + +def test_density_matrix_draw() -> None: + dm = DensityMatrix(data=[BasicStates.ZERO]) + assert dm.draw(OutputFormat.ASCII) == "[ 1 0 ]\n[ 0 0 ]" + assert dm.draw(OutputFormat.LaTeX) == r"\begin{pmatrix}1 & 0 \\ 0 & 0\end{pmatrix}" + + +def test_complex_to_str_exponential_with_radius() -> None: + # |z| != 1: the radius prefixes the exponential form (1 + i = √2 e^{iπ/4}). + assert complex_to_str(1 + 1j, OutputFormat.Unicode) == "√2·e^(iπ/4)" + assert complex_to_str(1 + 1j, OutputFormat.ASCII) == "sqrt(2)*e^(i*pi/4)" + assert complex_to_str(1 + 1j, OutputFormat.LaTeX) == r"\sqrt{2} \mathrm{e}^{\mathrm{i} \frac{\pi}{4}}" + + +def test_complex_to_str_cartesian_form() -> None: + # Both parts are recognized but the phase is not a simple fraction of π, so the + # cartesian form is used instead of the exponential one. + assert complex_to_str(0.5 + 0.25j, OutputFormat.Unicode) == "1/2 + 1/4i" + assert complex_to_str(0.5 + 0.25j, OutputFormat.LaTeX) == r"\frac{1}{2} + \frac{1}{4}\mathrm{i}" + + +def test_complex_to_str_complex_decimal_fallback() -> None: + # Neither part is a recognized value -> rounded decimal real and imaginary parts. + assert complex_to_str(0.123456 + 0.234567j, OutputFormat.Unicode) == "0.1235+0.2346i" + + +def test_complex_to_str_imaginary_formats() -> None: + assert complex_to_str(0.5j, OutputFormat.LaTeX) == r"\frac{1}{2}\mathrm{i}" + assert complex_to_str(0.5j, OutputFormat.ASCII) == "1/2i" + + +def test_complex_to_str_integer_times_sqrt() -> None: + assert complex_to_str(math.sqrt(12), OutputFormat.Unicode) == "2√3" + + +def test_statevec_draw_negative_and_parenthesized() -> None: + # Negative amplitudes use a `-` separator between terms. + neg = Statevec([0.5, -0.5, 0.5, 0.5]) + assert neg.draw(OutputFormat.Unicode) == "1/2|00⟩ - 1/2|01⟩ + 1/2|10⟩ + 1/2|11⟩" + # A compound (cartesian) amplitude is parenthesized before the ket. Build from a + # numpy array so the amplitudes are ``numpy.complex128`` (Python's ``complex`` only + # gained ``__complex__`` in 3.11, so a bare ``complex`` is rejected on 3.10). + binomial = Statevec(np.array([0.5 + 0.25j, (1 - abs(0.5 + 0.25j) ** 2) ** 0.5])) + assert binomial.draw(OutputFormat.Unicode) == "(1/2 + 1/4i)|0⟩ + √11/4|1⟩" + # A unit negative amplitude collapses to a bare `-|ket⟩`. + assert Statevec([-1.0, 0.0]).draw(OutputFormat.Unicode) == "-|0⟩" + + +def test_complex_to_str_precision_is_configurable() -> None: + z = 0.123456 + 0.234567j + assert complex_to_str(z, OutputFormat.ASCII, precision=2) == "0.12+0.23i" + assert complex_to_str(z, OutputFormat.ASCII, precision=6) == "0.123456+0.234567i" + # The default keeps the previous behaviour (four significant digits). + assert complex_to_str(z, OutputFormat.ASCII) == "0.1235+0.2346i"