'
@@ -221,12 +261,9 @@ def write_specs_to_html(
representation = writer_context.representation.value
if isinstance(representation, HTMLRepresentation):
- _stream = StringIO('')
+ _stream = StringIO("")
- representation.to_html(
- out=_stream,
- context=writer_context.context
- )
+ representation.to_html(out=_stream, context=writer_context.context)
s += f"""
@@ -239,36 +276,82 @@ def write_specs_to_html(
@dataclass(frozen=True, slots=True)
class _ElementInTree:
+ """
+ Represents an element within a tree structure.
+
+ This class holds information about an element, its index, children, and
+ potential subelement type. It provides functionality to create copies of the element.
+
+ Attributes:
+ element: The actual element being represented.
+ element_index: The index of the element in its parent's list of children.
+ children: A list of child _ElementInTree objects.
+ subelement_type: The type of subelement, if applicable.
+
+ Methods:
+ create_copy(subelement_type): Creates a copy of the current element.
+ """
+
element: Specsable
element_index: int
- children: list['_ElementInTree'] = field(default_factory=list)
+ children: list["_ElementInTree"] = field(default_factory=list)
subelement_type: str | None = None
- def create_copy(
- self, subelement_type: str | None
- ) -> '_ElementInTree':
+ def create_copy(self, subelement_type: str | None) -> "_ElementInTree":
return _ElementInTree(
element=self.element,
element_index=self.element_index,
children=self.children,
- subelement_type=subelement_type
+ subelement_type=subelement_type,
)
class _ElementsIterator:
+ """
+ Iterates through a collection of iterables, yielding elements with their index and writer context.
+
+ This class provides an iterator that handles multiple iterables, subelements, and ensures
+ each element is written only once while building a tree structure of the iterated elements.
+ """
+
def __init__(self, *iterables: Specsable, directory: str | Path) -> None:
+ """
+ Initializes a new instance of the class.
+
+ Args:
+ *iterables: The iterables to process. Can be multiple.
+ directory: The directory associated with the iterables.
+
+ Returns:
+ None
+ """
self.iterables = tuple(iterables)
self.directory = directory
self._iterated: dict[int, _ElementInTree] = {}
self._tree: list[_ElementInTree] | None = None
def __iter__(
- self
+ self,
) -> Generator[tuple[int, Specsable, _WriterContextGenerator], None, None]:
+ """
+ Iterates through the specsables and yields them with their index and writer context.
+
+ This method traverses a collection of specsables, handling subelements and ensuring
+ each element is written only once. It maintains an internal state to track iterated
+ elements and builds a tree structure as it iterates.
+
+ Args:
+ None
+
+ Returns:
+ Generator[tuple[int, Specsable, _WriterContextGenerator]]: A generator that yields tuples
+ containing the index of the element, the element itself (Specsable or SubelementSpecs),
+ and a writer context generator for the element.
+ """
def f(
specsables: Iterable[Specsable | SubelementSpecs],
- parent_children: list[_ElementInTree]
+ parent_children: list[_ElementInTree],
):
for element in specsables:
@@ -298,9 +381,7 @@ def f(
# Create a new tree element
element_in_tree = _ElementInTree(
- element,
- index,
- subelement_type=element_name
+ element, index, subelement_type=element_name
)
self._iterated[id(element)] = element_in_tree
parent_children.append(element_in_tree)
@@ -334,16 +415,26 @@ def write_elements_tree_to_str(
tree: list[_ElementInTree],
stream: TextIO,
):
- stream.write('\n\nTree:\n')
+ """
+ Writes the elements tree to a string stream.
+
+ Args:
+ tree: The list of root elements in the tree.
+ stream: The output stream to write to.
+
+ Returns:
+ None
+ """
+ stream.write("\n\nTree:\n")
def _write_element(tree_level: int, element: _ElementInTree):
- stream.write(' ' * (8 * tree_level))
+ stream.write(" " * (8 * tree_level))
element_name = element.element.__class__.__name__
- indexed_name = f'({element.element_index}) {element_name}'
+ indexed_name = f"({element.element_index}) {element_name}"
if element.subelement_type is not None:
- stream.write(f'[{element.subelement_type}] ')
- stream.write(f'{indexed_name}\n')
+ stream.write(f"[{element.subelement_type}] ")
+ stream.write(f"{indexed_name}\n")
for subelement in element.children:
_write_element(tree_level + 1, subelement)
@@ -356,16 +447,26 @@ def write_elements_tree_to_markdown(
tree: list[_ElementInTree],
stream: TextIO,
):
- stream.write('\n\n# Tree:\n')
+ """
+ Writes the elements tree to a markdown formatted string.
+
+ Args:
+ tree: The list of root elements in the tree.
+ stream: A file-like object to write the markdown output to.
+
+ Returns:
+ None
+ """
+ stream.write("\n\n# Tree:\n")
def _write_element(tree_level: int, element: _ElementInTree):
- stream.write(' ' * (4 * tree_level) + '* ')
+ stream.write(" " * (4 * tree_level) + "* ")
element_name = element.element.__class__.__name__
- indexed_name = f'`({element.element_index}) {element_name}`'
+ indexed_name = f"`({element.element_index}) {element_name}`"
if element.subelement_type is not None:
- stream.write(f'[{element.subelement_type}] ')
- stream.write(f'{indexed_name}\n')
+ stream.write(f"[{element.subelement_type}] ")
+ stream.write(f"{indexed_name}\n")
for subelement in element.children:
_write_element(tree_level + 1, subelement)
@@ -376,39 +477,47 @@ def _write_element(tree_level: int, element: _ElementInTree):
def write_specs(
*iterables: Specsable,
- filename: str = 'specs.txt',
- directory: str | Path = 'specs',
+ filename: str = "specs.txt",
+ directory: str | Path = "specs",
):
+ """
+ Writes specifications from iterables to a file.
+
+ Creates a directory if it doesn't exist and writes the specs to either a
+ text or markdown file based on the filename extension.
+
+ Args:
+ *iterables: One or more iterable objects containing specification data.
+ filename: The name of the output file (e.g., 'specs.txt' or 'specs.md').
+ directory: The directory to write the file to. Defaults to 'specs'.
+
+ Returns:
+ _ElementsIterator: An iterator object representing the written elements and their tree structure.
+ """
Path.mkdir(Path(directory), parents=True, exist_ok=True)
path = Path(directory, filename)
elements = _ElementsIterator(*iterables, directory=directory)
- with open(path, 'w') as file:
- if filename.endswith('.txt'):
+ with open(path, "w") as file:
+ if filename.endswith(".txt"):
for elemennt_index, element, writer_context_generator in elements:
write_specs_to_str(
element=element,
element_index=elemennt_index,
writer_context_generator=writer_context_generator,
- stream=file
+ stream=file,
)
- write_elements_tree_to_str(
- tree=elements.tree,
- stream=file
- )
- elif filename.endswith('.md'):
+ write_elements_tree_to_str(tree=elements.tree, stream=file)
+ elif filename.endswith(".md"):
for elemennt_index, element, writer_context_generator in elements:
write_specs_to_markdown(
element=element,
element_index=elemennt_index,
writer_context_generator=writer_context_generator,
- stream=file
+ stream=file,
)
- write_elements_tree_to_markdown(
- tree=elements.tree,
- stream=file
- )
+ write_elements_tree_to_markdown(tree=elements.tree, stream=file)
else:
raise ValueError(
"Unknown file extension. ' \
diff --git a/svetlanna/transforms.py b/svetlanna/transforms.py
index 804fd39..03af916 100644
--- a/svetlanna/transforms.py
+++ b/svetlanna/transforms.py
@@ -17,6 +17,7 @@ class ToWavefront(nn.Module):
(3) modulation_type='amp&phase' (any other str)
tensor values transforms to amplitude and phase simultaneously
"""
+
def __init__(self, modulation_type=None):
"""
Parameters
@@ -47,26 +48,32 @@ def forward(self, img_tensor: torch.Tensor) -> Wavefront:
# creation of a wavefront based on an image
if img_tensor.size()[0] == 1: # only one channel
# squeeze 0th channel dimension of image tensor
- normalized_tensor = torch.squeeze(img_tensor, 0) # values from 0 to 1, shape=[H, W]
+ normalized_tensor = torch.squeeze(
+ img_tensor, 0
+ ) # values from 0 to 1, shape=[H, W]
else: # more than 1 color channels
normalized_tensor = img_tensor # values from 0 to 1, shape=[C, H, W]
# TODO: check that in simulation parameters we have the same number of wavelengths?
- if self.modulation_type == 'amp': # amplitude modulation
+ if self.modulation_type == "amp": # amplitude modulation
amplitudes = normalized_tensor
phases = torch.zeros(size=normalized_tensor.size())
else:
# image -> phases from -pi + eps to pi - eps
normalized_tensor_fix = normalized_tensor
- normalized_tensor_fix[normalized_tensor_fix == 1.] -= self.eps # maximal values - eps
- normalized_tensor_fix[normalized_tensor_fix == 0.] += self.eps # 0 + eps
+ normalized_tensor_fix[
+ normalized_tensor_fix == 1.0
+ ] -= self.eps # maximal values - eps
+ normalized_tensor_fix[normalized_tensor_fix == 0.0] += self.eps # 0 + eps
# [0, 1] --> [-pi + eps, pi - eps]
phases = normalized_tensor_fix * 2 * torch.pi - torch.pi
- if self.modulation_type == 'phase': # phase modulation
+ if self.modulation_type == "phase": # phase modulation
# TODO: What is with an amplitude?
- amplitudes = torch.ones(size=normalized_tensor.size()) # constant amplitude
+ amplitudes = torch.ones(
+ size=normalized_tensor.size()
+ ) # constant amplitude
else: # phase AND amplitude modulation 'amp&phase'
amplitudes = normalized_tensor
@@ -80,11 +87,9 @@ class GaussModulation(nn.Module):
"""
Multiplies an amplitude of a Wavefront on a gaussian.
"""
+
def __init__(
- self,
- sim_params: SimulationParameters,
- fwhm_x, fwhm_y,
- peak_x=0., peak_y=0.
+ self, sim_params: SimulationParameters, fwhm_x, fwhm_y, peak_x=0.0, peak_y=0.0
):
"""
Parameters
@@ -113,12 +118,13 @@ def get_gauss(self):
gauss_2d : torch.Tensor
A gaussian distribution in a 2D plane.
"""
- x_grid, y_grid = self.sim_params.meshgrid(x_axis='W', y_axis='H')
+ x_grid, y_grid = self.sim_params.meshgrid(x_axis="W", y_axis="H")
gauss_2d = 1 * torch.exp(
- -1 * (
- (x_grid - self.peak_x) ** 2 / 2 / self.sigma_x ** 2 +
- (y_grid - self.peak_y) ** 2 / 2 / self.sigma_y ** 2
+ -1
+ * (
+ (x_grid - self.peak_x) ** 2 / 2 / self.sigma_x**2
+ + (y_grid - self.peak_y) ** 2 / 2 / self.sigma_y**2
)
)
return gauss_2d
@@ -138,11 +144,11 @@ def forward(self, wf: Wavefront) -> Wavefront:
wf_gauss : Wavefront
A gaussian distribution in a 2D plane.
"""
- sim_nodes_shape = self.sim_params.axes_size(axs=('H', 'W')) # [H, W]
+ sim_nodes_shape = self.sim_params.axes_size(axs=("H", "W")) # [H, W]
if not wf.size()[-2:] == sim_nodes_shape:
warnings.warn(
- message='A shape of an input Wavefront does not match with SimulationParameters! Gauss was not applied!'
+ message="A shape of an input Wavefront does not match with SimulationParameters! Gauss was not applied!"
)
wf_gauss = wf
else:
diff --git a/svetlanna/units.py b/svetlanna/units.py
index 31e25e8..15ff192 100644
--- a/svetlanna/units.py
+++ b/svetlanna/units.py
@@ -22,6 +22,7 @@ class ureg(Enum):
var = 10
assert var * ureg.mm == 10*1e-2
"""
+
Gm = _G
Mm = _M
km = _k
@@ -57,24 +58,83 @@ class ureg(Enum):
pHz = _p
def __mul__(self, other):
+ """
+ Multiplies the value of this object by another number.
+
+ Args:
+ other: The number to multiply this object's value by.
+
+ Returns:
+ float: The result of multiplying the object's value by the other number.
+ """
return self.value * other
def __rmul__(self, other):
+ """
+ Returns the result of multiplying 'other' by the value.
+
+ This method enables multiplication with this object on either side
+ (e.g., `2 * MyObject` or `MyObject * 2`). It leverages Python's
+ multiplication operator to achieve this.
+
+ Args:
+ other: The value to multiply by the object's value.
+
+ Returns:
+ The result of the multiplication.
+ """
return other * self.value
def __truediv__(self, other):
+ """
+ Divides the value of this object by another.
+
+ Args:
+ other: The number to divide this object's value by.
+
+ Returns:
+ float: The result of dividing this object's value by the given number.
+ """
return self.value / other
def __rtruediv__(self, other):
+ """
+ Divides another number by the value of this instance.
+
+ Args:
+ other: The number to be divided by the instance's value.
+
+ Returns:
+ float: The result of dividing `other` by the instance's `value`.
+ """
return other / self.value
def __pow__(self, other):
- return self.value ** other
+ """
+ Calculates the power of this value.
+
+ Args:
+ other: The exponent to raise the value to.
+
+ Returns:
+ float: The result of raising the value to the power of 'other'.
+ """
+ return self.value**other
def __array__(self, dtype=None, copy=None):
+ """
+ Returns an array representation of the value.
+
+ Args:
+ dtype: The desired data type of the returned array.
+ copy: Whether to allocate a copy of the underlying data.
+
+ Returns:
+ numpy.ndarray: A NumPy array containing the values. A copy is always created,
+ so attempting to set `copy=False` will raise a ValueError.
+ """
import numpy
+
if copy is False:
- raise ValueError(
- "`copy=False` isn't supported. A copy is always created."
- )
+ raise ValueError("`copy=False` isn't supported. A copy is always created.")
return numpy.array(self.value, dtype=dtype)
diff --git a/svetlanna/visualization/__init__.py b/svetlanna/visualization/__init__.py
index f688bf3..80d1c29 100644
--- a/svetlanna/visualization/__init__.py
+++ b/svetlanna/visualization/__init__.py
@@ -2,9 +2,9 @@
from .widgets import jinja_env, ElementHTML
__all__ = [
- 'show_specs',
- 'show_structure',
- 'show_stepwise_forward',
- 'jinja_env',
- 'ElementHTML'
+ "show_specs",
+ "show_structure",
+ "show_stepwise_forward",
+ "jinja_env",
+ "ElementHTML",
]
diff --git a/svetlanna/visualization/widgets.py b/svetlanna/visualization/widgets.py
index 75d85f8..c9e547f 100644
--- a/svetlanna/visualization/widgets.py
+++ b/svetlanna/visualization/widgets.py
@@ -15,52 +15,51 @@
import base64
-STATIC_FOLDER = pathlib.Path(__file__).parent / 'static'
-TEMPLATES_FOLDER = pathlib.Path(__file__).parent / 'templates'
+STATIC_FOLDER = pathlib.Path(__file__).parent / "static"
+TEMPLATES_FOLDER = pathlib.Path(__file__).parent / "templates"
jinja_env = Environment(
- loader=FileSystemLoader(TEMPLATES_FOLDER),
- autoescape=select_autoescape()
+ loader=FileSystemLoader(TEMPLATES_FOLDER), autoescape=select_autoescape()
)
StepwisePlotTypes = Union[
- Literal['A'],
- Literal['I'],
- Literal['phase'],
- Literal['Re'],
- Literal['Im']
+ Literal["A"], Literal["I"], Literal["phase"], Literal["Re"], Literal["Im"]
]
class StepwiseForwardWidget(anywidget.AnyWidget):
- _esm = STATIC_FOLDER / 'stepwise_forward_widget.js'
- _css = STATIC_FOLDER / 'setup_widget.css'
+ """
+ A widget for stepwise forward selection visualization."""
+
+ _esm = STATIC_FOLDER / "stepwise_forward_widget.js"
+ _css = STATIC_FOLDER / "setup_widget.css"
elements = traitlets.List([]).tag(sync=True)
- structure_html = traitlets.Unicode('').tag(sync=True)
+ structure_html = traitlets.Unicode("").tag(sync=True)
class SpecsWidget(anywidget.AnyWidget):
- _esm = STATIC_FOLDER / 'specs_widget.js'
- _css = STATIC_FOLDER / 'setup_widget.css'
+ """
+ A widget for displaying and interacting with specifications."""
+
+ _esm = STATIC_FOLDER / "specs_widget.js"
+ _css = STATIC_FOLDER / "setup_widget.css"
elements = traitlets.List([]).tag(sync=True)
- structure_html = traitlets.Unicode('').tag(sync=True)
+ structure_html = traitlets.Unicode("").tag(sync=True)
@dataclass(frozen=True, slots=True)
class ElementHTML:
"""Representation of an element in HTML format."""
+
element_type: str | None
html: str
def default_widget_html_method(
- index: int,
- name: str,
- element_type: str | None,
- subelements: list[ElementHTML]
+ index: int, name: str, element_type: str | None, subelements: list[ElementHTML]
) -> str:
"""Default `_widget_html_` method used for rendering `Specsable` elements.
@@ -81,14 +80,12 @@ def default_widget_html_method(
str
rendered HTML
"""
- return jinja_env.get_template('widget_default.html.jinja').render(
+ return jinja_env.get_template("widget_default.html.jinja").render(
index=index, name=name, subelements=subelements
)
-def _get_widget_html_method(
- element: Specsable
-) -> Callable[..., str]:
+def _get_widget_html_method(element: Specsable) -> Callable[..., str]:
"""Returns `_widget_html_` method based on type of element.
Parameters
@@ -101,8 +98,8 @@ def _get_widget_html_method(
Any
`_widget_html_` method
"""
- if hasattr(element, '_widget_html_'):
- return getattr(element, '_widget_html_')
+ if hasattr(element, "_widget_html_"):
+ return getattr(element, "_widget_html_")
return default_widget_html_method
@@ -129,15 +126,10 @@ def _subelements_html(subelements: list[_ElementInTree]) -> list[ElementHTML]:
index=subelement.element_index,
name=subelement.element.__class__.__name__,
element_type=subelement.subelement_type,
- subelements=_subelements_html(subelement.children)
+ subelements=_subelements_html(subelement.children),
)
- res.append(
- ElementHTML(
- subelement.subelement_type,
- html=raw_subelement_html
- )
- )
+ res.append(ElementHTML(subelement.subelement_type, html=raw_subelement_html))
return res
@@ -158,9 +150,9 @@ def generate_structure_html(subelements: list[_ElementInTree]) -> str:
elements_html = _subelements_html(subelements)
- return jinja_env.get_template(
- 'widget_structure_container.html.jinja'
- ).render(elements_html=elements_html)
+ return jinja_env.get_template("widget_structure_container.html.jinja").render(
+ elements_html=elements_html
+ )
def show_structure(*specsable: Specsable):
@@ -171,7 +163,7 @@ def show_structure(*specsable: Specsable):
from IPython.display import HTML, display
# Generate HTML
- elements = _ElementsIterator(*specsable, directory='')
+ elements = _ElementsIterator(*specsable, directory="")
structure_html = generate_structure_html(elements.tree)
# Display HTML
@@ -190,22 +182,20 @@ def show_specs(*specsable: Specsable) -> SpecsWidget:
The widget
"""
- elements = _ElementsIterator(*specsable, directory='')
+ elements = _ElementsIterator(*specsable, directory="")
# Prepare elements data for widget
elements_json = []
for element_index, element, writer_context_generator in elements:
- stream = StringIO('')
+ stream = StringIO("")
# Write element's parameter specs to the stream
- write_specs_to_html(
- element, element_index, writer_context_generator, stream
- )
+ write_specs_to_html(element, element_index, writer_context_generator, stream)
elements_json.append(
{
- 'index': element_index,
- 'name': element.__class__.__name__,
- 'specs_html': stream.getvalue()
+ "index": element_index,
+ "name": element.__class__.__name__,
+ "specs_html": stream.getvalue(),
}
)
@@ -213,10 +203,7 @@ def show_specs(*specsable: Specsable) -> SpecsWidget:
structure_html = generate_structure_html(elements.tree)
# Create a widget
- widget = SpecsWidget(
- structure_html=structure_html,
- elements=elements_json
- )
+ widget = SpecsWidget(structure_html=structure_html, elements=elements_json)
return widget
@@ -224,7 +211,7 @@ def show_specs(*specsable: Specsable) -> SpecsWidget:
def draw_wavefront(
wavefront: torch.Tensor,
simulation_parameters: SimulationParameters,
- types_to_plot: tuple[StepwisePlotTypes, ...] = ('I', 'phase')
+ types_to_plot: tuple[StepwisePlotTypes, ...] = ("I", "phase"),
) -> bytes:
"""Show field propagation in the setup via widget.
Currently only wavefronts of shape `(W, H)` are supported.
@@ -253,16 +240,10 @@ def draw_wavefront(
n_plots = len(types_to_plot)
- width_to_height = (
- width.max() - width.min()
- ) / (
- height.max() - height.min()
- )
+ width_to_height = (width.max() - width.min()) / (height.max() - height.min())
figure, ax = plt.subplots(
- 1, n_plots,
- figsize=(2+3*n_plots*width_to_height, 3),
- dpi=120
+ 1, n_plots, figsize=(2 + 3 * n_plots * width_to_height, 3), dpi=120
)
for i, plot_type in enumerate(types_to_plot):
@@ -272,25 +253,17 @@ def draw_wavefront(
axes = ax[i]
axes = cast(Axes, axes)
- if plot_type == 'A':
+ if plot_type == "A":
# Plot the wavefront amplitude
- axes.pcolorfast(
- width,
- height,
- wavefront.abs().numpy(force=True)
- )
- axes.set_title('Amplitude')
+ axes.pcolorfast(width, height, wavefront.abs().numpy(force=True))
+ axes.set_title("Amplitude")
- elif plot_type == 'I':
+ elif plot_type == "I":
# Plot the wavefront intensity
- axes.pcolorfast(
- width,
- height,
- (wavefront.abs()**2).numpy(force=True)
- )
- axes.set_title('Intensity')
+ axes.pcolorfast(width, height, (wavefront.abs() ** 2).numpy(force=True))
+ axes.set_title("Intensity")
- elif plot_type == 'phase':
+ elif plot_type == "phase":
# Plot the wavefront phase
axes.pcolorfast(
width,
@@ -299,27 +272,27 @@ def draw_wavefront(
vmin=-torch.pi,
vmax=torch.pi,
)
- axes.set_title('Phase')
+ axes.set_title("Phase")
- elif plot_type == 'Re':
+ elif plot_type == "Re":
# Plot the wavefront real part
axes.pcolorfast(
width,
height,
wavefront.real.numpy(force=True),
)
- axes.set_title('Real part')
+ axes.set_title("Real part")
- elif plot_type == 'Im':
+ elif plot_type == "Im":
# Plot the wavefront imaginary part
axes.pcolorfast(
width,
height,
wavefront.imag.numpy(force=True),
)
- axes.set_title('Imaginary part')
+ axes.set_title("Imaginary part")
- axes.set_aspect('equal')
+ axes.set_aspect("equal")
plt.tight_layout()
figure.savefig(stream)
@@ -332,7 +305,7 @@ def show_stepwise_forward(
*specsable: Specsable,
input: torch.Tensor,
simulation_parameters: SimulationParameters,
- types_to_plot: tuple[StepwisePlotTypes, ...] = ('I', 'phase')
+ types_to_plot: tuple[StepwisePlotTypes, ...] = ("I", "phase"),
) -> StepwiseForwardWidget:
"""Display the wavefront propagation through a setup structure
using a widget interface. Currently only wavefronts
@@ -354,7 +327,7 @@ def show_stepwise_forward(
"""
elements_to_call = tuple(s for s in specsable)
- elements = _ElementsIterator(*elements_to_call, directory='')
+ elements = _ElementsIterator(*elements_to_call, directory="")
outputs = {}
@@ -397,19 +370,19 @@ def capture_output_hook(module, args, output):
draw_wavefront(
wavefront=outputs[element],
simulation_parameters=simulation_parameters,
- types_to_plot=types_to_plot
+ types_to_plot=types_to_plot,
)
).decode()
except Exception as e:
- output_image = f'\n{e}'
+ output_image = f"\n{e}"
else:
output_image = None
elements_json.append(
{
- 'index': element_index,
- 'name': element.__class__.__name__,
- 'output_image': output_image
+ "index": element_index,
+ "name": element.__class__.__name__,
+ "output_image": output_image,
}
)
@@ -418,8 +391,7 @@ def capture_output_hook(module, args, output):
# Create a widget
widget = StepwiseForwardWidget(
- structure_html=structure_html,
- elements=elements_json
+ structure_html=structure_html, elements=elements_json
)
return widget
diff --git a/svetlanna/wavefront.py b/svetlanna/wavefront.py
index bb0272d..9afea54 100644
--- a/svetlanna/wavefront.py
+++ b/svetlanna/wavefront.py
@@ -6,8 +6,21 @@
class Wavefront(torch.Tensor):
"""Class that represents wavefront"""
+
@staticmethod
def __new__(cls, data, *args, **kwargs):
+ """
+ Creates a new Wavefront object from the given data.
+
+ Args:
+ data: The input data to be converted into a tensor.
+ *args: Variable length argument list.
+ **kwargs: Arbitrary keyword arguments.
+
+ Returns:
+ Wavefront: A new instance of the Wavefront class with the data
+ converted to a PyTorch tensor.
+ """
# see https://github.com/albanD/subclass_zoo/blob/ec47458346c2a1cfcd5e676926a4bbc6709ff62e/base_tensor.py # noqa: E501
data = torch.as_tensor(data)
return super(cls, Wavefront).__new__(cls, data)
@@ -47,10 +60,7 @@ def phase(self) -> torch.Tensor:
res = torch.angle(torch.Tensor(self) + 0.0)
return res
- def fwhm(
- self,
- simulation_parameters: SimulationParameters
- ) -> tuple[float, float]:
+ def fwhm(self, simulation_parameters: SimulationParameters) -> tuple[float, float]:
"""Calculates full width at half maximum of the wavefront
Returns
@@ -79,9 +89,9 @@ def fwhm(
def plane_wave(
cls,
simulation_parameters: SimulationParameters,
- distance: float = 0.,
+ distance: float = 0.0,
wave_direction: Any = None,
- initial_phase: float = 0.
+ initial_phase: float = 0.0,
) -> Self:
"""Generate wavefront of the plane wave
@@ -105,25 +115,21 @@ def plane_wave(
"""
# by default the wave propagates along z direction
if wave_direction is None:
- wave_direction = [0., 0., 1.]
+ wave_direction = [0.0, 0.0, 1.0]
wave_direction = torch.tensor(
- wave_direction,
- dtype=torch.float32,
- device=simulation_parameters.device
+ wave_direction, dtype=torch.float32, device=simulation_parameters.device
)
if wave_direction.shape != torch.Size([3]):
- raise ValueError(
- "wave_direction should contain exactly three components"
- )
+ raise ValueError("wave_direction should contain exactly three components")
wave_direction = wave_direction / torch.norm(wave_direction)
wave_number = 2 * torch.pi / simulation_parameters.axes.wavelength
x = simulation_parameters.axes.W[None, :]
y = simulation_parameters.axes.H[:, None]
- kxx, axes = tensor_dot(wave_number, x, 'wavelength', ('H', 'W'))
- kyy, _ = tensor_dot(wave_number, y, 'wavelength', ('H', 'W'))
+ kxx, axes = tensor_dot(wave_number, x, "wavelength", ("H", "W"))
+ kyy, _ = tensor_dot(wave_number, y, "wavelength", ("H", "W"))
kzz = wave_number[..., None, None] * distance
field = torch.exp(1j * wave_direction[0] * kxx)
@@ -137,9 +143,9 @@ def gaussian_beam(
cls,
simulation_parameters: SimulationParameters,
waist_radius: float,
- distance: float = 0.,
- dx: float = 0.,
- dy: float = 0.,
+ distance: float = 0.0,
+ dx: float = 0.0,
+ dy: float = 0.0,
) -> Self:
"""Generates the Gaussian beam.
@@ -164,15 +170,21 @@ def gaussian_beam(
wave_number = 2 * torch.pi / simulation_parameters.axes.wavelength
- rayleigh_range = torch.pi * (waist_radius**2) / simulation_parameters.axes.wavelength # noqa: E501
+ rayleigh_range = (
+ torch.pi * (waist_radius**2) / simulation_parameters.axes.wavelength
+ ) # noqa: E501
x = simulation_parameters.axes.W[None, :] - dx
y = simulation_parameters.axes.H[:, None] - dy
radial_distance_squared = x**2 + y**2
- hyperbolic_relation = waist_radius * (1 + (distance / rayleigh_range)**2)**(1/2) # noqa: E501
+ hyperbolic_relation = waist_radius * (1 + (distance / rayleigh_range) ** 2) ** (
+ 1 / 2
+ ) # noqa: E501
- inverse_radius_of_curvature = distance / (distance**2 + rayleigh_range**2) # noqa: E501
+ inverse_radius_of_curvature = distance / (
+ distance**2 + rayleigh_range**2
+ ) # noqa: E501
# Gouy phase
gouy_phase = torch.arctan(distance / rayleigh_range)
@@ -180,54 +192,56 @@ def gaussian_beam(
phase1, axes1 = tensor_dot(
a=1j * wave_number * inverse_radius_of_curvature / 2,
b=radial_distance_squared,
- a_axis='wavelength',
- b_axis=('H', 'W')
+ a_axis="wavelength",
+ b_axis=("H", "W"),
)
field = torch.exp(phase1)
field, _ = tensor_dot(
a=field,
b=torch.exp(1j * wave_number * distance),
- a_axis=axes1, b_axis='wavelength', preserve_a_axis=True
+ a_axis=axes1,
+ b_axis="wavelength",
+ preserve_a_axis=True,
)
field, _ = tensor_dot(
a=field,
b=torch.exp(-1j * gouy_phase),
- a_axis=axes1, b_axis='wavelength', preserve_a_axis=True
+ a_axis=axes1,
+ b_axis="wavelength",
+ preserve_a_axis=True,
)
phase2, axes2 = tensor_dot(
- a=-1/(hyperbolic_relation)**2,
+ a=-1 / (hyperbolic_relation) ** 2,
b=radial_distance_squared,
- a_axis='wavelength',
- b_axis=('H', 'W')
+ a_axis="wavelength",
+ b_axis=("H", "W"),
)
field, axes = tensor_dot(
a=field,
b=torch.exp(phase2),
a_axis=axes1,
b_axis=axes2,
- preserve_a_axis=True
+ preserve_a_axis=True,
)
field, _ = tensor_dot(
a=field,
b=waist_radius / hyperbolic_relation,
a_axis=axes,
- b_axis='wavelength',
- preserve_a_axis=True
+ b_axis="wavelength",
+ preserve_a_axis=True,
)
- return cls(
- cast_tensor(field, axes, simulation_parameters.axes.names)
- )
+ return cls(cast_tensor(field, axes, simulation_parameters.axes.names))
@classmethod
def spherical_wave(
cls,
simulation_parameters: SimulationParameters,
distance: float,
- initial_phase: float = 0.,
- dx: float = 0.,
- dy: float = 0.,
+ initial_phase: float = 0.0,
+ dx: float = 0.0,
+ dy: float = 0.0,
) -> Self:
"""Generate wavefront of the spherical wave
@@ -254,22 +268,17 @@ def spherical_wave(
x = simulation_parameters.axes.W[None, :] - dx
y = simulation_parameters.axes.H[:, None] - dy
- radius = torch.sqrt(
- (x**2 + y**2) + distance**2
- )
+ radius = torch.sqrt((x**2 + y**2) + distance**2)
phase, axes = tensor_dot(
- a=wave_number,
- b=radius,
- a_axis='wavelength',
- b_axis=('H', 'W')
+ a=wave_number, b=radius, a_axis="wavelength", b_axis=("H", "W")
)
field, _ = tensor_dot(
a=torch.exp(1j * (phase + initial_phase)),
b=1 / radius,
a_axis=axes,
- b_axis=('H', 'W'),
- preserve_a_axis=True
+ b_axis=("H", "W"),
+ preserve_a_axis=True,
)
return cls(cast_tensor(field, axes, simulation_parameters.axes.names))
@@ -277,30 +286,25 @@ def spherical_wave(
# === methods below are added for typing only ===
if TYPE_CHECKING:
- def __mul__(self, other: Any) -> Self:
- ...
- def __rmul__(self, other: Any) -> Self:
- ...
+ def __mul__(self, other: Any) -> Self: ...
+
+ def __rmul__(self, other: Any) -> Self: ...
- def __add__(self, other: Any) -> Self:
- ...
+ def __add__(self, other: Any) -> Self: ...
- def __radd__(self, other: Any) -> Self:
- ...
+ def __radd__(self, other: Any) -> Self: ...
- def __truediv__(self, other: Any) -> Self:
- ...
+ def __truediv__(self, other: Any) -> Self: ...
- def __rtruediv__(self, other: Any) -> Self:
- ...
+ def __rtruediv__(self, other: Any) -> Self: ...
DEFAULT_LAST_AXES_NAMES = (
# 'pol',
# 'wavelength',
- 'H',
- 'W'
+ "H",
+ "W",
)
@@ -308,7 +312,7 @@ def mul(
wf: Wavefront,
b: Any,
b_axis: str | Iterable[str],
- sim_params: SimulationParameters | None = None
+ sim_params: SimulationParameters | None = None,
) -> Wavefront:
"""Multiplication of the wavefront and tensor.
diff --git a/tests/analytical_solutions.py b/tests/analytical_solutions.py
index 1900900..0436447 100644
--- a/tests/analytical_solutions.py
+++ b/tests/analytical_solutions.py
@@ -8,8 +8,9 @@
class RectangleFresnel:
"""A class describing the analytical solution for the problem of free
- propagation after planar wave passes a rectangular aperture
+ propagation after planar wave passes a rectangular aperture
"""
+
def __init__(
self,
distance: float,
@@ -19,7 +20,7 @@ def __init__(
y_nodes: int,
width: float,
height: float,
- wavelength: torch.Tensor | float
+ wavelength: torch.Tensor | float,
):
"""Constructor method
@@ -71,24 +72,29 @@ def field(self) -> np.ndarray:
x_grid = x_grid[None, :]
y_grid = y_grid[None, :]
- psi1 = -np.sqrt(wave_number/(np.pi*self.distance))*(self.width/2
- + x_grid)
- psi2 = np.sqrt(wave_number/(np.pi*self.distance))*(self.width / 2
- - x_grid)
- eta1 = -np.sqrt(wave_number/(np.pi*self.distance))*(self.height / 2
- + y_grid)
- eta2 = np.sqrt(wave_number/(np.pi*self.distance))*(self.height / 2
- - y_grid)
+ psi1 = -np.sqrt(wave_number / (np.pi * self.distance)) * (
+ self.width / 2 + x_grid
+ )
+ psi2 = np.sqrt(wave_number / (np.pi * self.distance)) * (
+ self.width / 2 - x_grid
+ )
+ eta1 = -np.sqrt(wave_number / (np.pi * self.distance)) * (
+ self.height / 2 + y_grid
+ )
+ eta2 = np.sqrt(wave_number / (np.pi * self.distance)) * (
+ self.height / 2 - y_grid
+ )
s_psi1, c_psi1 = sp.special.fresnel(psi1)
s_psi2, c_psi2 = sp.special.fresnel(psi2)
s_eta1, c_eta1 = sp.special.fresnel(eta1)
s_eta2, c_eta2 = sp.special.fresnel(eta2)
- field = np.exp(1j * wave_number * self.distance) * (1 / 2j) * (
- (c_psi2 - c_psi1) + 1j * (s_psi2 - s_psi1)
- ) * (
- (c_eta2 - c_eta1) + 1j * (s_eta2 - s_eta1)
+ field = (
+ np.exp(1j * wave_number * self.distance)
+ * (1 / 2j)
+ * ((c_psi2 - c_psi1) + 1j * (s_psi2 - s_psi1))
+ * ((c_eta2 - c_eta1) + 1j * (s_eta2 - s_eta1))
)
# intensity = (1/4)*(np.power((c_psi2 - c_psi1), 2) +
@@ -99,13 +105,26 @@ def field(self) -> np.ndarray:
return field
def intensity(self) -> np.ndarray:
+ """
+ Calculates the intensity of the field.
+
+ Returns the squared magnitude of the electric or magnetic field.
+
+ Args:
+ None
+
+ Returns:
+ np.ndarray: The intensity, which is the absolute value of the field
+ squared.
+ """
return np.abs(self.field()) ** 2
class CircleFresnel:
"""A class describing the analytical solution for the problem of free
- propagation after planar wave passes a circular aperture aperture
+ propagation after planar wave passes a circular aperture aperture
"""
+
def __init__(
self,
distance: float,
@@ -115,8 +134,24 @@ def __init__(
y_nodes: int,
radius: float,
wavelength: torch.Tensor | float,
- summation_number: int = 50
+ summation_number: int = 50,
):
+ """
+ Initializes a new instance of the class.
+
+ Args:
+ distance: The distance parameter.
+ x_size: The x size parameter.
+ y_size: The y size parameter.
+ x_nodes: The number of nodes in the x dimension.
+ y_nodes: The number of nodes in the y dimension.
+ radius: The radius parameter.
+ wavelength: The wavelength parameter (can be a torch.Tensor or float).
+ summation_number: The number of summation terms to use, defaults to 50.
+
+ Returns:
+ None
+ """
self.distance = distance
self.x_size = x_size
@@ -128,6 +163,15 @@ def __init__(
self.wavelength = wavelength
def field(self) -> np.ndarray:
+ """
+ Calculates the complex-valued electromagnetic field distribution.
+
+ Args:
+ None
+
+ Returns:
+ np.ndarray: A 2D NumPy array representing the calculated field.
+ """
x_linear = np.linspace(-self.x_size / 2, self.x_size / 2, self.x_nodes)
y_linear = np.linspace(-self.y_size / 2, self.y_size / 2, self.y_nodes)
@@ -145,23 +189,34 @@ def field(self) -> np.ndarray:
series = np.zeros_like(x_grid, dtype=np.complex128)
for n in tqdm(range(self.summation_number)):
- series += ((
- -1j * radius / (self.radius)
- ) ** n) * jv(
- n, 2 * np.pi * self.radius * radius / (self.wavelength * self.distance) # noqa: E501
+ series += ((-1j * radius / (self.radius)) ** n) * jv(
+ n,
+ 2
+ * np.pi
+ * self.radius
+ * radius
+ / (self.wavelength * self.distance), # noqa: E501
)
self.field = np.exp(1j * wave_number * self.distance) * (
- 1 - np.exp(
- 1j * np.pi * radius**2 / (self.wavelength * self.distance)
- ) * np.exp(
- 1j * np.pi * self.radius**2 / (self.wavelength * self.distance)
- ) * series
+ 1
+ - np.exp(1j * np.pi * radius**2 / (self.wavelength * self.distance))
+ * np.exp(1j * np.pi * self.radius**2 / (self.wavelength * self.distance))
+ * series
)
return self.field
def intensity(self) -> np.ndarray:
+ """
+ Calculates the intensity pattern of a circular aperture.
+
+ Args:
+ None
+
+ Returns:
+ np.ndarray: A 2D NumPy array representing the calculated intensity pattern.
+ """
x_linear = np.linspace(-self.x_size / 2, self.x_size / 2, self.x_nodes)
y_linear = np.linspace(-self.y_size / 2, self.y_size / 2, self.y_nodes)
x_grid, y_grid = np.meshgrid(x_linear, y_linear)
@@ -175,9 +230,33 @@ def intensity(self) -> np.ndarray:
radius = np.sqrt(x_grid**2 + y_grid**2)
- intensity = 1 / (1 + np.exp((radius / self.radius)**2))**2 * (
- 1 + jv(0, 2 * np.pi * self.radius * radius / (self.wavelength * self.distance))**2 - 2*np.cos(
- np.pi * self.radius**2/(self.wavelength * self.distance) + np.pi * radius**2 / (self.distance*self.wavelength)
- ) * jv(0, 2 * np.pi * self.radius * radius / (self.wavelength * self.distance))
+ intensity = (
+ 1
+ / (1 + np.exp((radius / self.radius) ** 2)) ** 2
+ * (
+ 1
+ + jv(
+ 0,
+ 2
+ * np.pi
+ * self.radius
+ * radius
+ / (self.wavelength * self.distance),
+ )
+ ** 2
+ - 2
+ * np.cos(
+ np.pi * self.radius**2 / (self.wavelength * self.distance)
+ + np.pi * radius**2 / (self.distance * self.wavelength)
+ )
+ * jv(
+ 0,
+ 2
+ * np.pi
+ * self.radius
+ * radius
+ / (self.wavelength * self.distance),
+ )
+ )
)
return intensity
diff --git a/tests/test_analytic.py b/tests/test_analytic.py
index c565411..ce3f6e0 100644
--- a/tests/test_analytic.py
+++ b/tests/test_analytic.py
@@ -19,7 +19,7 @@
"width_test",
"height_test",
"expected_error",
- "error_energy"
+ "error_energy",
]
@@ -29,52 +29,58 @@
(
8, # ox_size, mm
8, # oy_size, mm
- 1200, # ox_nodes
- 1300, # oy_nodes
+ 1200, # ox_nodes
+ 1300, # oy_nodes
540 * 1e-6, # wavelength_test, mm
- 600, # distance_test, mm
+ 600, # distance_test, mm
4, # width_test, mm
2, # height_test, mm
0.075, # expected error
- 0.05 # error_energy
+ 0.05, # error_energy
),
(
10, # ox_size, mm
10, # oy_size, mm
- 1400, # ox_nodes
- 1300, # oy_nodes
- torch.linspace(330 * 1e-6, 660 * 1e-6, 5), # wavelength_test tensor, mm # noqa: E501
- 150, # distance_test, mm
- 3, # width_test, mm
- 3, # height_test, mm
+ 1400, # ox_nodes
+ 1300, # oy_nodes
+ torch.linspace(
+ 330 * 1e-6, 660 * 1e-6, 5
+ ), # wavelength_test tensor, mm # noqa: E501
+ 150, # distance_test, mm
+ 3, # width_test, mm
+ 3, # height_test, mm
0.065, # expected error
- 0.05 # error_energy
+ 0.05, # error_energy
),
(
8, # ox_size, mm
8, # oy_size, mm
- 1200, # ox_nodes
- 1300, # oy_nodes
- torch.linspace(330 * 1e-6, 660 * 1e-6, 5, dtype=torch.float64), # wavelength_test tensor, mm # noqa: E501
- 600, # distance_test, mm
+ 1200, # ox_nodes
+ 1300, # oy_nodes
+ torch.linspace(
+ 330 * 1e-6, 660 * 1e-6, 5, dtype=torch.float64
+ ), # wavelength_test tensor, mm # noqa: E501
+ 600, # distance_test, mm
2, # width_test, mm
2, # height_test, mm
0.075, # expected error
- 0.05 # error_energy
+ 0.05, # error_energy
),
(
8, # ox_size, mm
8, # oy_size, mm
- 1200, # ox_nodes
- 1300, # oy_nodes
- torch.linspace(330 * 1e-6, 660 * 1e-6, 5, dtype=torch.float64), # wavelength_test tensor, mm # noqa: E501
- 600, # distance_test, mm
+ 1200, # ox_nodes
+ 1300, # oy_nodes
+ torch.linspace(
+ 330 * 1e-6, 660 * 1e-6, 5, dtype=torch.float64
+ ), # wavelength_test tensor, mm # noqa: E501
+ 600, # distance_test, mm
4, # width_test, mm
2, # height_test, mm
0.075, # expected std
- 0.05 # error_energy
- )
- ]
+ 0.05, # error_energy
+ ),
+ ],
)
def test_rectangle_fresnel(
ox_size: float,
@@ -86,7 +92,7 @@ def test_rectangle_fresnel(
width_test: float,
height_test: float,
expected_error: float,
- error_energy: float
+ error_energy: float,
):
"""Test for the free propagation problem on the example of diffraction of
the plane wave on the rectangular aperture
@@ -117,13 +123,13 @@ def test_rectangle_fresnel(
params = SimulationParameters(
{
- 'W': torch.linspace(
- -ox_size/2, ox_size/2, ox_nodes, dtype=torch.float64
+ "W": torch.linspace(
+ -ox_size / 2, ox_size / 2, ox_nodes, dtype=torch.float64
),
- 'H': torch.linspace(
- -oy_size/2, oy_size/2, oy_nodes, dtype=torch.float64
+ "H": torch.linspace(
+ -oy_size / 2, oy_size / 2, oy_nodes, dtype=torch.float64
),
- 'wavelength': wavelength_test
+ "wavelength": wavelength_test,
}
)
@@ -131,30 +137,22 @@ def test_rectangle_fresnel(
dy = oy_size / oy_nodes
incident_field = Wavefront.plane_wave(
- simulation_parameters=params,
- distance=distance_test,
- wave_direction=[0, 0, 1]
+ simulation_parameters=params, distance=distance_test, wave_direction=[0, 0, 1]
)
# field after the square aperture
transmission_field = elements.RectangularAperture(
- simulation_parameters=params,
- height=height_test,
- width=width_test
+ simulation_parameters=params, height=height_test, width=width_test
)(incident_field)
# field on the screen by using Fresnel propagation method
output_field_fresnel = elements.FreeSpace(
- simulation_parameters=params,
- distance=distance_test,
- method='fresnel'
- )(transmission_field)
+ simulation_parameters=params, distance=distance_test, method="fresnel"
+ )(transmission_field)
# field on the screen by using Angular Spectrum method
output_field_as = elements.FreeSpace(
- simulation_parameters=params,
- distance=distance_test,
- method='AS'
- )(transmission_field)
+ simulation_parameters=params, distance=distance_test, method="AS"
+ )(transmission_field)
# intensity distribution on the screen by using Fresnel propagation method
intensity_output_fresnel = output_field_fresnel.intensity
@@ -170,40 +168,36 @@ def test_rectangle_fresnel(
y_nodes=oy_nodes,
width=width_test,
height=height_test,
- wavelength=wavelength_test
+ wavelength=wavelength_test,
).intensity()
if isinstance(intensity_analytic, np.ndarray):
intensity_analytic = torch.from_numpy(intensity_analytic)
energy_analytic = torch.sum(intensity_analytic, dim=(-2, -1)) * dx * dy
- energy_numeric_fresnel = torch.sum(
- intensity_output_fresnel, dim=(-2, -1)
- ) * dx * dy
+ energy_numeric_fresnel = torch.sum(intensity_output_fresnel, dim=(-2, -1)) * dx * dy
energy_numeric_as = torch.sum(intensity_output_as, dim=(-2, -1)) * dx * dy
intensity_difference_fresnel = torch.abs(
intensity_analytic - intensity_output_fresnel
) / (ox_nodes * oy_nodes)
- intensity_difference_as = torch.abs(
- intensity_analytic - intensity_output_as
- ) / (ox_nodes * oy_nodes)
+ intensity_difference_as = torch.abs(intensity_analytic - intensity_output_as) / (
+ ox_nodes * oy_nodes
+ )
error_fresnel, _ = intensity_difference_fresnel.view(
intensity_difference_fresnel.size(0), -1
).max(dim=1)
- error_as, _ = intensity_difference_as.view(
- intensity_difference_as.size(0), -1
- ).max(dim=1)
+ error_as, _ = intensity_difference_as.view(intensity_difference_as.size(0), -1).max(
+ dim=1
+ )
energy_error_fresnel = torch.abs(
(energy_analytic - energy_numeric_fresnel) / energy_analytic
)
- energy_error_as = torch.abs(
- (energy_analytic - energy_numeric_as) / energy_analytic
- )
+ energy_error_as = torch.abs((energy_analytic - energy_numeric_as) / energy_analytic)
assert (error_fresnel < expected_error).all()
assert (error_as < expected_error).all()
diff --git a/tests/test_apertures.py b/tests/test_apertures.py
index 1fe6e73..36a07fa 100644
--- a/tests/test_apertures.py
+++ b/tests/test_apertures.py
@@ -13,14 +13,16 @@
"wavelength_test",
"height_test",
"width_test",
- "expected_std"
+ "expected_std",
]
@pytest.mark.parametrize(
rectangle_parameters,
- [(10, 10, 1000, 1200, 1064 * 1e-6, 4, 10, 1e-5),
- (4, 4, 1300, 1000, 1064 * 1e-6, 3, 1, 1e-5)]
+ [
+ (10, 10, 1000, 1200, 1064 * 1e-6, 4, 10, 1e-5),
+ (4, 4, 1300, 1000, 1064 * 1e-6, 3, 1, 1e-5),
+ ],
)
def test_rectangle_aperture(
ox_size: float,
@@ -55,27 +57,27 @@ def test_rectangle_aperture(
"""
params = SimulationParameters(
{
- 'W': torch.linspace(-ox_size/2, ox_size/2, ox_nodes),
- 'H': torch.linspace(-oy_size/2, oy_size/2, oy_nodes),
- 'wavelength': wavelength_test
+ "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes),
+ "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes),
+ "wavelength": wavelength_test,
}
)
# transmission function of the rectangular aperture as a class method
aperture = elements.RectangularAperture(
- simulation_parameters=params,
- height=height_test,
- width=width_test
+ simulation_parameters=params, height=height_test, width=width_test
)
transmission_function = aperture.get_transmission_function()
x_linear = torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes)
y_linear = torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes)
- x_grid, y_grid = torch.meshgrid(x_linear, y_linear, indexing='xy')
+ x_grid, y_grid = torch.meshgrid(x_linear, y_linear, indexing="xy")
- transmission_function_analytic = 1 * (
- torch.abs(x_grid) <= width_test / 2
- ) * (torch.abs(y_grid) <= height_test / 2)
+ transmission_function_analytic = (
+ 1
+ * (torch.abs(x_grid) <= width_test / 2)
+ * (torch.abs(y_grid) <= height_test / 2)
+ )
standard_deviation = torch.std(
transmission_function - transmission_function_analytic
@@ -85,9 +87,7 @@ def test_rectangle_aperture(
# test forward calculations
wavefront = svetlanna.Wavefront.plane_wave(params)
- torch.testing.assert_close(
- aperture(wavefront), transmission_function * wavefront
- )
+ torch.testing.assert_close(aperture(wavefront), transmission_function * wavefront)
round_parameters = [
@@ -97,14 +97,16 @@ def test_rectangle_aperture(
"oy_nodes",
"wavelength_test",
"radius_test",
- "expected_std"
+ "expected_std",
]
@pytest.mark.parametrize(
round_parameters,
- [(10, 15, 1200, 1000, 1064 * 1e-6, 4, 1e-5),
- (8, 4, 1000, 1300, 1064 * 1e-6, 2.5, 1e-5)]
+ [
+ (10, 15, 1200, 1000, 1064 * 1e-6, 4, 1e-5),
+ (8, 4, 1000, 1300, 1064 * 1e-6, 2.5, 1e-5),
+ ],
)
def test_round_aperture(
ox_size: float,
@@ -137,22 +139,19 @@ def test_round_aperture(
params = SimulationParameters(
{
- 'W': torch.linspace(-ox_size/2, ox_size/2, ox_nodes),
- 'H': torch.linspace(-oy_size/2, oy_size/2, oy_nodes),
- 'wavelength': wavelength_test
+ "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes),
+ "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes),
+ "wavelength": wavelength_test,
}
)
# transmission function of the round aperture as a class method
- aperture = elements.RoundAperture(
- simulation_parameters=params,
- radius=radius_test
- )
+ aperture = elements.RoundAperture(simulation_parameters=params, radius=radius_test)
transmission_function = aperture.get_transmission_function()
x_linear = torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes)
y_linear = torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes)
- x_grid, y_grid = torch.meshgrid(x_linear, y_linear, indexing='xy')
+ x_grid, y_grid = torch.meshgrid(x_linear, y_linear, indexing="xy")
transmission_function_analytic = 1 * (
torch.pow(x_grid, 2) + torch.pow(y_grid, 2) <= radius_test**2
@@ -166,9 +165,7 @@ def test_round_aperture(
# test forward calculations
wavefront = svetlanna.Wavefront.plane_wave(params)
- torch.testing.assert_close(
- aperture(wavefront), transmission_function * wavefront
- )
+ torch.testing.assert_close(aperture(wavefront), transmission_function * wavefront)
arbitrary_parameters = [
@@ -178,14 +175,16 @@ def test_round_aperture(
"oy_nodes",
"wavelength_test",
"mask_test",
- "expected_std"
+ "expected_std",
]
@pytest.mark.parametrize(
arbitrary_parameters,
- [(10, 15, 1200, 1000, 1064 * 1e-6, torch.rand(1000, 1200), 1e-5),
- (8, 4, 1100, 1000, 1064 * 1e-6, torch.rand(1000, 1100), 1e-5)]
+ [
+ (10, 15, 1200, 1000, 1064 * 1e-6, torch.rand(1000, 1200), 1e-5),
+ (8, 4, 1100, 1000, 1064 * 1e-6, torch.rand(1000, 1100), 1e-5),
+ ],
)
def test_aperture(
ox_size: float,
@@ -218,17 +217,14 @@ def test_aperture(
params = SimulationParameters(
{
- 'W': torch.linspace(-ox_size/2, ox_size/2, ox_nodes),
- 'H': torch.linspace(-oy_size/2, oy_size/2, oy_nodes),
- 'wavelength': wavelength_test
+ "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes),
+ "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes),
+ "wavelength": wavelength_test,
}
)
# transmission function for the aperture with arbitrary shape as a
# class method
- aperture = elements.Aperture(
- simulation_parameters=params,
- mask=mask_test
- )
+ aperture = elements.Aperture(simulation_parameters=params, mask=mask_test)
transmission_function = aperture.get_transmission_function()
transmission_function_analytic = mask_test
@@ -241,6 +237,4 @@ def test_aperture(
# test forward calculations
wavefront = svetlanna.Wavefront.plane_wave(params)
- torch.testing.assert_close(
- aperture(wavefront), transmission_function * wavefront
- )
+ torch.testing.assert_close(aperture(wavefront), transmission_function * wavefront)
diff --git a/tests/test_autoencoder.py b/tests/test_autoencoder.py
index 9d1468d..08a973d 100644
--- a/tests/test_autoencoder.py
+++ b/tests/test_autoencoder.py
@@ -14,24 +14,26 @@ def empty_encoder_or_decoder(zero_free_space):
@pytest.mark.parametrize(
- "wf_real, wf_imag", [
+ "wf_real, wf_imag",
+ [
(1.00, 0.00),
(0.00, 1.00),
(2.50, 1.25),
- ]
+ ],
)
def test_autoencoder_forward(
- sim_params, empty_encoder_or_decoder, # fixtures
- wf_real, wf_imag
+ sim_params, empty_encoder_or_decoder, wf_real, wf_imag # fixtures
):
"""Test forward function for a single Wavefront sequence."""
- h, w = sim_params.axes_size(axs=('H', 'W')) # size of a wavefront according to SimulationParameters
+ h, w = sim_params.axes_size(
+ axs=("H", "W")
+ ) # size of a wavefront according to SimulationParameters
test_wavefront = Wavefront(
- torch.ones(size=(h, w), dtype=torch.float64) * wf_real +
- torch.ones(size=(h, w), dtype=torch.float64) * wf_imag * 1j
+ torch.ones(size=(h, w), dtype=torch.float64) * wf_real
+ + torch.ones(size=(h, w), dtype=torch.float64) * wf_imag * 1j
)
- for to_return in ['wf', 'amps']:
+ for to_return in ["wf", "amps"]:
autoencoder = LinearAutoencoder(
sim_params,
encoder_elements_list=empty_encoder_or_decoder,
@@ -45,9 +47,9 @@ def test_autoencoder_forward(
for wf in [wf_encoded, wf_decoded]:
assert isinstance(wf, Wavefront)
- if to_return == 'wf':
+ if to_return == "wf":
assert torch.allclose(wf, test_wavefront)
- if to_return == 'amps':
+ if to_return == "amps":
assert torch.allclose(wf, test_wavefront.abs() + 0j)
@@ -58,13 +60,15 @@ def test_autoencoder_device(sim_params, empty_encoder_or_decoder):
sim_params,
encoder_elements_list=empty_encoder_or_decoder,
decoder_elements_list=empty_encoder_or_decoder,
- device='cpu',
+ device="cpu",
)
- assert autoencoder.device == torch.device('cpu')
+ assert autoencoder.device == torch.device("cpu")
- new_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- if new_device == torch.device('cpu'): # if cuda is not available - check if `mps` is
- new_device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
+ new_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ if new_device == torch.device(
+ "cpu"
+ ): # if cuda is not available - check if `mps` is
+ new_device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
new_autoencoder = autoencoder.to(new_device)
diff --git a/tests/test_axes_math.py b/tests/test_axes_math.py
index a8e0379..be0b481 100644
--- a/tests/test_axes_math.py
+++ b/tests/test_axes_math.py
@@ -10,8 +10,8 @@
def test_append_slice():
"""Test that append slice"""
- axes = ('a',)
- new_axes = ('a', 'b')
+ axes = ("a",)
+ new_axes = ("a", "b")
full_slice = slice(None, None, None)
# no additional axes
@@ -21,16 +21,16 @@ def test_append_slice():
assert _append_slice(axes, new_axes) == (..., full_slice, None)
# two additional axis should be at the end
- for new_axes in permutations(('a', 'b', 'c')):
+ for new_axes in permutations(("a", "b", "c")):
assert _append_slice(axes, new_axes) == (..., full_slice, None, None)
def test_axes_indices_to_sort():
"""Test for `_axes_indices_to_sort` function"""
- axes = ('a', 'b')
- new_axes = ('b', 'd', 'a', 'c')
+ axes = ("a", "b")
+ new_axes = ("b", "d", "a", "c")
# axes of the tensor expanded with _append_slice
- appended_tensor_axes = ('a', 'b', 'd', 'c')
+ appended_tensor_axes = ("a", "b", "d", "c")
assert _axes_indices_to_sort(axes, new_axes) == tuple(
new_axes.index(axis) for axis in appended_tensor_axes
@@ -45,51 +45,75 @@ def test_swaps():
# elements swap
for i, j in _swaps(new_axes):
- new_axes_list[i], new_axes_list[j] \
- = new_axes_list[j], new_axes_list[i]
+ new_axes_list[i], new_axes_list[j] = new_axes_list[j], new_axes_list[i]
# test if new_axes_list is sorted after swapping
assert sorted(axes) == new_axes_list
def test_cast_tensor():
+ """
+ Tests the cast_tensor function with various axis configurations.
+
+ This tests checks that the `cast_tensor` function correctly adds, maintains, and swaps axes
+ of a given tensor while raising ValueErrors when invalid configurations are provided.
+
+ Args:
+ None
+
+ Returns:
+ None
+ """
a = torch.tensor([[1, 2], [3, 4]])
# additional axes
- b = cast_tensor(a=a, axes=('a',), new_axes=('a', 'b', 'c'))
+ b = cast_tensor(a=a, axes=("a",), new_axes=("a", "b", "c"))
assert len(b.shape) == 4
assert b.shape[-1] == b.shape[-2] == 1
- b = cast_tensor(a=a, axes=('a', 'b'), new_axes=('a', 'b', 'c'))
+ b = cast_tensor(a=a, axes=("a", "b"), new_axes=("a", "b", "c"))
assert len(b.shape) == 3
assert b.shape[-1] == 1
# same axes test
- b = cast_tensor(a=a, axes=('a', 'b'), new_axes=('a', 'b'))
+ b = cast_tensor(a=a, axes=("a", "b"), new_axes=("a", "b"))
assert len(b.shape) == 2
# swap axes test
- b = cast_tensor(a=a, axes=('a', 'b'), new_axes=('b', 'a'))
+ b = cast_tensor(a=a, axes=("a", "b"), new_axes=("b", "a"))
assert torch.allclose(a, b.T)
with pytest.raises(ValueError):
- b = cast_tensor(a=a, axes=('a', 'b'), new_axes=('a', 'c'))
+ b = cast_tensor(a=a, axes=("a", "b"), new_axes=("a", "c"))
def test_axis_to_tuple():
+ """
+ Tests the _axis_to_tuple function for various inputs and caching behavior.
+
+ This test verifies that _axis_to_tuple correctly converts different input types
+ (empty tuple, string, tuple of strings) into tuples and also confirms that it
+ caches results for identical inputs.
+
+ Parameters:
+ None
+
+ Returns:
+ None
+ """
a = _axis_to_tuple(())
- b = _axis_to_tuple('a')
- c = _axis_to_tuple(('a', 'b'))
+ b = _axis_to_tuple("a")
+ c = _axis_to_tuple(("a", "b"))
# test for values
assert a == ()
- assert b == ('a',)
- assert c == ('a', 'b')
+ assert b == ("a",)
+ assert c == ("a", "b")
# check for cache
assert a is _axis_to_tuple(())
- assert b is _axis_to_tuple('a')
- assert c is _axis_to_tuple(('a', 'b'))
+ assert b is _axis_to_tuple("a")
+ assert c is _axis_to_tuple(("a", "b"))
def test_new_axes():
@@ -101,38 +125,48 @@ def test_new_axes():
```
"""
- assert _new_axes(('a', 'b'), ('a',)) == ('a', 'b')
+ assert _new_axes(("a", "b"), ("a",)) == ("a", "b")
- assert _new_axes(('a', 'b'), ('c',)) == ('a', 'b', 'c')
- assert _new_axes(('a', 'b'), ('c', 'd')) == ('a', 'b', 'c', 'd')
+ assert _new_axes(("a", "b"), ("c",)) == ("a", "b", "c")
+ assert _new_axes(("a", "b"), ("c", "d")) == ("a", "b", "c", "d")
- assert _new_axes(('a', 'b'), ('a', 'c')) == ('a', 'b', 'c')
- assert _new_axes(('a', 'b'), ('c', 'a')) == ('a', 'b', 'c')
- assert _new_axes(('a', 'b'), ('b', 'c')) == ('a', 'b', 'c')
- assert _new_axes(('a', 'b'), ('c', 'b')) == ('a', 'b', 'c')
- assert _new_axes(('a', 'b'), ('c', 'd', 'b', 'e')) \
- == ('a', 'b', 'c', 'd', 'e')
+ assert _new_axes(("a", "b"), ("a", "c")) == ("a", "b", "c")
+ assert _new_axes(("a", "b"), ("c", "a")) == ("a", "b", "c")
+ assert _new_axes(("a", "b"), ("b", "c")) == ("a", "b", "c")
+ assert _new_axes(("a", "b"), ("c", "b")) == ("a", "b", "c")
+ assert _new_axes(("a", "b"), ("c", "d", "b", "e")) == ("a", "b", "c", "d", "e")
def test_is_scalar():
- assert is_scalar(123.)
- assert is_scalar(torch.tensor(123.))
- assert not is_scalar(torch.tensor([123.]))
- assert not is_scalar(torch.tensor([123., 123]))
- assert not is_scalar(torch.tensor([[123., 123]]))
+ """
+ Tests the is_scalar function with various inputs.
+
+ This tests that single number values and tensors containing a single value are correctly identified as scalar,
+ while tensors containing multiple values are not.
+
+ Returns:
+ None
+ """
+ assert is_scalar(123.0)
+ assert is_scalar(torch.tensor(123.0))
+ assert not is_scalar(torch.tensor([123.0]))
+ assert not is_scalar(torch.tensor([123.0, 123]))
+ assert not is_scalar(torch.tensor([[123.0, 123]]))
def test_check_axis():
+ """
+ Tests the _check_axis function for various error conditions."""
# test for unique
with pytest.raises(ValueError):
- _check_axis(torch.tensor([[[123]]]), ('a', 'a', 'b'))
+ _check_axis(torch.tensor([[[123]]]), ("a", "a", "b"))
# test for number of axes in tensor
with pytest.raises(ValueError):
- _check_axis(torch.tensor([123]), ('a', 'b'))
+ _check_axis(torch.tensor([123]), ("a", "b"))
# and for number of axes in float
- assert _check_axis(123, ('a', 'b')) is None
+ assert _check_axis(123, ("a", "b")) is None
def test_tensor_dot():
@@ -140,17 +174,17 @@ def test_tensor_dot():
e = 123
d = 321
# product of a scalar and a scalar
- c, c_axis = tensor_dot(d, e, ('a', 'b'), ('c'))
+ c, c_axis = tensor_dot(d, e, ("a", "b"), ("c"))
assert 123 * d == c
assert c_axis == ()
- c, c_axis = tensor_dot(d, e, ('a', 'b'), ('c'), preserve_a_axis=True)
+ c, c_axis = tensor_dot(d, e, ("a", "b"), ("c"), preserve_a_axis=True)
assert e * d == c
- assert c_axis == ('a', 'b')
+ assert c_axis == ("a", "b")
# product of a tensor and a scalar
- a = torch.tensor([1.])
- b = torch.tensor([[1., 2], [3., 4.]])
+ a = torch.tensor([1.0])
+ b = torch.tensor([[1.0, 2], [3.0, 4.0]])
c, c_axis = tensor_dot(a, e, (), ())
assert e * a == c
@@ -160,25 +194,25 @@ def test_tensor_dot():
assert torch.allclose(e * b, c)
assert c_axis == ()
- c, c_axis = tensor_dot(a, e, ('a',), ())
+ c, c_axis = tensor_dot(a, e, ("a",), ())
assert e * a == c
- assert c_axis == ('a',)
+ assert c_axis == ("a",)
- c, c_axis = tensor_dot(b, e, ('a',), ())
+ c, c_axis = tensor_dot(b, e, ("a",), ())
assert torch.allclose(e * b, c)
- assert c_axis == ('a',)
+ assert c_axis == ("a",)
- c, c_axis = tensor_dot(a, e, ('a',), ('b', 'c'))
+ c, c_axis = tensor_dot(a, e, ("a",), ("b", "c"))
assert e * a == c
- assert c_axis == ('a',)
+ assert c_axis == ("a",)
- c, c_axis = tensor_dot(b, e, ('a',), ('b', 'c'))
+ c, c_axis = tensor_dot(b, e, ("a",), ("b", "c"))
assert torch.allclose(e * b, c)
- assert c_axis == ('a',)
+ assert c_axis == ("a",)
- c, c_axis = tensor_dot(b, e, ('a', 'd'), ('b', 'c'))
+ c, c_axis = tensor_dot(b, e, ("a", "d"), ("b", "c"))
assert torch.allclose(e * b, c)
- assert c_axis == ('a', 'd')
+ assert c_axis == ("a", "d")
# product of a scalar and a tensor
c, c_axis = tensor_dot(e, a, (), ())
@@ -189,96 +223,98 @@ def test_tensor_dot():
assert torch.allclose(e * b, c)
assert c_axis == ()
- c, c_axis = tensor_dot(e, a, (), ('a'))
+ c, c_axis = tensor_dot(e, a, (), ("a"))
assert e * a == c
- assert c_axis == ('a',)
+ assert c_axis == ("a",)
- c, c_axis = tensor_dot(e, b, (), ('a'))
+ c, c_axis = tensor_dot(e, b, (), ("a"))
assert torch.allclose(e * b, c)
- assert c_axis == ('a',)
+ assert c_axis == ("a",)
- c, c_axis = tensor_dot(e, a, ('a',), ('a'))
+ c, c_axis = tensor_dot(e, a, ("a",), ("a"))
assert e * a == c
- assert c_axis == ('a',)
+ assert c_axis == ("a",)
- c, c_axis = tensor_dot(e, a, ('a', 'c'), ('a'))
+ c, c_axis = tensor_dot(e, a, ("a", "c"), ("a"))
assert e * a == c
- assert c_axis == ('a',)
+ assert c_axis == ("a",)
- c, c_axis = tensor_dot(e, a, ('a', 'c'), ('a'), preserve_a_axis=True)
+ c, c_axis = tensor_dot(e, a, ("a", "c"), ("a"), preserve_a_axis=True)
assert e * a == c
- assert c_axis == ('a', 'c')
+ assert c_axis == ("a", "c")
- c, c_axis = tensor_dot(e, b, ('a', 'c'), ('a'), preserve_a_axis=True)
+ c, c_axis = tensor_dot(e, b, ("a", "c"), ("a"), preserve_a_axis=True)
assert torch.allclose((e * b)[..., None], c)
- assert c_axis == ('a', 'c')
+ assert c_axis == ("a", "c")
# product of a tensor and a tensor
c, c_axis = tensor_dot(a, b, (), ())
assert torch.allclose((a * b), c)
assert c_axis == ()
- c, c_axis = tensor_dot(a, b, ('a'), ('a', 'b'))
+ c, c_axis = tensor_dot(a, b, ("a"), ("a", "b"))
d = b.clone()
d[:] *= a[:]
assert torch.allclose(c, d)
- assert c_axis == ('a', 'b')
+ assert c_axis == ("a", "b")
- c, c_axis = tensor_dot(a, b, ('a'), ('a'))
+ c, c_axis = tensor_dot(a, b, ("a"), ("a"))
d = b.clone()
d[..., :] *= a[:]
assert torch.allclose(c, d)
- assert c_axis == ('a',)
+ assert c_axis == ("a",)
- c, c_axis = tensor_dot(b, a, ('a', 'b'), ('a'))
+ c, c_axis = tensor_dot(b, a, ("a", "b"), ("a"))
d = b.clone()
d[:, ...] *= a[:]
assert torch.allclose(c, d)
- assert c_axis == ('a', 'b')
+ assert c_axis == ("a", "b")
- c, c_axis = tensor_dot(b, a, ('a', 'b'), ('c'))
+ c, c_axis = tensor_dot(b, a, ("a", "b"), ("c"))
d = b.clone()[..., None]
d[..., :] *= a[:]
assert torch.allclose(c, d)
- assert c_axis == ('a', 'b', 'c')
+ assert c_axis == ("a", "b", "c")
def test_mul():
- wf = Wavefront([[1.+1j]])
+ """
+ Tests the mul function with various Wavefront objects and tensors."""
+ wf = Wavefront([[1.0 + 1j]])
# test wf and non-tensor product
assert mul(wf, 123, ()) == wf * 123
# test default axes
- wf = Wavefront([[1., 2], [3, 4]])
+ wf = Wavefront([[1.0, 2], [3, 4]])
a = torch.tensor([10, 20])
- assert torch.allclose(mul(wf, a, ('H')), wf * a[:, None])
- assert torch.allclose(mul(wf, a, ('W')), wf * a[None, :])
+ assert torch.allclose(mul(wf, a, ("H")), wf * a[:, None])
+ assert torch.allclose(mul(wf, a, ("W")), wf * a[None, :])
with pytest.raises(AssertionError):
- mul(wf, torch.tensor([123]), ('s'))
+ mul(wf, torch.tensor([123]), ("s"))
# test non default axes
sim_params1 = SimulationParameters(
axes={
- 'H': torch.linspace(-1, 1, 2),
- 'W': torch.linspace(-1, 1, 2),
- 'wavelength': torch.tensor([1]),
+ "H": torch.linspace(-1, 1, 2),
+ "W": torch.linspace(-1, 1, 2),
+ "wavelength": torch.tensor([1]),
}
)
- wf1 = Wavefront([[[1., 2], [3, 4]]])
- assert torch.allclose(mul(wf1, 123, 'wavelength', sim_params1), 123 * wf1)
+ wf1 = Wavefront([[[1.0, 2], [3, 4]]])
+ assert torch.allclose(mul(wf1, 123, "wavelength", sim_params1), 123 * wf1)
r = wf1 * a[None, :]
- assert torch.allclose(mul(wf1, a, 'H', sim_params1), r)
+ assert torch.allclose(mul(wf1, a, "H", sim_params1), r)
# test the same product but with other simulation parameters
sim_params2 = SimulationParameters(
axes={
- 'wavelength': torch.tensor([1]),
- 'W': torch.linspace(-1, 1, 2),
- 'H': torch.linspace(-1, 1, 2),
+ "wavelength": torch.tensor([1]),
+ "W": torch.linspace(-1, 1, 2),
+ "H": torch.linspace(-1, 1, 2),
}
)
wf2 = Wavefront(wf1.swapaxes(0, 2))
- assert torch.allclose(mul(wf2, a, 'H', sim_params2), r.swapaxes(0, 2))
+ assert torch.allclose(mul(wf2, a, "H", sim_params2), r.swapaxes(0, 2))
diff --git a/tests/test_clerk.py b/tests/test_clerk.py
index fc598d4..5108456 100644
--- a/tests/test_clerk.py
+++ b/tests/test_clerk.py
@@ -5,14 +5,23 @@
def test_init(tmp_path):
+ """
+ Tests the Clerk initialization and experiment directory validation.
+
+ Args:
+ tmp_path: A temporary path to be used for testing.
+
+ Returns:
+ None
+ """
# Test the experiment directory
clerk = Clerk(tmp_path)
assert clerk.experiment_directory == tmp_path
# Test if the experiment directory is not a directory case
- new_path = tmp_path / 'test'
+ new_path = tmp_path / "test"
assert not new_path.exists()
- with open(new_path, 'w'):
+ with open(new_path, "w"):
pass
with pytest.raises(ValueError):
@@ -20,7 +29,16 @@ def test_init(tmp_path):
def test_make_experiment_dir(tmp_path):
- new_path = tmp_path / 'test'
+ """
+ Tests the _make_experiment_dir method to ensure it creates a directory.
+
+ Args:
+ tmp_path: A temporary path for testing purposes.
+
+ Returns:
+ None
+ """
+ new_path = tmp_path / "test"
clerk = Clerk(new_path)
assert not new_path.exists()
@@ -29,12 +47,14 @@ def test_make_experiment_dir(tmp_path):
def test_get_log_stream(tmp_path):
+ """
+ Tests the _get_log_stream method."""
clerk = Clerk(tmp_path)
- tag = '123'
+ tag = "123"
with clerk._get_log_stream(tag) as stream:
# Test if the file was created
- assert (tmp_path / (tag + '.jsonl')).exists()
+ assert (tmp_path / (tag + ".jsonl")).exists()
# Test if the stream is not closed after the context is closed
assert not stream.closed
@@ -44,23 +64,32 @@ def test_get_log_stream(tmp_path):
assert stream is stream2
# Test if the same stream is not used for the different tag
- other_tag = '312'
+ other_tag = "312"
assert tag != other_tag
with clerk._get_log_stream(other_tag) as stream3:
assert stream is not stream3
def test_get_log_stream_mode(tmp_path):
+ """
+ Tests that the log stream mode is 'w' for new runs and 'a' for resumed runs.
+
+ Args:
+ tmp_path: A temporary path to use for the Clerk instance.
+
+ Returns:
+ None
+ """
clerk = Clerk(tmp_path)
- tag = '123'
+ tag = "123"
# Test if the stream mode is 'w' for 'new_run' mode
# By default 'new_run' mode is used
with clerk:
with clerk._get_log_stream(tag) as stream:
assert clerk._mode == ClerkMode.new_run
- assert stream.mode == 'w'
+ assert stream.mode == "w"
# Test if the stream mode is 'a' for 'resume' mode
# The clerk.begin() should be used to set 'resume' mode
@@ -68,13 +97,27 @@ def test_get_log_stream_mode(tmp_path):
with clerk._get_log_stream(tag) as stream:
assert clerk._mode == ClerkMode.resume
- assert stream.mode == 'a'
+ assert stream.mode == "a"
def test_get_log_stream_flushed(tmp_path):
+ """
+ Tests the behavior of log stream flushing within a context manager.
+
+ This test verifies that the flush method is called on the underlying stream
+ only when explicitly requested via the `flush` parameter to
+ `_get_log_stream`. It also checks that flush isn't called if the context
+ manager exits normally without requesting a flush.
+
+ Args:
+ tmp_path: A temporary path for Clerk initialization.
+
+ Returns:
+ None
+ """
# TODO: refactoring
clerk = Clerk(tmp_path)
- tag = '123'
+ tag = "123"
with clerk._get_log_stream(tag) as stream:
pass
@@ -99,26 +142,33 @@ def monkey_flush():
def test_conditions(tmp_path):
- experiment_dir = tmp_path / 'experiment'
+ """
+ Tests the saving and loading of experiment conditions.
+
+ This method creates a Clerk instance, saves a dictionary of conditions to
+ a JSON file within an experiment directory, and then verifies that the
+ directory and file are created correctly. It also loads the conditions
+ from the saved file and asserts that they match the original conditions.
+
+ Args:
+ tmp_path: A temporary path for creating the experiment directory.
+
+ Returns:
+ None
+ """
+ experiment_dir = tmp_path / "experiment"
clerk = Clerk(experiment_dir)
conditions = {
- 'test1': 123,
- 'test2': [
- 123,
- 10.,
- 'a'
- ],
- 'test3': {
- 't': 'e',
- 's': 't'
- }
+ "test1": 123,
+ "test2": [123, 10.0, "a"],
+ "test3": {"t": "e", "s": "t"},
}
clerk.save_conditions(conditions)
# Test if the folder and the file are created
assert experiment_dir.exists()
- assert (experiment_dir / 'conditions.json').exists()
+ assert (experiment_dir / "conditions.json").exists()
# Test if when loaded, the conditions are the same
new_clerk = Clerk(experiment_dir)
@@ -129,13 +179,23 @@ def test_conditions(tmp_path):
def test_logs(tmp_path):
+ """
+ Tests the log writing and loading functionality of the Clerk.
+
+ This tests checks that logs cannot be written before a context is active,
+ that files are created/appended to correctly, and that loaded messages match
+ the original messages in both regular and resume modes.
+
+ Args:
+ tmp_path: A temporary path for creating log files.
+
+ Returns:
+ None
+ """
clerk = Clerk(tmp_path)
- tag = 'test'
- messages = [
- {'a': 123, 'b': 321.},
- {'a': 321, 'b': 5423}
- ]
+ tag = "test"
+ messages = [{"a": 123, "b": 321.0}, {"a": 321, "b": 5423}]
# Test if log can't be written before the clerk is used in any context
with pytest.raises(RuntimeError):
@@ -143,7 +203,7 @@ def test_logs(tmp_path):
clerk.write_log(tag, message)
# Test if log file does not exist
- assert not (tmp_path / (tag + '.jsonl')).exists()
+ assert not (tmp_path / (tag + ".jsonl")).exists()
# Write the logs
with clerk:
@@ -151,7 +211,7 @@ def test_logs(tmp_path):
clerk.write_log(tag, message)
# Test if log file is created
- assert (tmp_path / (tag + '.jsonl')).exists()
+ assert (tmp_path / (tag + ".jsonl")).exists()
# Test if when loaded, the messages are the same
loaded_messages = list(clerk.load_logs(tag))
@@ -159,14 +219,14 @@ def test_logs(tmp_path):
assert loaded_messages == messages
# Test if in resume mode the logs are appended in existing file
- tag2 = 'test2'
- assert not (tmp_path / (tag2 + '.jsonl')).exists()
+ tag2 = "test2"
+ assert not (tmp_path / (tag2 + ".jsonl")).exists()
with clerk.begin(resume=True):
for message in messages:
clerk.write_log(tag2, message)
- assert (tmp_path / (tag2 + '.jsonl')).exists()
+ assert (tmp_path / (tag2 + ".jsonl")).exists()
loaded_messages = clerk.load_logs(tag2)
for i, message in enumerate(loaded_messages):
@@ -183,13 +243,23 @@ def test_logs(tmp_path):
def test_logs_pandas(tmp_path):
+ """
+ Tests loading logs to a Pandas DataFrame.
+
+ This test writes log messages with a specific tag, then loads them into a
+ Pandas DataFrame and verifies that the loaded data matches the original
+ messages.
+
+ Args:
+ tmp_path: A temporary path for storing log files.
+
+ Returns:
+ None
+ """
clerk = Clerk(tmp_path)
- tag = 'test'
- messages = [
- {'a': 123, 'b': 321.},
- {'a': 321, 'b': 5423}
- ]
+ tag = "test"
+ messages = [{"a": 123, "b": 321.0}, {"a": 321, "b": 5423}]
with clerk:
for message in messages:
@@ -201,8 +271,20 @@ def test_logs_pandas(tmp_path):
def test_checkpoints(tmp_path):
+ """
+ Tests the Clerk's checkpointing functionality.
+
+ This tests various scenarios including writing checkpoints, loading them,
+ cleaning up old checkpoints, and handling metadata and targets.
+
+ Args:
+ tmp_path: A temporary path for storing checkpoint files.
+
+ Returns:
+ None
+ """
clerk = Clerk(tmp_path)
- checkpoints_filepath = tmp_path / 'checkpoints.txt'
+ checkpoints_filepath = tmp_path / "checkpoints.txt"
# Test if checkpoint can't be written before
# the clerk is used in any context
@@ -217,9 +299,7 @@ def test_checkpoints(tmp_path):
# Write checkpoint with metadata and no targets
with clerk:
for i in range(11):
- clerk.write_checkpoint(metadata={
- 'i': i
- })
+ clerk.write_checkpoint(metadata={"i": i})
assert checkpoints_filepath.exists()
@@ -228,15 +308,15 @@ def test_checkpoints(tmp_path):
first_run_checkpoint_filenames: list[str] = []
with open(checkpoints_filepath) as file:
for i, line in enumerate(file.readlines()):
- checkpoint_filename = f'{i}.pt'
+ checkpoint_filename = f"{i}.pt"
first_run_checkpoint_filenames.append(checkpoint_filename)
- assert line == checkpoint_filename + '\n'
+ assert line == checkpoint_filename + "\n"
assert (tmp_path / checkpoint_filename).exists()
metadata = clerk.load_checkpoint(i)
- assert metadata == {'i': i}
+ assert metadata == {"i": i}
same_metadata = clerk.load_checkpoint(checkpoint_filename)
assert same_metadata == metadata
@@ -254,7 +334,7 @@ def test_checkpoints(tmp_path):
with open(checkpoints_filepath) as file:
assert len(file.readlines()) == 3
for i in range(3):
- second_run_checkpoint_filenames.append(f'{i}.pt')
+ second_run_checkpoint_filenames.append(f"{i}.pt")
# Test clean_checkpoints
clerk.clean_checkpoints()
@@ -274,15 +354,12 @@ class ObjectWithState(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.test_parameter: torch.Tensor
- self.register_buffer('test_parameter', torch.tensor(0.0))
+ self.register_buffer("test_parameter", torch.tensor(0.0))
clerk = Clerk(tmp_path)
object1 = ObjectWithState()
object2 = ObjectWithState()
- clerk.set_checkpoint_targets({
- '1': object1,
- '2': object2
- })
+ clerk.set_checkpoint_targets({"1": object1, "2": object2})
# Write checkpoint with targets
with clerk:
@@ -294,16 +371,14 @@ def __init__(self) -> None:
for i in range(6):
clerk.load_checkpoint(i)
assert object1.test_parameter.item() == 0
- assert object2.test_parameter.item() == 2. + 2. * i
+ assert object2.test_parameter.item() == 2.0 + 2.0 * i
# Test load_checkpoint for specific target
object1.test_parameter = torch.tensor(123)
object2.test_parameter = torch.tensor(321)
object3 = ObjectWithState()
for i in range(6):
- clerk.load_checkpoint(i, targets={
- '2': object3
- })
+ clerk.load_checkpoint(i, targets={"2": object3})
# Test if object1 and object2 does not change
assert object1.test_parameter.item() == 123
assert object2.test_parameter.item() == 321
@@ -312,10 +387,7 @@ def __init__(self) -> None:
# Test if more checkpoints has been saved when resume mode
clerk = Clerk(tmp_path)
- clerk.set_checkpoint_targets({
- '1': object1,
- '2': object2
- })
+ clerk.set_checkpoint_targets({"1": object1, "2": object2})
assert object1.test_parameter.item() != 0
assert object2.test_parameter.item() != 16
@@ -333,10 +405,7 @@ def __init__(self) -> None:
# Test if resume_load_last_checkpoint can be turned off
clerk = Clerk(tmp_path)
- clerk.set_checkpoint_targets({
- '1': object1,
- '2': object2
- })
+ clerk.set_checkpoint_targets({"1": object1, "2": object2})
object1.test_parameter = torch.tensor(123)
object2.test_parameter = torch.tensor(321)
@@ -351,7 +420,19 @@ def __init__(self) -> None:
def test_context(tmp_path):
- new_path = tmp_path / 'test'
+ """
+ Tests the Clerk context manager functionality.
+
+ This tests nested context usage, directory creation, and automatic stream closing.
+ It also verifies exception handling when detaching a stream within a clerk context.
+
+ Args:
+ tmp_path: A temporary path for testing purposes.
+
+ Returns:
+ None
+ """
+ new_path = tmp_path / "test"
clerk = Clerk(new_path)
assert not new_path.exists()
@@ -365,7 +446,7 @@ def test_context(tmp_path):
assert new_path.exists()
# Test if all streams are closed automatically
- with clerk._get_log_stream('test', flush=False) as stream:
+ with clerk._get_log_stream("test", flush=False) as stream:
assert not stream.closed
with clerk:
@@ -374,14 +455,28 @@ def test_context(tmp_path):
assert stream.closed
with pytest.raises(ExceptionGroup):
- with clerk._get_log_stream('test', flush=False) as stream:
+ with clerk._get_log_stream("test", flush=False) as stream:
with clerk:
stream.detach()
def test_backup_checkpoint(tmp_path):
+ """
+ Tests the backup checkpoint functionality of the Clerk class.
+
+ This test verifies that a backup checkpoint is created when an exception
+ occurs during a clerk session with autosave enabled, and that it can be
+ loaded and cleaned up correctly. It also tests the handling of exceptions
+ during the preparation of checkpoint data.
+
+ Args:
+ tmp_path: A temporary path for creating test files.
+
+ Returns:
+ None
+ """
clerk = Clerk(tmp_path)
- checkpoints_filepath = tmp_path / 'checkpoints.txt'
+ checkpoints_filepath = tmp_path / "checkpoints.txt"
class SpecificException(Exception):
pass
@@ -395,12 +490,12 @@ class SpecificException(Exception):
# Test if the backup checkpoint is not in 'checkpoints.txt'
with open(checkpoints_filepath) as file:
- assert file.readlines() == ['0.pt\n']
+ assert file.readlines() == ["0.pt\n"]
backup_checkpoints: list[str] = []
# Find backup checkpoint files
for file in tmp_path.iterdir():
- if file.name.endswith('.pt'):
+ if file.name.endswith(".pt"):
if not CHECKPOINT_FILENAME_PATTERN.match(file.name):
backup_checkpoints.append(file.name)
@@ -409,8 +504,8 @@ class SpecificException(Exception):
# Test metadata
metadata = clerk.load_checkpoint(backup_checkpoints[0])
assert isinstance(metadata, dict)
- assert 'time' in metadata
- assert 'description' in metadata
+ assert "time" in metadata
+ assert "description" in metadata
# Test clean_backup_checkpoints method
assert (tmp_path / backup_checkpoints[0]).exists()
diff --git a/tests/test_conv4f_net.py b/tests/test_conv4f_net.py
index d7b0f6a..c79fde5 100644
--- a/tests/test_conv4f_net.py
+++ b/tests/test_conv4f_net.py
@@ -12,16 +12,17 @@
@pytest.fixture()
def some_elements_list(sim_params):
"""Returns list with a zero distance FreeSpace, i.e. empty network"""
- h, w = sim_params.axes_size(axs=('H', 'W'))
+ h, w = sim_params.axes_size(axs=("H", "W"))
elements_list = [
elements.DiffractiveLayer(
simulation_parameters=sim_params,
- mask=torch.rand(h, w) * 2 * torch.pi, # mask is not changing during the training!
+ mask=torch.rand(h, w)
+ * 2
+ * torch.pi, # mask is not changing during the training!
),
elements.FreeSpace(
- simulation_parameters=sim_params,
- distance=3.00 * 1e-2, method='AS'
+ simulation_parameters=sim_params, distance=3.00 * 1e-2, method="AS"
),
]
@@ -29,24 +30,28 @@ def some_elements_list(sim_params):
@pytest.mark.parametrize(
- "wf_real, wf_imag, focal_length", [
+ "wf_real, wf_imag, focal_length",
+ [
(1.00, 0.00, 1.00 * 1e-2),
(0.00, 1.00, 2.00 * 1e-2),
(2.50, 1.25, 3.00 * 1e-2),
- ]
+ ],
)
def test_conv4f_net_forward(
- sim_params, some_elements_list, # fixtures
- wf_real, wf_imag, focal_length
+ sim_params, some_elements_list, wf_real, wf_imag, focal_length # fixtures
):
"""Test forward function for a single Wavefront sequence."""
- h, w = sim_params.axes_size(axs=('H', 'W')) # size of a wavefront according to SimulationParameters
+ h, w = sim_params.axes_size(
+ axs=("H", "W")
+ ) # size of a wavefront according to SimulationParameters
test_wavefront = Wavefront(
- torch.ones(size=(h, w), dtype=torch.float64) * wf_real +
- torch.ones(size=(h, w), dtype=torch.float64) * wf_imag * 1j
+ torch.ones(size=(h, w), dtype=torch.float64) * wf_real
+ + torch.ones(size=(h, w), dtype=torch.float64) * wf_imag * 1j
)
- random_diffractive_mask = torch.rand(h, w) * 2 * torch.pi # random mask for a convolution
+ random_diffractive_mask = (
+ torch.rand(h, w) * 2 * torch.pi
+ ) # random mask for a convolution
# NETWORK
conv4f_net = ConvDiffNetwork4F(
@@ -77,7 +82,9 @@ def test_conv4f_net_forward(
def test_conv4f_net_device(sim_params, some_elements_list):
"""Test .to(device) function for a Convolutional Network."""
- h, w = sim_params.axes_size(axs=('H', 'W')) # size of a wavefront according to SimulationParameters
+ h, w = sim_params.axes_size(
+ axs=("H", "W")
+ ) # size of a wavefront according to SimulationParameters
random_diffractive_mask = torch.rand(h, w) # random mask for a convolution
# NETWORK
@@ -86,14 +93,16 @@ def test_conv4f_net_device(sim_params, some_elements_list):
network_elements_list=some_elements_list,
focal_length=1.00 * 1e-2,
conv_phase_mask=random_diffractive_mask,
- device='cpu',
+ device="cpu",
)
- assert conv4f_net.device == torch.device('cpu')
+ assert conv4f_net.device == torch.device("cpu")
- new_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- if new_device == torch.device('cpu'): # if cuda is not available - check if `mps` is
- new_device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
+ new_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ if new_device == torch.device(
+ "cpu"
+ ): # if cuda is not available - check if `mps` is
+ new_device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
new_conv4f_net = conv4f_net.to(new_device)
diff --git a/tests/test_detector.py b/tests/test_detector.py
index a61ca40..b70b55e 100644
--- a/tests/test_detector.py
+++ b/tests/test_detector.py
@@ -13,9 +13,9 @@ def test_detector_types():
detector = Detector(
SimulationParameters(
{
- 'W': torch.linspace(-1e-2/2, 1e-2/2, 5),
- 'H': torch.linspace(-1e-2/2, 1e-2/2, 5),
- 'wavelength': 1e-6
+ "W": torch.linspace(-1e-2 / 2, 1e-2 / 2, 5),
+ "H": torch.linspace(-1e-2 / 2, 1e-2 / 2, 5),
+ "wavelength": 1e-6,
}
)
)
@@ -24,10 +24,11 @@ def test_detector_types():
@pytest.mark.parametrize(
- "x_size, y_size, x_nodes, y_nodes, wavelength", [
+ "x_size, y_size, x_nodes, y_nodes, wavelength",
+ [
(10e-2, 10e-2, 10, 10, 1e-6),
(15e-2, 20e-2, 15, 20, 1e-6),
- ]
+ ],
)
def test_detector_intensity(x_size, y_size, x_nodes, y_nodes, wavelength):
"""
@@ -42,12 +43,12 @@ def test_detector_intensity(x_size, y_size, x_nodes, y_nodes, wavelength):
detector = Detector(
SimulationParameters(
{
- 'W': torch.linspace(-x_size/2, x_size/2, x_nodes),
- 'H': torch.linspace(-y_size/2, y_size/2, y_nodes),
- 'wavelength': wavelength
+ "W": torch.linspace(-x_size / 2, x_size / 2, x_nodes),
+ "H": torch.linspace(-y_size / 2, y_size / 2, y_nodes),
+ "wavelength": wavelength,
}
),
- func='intensity'
+ func="intensity",
)
input_field = torch.rand(size=[y_nodes, x_nodes])
@@ -57,17 +58,18 @@ def test_detector_intensity(x_size, y_size, x_nodes, y_nodes, wavelength):
@pytest.mark.parametrize(
- "num_classes, detector_x, expected_mask", [
- (4, 8, [[0, 0, 1, 1, 2, 2, 3, 3]]), # num_classes - even, detector_x - even
+ "num_classes, detector_x, expected_mask",
+ [
+ (4, 8, [[0, 0, 1, 1, 2, 2, 3, 3]]), # num_classes - even, detector_x - even
(2, 4, [[0, 0, 1, 1]]),
- (2, 7, [[0, 0, 0, -1, 1, 1, 1]]), # num_classes - even, detector_x - odd
+ (2, 7, [[0, 0, 0, -1, 1, 1, 1]]), # num_classes - even, detector_x - odd
(4, 7, [[-1, 0, 1, -1, 2, 3, -1]]),
- (3, 8, [[-1, 0, 0, 1, 1, 2, 2, -1]]), # num_classes - odd, detector_x - even
+ (3, 8, [[-1, 0, 0, 1, 1, 2, 2, -1]]), # num_classes - odd, detector_x - even
(3, 10, [[0, 0, 0, 1, 1, 1, 1, 2, 2, 2]]),
- (3, 7, [[0, 0, 1, 1, 1, 2, 2]]), # num_classes - odd, detector_x - odd
+ (3, 7, [[0, 0, 1, 1, 1, 2, 2]]), # num_classes - odd, detector_x - odd
(5, 7, [[-1, 0, 1, 2, 3, 4, -1]]),
(5, 11, [[0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4]]),
- ]
+ ],
)
def test_detector_segmentation_strips(num_classes, detector_x, expected_mask):
"""
@@ -87,11 +89,11 @@ def test_detector_segmentation_strips(num_classes, detector_x, expected_mask):
num_classes=num_classes,
simulation_parameters=SimulationParameters(
{
- 'W': torch.linspace(-1e-2/2, 1e-2/2, detector_x),
- 'H': torch.linspace(0, 0, 1),
- 'wavelength': 500e-6
+ "W": torch.linspace(-1e-2 / 2, 1e-2 / 2, detector_x),
+ "H": torch.linspace(0, 0, 1),
+ "wavelength": 500e-6,
}
- )
+ ),
)
assert isinstance(processor, torch.nn.Module)
@@ -100,19 +102,22 @@ def test_detector_segmentation_strips(num_classes, detector_x, expected_mask):
for ind_class in range(num_classes): # check if all classes zones are marked
assert ind_class in processor.segmented_detector
- assert torch.allclose(processor.segmented_detector, torch.tensor(expected_mask, dtype=torch.int32))
+ assert torch.allclose(
+ processor.segmented_detector, torch.tensor(expected_mask, dtype=torch.int32)
+ )
@pytest.mark.parametrize(
- "num_classes, segmented_detector, expected_weights", [
+ "num_classes, segmented_detector, expected_weights",
+ [
(2, [[0, 0, 1, 1, 0, 1, 0, 0]], [[3 / 5, 1.0]]),
(3, [[-1, -1, 0, 0, 0, 1, 2, 0, 1, 2, 2, 2, -1, -1]], [[0.5, 1.0, 0.5]]),
- (4,
- [[-1, -1, 1, 1, -1, -1, 3, 3],
- [0, 0, -1, -1, 2, 2, -1, -1]],
- [[1.0, 1.0, 1.0, 1.0]]
- ),
- ]
+ (
+ 4,
+ [[-1, -1, 1, 1, -1, -1, 3, 3], [0, 0, -1, -1, 2, 2, -1, -1]],
+ [[1.0, 1.0, 1.0, 1.0]],
+ ),
+ ],
)
def test_detector_weight_segments(num_classes, segmented_detector, expected_weights):
"""
@@ -135,9 +140,9 @@ def test_detector_weight_segments(num_classes, segmented_detector, expected_weig
num_classes=num_classes,
simulation_parameters=SimulationParameters(
{
- 'W': torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[1]),
- 'H': torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[0]),
- 'wavelength': 500e-6
+ "W": torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[1]),
+ "H": torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[0]),
+ "wavelength": 500e-6,
}
),
segmented_detector=segmented_detector_tensor,
@@ -150,24 +155,23 @@ def test_detector_weight_segments(num_classes, segmented_detector, expected_weig
@pytest.mark.parametrize(
- "num_classes, segmented_detector, batch_detector_data, expected_probas", [
+ "num_classes, segmented_detector, batch_detector_data, expected_probas",
+ [
(
- 2,
- [[0, 0], [1, 1]],
- [
- [[0.00, 0.00], [1.00, 0.00]],
- [[0.10, 0.00], [0.00, 0.90]],
- [[0.22, 0.00], [0.30, 0.48]]
- ],
- [
- [0.00, 1.00],
- [0.10, 0.90],
- [0.22, 0.78]
- ]
- ),
- ]
+ 2,
+ [[0, 0], [1, 1]],
+ [
+ [[0.00, 0.00], [1.00, 0.00]],
+ [[0.10, 0.00], [0.00, 0.90]],
+ [[0.22, 0.00], [0.30, 0.48]],
+ ],
+ [[0.00, 1.00], [0.10, 0.90], [0.22, 0.78]],
+ ),
+ ],
)
-def test_detector_batch_forward(num_classes, segmented_detector, batch_detector_data, expected_probas):
+def test_detector_batch_forward(
+ num_classes, segmented_detector, batch_detector_data, expected_probas
+):
"""
Test of a method of DetectorProcessorClf for calculating probabilities for a batch.
...
@@ -192,9 +196,9 @@ def test_detector_batch_forward(num_classes, segmented_detector, batch_detector_
num_classes=num_classes,
simulation_parameters=SimulationParameters(
{
- 'W': torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[1]),
- 'H': torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[0]),
- 'wavelength': 500e-6
+ "W": torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[1]),
+ "H": torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[0]),
+ "wavelength": 500e-6,
}
),
segmented_detector=segmented_detector_tensor,
@@ -207,20 +211,18 @@ def test_detector_batch_forward(num_classes, segmented_detector, batch_detector_
@pytest.mark.parametrize(
- "num_classes, segments_zone_size, segmented_detector", [
+ "num_classes, segments_zone_size, segmented_detector",
+ [
(
- 2,
- [2, 2],
- [
- [-1, -1, -1, -1],
- [-1, 0, 1, -1],
- [-1, 0, 1, -1],
- [-1, -1, -1, -1]
- ]
- ),
- ]
+ 2,
+ [2, 2],
+ [[-1, -1, -1, -1], [-1, 0, 1, -1], [-1, 0, 1, -1], [-1, -1, -1, -1]],
+ ),
+ ],
)
-def test_detector_segmentation_for_aperture(num_classes, segments_zone_size, segmented_detector):
+def test_detector_segmentation_for_aperture(
+ num_classes, segments_zone_size, segmented_detector
+):
"""
Test of a feature of DetectorProcessorClf that extends segmented detector sizes to match SimulationParameters.
...
@@ -242,9 +244,9 @@ def test_detector_segmentation_for_aperture(num_classes, segments_zone_size, seg
num_classes=num_classes,
simulation_parameters=SimulationParameters(
{
- 'W': torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[1]),
- 'H': torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[0]),
- 'wavelength': 500e-6
+ "W": torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[1]),
+ "H": torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[0]),
+ "wavelength": 500e-6,
}
),
segments_zone_size=torch.Size(segments_zone_size),
@@ -262,20 +264,20 @@ def test_detector_device():
num_classes=2,
simulation_parameters=SimulationParameters(
{
- 'W': torch.linspace(-1e-2 / 2, 1e-2 / 2, 5),
- 'H': torch.linspace(-1e-2 / 2, 1e-2 / 2, 5),
- 'wavelength': 500e-6
+ "W": torch.linspace(-1e-2 / 2, 1e-2 / 2, 5),
+ "H": torch.linspace(-1e-2 / 2, 1e-2 / 2, 5),
+ "wavelength": 500e-6,
}
),
)
- processor_2 = processor.to('cpu')
+ processor_2 = processor.to("cpu")
assert isinstance(processor_2, DetectorProcessorClf)
# available device?
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- if device == torch.device('cpu'):
- device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ if device == torch.device("cpu"):
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
processor_3 = processor.to(device)
device_3 = processor_3.device
diff --git a/tests/test_device.py b/tests/test_device.py
index f4b57e9..e624526 100644
--- a/tests/test_device.py
+++ b/tests/test_device.py
@@ -11,10 +11,10 @@
parameters = "device_type"
-@pytest.mark.parametrize(parameters, [
- torch.device("cuda" if torch.cuda.is_available() else "cpu"),
- torch.device("cpu")
-])
+@pytest.mark.parametrize(
+ parameters,
+ [torch.device("cuda" if torch.cuda.is_available() else "cpu"), torch.device("cpu")],
+)
def test_devices(device_type: torch.device):
"""A test that checks that all elements belong to the same device
@@ -24,26 +24,30 @@ def test_devices(device_type: torch.device):
device for objects
"""
- ox_size = 15.
- oy_size = 8.
+ ox_size = 15.0
+ oy_size = 8.0
ox_nodes = 1200
oy_nodes = 1100
- wavelength = torch.linspace(330*1e-6, 660*1e-6, 5)
- waist_radius = 2.
- distance = 100.
- focal_length = 100.
- radius = 10.
- height = 4.
- width = 3.
+ wavelength = torch.linspace(330 * 1e-6, 660 * 1e-6, 5)
+ waist_radius = 2.0
+ distance = 100.0
+ focal_length = 100.0
+ radius = 10.0
+ height = 4.0
+ width = 3.0
tensors = []
params = SimulationParameters(
axes={
- 'W': torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes).to(device_type), # noqa: E501
- 'H': torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes).to(device_type), # noqa: E501
- 'wavelength': wavelength.to(device_type)
- }
+ "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes).to(
+ device_type
+ ), # noqa: E501
+ "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes).to(
+ device_type
+ ), # noqa: E501
+ "wavelength": wavelength.to(device_type),
+ }
).to(device=device_type)
x_linear = params.axes.W
@@ -53,34 +57,24 @@ def test_devices(device_type: torch.device):
wavelength = params.axes.wavelength
tensors.append(wavelength)
- x_grid, y_grid = params.meshgrid(x_axis='W', y_axis='H')
+ x_grid, y_grid = params.meshgrid(x_axis="W", y_axis="H")
tensors.extend([x_grid, y_grid])
gaussian_beam = w.gaussian_beam(
- simulation_parameters=params,
- waist_radius=waist_radius,
- distance=distance
+ simulation_parameters=params, waist_radius=waist_radius, distance=distance
)
tensors.append(gaussian_beam)
- plane_wave = w.plane_wave(
- simulation_parameters=params,
- distance=distance
- )
+ plane_wave = w.plane_wave(simulation_parameters=params, distance=distance)
tensors.append(plane_wave)
- spherical_wave = w.spherical_wave(
- simulation_parameters=params,
- distance=distance
- )
+ spherical_wave = w.spherical_wave(simulation_parameters=params, distance=distance)
tensors.append(spherical_wave)
lens = elements.ThinLens(
- simulation_parameters=params,
- focal_length=focal_length,
- radius=radius
+ simulation_parameters=params, focal_length=focal_length, radius=radius
)
tensors.append(lens.get_transmission_function())
@@ -88,26 +82,20 @@ def test_devices(device_type: torch.device):
tensors.append(lens.reverse(gaussian_beam))
aperture = elements.Aperture(
- simulation_parameters=params,
- mask=torch.zeros(x_grid.shape).to(device_type)
+ simulation_parameters=params, mask=torch.zeros(x_grid.shape).to(device_type)
)
tensors.append(aperture.get_transmission_function())
tensors.append(aperture.forward(gaussian_beam))
rectangular_aperture = elements.RectangularAperture(
- simulation_parameters=params,
- height=height,
- width=width
+ simulation_parameters=params, height=height, width=width
)
tensors.append(rectangular_aperture.get_transmission_function())
tensors.append(rectangular_aperture.forward(gaussian_beam))
- round_aperture = elements.RoundAperture(
- simulation_parameters=params,
- radius=radius
- )
+ round_aperture = elements.RoundAperture(simulation_parameters=params, radius=radius)
tensors.append(round_aperture.get_transmission_function())
tensors.append(round_aperture.forward(gaussian_beam))
@@ -116,7 +104,7 @@ def test_devices(device_type: torch.device):
simulation_parameters=params,
height=height,
width=width,
- mask=torch.ones_like(x_grid)
+ mask=torch.ones_like(x_grid),
)
tensors.append(slm.transmission_function)
@@ -124,8 +112,7 @@ def test_devices(device_type: torch.device):
tensors.append(slm.reverse(gaussian_beam))
layer = elements.DiffractiveLayer(
- simulation_parameters=params,
- mask=torch.zeros_like(x_grid)
+ simulation_parameters=params, mask=torch.zeros_like(x_grid)
)
tensors.append(layer.transmission_function)
@@ -133,29 +120,25 @@ def test_devices(device_type: torch.device):
tensors.append(layer.reverse(gaussian_beam))
free_space_as = elements.FreeSpace(
- simulation_parameters=params,
- distance=distance, method='AS'
+ simulation_parameters=params, distance=distance, method="AS"
)
tensors.append(free_space_as.forward(gaussian_beam))
free_space_fresnel = elements.FreeSpace(
- simulation_parameters=params,
- distance=distance, method='fresnel'
+ simulation_parameters=params, distance=distance, method="fresnel"
)
tensors.append(free_space_fresnel.forward(gaussian_beam))
free_space_reverse = elements.FreeSpace(
- simulation_parameters=params,
- distance=distance, method='fresnel'
+ simulation_parameters=params, distance=distance, method="fresnel"
)
tensors.append(free_space_reverse.reverse(gaussian_beam))
nl = elements.NonlinearElement(
- simulation_parameters=params,
- response_function=lambda x: x**2
+ simulation_parameters=params, response_function=lambda x: x**2
)
tensors.append(nl.forward(gaussian_beam))
@@ -163,33 +146,42 @@ def test_devices(device_type: torch.device):
assert all(tensor.device.type == device_type.type for tensor in tensors)
-@pytest.mark.parametrize(parameters, [
- torch.device("cuda" if torch.cuda.is_available() else "cpu"),
- torch.device("cpu")
-])
+@pytest.mark.parametrize(
+ parameters,
+ [torch.device("cuda" if torch.cuda.is_available() else "cpu"), torch.device("cpu")],
+)
def test_device_setup(device_type: torch.device):
+ """
+ Tests that the optical setup and parameters are correctly moved to the specified device.
+
+ Args:
+ device_type: The torch.device to move the setup and parameters to (CPU or CUDA).
+
+ Returns:
+ None
+ """
- ox_size = 15.
- oy_size = 8.
+ ox_size = 15.0
+ oy_size = 8.0
ox_nodes = 1200
oy_nodes = 1100
- wavelength = torch.linspace(330*1e-6, 660*1e-6, 5)
+ wavelength = torch.linspace(330 * 1e-6, 660 * 1e-6, 5)
# waist_radius = 2.
- distance = 50.
- focal_length = 100.
- radius = 10.
- height = 4.
- width = 3.
+ distance = 50.0
+ focal_length = 100.0
+ radius = 10.0
+ height = 4.0
+ width = 3.0
params = SimulationParameters(
axes={
- 'W': torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes),
- 'H': torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes),
- 'wavelength': wavelength
- }
+ "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes),
+ "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes),
+ "wavelength": wavelength,
+ }
)
- x_grid, _ = params.meshgrid(x_axis='W', y_axis='H')
+ x_grid, _ = params.meshgrid(x_axis="W", y_axis="H")
# gaussian_beam = w.gaussian_beam(
# simulation_parameters=params,
@@ -198,48 +190,36 @@ def test_device_setup(device_type: torch.device):
# )
free_space = elements.FreeSpace(
- simulation_parameters=params,
- distance=distance,
- method="AS"
+ simulation_parameters=params, distance=distance, method="AS"
)
- circle = elements.RoundAperture(
- simulation_parameters=params,
- radius=radius
- )
+ circle = elements.RoundAperture(simulation_parameters=params, radius=radius)
rectangle = elements.RectangularAperture(
- simulation_parameters=params,
- height=height,
- width=width
+ simulation_parameters=params, height=height, width=width
)
aperture = elements.Aperture(
- simulation_parameters=params,
- mask=torch.ones_like(x_grid)
+ simulation_parameters=params, mask=torch.ones_like(x_grid)
)
lens = elements.ThinLens(
- simulation_parameters=params,
- focal_length=distance,
- radius=focal_length
+ simulation_parameters=params, focal_length=distance, radius=focal_length
)
slm = elements.SpatialLightModulator(
simulation_parameters=params,
- mask=torch.tensor([[1., 1.], [1., 1.]]),
+ mask=torch.tensor([[1.0, 1.0], [1.0, 1.0]]),
height=height,
- width=width
+ width=width,
)
nl = elements.NonlinearElement(
- simulation_parameters=params,
- response_function=lambda x: x**2
+ simulation_parameters=params, response_function=lambda x: x**2
)
dl = elements.DiffractiveLayer(
- simulation_parameters=params,
- mask=torch.zeros_like(x_grid)
+ simulation_parameters=params, mask=torch.zeros_like(x_grid)
)
det = detector.Detector(simulation_parameters=params)
@@ -260,8 +240,7 @@ def test_device_setup(device_type: torch.device):
free_space,
dl,
free_space,
- det
-
+ det,
]
)
diff --git a/tests/test_diffraction_peaks.py b/tests/test_diffraction_peaks.py
index 821386b..8112880 100644
--- a/tests/test_diffraction_peaks.py
+++ b/tests/test_diffraction_peaks.py
@@ -29,48 +29,48 @@
(
500, # ox_size
500, # oy_size
- 1000000, # ox_nodes
- 10, # oy_nodes
+ 1000000, # ox_nodes
+ 10, # oy_nodes
1064 * 1e-6, # wavelength_test tensor, mm # noqa: E501
- 1500, # distance, mm
- 0.1, # width, mm
+ 1500, # distance, mm
+ 0.1, # width, mm
5, # max diffraction order to check
- 0.02 # expected_error
+ 0.02, # expected_error
),
(
500, # ox_size
500, # oy_size
- 1000000, # ox_nodes
- 10, # oy_nodes
+ 1000000, # ox_nodes
+ 10, # oy_nodes
660 * 1e-6, # wavelength_test tensor, mm # noqa: E501
- 1500, # distance, mm
- 0.1, # width, mm
+ 1500, # distance, mm
+ 0.1, # width, mm
6, # max diffraction order to check
- 0.02 # expected_error
+ 0.02, # expected_error
),
(
500, # ox_size
500, # oy_size
- 1000000, # ox_nodes
- 10, # oy_nodes
+ 1000000, # ox_nodes
+ 10, # oy_nodes
540 * 1e-6, # wavelength_test tensor, mm # noqa: E501
- 1500, # distance, mm
- 0.1, # width, mm
+ 1500, # distance, mm
+ 0.1, # width, mm
4, # max diffraction order to check
- 0.02 # expected_error
+ 0.02, # expected_error
),
(
500, # ox_size
500, # oy_size
- 1000000, # ox_nodes
- 10, # oy_nodes
+ 1000000, # ox_nodes
+ 10, # oy_nodes
990 * 1e-6, # wavelength_test tensor, mm # noqa: E501
- 1500, # distance, mm
- 0.1, # width, mm
+ 1500, # distance, mm
+ 0.1, # width, mm
8, # max diffraction order to check
- 0.02 # expected_error
+ 0.02, # expected_error
),
- ]
+ ],
)
def test_diffraction_peaks(
ox_size: float,
@@ -81,7 +81,7 @@ def test_diffraction_peaks(
distance: float,
width: float,
diffraction_order: int,
- expected_error: float
+ expected_error: float,
):
"""Test checking the coincidence of diffraction maxima at diffraction on a
thin slit
@@ -114,46 +114,39 @@ def test_diffraction_peaks(
y_length = torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes)
params = SimulationParameters(
- axes={
- 'W': x_length,
- 'H': y_length,
- 'wavelength': wavelength_test
- })
+ axes={"W": x_length, "H": y_length, "wavelength": wavelength_test}
+ )
beam = Wavefront.gaussian_beam(
- simulation_parameters=params,
- waist_radius=2.,
- distance=distance
+ simulation_parameters=params, waist_radius=2.0, distance=distance
)
# create rectangular aperture
rectangular_aperture = elements.RectangularAperture(
- simulation_parameters=params,
- height=height,
- width=width
+ simulation_parameters=params, height=height, width=width
)
field_after_aperture = rectangular_aperture(beam)
fs = elements.FreeSpace(
- simulation_parameters=params, distance=distance, method='AS'
+ simulation_parameters=params, distance=distance, method="AS"
)
output_field = fs.forward(field_after_aperture)
intensity_output = output_field.intensity
- amplitude_1d = np.sqrt(intensity_output.detach().numpy())[int(oy_nodes/2)]
+ amplitude_1d = np.sqrt(intensity_output.detach().numpy())[int(oy_nodes / 2)]
def intensity_analytic(coordinates: torch.Tensor) -> np.ndarray:
phi = np.arctan(coordinates / distance)
u = np.pi / wavelength_test * width * np.sin(phi)
- return (np.sin(u) / u)**2 * intensity_output[int(oy_nodes/2), int(ox_nodes/2)] # noqa: E501
+ return (np.sin(u) / u) ** 2 * intensity_output[
+ int(oy_nodes / 2), int(ox_nodes / 2)
+ ] # noqa: E501
def find_maximum(start, end):
result = minimize_scalar(
- lambda x: -intensity_analytic(x),
- bounds=(start, end),
- method='bounded'
+ lambda x: -intensity_analytic(x), bounds=(start, end), method="bounded"
)
return result.x
@@ -168,7 +161,7 @@ def find_maximum(start, end):
# define Gaussian function
def gaussian(x, amp, cen, wid):
- return amp * np.exp(-(x-cen)**2 / (2*wid**2))
+ return amp * np.exp(-((x - cen) ** 2) / (2 * wid**2))
x_max_averaged = np.array([])
diff --git a/tests/test_drnn.py b/tests/test_drnn.py
index 9fb0a40..fc55d89 100644
--- a/tests/test_drnn.py
+++ b/tests/test_drnn.py
@@ -14,9 +14,9 @@ def sim_params():
"""Returns SimulationParameters object."""
return SimulationParameters(
{
- 'W': torch.linspace(-1e-2 / 2, 1e-2 / 2, 10),
- 'H': torch.linspace(-1e-2 / 2, 1e-2 / 2, 10),
- 'wavelength': 800e-6
+ "W": torch.linspace(-1e-2 / 2, 1e-2 / 2, 10),
+ "H": torch.linspace(-1e-2 / 2, 1e-2 / 2, 10),
+ "wavelength": 800e-6,
}
)
@@ -24,10 +24,7 @@ def sim_params():
@pytest.fixture()
def zero_free_space(sim_params):
"""Returns FreeSpace with a zero distance!"""
- return FreeSpace(
- simulation_parameters=sim_params,
- distance=0.0, method='AS'
- )
+ return FreeSpace(simulation_parameters=sim_params, distance=0.0, method="AS")
@pytest.fixture()
@@ -39,24 +36,29 @@ def empty_layer(zero_free_space):
@pytest.fixture()
def detector(sim_params):
"""Returns nn.Sequentional with a Detector for RNN detector_layer"""
- return nn.Sequential(
- Detector(sim_params, func='intensity')
- )
+ return nn.Sequential(Detector(sim_params, func="intensity"))
@pytest.mark.parametrize(
- "sequence_len, fusing_coeff, sequence_amplitudes", [
+ "sequence_len, fusing_coeff, sequence_amplitudes",
+ [
(1, 0.30, [1.00]),
(2, 0.25, [0.77, 0.13]),
(3, 0.50, [0.80, 1.50, 2.10]),
- ]
+ ],
)
def test_drnn_forward(
- sim_params, empty_layer, detector, # fixtures
- sequence_len, fusing_coeff, sequence_amplitudes
+ sim_params,
+ empty_layer,
+ detector, # fixtures
+ sequence_len,
+ fusing_coeff,
+ sequence_amplitudes,
):
"""Test forward function for a single Wavefront sequence."""
- h, w = sim_params.axes_size(axs=('H', 'W')) # size of a wavefront according to SimulationParameters
+ h, w = sim_params.axes_size(
+ axs=("H", "W")
+ ) # size of a wavefront according to SimulationParameters
wavefront_seq = Wavefront(
torch.ones(size=(sequence_len, h, w), dtype=torch.float64)
)
@@ -69,46 +71,84 @@ def test_drnn_forward(
for step_ind in range(sequence_len):
input = sequence_amplitudes[step_ind]
hidden = fusing_coeff * hidden + (1 - fusing_coeff) * input
- out_expected_val = hidden ** 2 # intensity output
+ out_expected_val = hidden**2 # intensity output
# empty D-RNN
drnn = DiffractiveRNN(
sim_params,
- sequence_len=sequence_len, fusing_coeff=fusing_coeff,
- read_in_layer=empty_layer, memory_layer=empty_layer,
+ sequence_len=sequence_len,
+ fusing_coeff=fusing_coeff,
+ read_in_layer=empty_layer,
+ memory_layer=empty_layer,
hidden_forward_layer=empty_layer,
- read_out_layer=empty_layer, detector_layer=detector,
+ read_out_layer=empty_layer,
+ detector_layer=detector,
device=torch.get_default_device(),
)
# forward for D-RNN
out_drnn = drnn(wavefront_seq)
assert torch.allclose(
- out_drnn,
- torch.ones(size=(h, w), dtype=torch.float64) * out_expected_val
+ out_drnn, torch.ones(size=(h, w), dtype=torch.float64) * out_expected_val
)
@pytest.mark.parametrize(
- "batch_size, sequence_len, fusing_coeff, sequence_amplitudes", [
- (3, 1, 0.30, [[1.00], [0.40], [0.55],]),
- (3, 2, 0.25, [[0.77, 0.13], [0.10, 1.10], [2.20, 5.00],]),
- (2, 3, 0.50, [[0.80, 1.50, 2.10], [1.00, 2.00, 3.00],]),
- ]
+ "batch_size, sequence_len, fusing_coeff, sequence_amplitudes",
+ [
+ (
+ 3,
+ 1,
+ 0.30,
+ [
+ [1.00],
+ [0.40],
+ [0.55],
+ ],
+ ),
+ (
+ 3,
+ 2,
+ 0.25,
+ [
+ [0.77, 0.13],
+ [0.10, 1.10],
+ [2.20, 5.00],
+ ],
+ ),
+ (
+ 2,
+ 3,
+ 0.50,
+ [
+ [0.80, 1.50, 2.10],
+ [1.00, 2.00, 3.00],
+ ],
+ ),
+ ],
)
def test_drnn_batch_forward(
- sim_params, empty_layer, detector, # fixtures
- batch_size, sequence_len, fusing_coeff, sequence_amplitudes
+ sim_params,
+ empty_layer,
+ detector, # fixtures
+ batch_size,
+ sequence_len,
+ fusing_coeff,
+ sequence_amplitudes,
):
"""Test forward function for a batch of Wavefront sequences."""
- h, w = sim_params.axes_size(axs=('H', 'W')) # size of a wavefront according to SimulationParameters
+ h, w = sim_params.axes_size(
+ axs=("H", "W")
+ ) # size of a wavefront according to SimulationParameters
wavefront_seq_batch = Wavefront(
torch.ones(size=(batch_size, sequence_len, h, w), dtype=torch.float64)
)
for seq_ind in range(batch_size):
for step_ind in range(sequence_len): # set amplitudes for a wavefront sequence
- wavefront_seq_batch[seq_ind, step_ind, :, :] *= sequence_amplitudes[seq_ind][step_ind]
+ wavefront_seq_batch[seq_ind, step_ind, :, :] *= sequence_amplitudes[
+ seq_ind
+ ][step_ind]
# calculate expected values for an empty D-RNN with specified fusing coefficient
out_expected_values = []
@@ -117,16 +157,19 @@ def test_drnn_batch_forward(
for step_ind in range(sequence_len):
input = sequence_amplitudes[seq_ind][step_ind]
hidden = fusing_coeff * hidden + (1 - fusing_coeff) * input
- out_expected_val = hidden ** 2 # intensity output
+ out_expected_val = hidden**2 # intensity output
out_expected_values.append(out_expected_val)
# empty D-RNN
drnn = DiffractiveRNN(
sim_params,
- sequence_len=sequence_len, fusing_coeff=fusing_coeff,
- read_in_layer=empty_layer, memory_layer=empty_layer,
+ sequence_len=sequence_len,
+ fusing_coeff=fusing_coeff,
+ read_in_layer=empty_layer,
+ memory_layer=empty_layer,
hidden_forward_layer=empty_layer,
- read_out_layer=empty_layer, detector_layer=detector,
+ read_out_layer=empty_layer,
+ detector_layer=detector,
device=torch.get_default_device(),
)
# forward for D-RNN
@@ -135,7 +178,7 @@ def test_drnn_batch_forward(
for ind_seq in range(batch_size):
assert torch.allclose(
out_drnn[ind_seq, :, :],
- torch.ones(size=(h, w), dtype=torch.float64) * out_expected_values[ind_seq]
+ torch.ones(size=(h, w), dtype=torch.float64) * out_expected_values[ind_seq],
)
@@ -144,17 +187,22 @@ def test_drnn_device(sim_params, empty_layer, detector):
# empty D-RNN
drnn = DiffractiveRNN(
sim_params,
- sequence_len=3, fusing_coeff=0.5, # some values
- read_in_layer=empty_layer, memory_layer=empty_layer,
+ sequence_len=3,
+ fusing_coeff=0.5, # some values
+ read_in_layer=empty_layer,
+ memory_layer=empty_layer,
hidden_forward_layer=empty_layer,
- read_out_layer=empty_layer, detector_layer=detector,
- device='cpu',
+ read_out_layer=empty_layer,
+ detector_layer=detector,
+ device="cpu",
)
- assert drnn.device == torch.device('cpu')
+ assert drnn.device == torch.device("cpu")
- new_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- if new_device == torch.device('cpu'): # if cuda is not available - check if `mps` is
- new_device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
+ new_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ if new_device == torch.device(
+ "cpu"
+ ): # if cuda is not available - check if `mps` is
+ new_device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
new_drnn = drnn.to(new_device)
diff --git a/tests/test_element.py b/tests/test_element.py
index 7c2c181..e2e2aab 100644
--- a/tests/test_element.py
+++ b/tests/test_element.py
@@ -8,44 +8,77 @@
class ElementToTest(svetlanna.elements.Element):
+ """
+ Represents an element for testing within a simulation.
+
+ This class serves as a basic building block for testing purposes,
+ primarily passing data through without modification.
+ """
+
def __init__(
self,
simulation_parameters: svetlanna.SimulationParameters,
test_parameter,
test_buffer,
) -> None:
+ """
+ Initializes the class with simulation parameters and test data.
+
+ Args:
+ simulation_parameters: The simulation parameters object.
+ test_parameter: The parameter to be processed.
+ test_buffer: The buffer to be created.
+
+ Returns:
+ None
+ """
super().__init__(simulation_parameters)
- self.test_parameter = self.process_parameter(
- 'test_parameter', test_parameter
- )
- self.test_buffer = self.make_buffer(
- 'test_buffer', test_buffer
- )
-
- def forward(
- self,
- incident_wavefront: svetlanna.Wavefront
- ) -> svetlanna.Wavefront:
+ self.test_parameter = self.process_parameter("test_parameter", test_parameter)
+ self.test_buffer = self.make_buffer("test_buffer", test_buffer)
+
+ def forward(self, incident_wavefront: svetlanna.Wavefront) -> svetlanna.Wavefront:
+ """
+ Passes the incident wavefront through the layer.
+
+ This method simply calls the `forward` method of the parent class,
+ effectively passing the input wavefront unchanged.
+
+ Args:
+ incident_wavefront: The incoming wavefront.
+
+ Returns:
+ svetlanna.Wavefront: The transmitted wavefront (identical to the input).
+ """
return super().forward(incident_wavefront)
def test_setattr():
+ """
+ Tests that setattr correctly saves inner storage of parameters.
+
+ This test creates a simulation and an element with a parameter, then
+ asserts that the inner storage of the parameter is saved as expected
+ and accessible through a specific attribute name. It also verifies that
+ the inner parameter is present in the element's parameters dictionary.
+
+ Args:
+ None
+
+ Returns:
+ None
+ """
sim_params = svetlanna.SimulationParameters(
{
- 'W': torch.linspace(-10, 10, 100),
- 'H': torch.linspace(-10, 10, 100),
- 'wavelength': 1.,
+ "W": torch.linspace(-10, 10, 100),
+ "H": torch.linspace(-10, 10, 100),
+ "wavelength": 1.0,
}
)
- test_parameter = svetlanna.Parameter(10.)
- element = ElementToTest(
- sim_params,
- test_parameter=test_parameter,
- test_buffer=None
- )
+ test_parameter = svetlanna.Parameter(10.0)
+ element = ElementToTest(sim_params, test_parameter=test_parameter, test_buffer=None)
# check if inner storage of the parameter has been saved
- parameter_name = 'test_parameter' + INNER_PARAMETER_SUFFIX
+ parameter_name = "test_parameter" + INNER_PARAMETER_SUFFIX
assert getattr(element, parameter_name) is test_parameter.inner_storage
assert element.test_parameter.inner_parameter in element.parameters()
@@ -53,146 +86,154 @@ def test_setattr():
@pytest.mark.parametrize(
("device",),
[
+ pytest.param("cpu"),
pytest.param(
- 'cpu'
- ),
- pytest.param(
- 'cuda',
+ "cuda",
marks=pytest.mark.skipif(
- not torch.cuda.is_available(),
- reason="cuda is not available"
- )
+ not torch.cuda.is_available(), reason="cuda is not available"
+ ),
),
pytest.param(
- 'mps',
+ "mps",
marks=pytest.mark.skipif(
- not torch.backends.mps.is_available(),
- reason="mps is not available"
- )
- )
- ]
+ not torch.backends.mps.is_available(), reason="mps is not available"
+ ),
+ ),
+ ],
)
def test_make_buffer(device):
+ """
+ Tests the registration and device placement of a buffer.
+
+ Args:
+ device: The device to move the element to ('cpu', 'cuda', or 'mps').
+
+ Returns:
+ None
+ """
sim_params = svetlanna.SimulationParameters(
{
- 'W': torch.linspace(-10, 10, 100),
- 'H': torch.linspace(-10, 10, 100),
- 'wavelength': 1.,
+ "W": torch.linspace(-10, 10, 100),
+ "H": torch.linspace(-10, 10, 100),
+ "wavelength": 1.0,
}
)
- test_buffer = torch.tensor(123.)
- element = ElementToTest(
- sim_params,
- test_parameter=None,
- test_buffer=test_buffer
- )
+ test_buffer = torch.tensor(123.0)
+ element = ElementToTest(sim_params, test_parameter=None, test_buffer=test_buffer)
# check if buffer has been registered
- assert hasattr(element, 'test_buffer')
- assert getattr(element, 'test_buffer') in element.buffers()
+ assert hasattr(element, "test_buffer")
+ assert getattr(element, "test_buffer") in element.buffers()
# check if buffer is automatically transferred to device
element.to(device)
- assert getattr(element, 'test_buffer').device.type == device
+ assert getattr(element, "test_buffer").device.type == device
# test if a buffer cannot be registered with a tensor on a device
# distinct from the simulation parameters' device
- if device != 'cpu':
+ if device != "cpu":
with pytest.raises(ValueError):
element = ElementToTest(
- sim_params,
- test_parameter=None,
- test_buffer=test_buffer.to(device)
+ sim_params, test_parameter=None, test_buffer=test_buffer.to(device)
)
@pytest.mark.parametrize(
("device",),
[
+ pytest.param("cpu"),
pytest.param(
- 'cpu'
- ),
- pytest.param(
- 'cuda',
+ "cuda",
marks=pytest.mark.skipif(
- not torch.cuda.is_available(),
- reason="cuda is not available"
- )
+ not torch.cuda.is_available(), reason="cuda is not available"
+ ),
),
pytest.param(
- 'mps',
+ "mps",
marks=pytest.mark.skipif(
- not torch.backends.mps.is_available(),
- reason="mps is not available"
- )
- )
- ]
+ not torch.backends.mps.is_available(), reason="mps is not available"
+ ),
+ ),
+ ],
)
def test_process_parameter(device):
+ """
+ Tests the processing of parameters within the ElementToTest class.
+
+ This test verifies that parameters are correctly registered, transferred to the specified device,
+ and handled appropriately when provided as tensors or directly as nn.Parameters. It also checks
+ for ValueErrors when attempting to register a parameter tensor on a different device than the simulation.
+
+ Args:
+ device: The device (e.g., 'cpu', 'cuda', 'mps') to test with.
+
+ Returns:
+ None
+ """
sim_params = svetlanna.SimulationParameters(
{
- 'W': torch.linspace(-10, 10, 100),
- 'H': torch.linspace(-10, 10, 100),
- 'wavelength': 1.,
+ "W": torch.linspace(-10, 10, 100),
+ "H": torch.linspace(-10, 10, 100),
+ "wavelength": 1.0,
}
)
- test_parameter = torch.nn.Parameter(torch.tensor(123.))
- element = ElementToTest(
- sim_params,
- test_parameter=test_parameter,
- test_buffer=None
- )
+ test_parameter = torch.nn.Parameter(torch.tensor(123.0))
+ element = ElementToTest(sim_params, test_parameter=test_parameter, test_buffer=None)
# check if parameter has been registered
- assert hasattr(element, 'test_parameter')
- assert getattr(element, 'test_parameter') in element.parameters()
+ assert hasattr(element, "test_parameter")
+ assert getattr(element, "test_parameter") in element.parameters()
# check if parameter is automatically transferred to device
element.to(device)
- assert getattr(element, 'test_parameter').device.type == device
+ assert getattr(element, "test_parameter").device.type == device
# test tensor as a parameter
- test_parameter = torch.tensor(123.)
- element = ElementToTest(
- sim_params,
- test_parameter=test_parameter,
- test_buffer=None
- )
+ test_parameter = torch.tensor(123.0)
+ element = ElementToTest(sim_params, test_parameter=test_parameter, test_buffer=None)
# check if test_parameter has been registered as a buffer
- assert hasattr(element, 'test_parameter')
- assert getattr(element, 'test_parameter') not in element.parameters()
- assert getattr(element, 'test_parameter') in element.buffers()
+ assert hasattr(element, "test_parameter")
+ assert getattr(element, "test_parameter") not in element.parameters()
+ assert getattr(element, "test_parameter") in element.buffers()
# test if a parameter cannot be registered with a tensor on a device
# distinct from the simulation parameters' device
- if device != 'cpu':
+ if device != "cpu":
with pytest.raises(ValueError):
element = ElementToTest(
- sim_params,
- test_parameter=test_parameter.to(device),
- test_buffer=None
+ sim_params, test_parameter=test_parameter.to(device), test_buffer=None
)
def test_to_specs():
+ """
+ Tests the conversion of an element to specifications.
+
+ This test creates a simulation parameter set and an element with a
+ test parameter, then asserts that the `to_specs` method generates a list
+ containing a single specification for the test parameter, and that this
+ specification contains a representation of type ReprRepr.
+
+ Args:
+ None
+
+ Returns:
+ None
+ """
sim_params = svetlanna.SimulationParameters(
{
- 'W': torch.linspace(-10, 10, 100),
- 'H': torch.linspace(-10, 10, 100),
- 'wavelength': 1.,
+ "W": torch.linspace(-10, 10, 100),
+ "H": torch.linspace(-10, 10, 100),
+ "wavelength": 1.0,
}
)
- test_parameter = torch.nn.Parameter(torch.tensor(123.))
- element = ElementToTest(
- sim_params,
- test_parameter=test_parameter,
- test_buffer=None
- )
+ test_parameter = torch.nn.Parameter(torch.tensor(123.0))
+ element = ElementToTest(sim_params, test_parameter=test_parameter, test_buffer=None)
specs = list(element.to_specs())
assert len(specs) == 1
- assert specs[0].parameter_name == 'test_parameter'
+ assert specs[0].parameter_name == "test_parameter"
representations = list(specs[0].representations)
assert len(representations) == 1
@@ -200,39 +241,56 @@ def test_to_specs():
def test_make_buffer_pattern():
+ """
+ Tests the creation of a buffer pattern using make_buffer.
+
+ This test instantiates an ElementToTest object with simulation parameters and
+ asserts that calling make_buffer returns an instance of _BufferedValueContainer.
+ It also checks for expected warnings when assigning a buffered value to another attribute.
+
+ Args:
+ None
+
+ Returns:
+ None
+ """
sim_params = svetlanna.SimulationParameters(
{
- 'W': torch.linspace(-10, 10, 100),
- 'H': torch.linspace(-10, 10, 100),
- 'wavelength': 1.,
+ "W": torch.linspace(-10, 10, 100),
+ "H": torch.linspace(-10, 10, 100),
+ "wavelength": 1.0,
}
)
- element = ElementToTest(
- sim_params,
- test_parameter=None,
- test_buffer=None
- )
+ element = ElementToTest(sim_params, test_parameter=None, test_buffer=None)
- assert isinstance(element.make_buffer('x', None), _BufferedValueContainer)
+ assert isinstance(element.make_buffer("x", None), _BufferedValueContainer)
with pytest.warns(
match="You set the attribute y with an object of internal type _BufferedValueContainer. Make sure this is the intended behavior."
):
- element.y = element.make_buffer('x', None)
+ element.y = element.make_buffer("x", None)
def test_repr_html():
+ """
+ Tests the HTML representation of an element.
+
+ This test instantiates a simulation and an ElementToTest object,
+ then asserts that the _repr_html_() method returns a string.
+
+ Parameters:
+ None
+
+ Returns:
+ None
+ """
sim_params = svetlanna.SimulationParameters(
{
- 'W': torch.linspace(-10, 10, 100),
- 'H': torch.linspace(-10, 10, 100),
- 'wavelength': 1.,
+ "W": torch.linspace(-10, 10, 100),
+ "H": torch.linspace(-10, 10, 100),
+ "wavelength": 1.0,
}
)
- element = ElementToTest(
- sim_params,
- test_parameter=None,
- test_buffer=None
- )
+ element = ElementToTest(sim_params, test_parameter=None, test_buffer=None)
assert isinstance(element._repr_html_(), str)
diff --git a/tests/test_freespace.py b/tests/test_freespace.py
index 78b0d2f..5cba607 100644
--- a/tests/test_freespace.py
+++ b/tests/test_freespace.py
@@ -16,7 +16,7 @@
"distance_total",
"distance_end",
"expected_error",
- "error_energy"
+ "error_energy",
]
@@ -27,28 +27,30 @@
(
6, # ox_size
6, # oy_size
- 1500, # ox_nodes
- 1600, # oy_nodes
- torch.linspace(330*1e-6, 660*1e-6, 5), # wavelength_test tensor, mm # noqa: E501
- 2., # waist_radius_test, mm
- 300, # distance_total, mm
- 200, # distance_end, mm
- 0.02, # expected_std
- 0.01 # error_energy
+ 1500, # ox_nodes
+ 1600, # oy_nodes
+ torch.linspace(
+ 330 * 1e-6, 660 * 1e-6, 5
+ ), # wavelength_test tensor, mm # noqa: E501
+ 2.0, # waist_radius_test, mm
+ 300, # distance_total, mm
+ 200, # distance_end, mm
+ 0.02, # expected_std
+ 0.01, # error_energy
),
(
6, # ox_size
6, # oy_size
- 1500, # ox_nodes
- 1600, # oy_nodes
+ 1500, # ox_nodes
+ 1600, # oy_nodes
660 * 1e-6, # wavelength_test, mm
- 2., # waist_radius_test, mm
- 300, # distance_total, mm
- 200, # distance_end, mm
- 0.02, # expected_std
- 0.01 # error_energy
- )
- ]
+ 2.0, # waist_radius_test, mm
+ 300, # distance_total, mm
+ 200, # distance_end, mm
+ 0.02, # expected_std
+ 0.01, # error_energy
+ ),
+ ],
)
def test_gaussian_beam_propagation(
ox_size: float,
@@ -60,7 +62,7 @@ def test_gaussian_beam_propagation(
distance_total: float,
distance_end: float,
expected_error: float,
- error_energy: float
+ error_energy: float,
):
"""Test for the free field propagation problem: free propagation of the
Gaussian beam at the arbitrary distance(distance_total). We calculate the
@@ -95,7 +97,7 @@ def test_gaussian_beam_propagation(
x_linear = torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes)
y_linear = torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes)
- x_grid, y_grid = torch.meshgrid(x_linear, y_linear, indexing='xy')
+ x_grid, y_grid = torch.meshgrid(x_linear, y_linear, indexing="xy")
# creating meshgrid
x_grid = x_grid[None, :]
@@ -103,7 +105,7 @@ def test_gaussian_beam_propagation(
# wave_number = 2 * torch.pi / wavelength_test[..., None, None]
- amplitude = 1.
+ amplitude = 1.0
dx = ox_size / ox_nodes
dy = oy_size / oy_nodes
@@ -112,35 +114,50 @@ def test_gaussian_beam_propagation(
wave_number = 2 * torch.pi / wavelength_test
rayleigh_range = torch.pi * (waist_radius_test**2) / wavelength_test
else:
- rayleigh_range = torch.pi * (waist_radius_test**2) / wavelength_test[..., None, None] # noqa: E501
+ rayleigh_range = (
+ torch.pi * (waist_radius_test**2) / wavelength_test[..., None, None]
+ ) # noqa: E501
wave_number = 2 * torch.pi / wavelength_test[..., None, None]
radial_distance_squared = torch.pow(x_grid, 2) + torch.pow(y_grid, 2)
- hyperbolic_relation = waist_radius_test * (1 + (
- distance_total / rayleigh_range)**2)**(1/2)
+ hyperbolic_relation = waist_radius_test * (
+ 1 + (distance_total / rayleigh_range) ** 2
+ ) ** (1 / 2)
- radius_of_curvature = distance_total * (
- 1 + (rayleigh_range / distance_total)**2
- )
+ radius_of_curvature = distance_total * (1 + (rayleigh_range / distance_total) ** 2)
# Gouy phase
gouy_phase = torch.arctan(torch.tensor(distance_total) / rayleigh_range)
# analytical equation for the propagation of the Gaussian beam
- field = amplitude * (waist_radius_test / hyperbolic_relation) * (
- torch.exp(-radial_distance_squared / (hyperbolic_relation)**2) * (
- torch.exp(-1j * (wave_number * distance_total + wave_number * (
- radial_distance_squared) / (2 * radius_of_curvature) - (
- gouy_phase)))))
+ field = (
+ amplitude
+ * (waist_radius_test / hyperbolic_relation)
+ * (
+ torch.exp(-radial_distance_squared / (hyperbolic_relation) ** 2)
+ * (
+ torch.exp(
+ -1j
+ * (
+ wave_number * distance_total
+ + wave_number
+ * (radial_distance_squared)
+ / (2 * radius_of_curvature)
+ - (gouy_phase)
+ )
+ )
+ )
+ )
+ )
intensity_analytic = torch.pow(torch.abs(field), 2)
params = SimulationParameters(
{
- 'W': torch.linspace(-ox_size/2, ox_size/2, ox_nodes),
- 'H': torch.linspace(-oy_size/2, oy_size/2, oy_nodes),
- 'wavelength': wavelength_test
+ "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes),
+ "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes),
+ "wavelength": wavelength_test,
}
)
@@ -149,53 +166,45 @@ def test_gaussian_beam_propagation(
field_gb_start = Wavefront.gaussian_beam(
simulation_parameters=params,
distance=distance_start,
- waist_radius=waist_radius_test
+ waist_radius=waist_radius_test,
)
# field on the screen by using Fresnel propagation method
field_end_fresnel = elements.FreeSpace(
- simulation_parameters=params, distance=distance_end, method='fresnel'
+ simulation_parameters=params, distance=distance_end, method="fresnel"
)(field_gb_start)
# field on the screen by using angular spectrum method
field_end_as = elements.FreeSpace(
- simulation_parameters=params, distance=distance_end, method='AS'
+ simulation_parameters=params, distance=distance_end, method="AS"
)(field_gb_start)
intensity_output_fresnel = field_end_fresnel.intensity
intensity_output_as = field_end_as.intensity
- energy_analytic = torch.sum(
- intensity_analytic, dim=(-2, -1)
- ) * dx * dy
- energy_numeric_fresnel = torch.sum(
- intensity_output_fresnel, dim=(-2, -1)
- ) * dx * dy
- energy_numeric_as = torch.sum(
- intensity_output_as, dim=(-2, -1)
- ) * dx * dy
+ energy_analytic = torch.sum(intensity_analytic, dim=(-2, -1)) * dx * dy
+ energy_numeric_fresnel = torch.sum(intensity_output_fresnel, dim=(-2, -1)) * dx * dy
+ energy_numeric_as = torch.sum(intensity_output_as, dim=(-2, -1)) * dx * dy
intensity_difference_fresnel = torch.abs(
intensity_analytic - intensity_output_fresnel
) / (ox_nodes * oy_nodes)
- intensity_difference_as = torch.abs(
- intensity_analytic - intensity_output_as
- ) / (ox_nodes * oy_nodes)
+ intensity_difference_as = torch.abs(intensity_analytic - intensity_output_as) / (
+ ox_nodes * oy_nodes
+ )
error_fresnel, _ = intensity_difference_fresnel.view(
intensity_difference_fresnel.size(0), -1
).max(dim=1)
- error_as, _ = intensity_difference_as.view(
- intensity_difference_as.size(0), -1
- ).max(dim=1)
+ error_as, _ = intensity_difference_as.view(intensity_difference_as.size(0), -1).max(
+ dim=1
+ )
energy_error_fresnel = torch.abs(
(energy_analytic - energy_numeric_fresnel) / energy_analytic
)
- energy_error_as = torch.abs(
- (energy_analytic - energy_numeric_as) / energy_analytic
- )
+ energy_error_as = torch.abs((energy_analytic - energy_numeric_as) / energy_analytic)
assert (error_fresnel <= expected_error).all()
assert (error_as <= expected_error).all()
@@ -211,7 +220,7 @@ def test_gaussian_beam_propagation(
"wavelength_test",
"waist_radius_test",
"distance",
- "expected_error"
+ "expected_error",
]
@@ -221,36 +230,34 @@ def test_gaussian_beam_propagation(
(
6, # ox_size
6, # oy_size
- 1569, # ox_nodes
- 1698, # oy_nodes
+ 1569, # ox_nodes
+ 1698, # oy_nodes
660 * 1e-6, # wavelength_test tensor, mm # noqa: E501
- 2., # waist_radius_test, mm
- 300, # distance, mm
- 0.5 # expected relative error
+ 2.0, # waist_radius_test, mm
+ 300, # distance, mm
+ 0.5, # expected relative error
),
-
(
15, # ox_size
8, # oy_size
- 1111, # ox_nodes
- 14070, # oy_nodes
+ 1111, # ox_nodes
+ 14070, # oy_nodes
330 * 1e-6, # wavelength_test tensor, mm # noqa: E501
- 1., # waist_radius_test, mm
- 50, # distance, mm
- 1.7 # expected relative error
+ 1.0, # waist_radius_test, mm
+ 50, # distance, mm
+ 1.7, # expected relative error
),
-
(
20, # ox_size
23, # oy_size
- 1800, # ox_nodes
- 1032, # oy_nodes
+ 1800, # ox_nodes
+ 1032, # oy_nodes
540 * 1e-6, # wavelength_test tensor, mm # noqa: E501
- 4., # waist_radius_test, mm
- 500, # distance, mm
- 0.5 # expected relative error
+ 4.0, # waist_radius_test, mm
+ 500, # distance, mm
+ 0.5, # expected relative error
),
- ]
+ ],
)
def test_gaussian_beam_fwhm(
ox_size: float,
@@ -262,28 +269,42 @@ def test_gaussian_beam_fwhm(
distance: float,
expected_error: float,
):
+ """
+ Tests the FWHM calculation for a Gaussian beam using Fresnel and Angular Spectrum methods.
+
+ Args:
+ ox_size: The size of the x-axis grid.
+ oy_size: The size of the y-axis grid.
+ ox_nodes: The number of nodes in the x-axis grid.
+ oy_nodes: The number of nodes in the y-axis grid.
+ wavelength_test: The wavelength of the light.
+ waist_radius_test: The waist radius of the Gaussian beam.
+ distance: The propagation distance.
+ expected_error: The expected relative error for the FWHM calculation.
+
+ Returns:
+ None. Raises an AssertionError if the calculated relative errors exceed the expected error.
+ """
params = SimulationParameters(
{
- 'W': torch.linspace(-ox_size/2, ox_size/2, ox_nodes),
- 'H': torch.linspace(-oy_size/2, oy_size/2, oy_nodes),
- 'wavelength': wavelength_test
+ "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes),
+ "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes),
+ "wavelength": wavelength_test,
}
)
field_gb_start = Wavefront.gaussian_beam(
- simulation_parameters=params,
- distance=0.,
- waist_radius=waist_radius_test
+ simulation_parameters=params, distance=0.0, waist_radius=waist_radius_test
)
# field on the screen by using Fresnel propagation method
field_end_fresnel = elements.FreeSpace(
- simulation_parameters=params, distance=distance, method='fresnel'
+ simulation_parameters=params, distance=distance, method="fresnel"
)(field_gb_start)
# field on the screen by using angular spectrum method
field_end_as = elements.FreeSpace(
- simulation_parameters=params, distance=distance, method='AS'
+ simulation_parameters=params, distance=distance, method="AS"
)(field_gb_start)
fwhm_x_as, fwhm_y_as = field_end_as.fwhm(simulation_parameters=params)
@@ -291,27 +312,24 @@ def test_gaussian_beam_fwhm(
simulation_parameters=params
)
- fwhm_analytical = torch.sqrt(
- 2. * torch.log(torch.tensor([2.]))
- ) * waist_radius_test * torch.sqrt(
- torch.tensor([1.]) + (
- distance / (torch.pi * waist_radius_test**2 / wavelength_test)
- )**2
+ fwhm_analytical = (
+ torch.sqrt(2.0 * torch.log(torch.tensor([2.0])))
+ * waist_radius_test
+ * torch.sqrt(
+ torch.tensor([1.0])
+ + (distance / (torch.pi * waist_radius_test**2 / wavelength_test)) ** 2
+ )
)
- relative_error_x_as = torch.abs(
- fwhm_x_as - fwhm_analytical
- ) / fwhm_analytical * 100
- relative_error_y_as = torch.abs(
- fwhm_y_as - fwhm_analytical
- ) / fwhm_analytical * 100
+ relative_error_x_as = torch.abs(fwhm_x_as - fwhm_analytical) / fwhm_analytical * 100
+ relative_error_y_as = torch.abs(fwhm_y_as - fwhm_analytical) / fwhm_analytical * 100
- relative_error_x_fresnel = torch.abs(
- fwhm_x_fresnel - fwhm_analytical
- ) / fwhm_analytical * 100
- relative_error_y_fresnel = torch.abs(
- fwhm_y_fresnel - fwhm_analytical
- ) / fwhm_analytical * 100
+ relative_error_x_fresnel = (
+ torch.abs(fwhm_x_fresnel - fwhm_analytical) / fwhm_analytical * 100
+ )
+ relative_error_y_fresnel = (
+ torch.abs(fwhm_y_fresnel - fwhm_analytical) / fwhm_analytical * 100
+ )
assert (relative_error_x_as <= expected_error).all()
assert (relative_error_y_as <= expected_error).all()
@@ -327,7 +345,7 @@ def test_gaussian_beam_fwhm(
"wavelength_test",
"waist_radius_test",
"distance",
- "expected_error"
+ "expected_error",
]
@@ -337,36 +355,34 @@ def test_gaussian_beam_fwhm(
(
6, # ox_size
6, # oy_size
- 1569, # ox_nodes
- 1698, # oy_nodes
+ 1569, # ox_nodes
+ 1698, # oy_nodes
660 * 1e-6, # wavelength_test tensor, mm # noqa: E501
- 2., # waist_radius_test, mm
- 300, # distance, mm
- 0.5 # expected relative error
+ 2.0, # waist_radius_test, mm
+ 300, # distance, mm
+ 0.5, # expected relative error
),
-
(
15, # ox_size
8, # oy_size
- 1111, # ox_nodes
- 14070, # oy_nodes
+ 1111, # ox_nodes
+ 14070, # oy_nodes
330 * 1e-6, # wavelength_test tensor, mm # noqa: E501
- 1., # waist_radius_test, mm
- 50, # distance, mm
- 1.7 # expected relative error
+ 1.0, # waist_radius_test, mm
+ 50, # distance, mm
+ 1.7, # expected relative error
),
-
(
20, # ox_size
23, # oy_size
- 1800, # ox_nodes
- 1032, # oy_nodes
+ 1800, # ox_nodes
+ 1032, # oy_nodes
540 * 1e-6, # wavelength_test tensor, mm # noqa: E501
- 4., # waist_radius_test, mm
- 500, # distance, mm
- 0.5 # expected relative error
+ 4.0, # waist_radius_test, mm
+ 500, # distance, mm
+ 0.5, # expected relative error
),
- ]
+ ],
)
def test_gaussian_beam_phase_profile(
ox_size: float,
@@ -378,34 +394,51 @@ def test_gaussian_beam_phase_profile(
distance: float,
expected_error: float,
):
+ """
+ Tests the phase profile of a Gaussian beam propagated using Fresnel and Angular Spectrum methods.
+
+ This test compares the phase profiles obtained from propagating a Gaussian beam
+ using both Fresnel and Angular Spectrum methods to an analytically calculated
+ phase profile. It checks if the standard deviation of the difference between
+ the computed and analytical phases is within a specified tolerance.
+
+ Args:
+ ox_size: The size of the x-axis grid.
+ oy_size: The size of the y-axis grid.
+ ox_nodes: The number of nodes in the x-axis grid.
+ oy_nodes: The number of nodes in the y-axis grid.
+ wavelength_test: The wavelength of the light used for simulation.
+ waist_radius_test: The waist radius of the Gaussian beam.
+ distance: The propagation distance.
+ expected_error: The expected maximum standard deviation of the phase difference.
+
+ Returns:
+ None. Raises an AssertionError if the tests fail.
+ """
params = SimulationParameters(
{
- 'W': torch.linspace(-ox_size/2, ox_size/2, ox_nodes),
- 'H': torch.linspace(-oy_size/2, oy_size/2, oy_nodes),
- 'wavelength': wavelength_test
+ "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes),
+ "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes),
+ "wavelength": wavelength_test,
}
)
field_gb_start = Wavefront.gaussian_beam(
- simulation_parameters=params,
- distance=0.,
- waist_radius=waist_radius_test
+ simulation_parameters=params, distance=0.0, waist_radius=waist_radius_test
)
# field on the screen by using Fresnel propagation method
field_end_fresnel = elements.FreeSpace(
- simulation_parameters=params, distance=distance, method='fresnel'
+ simulation_parameters=params, distance=distance, method="fresnel"
)(field_gb_start)
# field on the screen by using angular spectrum method
field_end_as = elements.FreeSpace(
- simulation_parameters=params, distance=distance, method='AS'
+ simulation_parameters=params, distance=distance, method="AS"
)(field_gb_start)
total_field = Wavefront.gaussian_beam(
- simulation_parameters=params,
- waist_radius=waist_radius_test,
- distance=distance
+ simulation_parameters=params, waist_radius=waist_radius_test, distance=distance
)
intensity_analytic = total_field.intensity
@@ -417,9 +450,5 @@ def test_gaussian_beam_phase_profile(
output_phase_fresnel = field_end_fresnel.phase * target_region
output_phase_analytical = total_field.phase * target_region
- assert torch.std(
- output_phase_as - output_phase_analytical
- ) <= expected_error
- assert torch.std(
- output_phase_fresnel - output_phase_analytical
- ) <= expected_error
+ assert torch.std(output_phase_as - output_phase_analytical) <= expected_error
+ assert torch.std(output_phase_fresnel - output_phase_analytical) <= expected_error
diff --git a/tests/test_lens.py b/tests/test_lens.py
index f3c8613..29f9148 100644
--- a/tests/test_lens.py
+++ b/tests/test_lens.py
@@ -12,7 +12,7 @@
"wavelength_test",
"focal_length_test",
"radius_test",
- "expected_std"
+ "expected_std",
]
@@ -20,26 +20,30 @@
lens_parameters,
[
(
- 8, # ox_size, mm
- 12, # oy_size, mm
- 1200, # ox_nodes
- 1400, # oy_nodes
- torch.linspace(330 * 1e-6, 1064 * 1e-6, 20), # wavelength_test, tensor # noqa: E501
- 100, # focal_length_test, mm
- 10, # radius_test, mm
- 1e-5 # expected_std
+ 8, # ox_size, mm
+ 12, # oy_size, mm
+ 1200, # ox_nodes
+ 1400, # oy_nodes
+ torch.linspace(
+ 330 * 1e-6, 1064 * 1e-6, 20
+ ), # wavelength_test, tensor # noqa: E501
+ 100, # focal_length_test, mm
+ 10, # radius_test, mm
+ 1e-5, # expected_std
),
(
8, # ox_size, mm
4, # oy_size, mm
- 1100, # ox_nodes
- 1000, # oy_nodes
- torch.linspace(660 * 1e-6, 1600 * 1e-6, 20), # wavelength_test, tensor # noqa: E501
- 200, # focal_length_test, mm
- 15, # radius_test, mm
- 1e-5 # expected_std
- )
- ]
+ 1100, # ox_nodes
+ 1000, # oy_nodes
+ torch.linspace(
+ 660 * 1e-6, 1600 * 1e-6, 20
+ ), # wavelength_test, tensor # noqa: E501
+ 200, # focal_length_test, mm
+ 15, # radius_test, mm
+ 1e-5, # expected_std
+ ),
+ ],
)
def test_lens(
ox_size: float,
@@ -75,17 +79,15 @@ def test_lens(
params = SimulationParameters(
{
- 'W': torch.linspace(-ox_size/2, ox_size/2, ox_nodes),
- 'H': torch.linspace(-oy_size/2, oy_size/2, oy_nodes),
- 'wavelength': wavelength_test
+ "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes),
+ "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes),
+ "wavelength": wavelength_test,
}
)
# transmission function of the thin lens as a class method
transmission_function = elements.ThinLens(
- simulation_parameters=params,
- focal_length=focal_length_test,
- radius=radius_test
+ simulation_parameters=params, focal_length=focal_length_test, radius=radius_test
).get_transmission_function()
x_linear = torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes)
@@ -102,16 +104,21 @@ def test_lens(
radius_squared = torch.pow(x_grid, 2) + torch.pow(y_grid, 2)
transmission_function_analytic = torch.exp(
- 1j * (-wave_number / (2 * focal_length_test) * radius_squared * (
- radius_squared <= radius_test**2
- ))
+ 1j
+ * (
+ -wave_number
+ / (2 * focal_length_test)
+ * radius_squared
+ * (radius_squared <= radius_test**2)
+ )
)
standard_deviation = torch.std(
- torch.real((1 / 1j) * (
- torch.log(transmission_function) - torch.log(
- transmission_function_analytic
- )
+ torch.real(
+ (1 / 1j)
+ * (
+ torch.log(transmission_function)
+ - torch.log(transmission_function_analytic)
)
)
)
@@ -120,21 +127,29 @@ def test_lens(
def test_reverse():
+ """
+ Tests the reversibility of the ThinLens forward and reverse propagation.
+
+ This test checks if applying the `forward` method followed by the `reverse`
+ method to a wavefront results in the original wavefront, confirming the
+ correctness of the inverse propagation implementation.
+
+ Args:
+ None
+
+ Returns:
+ None
+ """
params = SimulationParameters(
{
- 'W': torch.linspace(-10/2, 10/2, 10),
- 'H': torch.linspace(-10/2, 10/2, 10),
- 'wavelength': 1
+ "W": torch.linspace(-10 / 2, 10 / 2, 10),
+ "H": torch.linspace(-10 / 2, 10 / 2, 10),
+ "wavelength": 1,
}
)
- lens = elements.ThinLens(
- simulation_parameters=params,
- focal_length=1
- )
+ lens = elements.ThinLens(simulation_parameters=params, focal_length=1)
# test is reverse(forward(x)) is x, where x is a wavefront
wavefront = svetlanna.Wavefront.plane_wave(params)
- assert torch.allclose(
- lens.reverse(lens.forward(wavefront)), wavefront
- )
+ assert torch.allclose(lens.reverse(lens.forward(wavefront)), wavefront)
diff --git a/tests/test_lightpipes_comparison.py b/tests/test_lightpipes_comparison.py
index 6757570..d31acfa 100644
--- a/tests/test_lightpipes_comparison.py
+++ b/tests/test_lightpipes_comparison.py
@@ -7,14 +7,7 @@
from svetlanna import elements
-parameters = [
- "ox_size",
- "ox_nodes",
- "wavelength",
- "radius",
- "distance",
- "focal_length"
-]
+parameters = ["ox_size", "ox_nodes", "wavelength", "radius", "distance", "focal_length"]
# TODO: fix docstrings
@@ -23,29 +16,29 @@
[
(
25 * lp.mm, # ox_size
- 3000, # ox_nodes
+ 3000, # ox_nodes
1064 * lp.nm, # wavelength, mm
- 2 * lp.mm, # radius, mm
- 2000 * lp.mm, # distance, mm
- 2000 * lp.mm, # focal_length, mm
+ 2 * lp.mm, # radius, mm
+ 2000 * lp.mm, # distance, mm
+ 2000 * lp.mm, # focal_length, mm
),
(
25 * lp.mm, # ox_size
- 3000, # ox_nodes
+ 3000, # ox_nodes
1064 * lp.nm, # wavelength, mm
- 2 * lp.mm, # radius, mm
- 100 * lp.mm, # distance, mm
- 20 * lp.mm, # focal_length, mm
+ 2 * lp.mm, # radius, mm
+ 100 * lp.mm, # distance, mm
+ 20 * lp.mm, # focal_length, mm
),
(
25 * lp.mm, # ox_size
- 100, # ox_nodes
+ 100, # ox_nodes
123 * lp.nm, # wavelength, mm
- 2 * lp.mm, # radius, mm
- 200 * lp.mm, # distance, mm
- 2100 * lp.mm, # focal_length, mm
- )
- ]
+ 2 * lp.mm, # radius, mm
+ 200 * lp.mm, # distance, mm
+ 2100 * lp.mm, # focal_length, mm
+ ),
+ ],
)
def test_circular_aperture(
ox_size: float,
@@ -53,8 +46,28 @@ def test_circular_aperture(
wavelength: float,
radius: float,
distance: float,
- focal_length: float
+ focal_length: float,
):
+ """
+ Tests the circular aperture propagation using LightPipes and SVETlANNa.
+
+ This test compares the field calculated by LightPipes with the field
+ calculated by SVETlANNa for a circular aperture, free space propagation,
+ and a lens. It asserts that the mean absolute difference between the two
+ fields (normalized by their maximum absolute values) is less than 0.01
+ before and after the lens.
+
+ Args:
+ ox_size: The size of the computational grid in x direction.
+ ox_nodes: The number of nodes in the computational grid.
+ wavelength: The wavelength of light.
+ radius: The radius of the circular aperture.
+ distance: The distance to propagate before the lens.
+ focal_length: The focal length of the lens.
+
+ Returns:
+ None. The function asserts that the difference between LightPipes and SVETlANNa results is within tolerance.
+ """
# ----------------------------------
# LightPipes fields calculations
# ----------------------------------
@@ -71,39 +84,21 @@ def test_circular_aperture(
# ----------------------------------
oy_size = ox_size
oy_nodes = ox_nodes
- x_length = torch.linspace(
- -ox_size / 2, ox_size / 2, ox_nodes, dtype=torch.float64
- )
- y_length = torch.linspace(
- -oy_size / 2, oy_size / 2, oy_nodes, dtype=torch.float64
- )
+ x_length = torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes, dtype=torch.float64)
+ y_length = torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes, dtype=torch.float64)
simulation_parameters = sv.SimulationParameters(
axes={
- 'W': x_length,
- 'H': y_length,
- 'wavelength': torch.tensor(wavelength, dtype=torch.float64)
+ "W": x_length,
+ "H": y_length,
+ "wavelength": torch.tensor(wavelength, dtype=torch.float64),
}
)
# elements' definitions
- aperture = elements.RoundAperture(
- simulation_parameters,
- radius
- )
- fs1 = elements.FreeSpace(
- simulation_parameters,
- distance,
- method='fresnel'
- )
- lens = elements.ThinLens(
- simulation_parameters,
- focal_length
- )
- fs2 = elements.FreeSpace(
- simulation_parameters,
- focal_length,
- method='fresnel'
- )
+ aperture = elements.RoundAperture(simulation_parameters, radius)
+ fs1 = elements.FreeSpace(simulation_parameters, distance, method="fresnel")
+ lens = elements.ThinLens(simulation_parameters, focal_length)
+ fs2 = elements.FreeSpace(simulation_parameters, focal_length, method="fresnel")
# field calculations
G = sv.Wavefront.plane_wave(simulation_parameters)
@@ -121,12 +116,13 @@ def test_circular_aperture(
# ----------------------------------
# results testing
# ----------------------------------
- assert torch.mean(
- torch.abs(field_before_lens_lp - field_before_lens_sv)
- ) / before_lens_norm < 0.01
+ assert (
+ torch.mean(torch.abs(field_before_lens_lp - field_before_lens_sv))
+ / before_lens_norm
+ < 0.01
+ )
+
+ assert torch.mean(torch.abs(field_output_lp - field_output_sv)) / output_norm < 0.01
- assert torch.mean(
- torch.abs(field_output_lp - field_output_sv)
- ) / output_norm < 0.01
# TODO: ΡΡΠ°Π²Π½ΠΈΡΡ ΠΏΠΈΠΊΠΎΠ²ΡΡ ΠΌΠΎΡΠ½ΠΎΡΡΡ ΠΈ ΠΏΠΎΠ»ΠΎΠΆΠ΅Π½ΠΈΠ΅ ΠΌΠ°ΠΊΡΠΈΠΌΡΠΌΠΎΠ²
diff --git a/tests/test_logging.py b/tests/test_logging.py
index ea5e5fe..721d837 100644
--- a/tests/test_logging.py
+++ b/tests/test_logging.py
@@ -9,17 +9,28 @@
@pytest.mark.parametrize(
- 'input', [
- torch.tensor([10., 10.]),
- torch.tensor(20),
- svetlanna.Parameter(11.),
- svetlanna.ConstrainedParameter(11., min_value=-10, max_value=100),
- 123,
- 123.,
- None
- ]
+ "input",
+ [
+ torch.tensor([10.0, 10.0]),
+ torch.tensor(20),
+ svetlanna.Parameter(11.0),
+ svetlanna.ConstrainedParameter(11.0, min_value=-10, max_value=100),
+ 123,
+ 123.0,
+ None,
+ ],
)
def test_agr_short_description(input):
+ """
+ Tests the agr_short_description function with various inputs.
+
+ Args:
+ input: The input to be tested. Can be a torch.Tensor,
+ svetlanna.Parameter, svetlanna.ConstrainedParameter, number, or None.
+
+ Returns:
+ None
+ """
if isinstance(input, torch.Tensor):
# test for torch.Tensor
assert agr_short_description(input) == (
@@ -28,44 +39,61 @@ def test_agr_short_description(input):
)
else:
# test for other types
- assert agr_short_description(input) == f'{type(input)}'
+ assert agr_short_description(input) == f"{type(input)}"
def test_log_message(capfd, caplog):
+ """
+ Tests the log_message function with both 'print' and 'logging' types.
+
+ Args:
+ capfd: A pytest fixture for capturing stdout/stderr.
+ caplog: A pytest fixture for capturing logging messages.
+
+ Returns:
+ None
+ """
# test for 'print' type
- svetlanna.set_debug_logging(False, type='print') # set 'print' type
+ svetlanna.set_debug_logging(False, type="print") # set 'print' type
# log_message prints the message even if mode set to False!
- log_message('test message') # print message
+ log_message("test message") # print message
out, _ = capfd.readouterr() # read stdout
- assert out == 'test message\n'
+ assert out == "test message\n"
# test for 'logging' type
- svetlanna.set_debug_logging(False, type='logging') # set 'logging' type
+ svetlanna.set_debug_logging(False, type="logging") # set 'logging' type
- logger = logging.getLogger('svetlanna.logging') # get logger
+ logger = logging.getLogger("svetlanna.logging") # get logger
logger.setLevel(logging.DEBUG) # set logging level to DEBUG
- log_message('test message') # print message
+ log_message("test message") # print message
assert caplog.record_tuples == [
("svetlanna.logging", logging.DEBUG, "test message")
]
-@pytest.mark.parametrize(
- 'input', [
- 1, 1., (1, 2), tuple()
- ]
-)
-@pytest.mark.parametrize(
- 'output', [
- 1, 1., (1, 2), tuple()
- ]
-)
+@pytest.mark.parametrize("input", [1, 1.0, (1, 2), tuple()])
+@pytest.mark.parametrize("output", [1, 1.0, (1, 2), tuple()])
def test_forward_logging_hook(input, output, capfd):
- svetlanna.set_debug_logging(False, type='print') # set 'print' type
+ """
+ Tests the forward logging hook functionality.
+
+ This test verifies that the forward logging hook does not log anything for
+ modules that are not instances of svetlanna.elements.Element and logs the
+ input/output types when called with an Element-like module.
+
+ Args:
+ input: The input to the forward method.
+ output: The output from the forward method.
+ capfd: A pytest fixture for capturing stdout.
+
+ Returns:
+ None
+ """
+ svetlanna.set_debug_logging(False, type="print") # set 'print' type
# test for random element ignorance
class NotElement(torch.nn.Module):
@@ -74,7 +102,7 @@ class NotElement(torch.nn.Module):
forward_logging_hook(NotElement(), input, output)
out, _ = capfd.readouterr() # read stdout
- assert out == ''
+ assert out == ""
# test for elements
class ElementLike(svetlanna.elements.Element):
@@ -84,48 +112,51 @@ def forward(self, *args, **kwargs):
element = ElementLike(
simulation_parameters=svetlanna.SimulationParameters(
axes={
- 'H': torch.linspace(-1, 1, 10),
- 'W': torch.linspace(-1, 1, 10),
- 'wavelength': 1.
+ "H": torch.linspace(-1, 1, 10),
+ "W": torch.linspace(-1, 1, 10),
+ "wavelength": 1.0,
}
)
)
forward_logging_hook(element, input, output)
- expected_out = 'The forward method of ElementLike was computed'
+ expected_out = "The forward method of ElementLike was computed"
input = input if isinstance(input, tuple) else (input,)
output = output if isinstance(output, tuple) else (output,)
for i, _input in enumerate(input):
- expected_out += f'\n input {i}: {type(_input)}'
+ expected_out += f"\n input {i}: {type(_input)}"
for i, _output in enumerate(output):
- expected_out += f'\n output {i}: {type(_output)}'
+ expected_out += f"\n output {i}: {type(_output)}"
out, _ = capfd.readouterr() # read stdout
- assert out == expected_out + '\n'
+ assert out == expected_out + "\n"
-@pytest.mark.parametrize(
- 'input', [
- 1, 1., (1, 2), tuple()
- ]
-)
-@pytest.mark.parametrize(
- 'type_', [
- 'Parameter', 'Buffer', 'Module'
- ]
-)
+@pytest.mark.parametrize("input", [1, 1.0, (1, 2), tuple()])
+@pytest.mark.parametrize("type_", ["Parameter", "Buffer", "Module"])
def test_register_logging_hook(input, type_, capfd):
- svetlanna.set_debug_logging(False, type='print') # set 'print' type
+ """
+ Tests the registration of logging hooks for different input types and element types.
+
+ Args:
+ input: The input to be logged.
+ type_: The type of element being logged (e.g., 'Parameter', 'Buffer', 'Module').
+ capfd: A pytest fixture for capturing stdout/stderr.
+
+ Returns:
+ None
+ """
+ svetlanna.set_debug_logging(False, type="print") # set 'print' type
# test for random element ignorance
class NotElement(torch.nn.Module):
pass
- register_logging_hook(NotElement(), 'test_name', input, type_)
+ register_logging_hook(NotElement(), "test_name", input, type_)
out, _ = capfd.readouterr() # read stdout
- assert out == ''
+ assert out == ""
# test for elements
class ElementLike(svetlanna.elements.Element):
@@ -135,49 +166,59 @@ def forward(self, *args, **kwargs):
element = ElementLike(
simulation_parameters=svetlanna.SimulationParameters(
axes={
- 'H': torch.linspace(-1, 1, 10),
- 'W': torch.linspace(-1, 1, 10),
- 'wavelength': 1.
+ "H": torch.linspace(-1, 1, 10),
+ "W": torch.linspace(-1, 1, 10),
+ "wavelength": 1.0,
}
)
)
- register_logging_hook(element, 'test_name', input, type_)
+ register_logging_hook(element, "test_name", input, type_)
- expected_out = f'{type_} of {element._get_name()} was registered with name test_name:'
- expected_out += f'\n {type(input)}'
+ expected_out = (
+ f"{type_} of {element._get_name()} was registered with name test_name:"
+ )
+ expected_out += f"\n {type(input)}"
out, _ = capfd.readouterr() # read stdout
- assert out == expected_out + '\n'
+ assert out == expected_out + "\n"
-@pytest.mark.parametrize(
- 'input', [
- 1, 1., (1, 2), tuple()
- ]
-)
-@pytest.mark.parametrize(
- 'output', [
- 1, 1., (1, 2), tuple()
- ]
-)
+@pytest.mark.parametrize("input", [1, 1.0, (1, 2), tuple()])
+@pytest.mark.parametrize("output", [1, 1.0, (1, 2), tuple()])
def test_set_debug_logging(input, output, capfd, caplog):
+ """
+ Tests the set_debug_logging function with different configurations.
+
+ This test verifies that svetlanna.set_debug_logging correctly handles
+ different debug logging types ('print' and 'logging') and enables/disables
+ debugging as expected. It also checks for correct output when debugging is
+ enabled, ensuring the expected messages are printed or logged.
+
+ Args:
+ input: Input values to be passed to the ElementLike forward method.
+ output: Output values returned by the ElementLike forward method.
+ capfd: Pytest fixture for capturing stdout and stderr.
+ caplog: Pytest fixture for capturing log messages.
+
+ Returns:
+ None
+ """
# test wrong type
with pytest.raises(ValueError):
- svetlanna.set_debug_logging(False, type='123') # type: ignore
+ svetlanna.set_debug_logging(False, type="123") # type: ignore
input = input if isinstance(input, tuple) else (input,)
output = output if isinstance(output, tuple) else (output,)
class ElementLike(svetlanna.elements.Element):
def __init__(
- self,
- simulation_parameters: svetlanna.SimulationParameters
+ self, simulation_parameters: svetlanna.SimulationParameters
) -> None:
super().__init__(simulation_parameters)
self.a = torch.nn.Module()
- self.b = svetlanna.Parameter(123.)
- self.register_buffer('c', torch.tensor(123.))
+ self.b = svetlanna.Parameter(123.0)
+ self.register_buffer("c", torch.tensor(123.0))
def forward(self, *args, **kwargs):
return output
@@ -186,9 +227,9 @@ def run_element():
element = ElementLike(
simulation_parameters=svetlanna.SimulationParameters(
axes={
- 'H': torch.linspace(-1, 1, 10),
- 'W': torch.linspace(-1, 1, 10),
- 'wavelength': 1.
+ "H": torch.linspace(-1, 1, 10),
+ "W": torch.linspace(-1, 1, 10),
+ "wavelength": 1.0,
}
)
)
@@ -206,50 +247,44 @@ def run_element():
"Buffer of ElementLike was registered with name c:\n"
"
shape=torch.Size([]), dtype=torch.float32, device=cpu"
)
- expected_output_4 = (
- "The forward method of ElementLike was computed"
- )
+ expected_output_4 = "The forward method of ElementLike was computed"
for i, _input in enumerate(input):
- expected_output_4 += f'\n input {i}: {type(_input)}'
+ expected_output_4 += f"\n input {i}: {type(_input)}"
for i, _output in enumerate(output):
- expected_output_4 += f'\n output {i}: {type(_output)}'
+ expected_output_4 += f"\n output {i}: {type(_output)}"
expected_outputs = [
expected_output_1,
expected_output_2,
expected_output_3,
- expected_output_4
+ expected_output_4,
]
# test for print type
- svetlanna.set_debug_logging(True, type='print')
+ svetlanna.set_debug_logging(True, type="print")
run_element()
out, _ = capfd.readouterr() # read stdout
- assert out == '\n'.join(expected_outputs) + '\n'
+ assert out == "\n".join(expected_outputs) + "\n"
# test for print type, with disabled debug logging
- svetlanna.set_debug_logging(False, type='print')
+ svetlanna.set_debug_logging(False, type="print")
run_element()
out, _ = capfd.readouterr() # read stdout
- assert out == ''
+ assert out == ""
# test for logging type
- svetlanna.set_debug_logging(True, type='logging')
- logger = logging.getLogger('svetlanna.logging') # get logger
+ svetlanna.set_debug_logging(True, type="logging")
+ logger = logging.getLogger("svetlanna.logging") # get logger
logger.setLevel(logging.DEBUG) # set logging level to DEBUG
run_element()
assert caplog.record_tuples == [
- (
- "svetlanna.logging",
- logging.DEBUG,
- message
- ) for message in expected_outputs
+ ("svetlanna.logging", logging.DEBUG, message) for message in expected_outputs
]
caplog.clear() # clear caplog
assert caplog.record_tuples == []
# test for logging type, with disabled debug logging
- svetlanna.set_debug_logging(False, type='logging')
+ svetlanna.set_debug_logging(False, type="logging")
run_element()
assert caplog.record_tuples == []
diff --git a/tests/test_nonlinear_element.py b/tests/test_nonlinear_element.py
index 21f94c6..7c0c8da 100644
--- a/tests/test_nonlinear_element.py
+++ b/tests/test_nonlinear_element.py
@@ -14,12 +14,23 @@
"oy_nodes",
"wavelength_test",
"response_function",
- "response_parameters"
+ "response_parameters",
]
def func(x, a, b):
- return a / 1 + torch.exp(-b*x)
+ """
+ Computes a value based on the given inputs.
+
+ Args:
+ x: The input value.
+ a: A constant value.
+ b: Another constant value.
+
+ Returns:
+ torch.Tensor: The computed result of the formula a / 1 + torch.exp(-b*x).
+ """
+ return a / 1 + torch.exp(-b * x)
@pytest.mark.parametrize(
@@ -27,9 +38,25 @@ def func(x, a, b):
[
(10, 10, 1000, 1200, 1064 * 1e-6, lambda x: x**2, None),
(4, 4, 1300, 1000, 1064 * 1e-6, lambda x: torch.sin(x) + x**3, None),
- (15, 8, 1319, 917, 1e-6 * torch.tensor([330, 660, 1064]), lambda x: torch.sin(x) + x**3, None), # noqa: E501
- (16, 7, 500, 868, 1e-6 * torch.tensor([330, 660, 1064]), func, {"a": 1., "b": 9.}) # noqa: E501
- ]
+ (
+ 15,
+ 8,
+ 1319,
+ 917,
+ 1e-6 * torch.tensor([330, 660, 1064]),
+ lambda x: torch.sin(x) + x**3,
+ None,
+ ), # noqa: E501
+ (
+ 16,
+ 7,
+ 500,
+ 868,
+ 1e-6 * torch.tensor([330, 660, 1064]),
+ func,
+ {"a": 1.0, "b": 9.0},
+ ), # noqa: E501
+ ],
)
def test_nonlinear_element(
ox_size: float,
@@ -38,13 +65,29 @@ def test_nonlinear_element(
oy_nodes: int,
wavelength_test: float,
response_function: Callable[[torch.Tensor], torch.Tensor],
- response_parameters: Dict
+ response_parameters: Dict,
):
+ """
+ Tests the NonlinearElement class with various parameters.
+
+ Args:
+ ox_size: The size of the simulation area in the x-direction.
+ oy_size: The size of the simulation area in the y-direction.
+ ox_nodes: The number of nodes in the x-direction.
+ oy_nodes: The number of nodes in the y-direction.
+ wavelength_test: The wavelength of the incident light.
+ response_function: The nonlinear response function to use.
+ response_parameters: A dictionary of parameters for the response function.
+
+ Returns:
+ None. This method asserts that the output field from the NonlinearElement
+ matches the analytically calculated output field.
+ """
params = SimulationParameters(
{
- 'W': torch.linspace(-ox_size/2, ox_size/2, ox_nodes),
- 'H': torch.linspace(-oy_size/2, oy_size/2, oy_nodes),
- 'wavelength': wavelength_test
+ "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes),
+ "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes),
+ "wavelength": wavelength_test,
}
)
@@ -53,7 +96,7 @@ def test_nonlinear_element(
nle = elements.NonlinearElement(
simulation_parameters=params,
response_function=response_function,
- response_parameters=response_parameters
+ response_parameters=response_parameters,
)
incident_amplitude = torch.abs(incident_field)
@@ -65,7 +108,7 @@ def test_nonlinear_element(
output_amplitude = response_function(
incident_amplitude,
response_parameters[keys[0]],
- response_parameters[keys[1]]
+ response_parameters[keys[1]],
)
else:
diff --git a/tests/test_parameters.py b/tests/test_parameters.py
index e034c0e..7806761 100644
--- a/tests/test_parameters.py
+++ b/tests/test_parameters.py
@@ -5,19 +5,21 @@
def test_inner_parameter_storage():
- torch_parameter = torch.nn.Parameter(torch.tensor(1.))
- torch_tensor = torch.tensor(2.)
- sv_parameter = Parameter(torch.tensor(3.))
+ """
+ Tests the inner parameter storage module."""
+ torch_parameter = torch.nn.Parameter(torch.tensor(1.0))
+ torch_tensor = torch.tensor(2.0)
+ sv_parameter = Parameter(torch.tensor(3.0))
sv_bounded_parameter = ConstrainedParameter(
- torch.tensor(4.), min_value=0., max_value=2.
+ torch.tensor(4.0), min_value=0.0, max_value=2.0
)
storage = InnerParameterStorageModule(
{
- 'value1': torch_parameter,
- 'value2': torch_tensor,
- 'value3': sv_parameter,
- 'value4': sv_bounded_parameter,
+ "value1": torch_parameter,
+ "value2": torch_tensor,
+ "value3": sv_parameter,
+ "value4": sv_bounded_parameter,
}
)
@@ -39,18 +41,32 @@ def test_inner_parameter_storage():
with pytest.raises(TypeError):
InnerParameterStorageModule(
{
- 'a': 123, # type: ignore
+ "a": 123, # type: ignore
}
)
@pytest.mark.parametrize(
- "parameter", [
- Parameter(data=123.),
- ConstrainedParameter(data=123., min_value=0, max_value=300)
- ]
+ "parameter",
+ [
+ Parameter(data=123.0),
+ ConstrainedParameter(data=123.0, min_value=0, max_value=300),
+ ],
)
def test_new(parameter: Parameter | ConstrainedParameter):
+ """
+ Tests the properties of a new parameter object.
+
+ This function verifies that the provided parameter is a PyTorch tensor,
+ not a `torch.nn.Parameter`, and behaves correctly with basic tensor operations.
+ It also checks the types of inner attributes.
+
+ Args:
+ parameter: The Parameter or ConstrainedParameter instance to test.
+
+ Returns:
+ None
+ """
# check if parameter is a tensor and not a torch parameter
assert isinstance(parameter, torch.Tensor)
assert not isinstance(parameter, torch.nn.Parameter)
@@ -64,16 +80,30 @@ def test_new(parameter: Parameter | ConstrainedParameter):
@pytest.mark.parametrize(
- "parameter", [
- Parameter(data=123.),
- ConstrainedParameter(data=123., min_value=0, max_value=300)
- ]
+ "parameter",
+ [
+ Parameter(data=123.0),
+ ConstrainedParameter(data=123.0, min_value=0, max_value=300),
+ ],
)
def test_behavior_as_a_tensor(parameter):
- a = 123.
+ """
+ Tests the behavior of the parameter when used as a tensor.
+
+ This tests multiplication and exponentiation operations with a scalar,
+ both directly and using torch functions to ensure proper handling via
+ __torch_function__.
+
+ Args:
+ parameter: The parameter object to test.
+
+ Returns:
+ None
+ """
+ a = 123.0
b = 10
res_mul = torch.tensor(a * b) # a * b
- res_pow = torch.tensor(a ** b) # a + b
+ res_pow = torch.tensor(a**b) # a + b
# test __torch_function__ for args processing
torch.testing.assert_close(parameter * b, res_mul)
@@ -84,30 +114,40 @@ def test_behavior_as_a_tensor(parameter):
def test_bounded_parameter_inner_value():
- data = 2.
- min_value = 0.
- max_value = 5.
+ """
+ Tests the inner parameter value of ConstrainedParameter with and without custom bound functions.
+
+ This test verifies that the inner parameter correctly maps to the constrained data
+ value using both the default sigmoid function and a user-defined bound function.
+ It also checks the `value` property when a custom bound function is provided.
+
+ Args:
+ None
+
+ Returns:
+ None
+ """
+ data = 2.0
+ min_value = 0.0
+ max_value = 5.0
# === default bound_func ===
parameter = ConstrainedParameter(
- data=data,
- min_value=min_value,
- max_value=max_value
+ data=data, min_value=min_value, max_value=max_value
)
# test inner parameter value
torch.testing.assert_close(
- (max_value-min_value) * torch.sigmoid(parameter.inner_parameter)
- + min_value,
- torch.tensor(data)
+ (max_value - min_value) * torch.sigmoid(parameter.inner_parameter) + min_value,
+ torch.tensor(data),
)
# === custom bound_func ===
def bound_func(x: torch.Tensor) -> torch.Tensor:
if x < 0:
- return torch.tensor(0.)
+ return torch.tensor(0.0)
if x > 1:
- return torch.tensor(1.)
+ return torch.tensor(1.0)
return x
def inv_bound_func(x: torch.Tensor) -> torch.Tensor:
@@ -118,7 +158,7 @@ def inv_bound_func(x: torch.Tensor) -> torch.Tensor:
min_value=min_value,
max_value=max_value,
bound_func=bound_func,
- inv_bound_func=inv_bound_func
+ inv_bound_func=inv_bound_func,
)
# test `value` property
@@ -126,19 +166,28 @@ def inv_bound_func(x: torch.Tensor) -> torch.Tensor:
# test inner parameter value
torch.testing.assert_close(
- (max_value-min_value) * bound_func(parameter.inner_parameter)
- + min_value,
- torch.tensor(data)
+ (max_value - min_value) * bound_func(parameter.inner_parameter) + min_value,
+ torch.tensor(data),
)
@pytest.mark.parametrize(
- "parameter", [
- Parameter(data=123.),
- ConstrainedParameter(data=123., min_value=0, max_value=300)
- ]
+ "parameter",
+ [
+ Parameter(data=123.0),
+ ConstrainedParameter(data=123.0, min_value=0, max_value=300),
+ ],
)
def test_repr(parameter):
+ """
+ Tests the repr of a parameter.
+
+ Args:
+ parameter: The parameter to test.
+
+ Returns:
+ None: This function only asserts that `repr(parameter)` does not raise an exception.
+ """
assert repr(parameter)
@@ -146,35 +195,42 @@ def test_repr(parameter):
("device",),
[
pytest.param(
- 'cuda',
+ "cuda",
marks=pytest.mark.skipif(
- not torch.cuda.is_available(),
- reason="cuda is not available"
- )
+ not torch.cuda.is_available(), reason="cuda is not available"
+ ),
),
pytest.param(
- 'mps',
+ "mps",
marks=pytest.mark.skipif(
- not torch.backends.mps.is_available(),
- reason="mps is not available"
- )
- )
- ]
+ not torch.backends.mps.is_available(), reason="mps is not available"
+ ),
+ ),
+ ],
)
def test_storage_to_device(device):
- torch_parameter = torch.nn.Parameter(torch.tensor(1.))
- torch_tensor = torch.tensor(2.)
- sv_parameter = Parameter(torch.tensor(3.))
+ """
+ Tests moving an InnerParameterStorageModule to a specified device and back to CPU.
+
+ Args:
+ device: The device to move the storage to (e.g., 'cuda', 'mps').
+
+ Returns:
+ None
+ """
+ torch_parameter = torch.nn.Parameter(torch.tensor(1.0))
+ torch_tensor = torch.tensor(2.0)
+ sv_parameter = Parameter(torch.tensor(3.0))
sv_bounded_parameter = ConstrainedParameter(
- torch.tensor(4.), min_value=0., max_value=2.
+ torch.tensor(4.0), min_value=0.0, max_value=2.0
)
storage = InnerParameterStorageModule(
{
- 'value1': torch_parameter,
- 'value2': torch_tensor,
- 'value3': sv_parameter,
- 'value4': sv_bounded_parameter,
+ "value1": torch_parameter,
+ "value2": torch_tensor,
+ "value3": sv_parameter,
+ "value4": sv_bounded_parameter,
}
)
@@ -185,44 +241,51 @@ def test_storage_to_device(device):
assert storage.value3.device.type == device
assert storage.value4.device.type == device
- storage.to(device='cpu')
+ storage.to(device="cpu")
# test if all values has been transferred to the cpu
- assert storage.value1.device.type == 'cpu'
- assert storage.value2.device.type == 'cpu'
- assert storage.value3.device.type == 'cpu'
- assert storage.value4.device.type == 'cpu'
+ assert storage.value1.device.type == "cpu"
+ assert storage.value2.device.type == "cpu"
+ assert storage.value3.device.type == "cpu"
+ assert storage.value4.device.type == "cpu"
@pytest.mark.parametrize(
("device",),
[
pytest.param(
- 'cuda',
+ "cuda",
marks=pytest.mark.skipif(
- not torch.cuda.is_available(),
- reason="cuda is not available"
- )
+ not torch.cuda.is_available(), reason="cuda is not available"
+ ),
),
pytest.param(
- 'mps',
+ "mps",
marks=pytest.mark.skipif(
- not torch.backends.mps.is_available(),
- reason="mps is not available"
- )
- )
- ]
+ not torch.backends.mps.is_available(), reason="mps is not available"
+ ),
+ ),
+ ],
)
@pytest.mark.parametrize(
- "parameter", [
- Parameter(data=torch.tensor(123., dtype=torch.float32)),
+ "parameter",
+ [
+ Parameter(data=torch.tensor(123.0, dtype=torch.float32)),
ConstrainedParameter(
- data=torch.tensor(123., dtype=torch.float32),
- min_value=0,
- max_value=300
- )
- ]
+ data=torch.tensor(123.0, dtype=torch.float32), min_value=0, max_value=300
+ ),
+ ],
)
def test_parameter_to_device(device, parameter):
+ """
+ Tests that a Parameter or ConstrainedParameter can be moved to the specified device.
+
+ Args:
+ device: The device to move the parameter to (e.g., 'cuda', 'mps').
+ parameter: The Parameter or ConstrainedParameter instance to test.
+
+ Returns:
+ None
+ """
# transferred_parameter = parameter.to(device)
# assert transferred_parameter.device.type == device
# assert transferred_parameter.inner_storage.device.type == device
diff --git a/tests/test_phase_retrieval.py b/tests/test_phase_retrieval.py
index fcd5238..04a7dd9 100644
--- a/tests/test_phase_retrieval.py
+++ b/tests/test_phase_retrieval.py
@@ -10,11 +10,13 @@
def test_retrieve_phase_api(capsys):
+ """
+ Tests the retrieve_phase API with different scenarios."""
params = SimulationParameters(
{
- 'W': torch.linspace(-1, 1, 10),
- 'H': torch.linspace(-1, 1, 10),
- 'wavelength': 1
+ "W": torch.linspace(-1, 1, 10),
+ "H": torch.linspace(-1, 1, 10),
+ "wavelength": 1,
}
)
# no initial_phase, no additional options
@@ -30,7 +32,7 @@ def test_retrieve_phase_api(capsys):
Wavefront.plane_wave(params).abs(),
LinearOpticalSetup([]),
Wavefront.plane_wave(params).abs(),
- method='abs' # type: ignore
+ method="abs", # type: ignore
)
# Test disp option for the intensity profile problem type
@@ -38,12 +40,10 @@ def test_retrieve_phase_api(capsys):
Wavefront.plane_wave(params).abs(),
LinearOpticalSetup([]),
Wavefront.plane_wave(params).abs(),
- options={
- 'disp': True
- }
+ options={"disp": True},
)
- captured = capsys.readouterr().out.split('\n')[0]
- assert captured == 'Type of problem: generate intensity profile'
+ captured = capsys.readouterr().out.split("\n")[0]
+ assert captured == "Type of problem: generate intensity profile"
# Test disp option for the phase reconstruction problem type
phase_retrieval.retrieve_phase(
@@ -52,12 +52,10 @@ def test_retrieve_phase_api(capsys):
Wavefront.plane_wave(params).abs(),
target_phase=torch.zeros((10, 10)),
target_region=torch.zeros((10, 10)),
- options={
- 'disp': True
- }
+ options={"disp": True},
)
- captured = capsys.readouterr().out.split('\n')[0]
- assert captured == 'Type of problem: phase reconstruction'
+ captured = capsys.readouterr().out.split("\n")[0]
+ assert captured == "Type of problem: phase reconstruction"
parameters = [
@@ -67,20 +65,20 @@ def test_retrieve_phase_api(capsys):
"oy_nodes",
"wavelength_test",
"waist_radius_test",
- "distance_test"
+ "distance_test",
]
@pytest.mark.parametrize(
parameters,
[
- (10, 10, 200, 200, 0.025, 0.7, 100.),
- (7, 8, 200, 200, 0.02, 0.7, 150.),
- (15, 8, 300, 200, 0.02, 0.5, 120.),
- ]
+ (10, 10, 200, 200, 0.025, 0.7, 100.0),
+ (7, 8, 200, 200, 0.02, 0.7, 150.0),
+ (15, 8, 300, 200, 0.02, 0.5, 120.0),
+ ],
)
-@pytest.mark.parametrize('use_phase_target', [True, False])
-@pytest.mark.parametrize('method', ['HIO', 'GS'])
+@pytest.mark.parametrize("use_phase_target", [True, False])
+@pytest.mark.parametrize("method", ["HIO", "GS"])
def test_phase_retrieval(
ox_size: float,
oy_size: float,
@@ -90,7 +88,7 @@ def test_phase_retrieval(
waist_radius_test: float,
distance_test: float,
use_phase_target: bool,
- method: phase_retrieval.Method
+ method: phase_retrieval.Method,
):
"""Test for phase reconstruction problem and generate target intensity
problem using HIO and Gerchberg-Saxton algorithms on the example of a
@@ -122,32 +120,27 @@ def test_phase_retrieval(
params = SimulationParameters(
{
- 'W': torch.linspace(-ox_size/2, ox_size/2, ox_nodes),
- 'H': torch.linspace(-oy_size/2, oy_size/2, oy_nodes),
- 'wavelength': wavelength_test
+ "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes),
+ "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes),
+ "wavelength": wavelength_test,
}
)
- x_grid, y_grid = params.meshgrid('W', 'H')
+ x_grid, y_grid = params.meshgrid("W", "H")
field_before_lens1 = Wavefront.gaussian_beam(
simulation_parameters=params,
distance=0.05 * distance_test,
- waist_radius=waist_radius_test
+ waist_radius=waist_radius_test,
)
intensity_source = field_before_lens1.intensity
- lens1 = elements.ThinLens(
- simulation_parameters=params,
- focal_length=distance_test
- )
+ lens1 = elements.ThinLens(simulation_parameters=params, focal_length=distance_test)
field_after_lens1 = lens1(field_before_lens1)
free_space1 = elements.FreeSpace(
- simulation_parameters=params,
- distance=0.05 * distance_test,
- method='AS'
+ simulation_parameters=params, distance=0.05 * distance_test, method="AS"
)
output_field = free_space1(field_after_lens1)
@@ -159,7 +152,7 @@ def test_phase_retrieval(
# target phase profile for phase reconstruction problem
if use_phase_target:
phase_target = torch.angle(output_field)
- target_region = (x_grid**2 + y_grid ** 2 <= 0.12).float()
+ target_region = (x_grid**2 + y_grid**2 <= 0.12).float()
result_hio = phase_retrieval.retrieve_phase(
source_intensity=intensity_source,
@@ -169,10 +162,7 @@ def test_phase_retrieval(
target_region=target_region,
initial_phase=torch.full_like(intensity_target, 0),
method=method,
- options={
- 'maxiter': 100,
- 'constant_factor': 0.5
- }
+ options={"maxiter": 100, "constant_factor": 0.5},
)
else:
result_hio = phase_retrieval.retrieve_phase(
@@ -181,16 +171,13 @@ def test_phase_retrieval(
target_intensity=intensity_target,
initial_phase=torch.full_like(intensity_target, 0),
method=method,
- options={
- 'maxiter': 100,
- 'constant_factor': 0.5
- }
+ options={"maxiter": 100, "constant_factor": 0.5},
)
errors = result_hio.cost_func_evolution
# test if the error decreases
- assert np.sum(np.diff(errors) < 0) > 0.7 * (len(errors)-1)
+ assert np.sum(np.diff(errors) < 0) > 0.7 * (len(errors) - 1)
assert (errors[0] - errors[-1]) / errors[0] > 0.6
@@ -204,7 +191,7 @@ def test_phase_retrieval(
"waist_radius_test",
"distance_test",
"radius_test",
- "error_energy"
+ "error_energy",
]
@@ -212,10 +199,10 @@ def test_phase_retrieval(
@pytest.mark.parametrize(
parameters_4f,
[
- (10, 10, 1000, 1000, 660 * 1e-6, 0.5, 100., 10., 1e-4),
- (7, 8, 1000, 1000, 1064 * 1e-6, 0.7, 150., 10., 1e-4),
- (15, 8, 1500, 1000, 550 * 1e-6, 0.5, 120., 10., 1e-4)
- ]
+ (10, 10, 1000, 1000, 660 * 1e-6, 0.5, 100.0, 10.0, 1e-4),
+ (7, 8, 1000, 1000, 1064 * 1e-6, 0.7, 150.0, 10.0, 1e-4),
+ (15, 8, 1500, 1000, 550 * 1e-6, 0.5, 120.0, 10.0, 1e-4),
+ ],
)
def test_4f_system(
ox_size: float,
@@ -226,7 +213,7 @@ def test_4f_system(
waist_radius_test: float,
distance_test: float,
radius_test: float,
- error_energy: float
+ error_energy: float,
):
"""Test for phase reconstruction problem using HIO algorithm on the
example of a 4f optical setup
@@ -254,31 +241,29 @@ def test_4f_system(
"""
x_linear = torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes)
y_linear = torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes)
- x_grid, y_grid = torch.meshgrid(x_linear, y_linear, indexing='xy')
+ x_grid, y_grid = torch.meshgrid(x_linear, y_linear, indexing="xy")
dx = ox_size / ox_nodes
dy = oy_size / oy_nodes
params = SimulationParameters(
{
- 'W': torch.linspace(-ox_size/2, ox_size/2, ox_nodes),
- 'H': torch.linspace(-oy_size/2, oy_size/2, oy_nodes),
- 'wavelength': wavelength_test
+ "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes),
+ "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes),
+ "wavelength": wavelength_test,
}
)
field_before_lens1 = Wavefront.gaussian_beam(
simulation_parameters=params,
distance=distance_test,
- waist_radius=waist_radius_test
+ waist_radius=waist_radius_test,
)
intensity_source = field_before_lens1.intensity.detach().numpy()
lens1 = elements.ThinLens(
- simulation_parameters=params,
- focal_length=distance_test,
- radius=radius_test
+ simulation_parameters=params, focal_length=distance_test, radius=radius_test
)
field_after_lens1 = lens1.forward(input_field=field_before_lens1)
@@ -286,38 +271,37 @@ def test_4f_system(
free_space1 = elements.FreeSpace(
simulation_parameters=params,
distance=torch.tensor(2 * distance_test),
- method='AS'
+ method="AS",
)
field_before_lens2 = free_space1.forward(input_field=field_after_lens1)
lens2 = elements.ThinLens(
- simulation_parameters=params,
- focal_length=distance_test,
- radius=radius_test
+ simulation_parameters=params, focal_length=distance_test, radius=radius_test
)
field_after_lens2 = lens2.forward(input_field=field_before_lens2)
free_space2 = elements.FreeSpace(
- simulation_parameters=params,
- distance=torch.tensor(distance_test),
- method='AS'
+ simulation_parameters=params, distance=torch.tensor(distance_test), method="AS"
)
output_field = free_space2.forward(input_field=field_after_lens2)
phase_target = (
- torch.angle(output_field) + 2 * torch.pi * (
- torch.angle(output_field) < 0.
- ).float()
- ).detach().numpy()
+ (
+ torch.angle(output_field)
+ + 2 * torch.pi * (torch.angle(output_field) < 0.0).float()
+ )
+ .detach()
+ .numpy()
+ )
intensity_target = output_field.intensity.detach().numpy()
optical_setup = LinearOpticalSetup([free_space1, lens2, free_space2])
- goal = (x_grid**2 + y_grid ** 2 <= 2).float()
+ goal = (x_grid**2 + y_grid**2 <= 2).float()
result_hio = phase_retrieval.retrieve_phase(
source_intensity=torch.tensor(intensity_source),
@@ -326,7 +310,7 @@ def test_4f_system(
target_phase=torch.tensor(phase_target),
target_region=goal,
initial_phase=None,
- method='HIO',
+ method="HIO",
)
phase_reconstruction_hio = result_hio.solution
@@ -335,14 +319,11 @@ def test_4f_system(
mask_reconstruction_hio = phase_reconstruction_hio // step
field_after_slm = elements.SpatialLightModulator(
- simulation_parameters=params,
- mask=mask_reconstruction_hio
+ simulation_parameters=params, mask=mask_reconstruction_hio
).forward(field_before_lens1)
output_field = optical_setup.forward(field_after_slm)
- intensity_target_opt = torch.pow(
- torch.abs(output_field), 2
- ).detach().numpy()
+ intensity_target_opt = torch.pow(torch.abs(output_field), 2).detach().numpy()
energy_reconstruction_hio = np.sum(intensity_target_opt) * dx * dy
energy_true = np.sum(intensity_target) * dx * dy
diff --git a/tests/test_reservoir.py b/tests/test_reservoir.py
index 36380f9..0c47941 100644
--- a/tests/test_reservoir.py
+++ b/tests/test_reservoir.py
@@ -4,24 +4,29 @@
def test_queue():
+ """
+ Tests the functionality of the feedback queue in SimpleReservoir.
+
+ This tests appends to, pops from and drops the feedback queue within a
+ SimpleReservoir instance, verifying correct behavior with different queue lengths
+ relative to the specified delay.
+
+ Args:
+ None
+
+ Returns:
+ None
+ """
sim_params = SimulationParameters(
- {
- 'W': torch.tensor([0]),
- 'H': torch.tensor([0]),
- 'wavelength': 1.
- }
+ {"W": torch.tensor([0]), "H": torch.tensor([0]), "wavelength": 1.0}
)
reservoir = SimpleReservoir(
sim_params,
- nonlinear_element=DiffractiveLayer(
- sim_params, mask=torch.tensor([[0.]])
- ),
- delay_element=DiffractiveLayer(
- sim_params, mask=torch.tensor([[0.]])
- ),
+ nonlinear_element=DiffractiveLayer(sim_params, mask=torch.tensor([[0.0]])),
+ delay_element=DiffractiveLayer(sim_params, mask=torch.tensor([[0.0]])),
delay=2,
feedback_gain=1,
- input_gain=1
+ input_gain=1,
)
# feedback queue is empty
@@ -54,20 +59,25 @@ def test_queue():
def test_forward():
+ """
+ Tests the forward pass of the SimpleReservoir.
+
+ This test verifies that the reservoir correctly implements feedback and delay,
+ and that the output matches expectations for both initial iterations and after
+ the delay line is populated.
+
+ Args:
+ None
+
+ Returns:
+ None
+ """
sim_params = SimulationParameters(
- {
- 'W': torch.tensor([0]),
- 'H': torch.tensor([0]),
- 'wavelength': 1.
- }
+ {"W": torch.tensor([0]), "H": torch.tensor([0]), "wavelength": 1.0}
)
- nonlinear_element = DiffractiveLayer(
- sim_params, mask=torch.tensor([[0.]])
- )
- delay_element = DiffractiveLayer(
- sim_params, mask=torch.tensor([[0.]])
- )
+ nonlinear_element = DiffractiveLayer(sim_params, mask=torch.tensor([[0.0]]))
+ delay_element = DiffractiveLayer(sim_params, mask=torch.tensor([[0.0]]))
feedback_gain = 0.8
input_gain = 0.6
delay = 5
@@ -78,7 +88,7 @@ def test_forward():
delay_element=delay_element,
delay=delay,
feedback_gain=feedback_gain,
- input_gain=input_gain
+ input_gain=input_gain,
)
wf = Wavefront.plane_wave(sim_params)
@@ -102,8 +112,6 @@ def test_forward():
# hard coded very first delay line related contribution
wf_out_expected = nonlinear_element(
- input_gain * wf + feedback_gain * nonlinear_element(
- input_gain * wf
- )
+ input_gain * wf + feedback_gain * nonlinear_element(input_gain * wf)
)
assert torch.allclose(wf_out, wf_out_expected)
diff --git a/tests/test_setup.py b/tests/test_setup.py
index b020740..e4a2d4b 100644
--- a/tests/test_setup.py
+++ b/tests/test_setup.py
@@ -9,30 +9,82 @@
class SimpleElement(Element):
- def __init__(
- self,
- a: Any,
- simulation_parameters: SimulationParameters
- ) -> None:
+ """
+ Represents a simple optical element that scales a wavefront.
+
+ Attributes:
+ a: The scaling factor for the wavefront.
+ simulation_parameters: Parameters used for the simulation.
+
+ Methods:
+ __init__: Initializes the instance with given parameters.
+ forward: Applies a scaling factor to the input wavefront.
+ """
+
+ def __init__(self, a: Any, simulation_parameters: SimulationParameters) -> None:
+ """
+ Initializes the instance with given parameters.
+
+ Args:
+ a: The value for attribute 'a'.
+ simulation_parameters: Parameters used for the simulation.
+
+ Returns:
+ None
+ """
super().__init__(simulation_parameters)
self.a = a
def forward(self, incident_wavefront: Wavefront) -> Wavefront:
+ """
+ Applies a scaling factor to the input wavefront.
+
+ Args:
+ incident_wavefront: The input Wavefront object.
+
+ Returns:
+ Wavefront: A new Wavefront object representing the scaled wavefront.
+ """
return incident_wavefront * self.a
class ReversableSimpleElement(SimpleElement):
+ """
+ Reverses a wavefront using a scaling factor."""
+
def reverse(self, wavefront):
+ """
+ Reverses a wavefront by multiplying it with the scaling factor.
+
+ Args:
+ wavefront: The wavefront to be reversed.
+
+ Returns:
+ A numpy array representing the reversed wavefront.
+ """
return wavefront * self.a
def test_init():
+ """
+ Tests the initialization and forward pass of LinearOpticalSetup.
+
+ This test creates a LinearOpticalSetup with three SimpleElements,
+ verifies that the internal neural network is a torch.nn.Module,
+ and checks if the forward pass correctly applies the element's 'a' value.
+
+ Parameters:
+ None
+
+ Returns:
+ None
+ """
sim_params = SimulationParameters(
{
- 'W': torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10),
- 'H': torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10),
- 'wavelength': 1
+ "W": torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10),
+ "H": torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10),
+ "wavelength": 1,
}
)
@@ -41,9 +93,7 @@ def test_init():
el2 = SimpleElement(a=a, simulation_parameters=sim_params)
el3 = SimpleElement(a=a, simulation_parameters=sim_params)
- setup = LinearOpticalSetup(elements=[
- el1, el2, el3
- ])
+ setup = LinearOpticalSetup(elements=[el1, el2, el3])
assert isinstance(setup.net, torch.nn.Module)
@@ -53,18 +103,21 @@ def test_init():
def test_init_warning():
+ """
+ Tests that a UserWarning is raised when initializing LinearOpticalSetup with identical simulation parameters.
+ """
sim_params1 = SimulationParameters(
{
- 'W': torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10),
- 'H': torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10),
- 'wavelength': 1
+ "W": torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10),
+ "H": torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10),
+ "wavelength": 1,
}
)
sim_params2 = SimulationParameters(
{
- 'W': torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10),
- 'H': torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10),
- 'wavelength': 1
+ "W": torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10),
+ "H": torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10),
+ "wavelength": 1,
}
)
@@ -73,50 +126,48 @@ def test_init_warning():
el2 = SimpleElement(a=a, simulation_parameters=sim_params2)
with pytest.warns(UserWarning):
- LinearOpticalSetup(elements=[
- el1, el2
- ])
+ LinearOpticalSetup(elements=[el1, el2])
@pytest.mark.parametrize(
("device",),
[
pytest.param(
- 'cuda',
+ "cuda",
marks=pytest.mark.skipif(
- not torch.cuda.is_available(),
- reason="cuda is not available"
- )
+ not torch.cuda.is_available(), reason="cuda is not available"
+ ),
),
pytest.param(
- 'mps',
+ "mps",
marks=pytest.mark.skipif(
- not torch.backends.mps.is_available(),
- reason="mps is not available"
- )
- )
- ]
+ not torch.backends.mps.is_available(), reason="mps is not available"
+ ),
+ ),
+ ],
)
def test_to_device(device):
+ """
+ Tests that moving the network to a device also moves its parameters.
+
+ Args:
+ device: The device to move the network to ('cuda' or 'mps').
+
+ Returns:
+ None
+ """
sim_params = SimulationParameters(
{
- 'W': torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10),
- 'H': torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10),
- 'wavelength': 1
+ "W": torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10),
+ "H": torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10),
+ "wavelength": 1,
}
)
- el1 = SimpleElement(
- a=Parameter(2.),
- simulation_parameters=sim_params
- )
+ el1 = SimpleElement(a=Parameter(2.0), simulation_parameters=sim_params)
el2 = SimpleElement(
- a=ConstrainedParameter(
- data=0.5,
- min_value=0,
- max_value=1
- ),
- simulation_parameters=sim_params
+ a=ConstrainedParameter(data=0.5, min_value=0, max_value=1),
+ simulation_parameters=sim_params,
)
setup = LinearOpticalSetup([el1, el2])
@@ -131,33 +182,38 @@ def test_to_device(device):
def test_reverse():
+ """
+ Tests the reverse method of LinearOpticalSetup.
+
+ Args:
+ None
+
+ Returns:
+ None
+ """
# test empty setup
setup = LinearOpticalSetup(elements=[])
- wf = torch.Tensor([2., 3.])
+ wf = torch.Tensor([2.0, 3.0])
assert setup.reverse(wf) is wf
# test setup
sim_params = SimulationParameters(
{
- 'W': torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10),
- 'H': torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10),
- 'wavelength': 1
+ "W": torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10),
+ "H": torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10),
+ "wavelength": 1,
}
)
a = torch.tensor(2)
# test unreversable element
el = SimpleElement(a=a, simulation_parameters=sim_params)
- setup = LinearOpticalSetup(elements=[
- el
- ])
+ setup = LinearOpticalSetup(elements=[el])
with pytest.raises(TypeError):
setup.reverse(wf)
# test reversable element
el = ReversableSimpleElement(a=a, simulation_parameters=sim_params)
- setup = LinearOpticalSetup(elements=[
- el
- ])
+ setup = LinearOpticalSetup(elements=[el])
torch.testing.assert_close(setup.reverse(wf), wf * a)
diff --git a/tests/test_simulation_parameters.py b/tests/test_simulation_parameters.py
index 648ce01..2736d7c 100644
--- a/tests/test_simulation_parameters.py
+++ b/tests/test_simulation_parameters.py
@@ -5,66 +5,82 @@
def test_axes():
+ """
+ Tests the Axes class for correct axis handling and validation."""
# Test required axes are actually required
with pytest.raises(ValueError):
Axes({})
- SimulationParameters({
- 'W': torch.linspace(-1, 1, 10),
- })
+ SimulationParameters(
+ {
+ "W": torch.linspace(-1, 1, 10),
+ }
+ )
with pytest.raises(ValueError):
- Axes({
- 'W': torch.linspace(-1, 1, 10),
- 'H': torch.linspace(-1, 1, 10),
- })
- Axes({
- 'W': torch.linspace(-1, 1, 10),
- 'H': torch.linspace(-1, 1, 10),
- 'wavelength': torch.tensor(312)
- })
+ Axes(
+ {
+ "W": torch.linspace(-1, 1, 10),
+ "H": torch.linspace(-1, 1, 10),
+ }
+ )
+ Axes(
+ {
+ "W": torch.linspace(-1, 1, 10),
+ "H": torch.linspace(-1, 1, 10),
+ "wavelength": torch.tensor(312),
+ }
+ )
# Test with wrong H and W axis shape
with pytest.raises(ValueError):
- Axes({
- 'W': torch.tensor([[10.]]), # wrong shape
- 'H': torch.linspace(-1, 1, 10),
- 'wavelength': torch.tensor(312)
- })
+ Axes(
+ {
+ "W": torch.tensor([[10.0]]), # wrong shape
+ "H": torch.linspace(-1, 1, 10),
+ "wavelength": torch.tensor(312),
+ }
+ )
with pytest.raises(ValueError):
- Axes({
- 'W': torch.linspace(-1, 1, 10),
- 'H': torch.tensor([[10.]]), # wrong shape
- 'wavelength': torch.tensor(312)
- })
+ Axes(
+ {
+ "W": torch.linspace(-1, 1, 10),
+ "H": torch.tensor([[10.0]]), # wrong shape
+ "wavelength": torch.tensor(312),
+ }
+ )
# Test with wrong additional axes shape
with pytest.raises(ValueError):
- Axes({
- 'W': torch.linspace(-1, 1, 10),
- 'H': torch.linspace(-1, 1, 10),
- 'wavelength': torch.tensor(312),
- 'pol': torch.tensor([[1.2, 3.4]]), # wrong shape
- })
+ Axes(
+ {
+ "W": torch.linspace(-1, 1, 10),
+ "H": torch.linspace(-1, 1, 10),
+ "wavelength": torch.tensor(312),
+ "pol": torch.tensor([[1.2, 3.4]]), # wrong shape
+ }
+ )
w_axis = torch.linspace(-1, 1, 10)
- pol_axis = torch.tensor([1., 0.])
- axes = Axes({
- 'W': w_axis,
- 'H': torch.linspace(-1, 1, 10),
- 'wavelength': torch.tensor(312),
- 'pol': pol_axis,
- })
+ pol_axis = torch.tensor([1.0, 0.0])
+ axes = Axes(
+ {
+ "W": w_axis,
+ "H": torch.linspace(-1, 1, 10),
+ "wavelength": torch.tensor(312),
+ "pol": pol_axis,
+ }
+ )
# Test names of non-scalar axes
- assert axes.names == ('pol', 'H', 'W')
+ assert axes.names == ("pol", "H", "W")
# Test indices
- assert axes.index('pol') == -3
- assert axes.index('H') == -2
- assert axes.index('W') == -1
+ assert axes.index("pol") == -3
+ assert axes.index("H") == -2
+ assert axes.index("W") == -1
with pytest.raises(AxisNotFound):
- axes.index('wavelength') # scalar axis
+ axes.index("wavelength") # scalar axis
with pytest.raises(AxisNotFound):
- axes.index('t') # axis does not exists
+ axes.index("t") # axis does not exists
# Test __getattribute__ for named axes
assert axes.W is w_axis
@@ -76,86 +92,106 @@ def test_axes():
assert axes.W is w_axis
# Test __getitem__
- assert axes['W'] is w_axis
- assert axes['pol'] is pol_axis
- assert axes['wavelength'] == torch.tensor(312)
+ assert axes["W"] is w_axis
+ assert axes["pol"] is pol_axis
+ assert axes["wavelength"] == torch.tensor(312)
with pytest.raises(AxisNotFound):
- axes['t'] # axis does not exists
+ axes["t"] # axis does not exists
# Test disabled __setitem__
with pytest.raises(RuntimeError):
- axes['W'] = w_axis
+ axes["W"] = w_axis
with pytest.raises(RuntimeError):
- axes['pol'] = pol_axis
+ axes["pol"] = pol_axis
with pytest.raises(RuntimeError):
- axes['t'] = 123
+ axes["t"] = 123
# Test __dir__
- assert set(dir(axes)) == {'H', 'W', 'pol', 'wavelength'}
+ assert set(dir(axes)) == {"H", "W", "pol", "wavelength"}
def test_simulation_parameters():
+ """
+ Tests the SimulationParameters class functionality.
+
+ This tests the __getitem__ method, meshgrid generation, and axes size retrieval
+ of the SimulationParameters class with various parameters. It also checks for
+ expected warnings when accessing non-existent axes.
+
+ Args:
+ None
+
+ Returns:
+ None
+ """
w_axis = torch.linspace(-1, 2, 13)
h_axis = torch.linspace(-12, -3, 25)
- pol_axis = torch.tensor([1., 0.])
- sim_params = SimulationParameters({
- 'W': w_axis,
- 'H': h_axis,
- 'wavelength': 123.,
- 'pol': pol_axis,
- 't': 0.0
- })
+ pol_axis = torch.tensor([1.0, 0.0])
+ sim_params = SimulationParameters(
+ {"W": w_axis, "H": h_axis, "wavelength": 123.0, "pol": pol_axis, "t": 0.0}
+ )
# Test __getitem__
- assert sim_params['W'] is w_axis
- assert sim_params['pol'] is pol_axis
- assert sim_params['t'] == 0
- assert sim_params['wavelength'] == 123
+ assert sim_params["W"] is w_axis
+ assert sim_params["pol"] is pol_axis
+ assert sim_params["t"] == 0
+ assert sim_params["wavelength"] == 123
# Test meshgrid
- meshgrid_W, meshgrid_H = sim_params.meshgrid('W', 'H')
+ meshgrid_W, meshgrid_H = sim_params.meshgrid("W", "H")
assert torch.allclose(meshgrid_W, w_axis[None, ...])
assert torch.allclose(meshgrid_H, h_axis[..., None])
- meshgrid_W1, meshgrid_W2 = sim_params.meshgrid('W', 'W')
+ meshgrid_W1, meshgrid_W2 = sim_params.meshgrid("W", "W")
assert torch.allclose(meshgrid_W1, w_axis[None, ...])
assert torch.allclose(meshgrid_W2, w_axis[..., None])
- meshgrid_H, meshgrid_wl = sim_params.meshgrid('H', 'wavelength')
+ meshgrid_H, meshgrid_wl = sim_params.meshgrid("H", "wavelength")
assert torch.allclose(meshgrid_H, h_axis[None, ...])
- assert torch.allclose(meshgrid_wl, torch.tensor(123.)[None])
+ assert torch.allclose(meshgrid_wl, torch.tensor(123.0)[None])
# Test axes_size
- assert sim_params.axes_size(('W',)) == torch.Size((13,))
- assert sim_params.axes_size(('wavelength', 'H')) == torch.Size((1, 25))
- assert sim_params.axes_size(('H',)) == torch.Size((25,))
+ assert sim_params.axes_size(("W",)) == torch.Size((13,))
+ assert sim_params.axes_size(("wavelength", "H")) == torch.Size((1, 25))
+ assert sim_params.axes_size(("H",)) == torch.Size((25,))
with pytest.warns(UserWarning):
# non existing axis
- assert sim_params.axes_size(('a', 'H')) == torch.Size((0, 25))
+ assert sim_params.axes_size(("a", "H")) == torch.Size((0, 25))
@pytest.fixture(
- scope='function',
+ scope="function",
params=[
- 'cpu',
+ "cpu",
pytest.param(
- 'cuda',
+ "cuda",
marks=pytest.mark.skipif(
- not torch.cuda.is_available(),
- reason="cuda is not available"
- )
+ not torch.cuda.is_available(), reason="cuda is not available"
+ ),
),
pytest.param(
- 'mps',
+ "mps",
marks=pytest.mark.skipif(
- not torch.backends.mps.is_available(),
- reason="mps is not available"
- )
- )
- ]
+ not torch.backends.mps.is_available(), reason="mps is not available"
+ ),
+ ),
+ ],
)
def default_device(request):
+ """
+ Provides a fixture for setting the default PyTorch device.
+
+ This fixture iterates through 'cpu', 'cuda' (if available), and 'mps' (if available)
+ as parameters, temporarily setting the default device to each one within the scope of a test function.
+ It yields the current default device and then restores the original default device after the test completes.
+
+ Args:
+ request: The pytest request object providing access to fixture parameters.
+
+ Returns:
+ str: The currently set default PyTorch device (e.g., 'cpu', 'cuda', or 'mps').
+ """
# Set the default device
old_default_device = torch.get_default_device()
torch.set_default_device(request.param)
@@ -164,23 +200,40 @@ def default_device(request):
def test_device(default_device: torch.device):
- w_axis = torch.linspace(-1, 2, 13, device='cpu')
+ """
+ Tests device placement and transfer for SimulationParameters.
+
+ This method verifies that the SimulationParameters class correctly handles
+ device placement of axis tensors, raises errors when appropriate, and
+ that the `to()` method functions as expected for transferring data between devices.
+
+ Args:
+ default_device: The default device to use for testing.
+
+ Returns:
+ None
+ """
+ w_axis = torch.linspace(-1, 2, 13, device="cpu")
h_axis = torch.linspace(-12, -3, 25)
- if default_device.type != 'cpu':
+ if default_device.type != "cpu":
with pytest.raises(ValueError):
- SimulationParameters({
- 'W': w_axis,
- 'H': h_axis.to(default_device),
- 'wavelength': 123.,
- })
+ SimulationParameters(
+ {
+ "W": w_axis,
+ "H": h_axis.to(default_device),
+ "wavelength": 123.0,
+ }
+ )
# Test if in the following case the axis tensor is located on the device
- sim_params = SimulationParameters({ # type: ignore
- 'W': [1., 2., 3.],
- 'H': [1., 2., 3.],
- 'wavelength': 123.
- })
+ sim_params = SimulationParameters(
+ { # type: ignore
+ "W": [1.0, 2.0, 3.0],
+ "H": [1.0, 2.0, 3.0],
+ "wavelength": 123.0,
+ }
+ )
assert sim_params.axes.W.device == default_device
# Test to() method
@@ -188,10 +241,10 @@ def test_device(default_device: torch.device):
assert transferred_sim_params is sim_params
# Test to('cpu')
- transferred_sim_params = sim_params.to('cpu')
- assert transferred_sim_params.device.type == 'cpu' # type: ignore
+ transferred_sim_params = sim_params.to("cpu")
+ assert transferred_sim_params.device.type == "cpu" # type: ignore
for axis_name in sim_params.axes.names:
- assert transferred_sim_params.axes[axis_name].device.type == 'cpu'
+ assert transferred_sim_params.axes[axis_name].device.type == "cpu"
# And back
transferred_sim_params = transferred_sim_params.to(default_device)
assert transferred_sim_params.device == default_device
diff --git a/tests/test_slm.py b/tests/test_slm.py
index d143434..6df2f5a 100644
--- a/tests/test_slm.py
+++ b/tests/test_slm.py
@@ -15,32 +15,72 @@
"width",
"mode",
"mask",
- "resized_mask"
+ "resized_mask",
]
@pytest.mark.parametrize(
parameters_mask,
[
- (10, 10, 4, 4, 10, 10, "nearest", torch.Tensor([[1., 2.], [3., 4.]]),
- torch.Tensor([
- [1., 1., 2., 2.,],
- [1., 1., 2., 2.,],
- [3., 3., 4., 4.,],
- [3., 3., 4., 4.,]
- ])
- ),
- (15, 8, 6, 6, 8, 15, "nearest", torch.Tensor([[2., 3.], [4., 5.]]),
- torch.Tensor([
- [2., 2., 2., 3., 3., 3.],
- [2., 2., 2., 3., 3., 3.],
- [2., 2., 2., 3., 3., 3.],
- [4., 4., 4., 5., 5., 5.],
- [4., 4., 4., 5., 5., 5.],
- [4., 4., 4., 5., 5., 5.]
- ])
- )
- ]
+ (
+ 10,
+ 10,
+ 4,
+ 4,
+ 10,
+ 10,
+ "nearest",
+ torch.Tensor([[1.0, 2.0], [3.0, 4.0]]),
+ torch.Tensor(
+ [
+ [
+ 1.0,
+ 1.0,
+ 2.0,
+ 2.0,
+ ],
+ [
+ 1.0,
+ 1.0,
+ 2.0,
+ 2.0,
+ ],
+ [
+ 3.0,
+ 3.0,
+ 4.0,
+ 4.0,
+ ],
+ [
+ 3.0,
+ 3.0,
+ 4.0,
+ 4.0,
+ ],
+ ]
+ ),
+ ),
+ (
+ 15,
+ 8,
+ 6,
+ 6,
+ 8,
+ 15,
+ "nearest",
+ torch.Tensor([[2.0, 3.0], [4.0, 5.0]]),
+ torch.Tensor(
+ [
+ [2.0, 2.0, 2.0, 3.0, 3.0, 3.0],
+ [2.0, 2.0, 2.0, 3.0, 3.0, 3.0],
+ [2.0, 2.0, 2.0, 3.0, 3.0, 3.0],
+ [4.0, 4.0, 4.0, 5.0, 5.0, 5.0],
+ [4.0, 4.0, 4.0, 5.0, 5.0, 5.0],
+ [4.0, 4.0, 4.0, 5.0, 5.0, 5.0],
+ ]
+ ),
+ ),
+ ],
)
def test_slm_mask(
ox_size: float,
@@ -51,25 +91,38 @@ def test_slm_mask(
width: float,
mode: str,
mask: torch.Tensor,
- resized_mask: torch.Tensor
+ resized_mask: torch.Tensor,
):
+ """
+ Tests the SpatialLightModulator's resized mask functionality.
+
+ Args:
+ ox_size: The size of the x-axis in simulation units.
+ oy_size: The size of the y-axis in simulation units.
+ ox_nodes: The number of nodes along the x-axis.
+ oy_nodes: The number of nodes along the y-axis.
+ height: The height of the SLM mask.
+ width: The width of the SLM mask.
+ mode: The resizing mode (e.g., "nearest").
+ mask: The input mask tensor.
+ resized_mask: The expected resized mask tensor.
+
+ Returns:
+ None. Raises an AssertionError if the resized masks do not match.
+ """
x_length = torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes)
y_length = torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes)
params = SimulationParameters(
axes={
- 'W': x_length,
- 'H': y_length,
- 'wavelength': 1064 * 1e-6,
+ "W": x_length,
+ "H": y_length,
+ "wavelength": 1064 * 1e-6,
}
)
slm = elements.SpatialLightModulator(
- simulation_parameters=params,
- mask=mask,
- height=height,
- width=width,
- mode=mode
+ simulation_parameters=params, mask=mask, height=height, width=width, mode=mode
)
slm.get_aperture
resized_mask_slm = slm.resized_mask
@@ -77,14 +130,7 @@ def test_slm_mask(
assert torch.allclose(resized_mask, resized_mask_slm)
-parameters_resize = [
- "ox_size",
- "oy_size",
- "ox_nodes",
- "oy_nodes",
- "mode",
- "mask"
-]
+parameters_resize = ["ox_size", "oy_size", "ox_nodes", "oy_nodes", "mode", "mask"]
@pytest.mark.parametrize(
@@ -95,8 +141,8 @@ def test_slm_mask(
(6, 5, 1570, 632, "bicubic", torch.rand(100, 100)),
(15.8, 8.61, 109, 120, "area", torch.rand(100, 100)),
(19, 7, 1089, 2007, "nearest-exact", torch.rand(100, 100)),
- (15, 8, 300, 400, "nearest-exact", torch.rand(1080, 1920))
- ]
+ (15, 8, 300, 400, "nearest-exact", torch.rand(1080, 1920)),
+ ],
)
def test_slm_resize(
ox_size: float,
@@ -104,16 +150,30 @@ def test_slm_resize(
ox_nodes: int,
oy_nodes: int,
mode: str,
- mask: torch.Tensor
+ mask: torch.Tensor,
):
+ """
+ Tests the resizing functionality of the SpatialLightModulator.
+
+ Args:
+ ox_size: The size of the x-axis.
+ oy_size: The size of the y-axis.
+ ox_nodes: The number of nodes along the x-axis.
+ oy_nodes: The number of nodes along the y-axis.
+ mode: The resizing mode to use (e.g., "nearest", "bilinear").
+ mask: The input mask tensor.
+
+ Returns:
+ None. This function asserts properties of the resized mask and aperture.
+ """
x_length = torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes)
y_length = torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes)
params = SimulationParameters(
axes={
- 'W': x_length,
- 'H': y_length,
- 'wavelength': 1064 * 1e-6,
+ "W": x_length,
+ "H": y_length,
+ "wavelength": 1064 * 1e-6,
}
)
@@ -122,7 +182,7 @@ def test_slm_resize(
mask=mask,
height=oy_size,
width=ox_size,
- mode=mode
+ mode=mode,
)
aperture = slm.get_aperture
resized_mask = slm.resized_mask
@@ -141,51 +201,87 @@ def test_slm_resize(
"height",
"width",
"location",
- "aperture"
+ "aperture",
]
@pytest.mark.parametrize(
parameters_aperture,
[
- (6, 5, 6, 5, 3, 3, (-1.5, -1),
- torch.tensor([
- [1., 1., 1., 0., 0., 0.],
- [1., 1., 1., 0., 0., 0.],
- [1., 1., 1., 0., 0., 0.],
- [0., 0., 0., 0., 0., 0.],
- [0., 0., 0., 0., 0., 0.]
- ])
- ),
- (6, 5, 6, 5, 3, 3, (-1.5, 1),
- torch.tensor([
- [0., 0., 0., 0., 0., 0.],
- [0., 0., 0., 0., 0., 0.],
- [1., 1., 1., 0., 0., 0.],
- [1., 1., 1., 0., 0., 0.],
- [1., 1., 1., 0., 0., 0.]
- ])
- ),
- (6, 5, 6, 5, 3, 3, (1.5, 1),
- torch.tensor([
- [0., 0., 0., 0., 0., 0.],
- [0., 0., 0., 0., 0., 0.],
- [0., 0., 0., 1., 1., 1.],
- [0., 0., 0., 1., 1., 1.],
- [0., 0., 0., 1., 1., 1.]
- ])
- ),
- (6, 5, 6, 5, 3, 3, (1.5, -1),
- torch.tensor([
- [0., 0., 0., 1., 1., 1.],
- [0., 0., 0., 1., 1., 1.],
- [0., 0., 0., 1., 1., 1.],
- [0., 0., 0., 0., 0., 0.],
- [0., 0., 0., 0., 0., 0.]
- ])
- ),
- (6, 5, 6, 5, 3, 3, (-100, 100), torch.zeros(5, 6))
- ]
+ (
+ 6,
+ 5,
+ 6,
+ 5,
+ 3,
+ 3,
+ (-1.5, -1),
+ torch.tensor(
+ [
+ [1.0, 1.0, 1.0, 0.0, 0.0, 0.0],
+ [1.0, 1.0, 1.0, 0.0, 0.0, 0.0],
+ [1.0, 1.0, 1.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ ]
+ ),
+ ),
+ (
+ 6,
+ 5,
+ 6,
+ 5,
+ 3,
+ 3,
+ (-1.5, 1),
+ torch.tensor(
+ [
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ [1.0, 1.0, 1.0, 0.0, 0.0, 0.0],
+ [1.0, 1.0, 1.0, 0.0, 0.0, 0.0],
+ [1.0, 1.0, 1.0, 0.0, 0.0, 0.0],
+ ]
+ ),
+ ),
+ (
+ 6,
+ 5,
+ 6,
+ 5,
+ 3,
+ 3,
+ (1.5, 1),
+ torch.tensor(
+ [
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
+ [0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
+ [0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
+ ]
+ ),
+ ),
+ (
+ 6,
+ 5,
+ 6,
+ 5,
+ 3,
+ 3,
+ (1.5, -1),
+ torch.tensor(
+ [
+ [0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
+ [0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
+ [0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ ]
+ ),
+ ),
+ (6, 5, 6, 5, 3, 3, (-100, 100), torch.zeros(5, 6)),
+ ],
)
def test_slm_aperture(
ox_size: float,
@@ -195,16 +291,32 @@ def test_slm_aperture(
height: float,
width: float,
location: tuple,
- aperture: torch.Tensor
+ aperture: torch.Tensor,
):
+ """
+ Tests the SpatialLightModulator aperture with different parameters.
+
+ Args:
+ ox_size: The size of the x-axis.
+ oy_size: The size of the y-axis.
+ ox_nodes: The number of nodes in the x-axis.
+ oy_nodes: The number of nodes in the y-axis.
+ height: The height of the SLM.
+ width: The width of the SLM.
+ location: The location of the SLM.
+ aperture: The expected aperture tensor.
+
+ Returns:
+ None. Asserts that the calculated aperture matches the expected value.
+ """
x_length = torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes)
y_length = torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes)
params = SimulationParameters(
axes={
- 'W': x_length,
- 'H': y_length,
- 'wavelength': 1064 * 1e-6,
+ "W": x_length,
+ "H": y_length,
+ "wavelength": 1064 * 1e-6,
}
)
@@ -213,7 +325,7 @@ def test_slm_aperture(
mask=torch.zeros(ox_nodes, oy_nodes),
height=height,
width=width,
- location=location
+ location=location,
)
slm.get_aperture
@@ -230,29 +342,61 @@ def test_slm_aperture(
"location",
"mask",
"wavelength",
- "mode"
+ "mode",
]
@pytest.mark.parametrize(
parameters_propagation,
[
- (10, 10, 1000, 1200, 3., 4., (0., 0.),
- torch.rand(100, 100), torch.linspace(330, 1064, 4) * 1e-6,
- "nearest"
- ),
- (9, 12, 1000, 1200, 3., 4., (-2., 3.),
- torch.rand(100, 100), torch.linspace(330, 1064, 4) * 1e-6,
- "bilinear"
- ),
- (15.8, 8.61, 1920, 1080, 2., 2., (2., 0.),
- torch.rand(100, 100), torch.linspace(330, 1064, 4) * 1e-6,
- "bicubic"
- ),
- (30, 15, 1920*2, 1080*2, 15.8, 8.61, (-1., 1.),
- torch.rand(1080, 1920), torch.linspace(330, 1064, 4) * 1e-6,
- "bicubic"
- ),
+ (
+ 10,
+ 10,
+ 1000,
+ 1200,
+ 3.0,
+ 4.0,
+ (0.0, 0.0),
+ torch.rand(100, 100),
+ torch.linspace(330, 1064, 4) * 1e-6,
+ "nearest",
+ ),
+ (
+ 9,
+ 12,
+ 1000,
+ 1200,
+ 3.0,
+ 4.0,
+ (-2.0, 3.0),
+ torch.rand(100, 100),
+ torch.linspace(330, 1064, 4) * 1e-6,
+ "bilinear",
+ ),
+ (
+ 15.8,
+ 8.61,
+ 1920,
+ 1080,
+ 2.0,
+ 2.0,
+ (2.0, 0.0),
+ torch.rand(100, 100),
+ torch.linspace(330, 1064, 4) * 1e-6,
+ "bicubic",
+ ),
+ (
+ 30,
+ 15,
+ 1920 * 2,
+ 1080 * 2,
+ 15.8,
+ 8.61,
+ (-1.0, 1.0),
+ torch.rand(1080, 1920),
+ torch.linspace(330, 1064, 4) * 1e-6,
+ "bicubic",
+ ),
# (5, 9, 100, 400, 3., 4., (-1., 8.),
# torch.rand(100, 100), torch.linspace(330, 1064, 4) * 1e-6,
# "area"
@@ -261,7 +405,7 @@ def test_slm_aperture(
# torch.rand(100, 100), torch.linspace(330, 1064, 4) * 1e-6,
# "nearest-exact"
# )
- ]
+ ],
)
def test_slm_propagation(
ox_size: float,
@@ -273,17 +417,35 @@ def test_slm_propagation(
location: tuple,
mask: torch.Tensor,
wavelength: float,
- mode: str
+ mode: str,
):
+ """
+ Tests the propagation of a wavefront through an SLM.
+
+ Args:
+ ox_size: The size of the x-axis in meters.
+ oy_size: The size of the y-axis in meters.
+ ox_nodes: The number of nodes along the x-axis.
+ oy_nodes: The number of nodes along the y-axis.
+ height: The height of the SLM in meters.
+ width: The width of the SLM in meters.
+ location: The location of the SLM center in (x, y) coordinates.
+ mask: A 2D tensor representing the mask applied by the SLM.
+ wavelength: The wavelength of light in meters.
+ mode: The interpolation mode used for the SLM ("nearest", "bilinear", etc.).
+
+ Returns:
+ None. Asserts that the output field has the correct size.
+ """
x_length = torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes)
y_length = torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes)
params = SimulationParameters(
axes={
- 'W': x_length,
- 'H': y_length,
- 'wavelength': wavelength,
+ "W": x_length,
+ "H": y_length,
+ "wavelength": wavelength,
}
)
@@ -293,13 +455,11 @@ def test_slm_propagation(
height=height,
width=width,
location=location,
- mode=mode
+ mode=mode,
)
incident_field = w.Wavefront.gaussian_beam(
- simulation_parameters=params,
- waist_radius=2.,
- distance=100
+ simulation_parameters=params, waist_radius=2.0, distance=100
)
transmitted_field = slm(incident_field)
diff --git a/tests/test_specs.py b/tests/test_specs.py
index 95f324d..166c2ad 100644
--- a/tests/test_specs.py
+++ b/tests/test_specs.py
@@ -17,22 +17,48 @@
def test_save_context_get_new_filepath(tmp_path):
+ """
+ Tests the get_new_filepath method of ParameterSaveContext.
+
+ This tests that subsequent calls to `get_new_filepath` with the same extension
+ return unique filenames within the specified directory, incrementing a counter.
+
+ Args:
+ tmp_path: A temporary path for testing purposes.
+
+ Returns:
+ None
+ """
context = ParameterSaveContext(
- parameter_name='test',
+ parameter_name="test",
directory=tmp_path,
)
# test filename
path = context.get_new_filepath("testext")
- assert Path(tmp_path, 'test_0.testext') == path
+ assert Path(tmp_path, "test_0.testext") == path
path = context.get_new_filepath("testext")
- assert Path(tmp_path, 'test_1.testext') == path
+ assert Path(tmp_path, "test_1.testext") == path
def test_save_context_file(tmp_path):
+ """
+ Tests saving a file within the parameter save context.
+
+ This test creates a ParameterSaveContext, writes data to a new file obtained
+ through the context, and verifies that the data was written correctly and
+ that subsequent calls for new files generate different filenames while
+ remaining in the same directory.
+
+ Args:
+ tmp_path: A temporary path where the test file will be created.
+
+ Returns:
+ None
+ """
context = ParameterSaveContext(
- parameter_name='test',
+ parameter_name="test",
directory=tmp_path,
)
@@ -43,7 +69,7 @@ def test_save_context_file(tmp_path):
file.write(text.encode())
# check if the test text is written into the file
- with open(path, 'rb') as file:
+ with open(path, "rb") as file:
assert file.readline() == text.encode()
# check if the new file will have another name, but same folder
@@ -53,8 +79,17 @@ def test_save_context_file(tmp_path):
def test_save_context_rel_filepath(tmp_path):
+ """
+ Tests that the relative filepath is correctly computed.
+
+ Args:
+ tmp_path: A temporary path to use for testing.
+
+ Returns:
+ None
+ """
contexts = ParameterSaveContext(
- parameter_name='test',
+ parameter_name="test",
directory=tmp_path,
)
@@ -69,23 +104,28 @@ def test_save_context_rel_filepath(tmp_path):
###############################################################################
-@pytest.mark.usefixtures('tmp_path')
-@pytest.mark.parametrize(
- 'mode', ('1', 'L', 'LA', 'I', 'P', 'RGB', 'RGBA')
-)
+@pytest.mark.usefixtures("tmp_path")
+@pytest.mark.parametrize("mode", ("1", "L", "LA", "I", "P", "RGB", "RGBA"))
def test_image_repr_draw_image(tmp_path, mode):
+ """
+ Tests that drawing an ImageRepr to a file produces the correct image.
+
+ Args:
+ tmp_path: A temporary path for saving the image.
+ mode: The color mode of the image (e.g., '1', 'L', 'RGB').
+
+ Returns:
+ None
+ """
context = ParameterSaveContext(
- parameter_name='test',
+ parameter_name="test",
directory=tmp_path,
)
# TODO: mode-based test
image_to_draw = np.array([[1]])
- repr = ImageRepr(
- value=image_to_draw,
- mode=mode
- )
+ repr = ImageRepr(value=image_to_draw, mode=mode)
# draw image to the path
path = context.get_new_filepath("png")
@@ -96,14 +136,21 @@ def test_image_repr_draw_image(tmp_path, mode):
def test_image_repr_to(tmp_path):
+ """
+ Tests the to_str, to_markdown, and to_html methods of ImageRepr.
+
+ Args:
+ tmp_path: A temporary path for ParameterSaveContext.
+
+ Returns:
+ None
+ """
context = ParameterSaveContext(
- parameter_name='test',
+ parameter_name="test",
directory=tmp_path,
)
- repr = ImageRepr(
- value=np.array([[0.5]])
- )
+ repr = ImageRepr(value=np.array([[0.5]]))
# test for all possible exports
test_out = StringIO()
@@ -125,14 +172,21 @@ def test_image_repr_to(tmp_path):
def test_repr_repr_to(tmp_path):
+ """
+ Tests the to_str, to_markdown, and to_html methods of ReprRepr.
+
+ Args:
+ tmp_path: A temporary path for ParameterSaveContext.
+
+ Returns:
+ None
+ """
context = ParameterSaveContext(
- parameter_name='test',
+ parameter_name="test",
directory=tmp_path,
)
- repr = ReprRepr(
- value=np.array([[0.5]])
- )
+ repr = ReprRepr(value=np.array([[0.5]]))
# test for all possible exports
test_out = StringIO()
@@ -153,26 +207,36 @@ def test_repr_repr_to(tmp_path):
###############################################################################
-@pytest.mark.usefixtures('tmp_path')
+@pytest.mark.usefixtures("tmp_path")
@pytest.mark.parametrize(
- 'value', (
+ "value",
+ (
np.random.rand(2, 2),
torch.rand(2, 2),
torch.tensor(random.random()),
random.random(),
random.randint(0, 10),
- ConstrainedParameter(10, 0, 20)
- )
+ ConstrainedParameter(10, 0, 20),
+ ),
)
def test_pretty_repr_repr_to(tmp_path, value, monkeypatch):
+ """
+ Tests the conversion of a value to string, markdown, and HTML representations.
+
+ Args:
+ tmp_path: A temporary path for testing file operations (pytest fixture).
+ value: The value to be converted. Can be a numpy array, torch tensor, float, int or ConstrainedParameter object.
+ monkeypatch: A pytest monkeypatch fixture used for mocking imports.
+
+ Returns:
+ None
+ """
context = ParameterSaveContext(
- parameter_name='test',
+ parameter_name="test",
directory=tmp_path,
)
- repr = PrettyReprRepr(
- value=value
- )
+ repr = PrettyReprRepr(value=value)
# test for all possible exports
test_out = StringIO()
@@ -198,13 +262,13 @@ def import_with_no_svetlanna(name, *args, **kwargs):
raise ImportError
return original_import(name, *args, **kwargs)
- monkeypatch.setattr(builtins, '__import__', import_with_no_svetlanna)
+ monkeypatch.setattr(builtins, "__import__", import_with_no_svetlanna)
# Test if default string is written to the buffer
test_out = StringIO()
repr.to_str(test_out, context)
class_name = value.__class__.__name__
- assert test_out.getvalue() == f'{class_name}\n{value.item()}\n'
+ assert test_out.getvalue() == f"{class_name}\n{value.item()}\n"
###############################################################################
@@ -212,23 +276,32 @@ def import_with_no_svetlanna(name, *args, **kwargs):
###############################################################################
-@pytest.mark.usefixtures('tmp_path')
+@pytest.mark.usefixtures("tmp_path")
@pytest.mark.parametrize(
- 'value', (
+ "value",
+ (
np.random.rand(10, 10),
random.random(),
random.randint(0, 10),
- )
+ ),
)
def test_npy_file_repr_save_to_file(tmp_path, value):
+ """
+ Saves a value to a file using NpyFileRepr and verifies it can be loaded correctly.
+
+ Args:
+ tmp_path: A temporary path for saving the file.
+ value: The value to save (NumPy array or scalar).
+
+ Returns:
+ None
+ """
context = ParameterSaveContext(
- parameter_name='test',
+ parameter_name="test",
directory=tmp_path,
)
- repr = NpyFileRepr(
- value=value
- )
+ repr = NpyFileRepr(value=value)
# save the value to a new file
path = context.get_new_filepath("png")
@@ -239,14 +312,21 @@ def test_npy_file_repr_save_to_file(tmp_path, value):
def test_npy_file_repr_to(tmp_path):
+ """
+ Tests the to_str and to_markdown methods of NpyFileRepr.
+
+ Args:
+ tmp_path: A temporary path for creating files.
+
+ Returns:
+ None
+ """
context = ParameterSaveContext(
- parameter_name='test',
+ parameter_name="test",
directory=tmp_path,
)
- repr = NpyFileRepr(
- value=np.array([[0.5]])
- )
+ repr = NpyFileRepr(value=np.array([[0.5]]))
# test for all possible exports
test_out = StringIO()
@@ -264,15 +344,21 @@ def test_npy_file_repr_to(tmp_path):
def test_parameter_specs():
+ """
+ Tests the ParameterSpecs class with a simple example.
+
+ Args:
+ None
+
+ Returns:
+ None
+ """
representations = (
ReprRepr(123),
ReprRepr(321),
)
- specs = ParameterSpecs(
- parameter_name='test',
- representations=representations
- )
+ specs = ParameterSpecs(parameter_name="test", representations=representations)
assert specs.representations == representations
@@ -283,18 +369,15 @@ def test_parameter_specs():
def test_subelement_specs():
- specs = [
- ParameterSpecs('test', [])
- ]
+ """
+ Tests that the SubelementSpecs class correctly stores its subelement."""
+ specs = [ParameterSpecs("test", [])]
class Subelement:
def to_specs(self):
return specs
subelement = Subelement()
- subelement_specs = SubelementSpecs(
- 'test_type',
- subelement
- )
+ subelement_specs = SubelementSpecs("test_type", subelement)
assert subelement_specs.subelement is subelement
diff --git a/tests/test_specs_writer.py b/tests/test_specs_writer.py
index ec040a8..e63ded0 100644
--- a/tests/test_specs_writer.py
+++ b/tests/test_specs_writer.py
@@ -18,60 +18,99 @@
class SpecsTestElement(Element):
+ """
+ Tests a set of parameter specifications or subelement specifications."""
def __init__(
self,
simulation_parameters: SimulationParameters,
- test_specs: Iterable[ParameterSpecs | SubelementSpecs]
+ test_specs: Iterable[ParameterSpecs | SubelementSpecs],
) -> None:
+ """
+ Initializes a new instance of the class.
+
+ Args:
+ simulation_parameters: The simulation parameters to use.
+ test_specs: An iterable of parameter specifications or subelement specifications
+ to be tested.
+
+ Returns:
+ None
+ """
super().__init__(simulation_parameters)
self.test_specs = test_specs
def forward(self, incident_wavefront: Wavefront) -> Wavefront:
+ """
+ Passes the wavefront to the next layer.
+
+ This method simply calls the `forward` method of the parent class,
+ effectively passing the incident wavefront along for further processing.
+
+ Args:
+ incident_wavefront: The input wavefront representing the current state
+ of the wave propagation.
+
+ Returns:
+ Wavefront: The output wavefront after being processed by the next layer.
+ """
return super().forward(incident_wavefront)
def to_specs(self) -> Iterable[ParameterSpecs | SubelementSpecs]:
+ """
+ Returns the test specifications.
+
+ Args:
+ None
+
+ Returns:
+ Iterable[ParameterSpecs | SubelementSpecs]: An iterable of parameter or subelement specifications.
+ """
return self.test_specs
def test_context_generator(tmp_path):
+ """
+ Tests the context generator with a sample SpecsTestElement.
+
+ This test verifies that the context generator produces contexts with the
+ correct parameter names, representations, and indices. It also tests
+ the output of writing specs to string, markdown, and HTML formats.
+
+ Args:
+ tmp_path: A temporary path for testing purposes.
+
+ Returns:
+ None
+ """
simulation_parameters = SimulationParameters(
- axes={'W': torch.tensor([0]), 'H': torch.tensor([0]), 'wavelength': 1}
+ axes={"W": torch.tensor([0]), "H": torch.tensor([0]), "wavelength": 1}
)
- repr1 = ReprRepr(1.)
- repr2 = ReprRepr(2.)
- repr3 = ReprRepr(3.)
- repr4 = ReprRepr(4.)
+ repr1 = ReprRepr(1.0)
+ repr2 = ReprRepr(2.0)
+ repr3 = ReprRepr(3.0)
+ repr4 = ReprRepr(4.0)
subelement = SpecsTestElement(simulation_parameters, [])
element = SpecsTestElement(
simulation_parameters=simulation_parameters,
test_specs=[
+ ParameterSpecs("test1", [repr1, repr2]),
ParameterSpecs(
- 'test1',
- [
- repr1,
- repr2
- ]
- ),
- ParameterSpecs(
- 'test2',
+ "test2",
[
repr3,
- ]
+ ],
),
ParameterSpecs(
- 'test2', # test for the parameter spec with the same name
+ "test2", # test for the parameter spec with the same name
[
repr4,
- ]
+ ],
),
- SubelementSpecs(
- 'test_type',
- subelement
- )
- ]
+ SubelementSpecs("test_type", subelement),
+ ],
)
subelements: list[SubelementSpecs] = []
@@ -83,10 +122,10 @@ def test_context_generator(tmp_path):
assert subelements[0].subelement is subelement
# test parameter_name attribute
- assert contexts[0].parameter_name.value == 'test1'
- assert contexts[1].parameter_name.value == 'test1'
- assert contexts[2].parameter_name.value == 'test2'
- assert contexts[3].parameter_name.value == 'test2'
+ assert contexts[0].parameter_name.value == "test1"
+ assert contexts[1].parameter_name.value == "test1"
+ assert contexts[2].parameter_name.value == "test2"
+ assert contexts[3].parameter_name.value == "test2"
assert contexts[0].parameter_name.index == 0
assert contexts[1].parameter_name.index == 0
@@ -125,52 +164,81 @@ def test_context_generator(tmp_path):
def test_ElementInTree():
+ """
+ Tests the creation and copying of an ElementInTree object.
+
+ This test verifies that creating a copy of _ElementInTree correctly
+ shares references to immutable attributes (element, element_index, children)
+ while creating a new instance for mutable attributes (subelement_type).
+
+ Parameters:
+ None
+
+ Returns:
+ None
+ """
simulation_parameters = SimulationParameters(
- axes={'W': torch.tensor([0]), 'H': torch.tensor([0]), 'wavelength': 1}
+ axes={"W": torch.tensor([0]), "H": torch.tensor([0]), "wavelength": 1}
)
element = SpecsTestElement(
- simulation_parameters=simulation_parameters,
- test_specs=[]
+ simulation_parameters=simulation_parameters, test_specs=[]
)
- tree_element = _ElementInTree(element, 123, [], 'test_1')
- tree_element_copy = tree_element.create_copy('test_2')
+ tree_element = _ElementInTree(element, 123, [], "test_1")
+ tree_element_copy = tree_element.create_copy("test_2")
assert tree_element_copy.element is tree_element.element
assert tree_element_copy.element_index is tree_element.element_index
assert tree_element_copy.children is tree_element.children
assert tree_element_copy.subelement_type != tree_element.subelement_type
- assert tree_element_copy.subelement_type == 'test_2'
+ assert tree_element_copy.subelement_type == "test_2"
def test_ElementsIterator():
+ """
+ Tests the functionality of the ElementsIterator class.
+
+ This test case creates a complex element structure with nested subelements and
+ verifies that the iterator correctly traverses this structure, yielding each
+ element in the expected order. It also checks if the tree is saved and rebuilt
+ correctly during multiple iterations.
+
+ Args:
+ None
+
+ Returns:
+ None
+ """
simulation_parameters = SimulationParameters(
- axes={'W': torch.tensor([0]), 'H': torch.tensor([0]), 'wavelength': 1}
+ axes={"W": torch.tensor([0]), "H": torch.tensor([0]), "wavelength": 1}
)
- repr1 = ReprRepr(1.)
+ repr1 = ReprRepr(1.0)
subelement1 = SpecsTestElement(simulation_parameters, [])
subelement2 = SpecsTestElement(simulation_parameters, [])
- subelement3 = SpecsTestElement(simulation_parameters, [
- SubelementSpecs('subelement1', subelement1),
- SubelementSpecs('subelement2', subelement2)
- ])
+ subelement3 = SpecsTestElement(
+ simulation_parameters,
+ [
+ SubelementSpecs("subelement1", subelement1),
+ SubelementSpecs("subelement2", subelement2),
+ ],
+ )
element = SpecsTestElement(
simulation_parameters=simulation_parameters,
test_specs=[
ParameterSpecs(
- 'test1',
+ "test1",
[
repr1,
- ]
+ ],
),
- SubelementSpecs('subelement1_copy', subelement1),
- SubelementSpecs('subelement3', subelement3),
- ]
+ SubelementSpecs("subelement1_copy", subelement1),
+ SubelementSpecs("subelement3", subelement3),
+ ],
)
- elements = _ElementsIterator(element, directory='')
+ elements = _ElementsIterator(element, directory="")
# Test iterator output
iterated_indices = []
@@ -185,9 +253,7 @@ def test_ElementsIterator():
iterated_elements.append(el)
assert iterated_indices == list(range(4))
- assert iterated_elements == [
- element, subelement1, subelement3, subelement2
- ]
+ assert iterated_elements == [element, subelement1, subelement3, subelement2]
# Test if the tree is saved in the iterator
tree = elements.tree
@@ -212,87 +278,109 @@ def test_ElementsIterator():
assert elements.tree == tree
# Test if the tree can be generated automatically
- new_elements = _ElementsIterator(element, directory='')
+ new_elements = _ElementsIterator(element, directory="")
assert new_elements.tree is not tree
assert new_elements.tree == tree
assert new_elements.tree is new_elements.tree
def test_write_tree(tmp_path):
+ """
+ Tests writing the elements tree to both string and markdown formats.
+
+ This test creates a nested structure of simulation elements and then
+ verifies that `write_elements_tree_to_str` and `write_elements_tree_to_markdown`
+ can successfully write this tree to a string stream and produce non-empty output,
+ respectively.
+
+ Args:
+ tmp_path: A temporary path (not directly used in the test but required by the fixture).
+
+ Returns:
+ None
+ """
simulation_parameters = SimulationParameters(
- axes={'W': torch.tensor([0]), 'H': torch.tensor([0]), 'wavelength': 1}
+ axes={"W": torch.tensor([0]), "H": torch.tensor([0]), "wavelength": 1}
)
- repr1 = ReprRepr(1.)
+ repr1 = ReprRepr(1.0)
subelement1 = SpecsTestElement(simulation_parameters, [])
subelement2 = SpecsTestElement(simulation_parameters, [])
- subelement3 = SpecsTestElement(simulation_parameters, [
- SubelementSpecs('subelement1', subelement1),
- SubelementSpecs('subelement2', subelement2)
- ])
+ subelement3 = SpecsTestElement(
+ simulation_parameters,
+ [
+ SubelementSpecs("subelement1", subelement1),
+ SubelementSpecs("subelement2", subelement2),
+ ],
+ )
element = SpecsTestElement(
simulation_parameters=simulation_parameters,
test_specs=[
ParameterSpecs(
- 'test1',
+ "test1",
[
repr1,
- ]
+ ],
),
- SubelementSpecs('subelement1_copy', subelement1),
- SubelementSpecs('subelement3', subelement3),
- ]
+ SubelementSpecs("subelement1_copy", subelement1),
+ SubelementSpecs("subelement3", subelement3),
+ ],
)
- elements = _ElementsIterator(element, directory='')
+ elements = _ElementsIterator(element, directory="")
# === test str ===
- stream = StringIO('')
+ stream = StringIO("")
write_elements_tree_to_str(elements.tree, stream)
assert stream.getvalue()
# === test md ===
- stream = StringIO('')
+ stream = StringIO("")
write_elements_tree_to_markdown(elements.tree, stream)
assert stream.getvalue()
def test_write_specs(tmp_path):
+ """
+ Tests the write_specs function with different file formats."""
simulation_parameters = SimulationParameters(
- axes={'W': torch.tensor([0]), 'H': torch.tensor([0]), 'wavelength': 1}
+ axes={"W": torch.tensor([0]), "H": torch.tensor([0]), "wavelength": 1}
)
- repr1 = ReprRepr(1.)
+ repr1 = ReprRepr(1.0)
subelement1 = SpecsTestElement(simulation_parameters, [])
subelement2 = SpecsTestElement(simulation_parameters, [])
- subelement3 = SpecsTestElement(simulation_parameters, [
- SubelementSpecs('subelement1', subelement1),
- SubelementSpecs('subelement2', subelement2)
- ])
+ subelement3 = SpecsTestElement(
+ simulation_parameters,
+ [
+ SubelementSpecs("subelement1", subelement1),
+ SubelementSpecs("subelement2", subelement2),
+ ],
+ )
element = SpecsTestElement(
simulation_parameters=simulation_parameters,
test_specs=[
ParameterSpecs(
- 'test1',
+ "test1",
[
repr1,
- ]
+ ],
),
- SubelementSpecs('subelement1_copy', subelement1),
- SubelementSpecs('subelement3', subelement3),
- ]
+ SubelementSpecs("subelement1_copy", subelement1),
+ SubelementSpecs("subelement3", subelement3),
+ ],
)
# === test txt ===
- write_specs(element, filename='test_specs.txt', directory=tmp_path)
- assert Path.exists(tmp_path / 'test_specs.txt')
+ write_specs(element, filename="test_specs.txt", directory=tmp_path)
+ assert Path.exists(tmp_path / "test_specs.txt")
# === test md ===
- write_specs(element, filename='test_specs.md', directory=tmp_path)
- assert Path.exists(tmp_path / 'test_specs.md')
+ write_specs(element, filename="test_specs.md", directory=tmp_path)
+ assert Path.exists(tmp_path / "test_specs.md")
# === test unknown format ===
with pytest.raises(ValueError):
- write_specs(element, filename='test_specs.test', directory=tmp_path)
+ write_specs(element, filename="test_specs.test", directory=tmp_path)
diff --git a/tests/test_types.py b/tests/test_types.py
index d04d544..0415ca1 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -8,10 +8,7 @@
parameters = "default_type"
-@pytest.mark.parametrize(parameters, [
- torch.float64,
- torch.float32
-])
+@pytest.mark.parametrize(parameters, [torch.float64, torch.float32])
def test_types(default_type: torch.dtype):
"""A test that checks that all elements belong to the same data type
@@ -23,17 +20,17 @@ def test_types(default_type: torch.dtype):
torch.set_default_dtype(default_type)
- ox_size = 15.
- oy_size = 8.
+ ox_size = 15.0
+ oy_size = 8.0
ox_nodes = 1200
oy_nodes = 1100
- wavelength = torch.linspace(330*1e-6, 660*1e-6, 5)
- waist_radius = 2.
- distance = 100.
- focal_length = 100.
- radius = 10.
- height = 4.
- width = 3.
+ wavelength = torch.linspace(330 * 1e-6, 660 * 1e-6, 5)
+ waist_radius = 2.0
+ distance = 100.0
+ focal_length = 100.0
+ radius = 10.0
+ height = 4.0
+ width = 3.0
if torch.get_default_dtype() == torch.float64:
default_complex_dtype = torch.complex128
@@ -42,81 +39,60 @@ def test_types(default_type: torch.dtype):
params = SimulationParameters(
axes={
- 'W': torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes),
- 'H': torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes),
- 'wavelength': wavelength
- }
+ "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes),
+ "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes),
+ "wavelength": wavelength,
+ }
)
x_linear = params.axes.W
y_linear = params.axes.H
wavelength = params.axes.wavelength
- x_grid, y_grid = params.meshgrid(x_axis='W', y_axis='H')
+ x_grid, y_grid = params.meshgrid(x_axis="W", y_axis="H")
gaussian_beam = w.gaussian_beam(
- simulation_parameters=params,
- waist_radius=waist_radius,
- distance=distance
+ simulation_parameters=params, waist_radius=waist_radius, distance=distance
)
- plane_wave = w.plane_wave(
- simulation_parameters=params,
- distance=distance
- )
+ plane_wave = w.plane_wave(simulation_parameters=params, distance=distance)
- spherical_wave = w.spherical_wave(
- simulation_parameters=params,
- distance=distance
- )
+ spherical_wave = w.spherical_wave(simulation_parameters=params, distance=distance)
lens = elements.ThinLens(
- simulation_parameters=params,
- focal_length=focal_length,
- radius=radius
+ simulation_parameters=params, focal_length=focal_length, radius=radius
).get_transmission_function()
aperture = elements.Aperture(
- simulation_parameters=params,
- mask=torch.zeros(x_grid.shape)
+ simulation_parameters=params, mask=torch.zeros(x_grid.shape)
).get_transmission_function()
rectangular_aperture = elements.RectangularAperture(
- simulation_parameters=params,
- height=height,
- width=width
+ simulation_parameters=params, height=height, width=width
).get_transmission_function()
round_aperture = elements.RoundAperture(
- simulation_parameters=params,
- radius=radius
+ simulation_parameters=params, radius=radius
).get_transmission_function()
slm = elements.SpatialLightModulator(
- simulation_parameters=params,
- mask=torch.ones_like(x_grid),
- height=8,
- width=9
+ simulation_parameters=params, mask=torch.ones_like(x_grid), height=8, width=9
).transmission_function
layer = elements.DiffractiveLayer(
- simulation_parameters=params,
- mask=torch.zeros(x_grid.shape)
+ simulation_parameters=params, mask=torch.zeros(x_grid.shape)
).transmission_function
free_space_as = elements.FreeSpace(
- simulation_parameters=params,
- distance=distance, method='AS'
+ simulation_parameters=params, distance=distance, method="AS"
)(gaussian_beam)
free_space_fresnel = elements.FreeSpace(
- simulation_parameters=params,
- distance=distance, method='fresnel'
+ simulation_parameters=params, distance=distance, method="fresnel"
)(gaussian_beam)
free_space_reverse = elements.FreeSpace(
- simulation_parameters=params,
- distance=distance, method='fresnel'
+ simulation_parameters=params, distance=distance, method="fresnel"
).reverse(transmission_wavefront=gaussian_beam)
default_type = torch.get_default_dtype()
diff --git a/tests/test_units.py b/tests/test_units.py
index 9c8eddc..fb25a4f 100644
--- a/tests/test_units.py
+++ b/tests/test_units.py
@@ -5,7 +5,7 @@
@pytest.mark.parametrize(
- 'other',
+ "other",
(
123,
1.234,
@@ -15,27 +15,36 @@
np.array(123),
np.array(1.234),
np.array([[1.23, 4.56]]),
- )
+ ),
)
def test_arithmetics(other):
- torch.testing.assert_close(
- other * ureg.mm, other * ureg.mm.value
- )
- torch.testing.assert_close(
- ureg.mm * other, other * ureg.mm.value
- )
- torch.testing.assert_close(
- other / ureg.mm, other / ureg.mm.value
- )
- torch.testing.assert_close(
- ureg.mm / other, ureg.mm.value / other
- )
- torch.testing.assert_close(
- ureg.mm ** other, ureg.mm.value ** other
- )
+ """
+ Tests arithmetic operations with the unit 'mm'.
+
+ This function checks if basic arithmetic operations (multiplication, division, and exponentiation)
+ between a given value and the 'mm' unit from astropy.units produce the expected results when compared to
+ the underlying numerical value of the unit. It tests both left-hand and right-hand side operations.
+
+ Args:
+ other: The value to perform arithmetic with. Can be an integer, float, torch tensor or numpy array.
+
+ Returns:
+ None: This function only performs assertions and does not return a value.
+ """
+ torch.testing.assert_close(other * ureg.mm, other * ureg.mm.value)
+ torch.testing.assert_close(ureg.mm * other, other * ureg.mm.value)
+ torch.testing.assert_close(other / ureg.mm, other / ureg.mm.value)
+ torch.testing.assert_close(ureg.mm / other, ureg.mm.value / other)
+ torch.testing.assert_close(ureg.mm**other, ureg.mm.value**other)
def test_array_api():
+ """
+ Tests array API compatibility with pint and numpy.
+
+ This tests that adding a pint Quantity to a NumPy array results in a NumPy array,
+ and that attempting to use __array__ with copy=False on a pint unit raises a ValueError.
+ """
assert isinstance(ureg.m + np.array([0.0]), np.ndarray)
with pytest.raises(ValueError):
diff --git a/tests/test_visualization.py b/tests/test_visualization.py
index 726b06f..693f694 100644
--- a/tests/test_visualization.py
+++ b/tests/test_visualization.py
@@ -12,38 +12,60 @@
def test_html_element():
+ """
+ Tests that the HTML representation of a FreeSpace element is generated."""
sim_params = svetlanna.SimulationParameters(
- {'W': torch.tensor([0]), 'H': torch.tensor([0]), 'wavelength': 1}
+ {"W": torch.tensor([0]), "H": torch.tensor([0]), "wavelength": 1}
)
- element = svetlanna.elements.FreeSpace(sim_params, distance=1, method='AS')
+ element = svetlanna.elements.FreeSpace(sim_params, distance=1, method="AS")
assert element._repr_html_()
def test_default_widget_html_method():
- assert default_widget_html_method(123, 'test', 'element_type', [])
+ """
+ Tests the default widget HTML method.
+
+ Args:
+ None
+
+ Returns:
+ None
+ """
+ assert default_widget_html_method(123, "test", "element_type", [])
def test_generate_structure_html():
+ """
+ Tests the generation of HTML structure for a simple simulation element.
+
+ This test creates a basic simulation setup with a FreeSpace element and
+ a nested NoWidgetHTMLElement, then asserts that generate_structure_html
+ returns without errors when given this structure.
+
+ Args:
+ None
+
+ Returns:
+ None
+ """
sim_params = svetlanna.SimulationParameters(
- {'W': torch.tensor([0]), 'H': torch.tensor([0]), 'wavelength': 1}
+ {"W": torch.tensor([0]), "H": torch.tensor([0]), "wavelength": 1}
)
- element = svetlanna.elements.FreeSpace(sim_params, distance=1, method='AS')
+ element = svetlanna.elements.FreeSpace(sim_params, distance=1, method="AS")
class NoWidgetHTMLElement:
def to_specs(self):
return []
assert generate_structure_html(
- [
- _ElementInTree(element, 0, [
- _ElementInTree(NoWidgetHTMLElement(), 0, [])
- ])
- ]
+ [_ElementInTree(element, 0, [_ElementInTree(NoWidgetHTMLElement(), 0, [])])]
)
def test_show_structure(monkeypatch):
+ """
+ Tests the show_structure function's behavior with and without IPython."""
import IPython.display
# monkeypatch IPython.display.display
@@ -53,7 +75,7 @@ def set_displayed():
nonlocal displayed
displayed = True
- monkeypatch.setattr(IPython.display, 'display', lambda _: set_displayed())
+ monkeypatch.setattr(IPython.display, "display", lambda _: set_displayed())
# Test if the HTML has been displayed
displayed = False
@@ -79,50 +101,83 @@ def import_with_no_ipython(name, *args, **kwargs):
def test_show_specs():
+ """
+ Tests the show_specs function.
+
+ This test creates a simple simulation setup with a FreeSpace element and
+ verifies that the show_specs function returns a SpecsWidget containing
+ information about the element.
+
+ Args:
+ None
+
+ Returns:
+ None
+ """
sim_params = svetlanna.SimulationParameters(
- {'W': torch.tensor([0]), 'H': torch.tensor([0]), 'wavelength': 1}
+ {"W": torch.tensor([0]), "H": torch.tensor([0]), "wavelength": 1}
)
- element = svetlanna.elements.FreeSpace(sim_params, distance=1, method='AS')
+ element = svetlanna.elements.FreeSpace(sim_params, distance=1, method="AS")
widget = show_specs(element)
assert isinstance(widget, SpecsWidget)
assert len(widget.elements) == 1
- assert widget.elements[0]['name'] == 'FreeSpace'
+ assert widget.elements[0]["name"] == "FreeSpace"
def test_draw_wavefront():
+ """
+ Tests the draw_wavefront function with different type combinations.
+
+ This test creates a simple plane wavefront and then calls draw_wavefront
+ with single types and all available types to ensure it functions correctly
+ for various plotting configurations.
+
+ Args:
+ None
+
+ Returns:
+ bool: True if all assertions pass, indicating the function works as expected.
+ """
sim_params = svetlanna.SimulationParameters(
{
- 'W': torch.linspace(-1, 1, 10),
- 'H': torch.linspace(-1, 1, 10),
- 'wavelength': 1
+ "W": torch.linspace(-1, 1, 10),
+ "H": torch.linspace(-1, 1, 10),
+ "wavelength": 1,
}
)
wavefront = svetlanna.Wavefront.plane_wave(sim_params)
# Single type
- types = ('A', 'I', 'phase', 'Re', 'Im')
+ types = ("A", "I", "phase", "Re", "Im")
for t in types:
- assert draw_wavefront(
- wavefront,
- sim_params,
- types_to_plot=(t,)
- )
+ assert draw_wavefront(wavefront, sim_params, types_to_plot=(t,))
# All types
- assert draw_wavefront(
- wavefront,
- sim_params,
- types_to_plot=types
- )
+ assert draw_wavefront(wavefront, sim_params, types_to_plot=types)
def test_show_stepwise_forward():
+ """
+ Tests the show_stepwise_forward function with various elements.
+
+ This test creates a simulation setup with different optical elements,
+ including a valid FreeSpace element, an element that returns None, and
+ an element that returns a tensor instead of an image. It then asserts
+ that the resulting widget is a StepwiseForwardWidget, contains all three
+ elements, and correctly represents their outputs in JSON format.
+
+ Args:
+ None
+
+ Returns:
+ None
+ """
sim_params = svetlanna.SimulationParameters(
{
- 'W': torch.linspace(-1, 1, 10),
- 'H': torch.linspace(-1, 1, 10),
- 'wavelength': 1
+ "W": torch.linspace(-1, 1, 10),
+ "H": torch.linspace(-1, 1, 10),
+ "wavelength": 1,
}
)
@@ -135,36 +190,32 @@ def to_specs(self):
class WrongTensorForwardElement(torch.nn.Module):
def forward(self, x):
- return torch.tensor([1, 2, 3.])
+ return torch.tensor([1, 2, 3.0])
def to_specs(self):
return []
- element1 = svetlanna.elements.FreeSpace(sim_params, distance=1, method='AS')
+ element1 = svetlanna.elements.FreeSpace(sim_params, distance=1, method="AS")
element2 = NoneForwardElement()
element3 = WrongTensorForwardElement()
wavefront = svetlanna.Wavefront.plane_wave(sim_params)
widget = show_stepwise_forward(
- element1,
- element2,
- element3,
- input=wavefront,
- simulation_parameters=sim_params
+ element1, element2, element3, input=wavefront, simulation_parameters=sim_params
)
assert isinstance(widget, StepwiseForwardWidget)
assert len(widget.elements) == 3
element1_json = widget.elements[0]
- assert element1_json['name'] == 'FreeSpace'
- assert element1_json['output_image']
+ assert element1_json["name"] == "FreeSpace"
+ assert element1_json["output_image"]
element2_json = widget.elements[1]
- assert element2_json['name'] == 'NoneForwardElement'
- assert element2_json['output_image'] is None
+ assert element2_json["name"] == "NoneForwardElement"
+ assert element2_json["output_image"] is None
element3_json = widget.elements[2]
- assert element3_json['name'] == 'WrongTensorForwardElement'
- assert element3_json['output_image'][:1] == '\n'
+ assert element3_json["name"] == "WrongTensorForwardElement"
+ assert element3_json["output_image"][:1] == "\n"
diff --git a/tests/test_wavefront.py b/tests/test_wavefront.py
index d576159..e77b9a5 100644
--- a/tests/test_wavefront.py
+++ b/tests/test_wavefront.py
@@ -4,13 +4,23 @@
def test_creation():
- wf = Wavefront(1.)
+ """
+ Tests the creation of Wavefront objects with various inputs.
+
+ This method checks that a Wavefront object can be successfully initialized
+ with different types of input data (float, complex number, list of complex numbers, and torch tensor)
+ and verifies that the resulting object is a PyTorch tensor and an instance of the Wavefront class.
+
+ Returns:
+ None
+ """
+ wf = Wavefront(1.0)
assert isinstance(wf, torch.Tensor)
- wf = Wavefront(1. + 1.j)
+ wf = Wavefront(1.0 + 1.0j)
assert isinstance(wf, torch.Tensor)
- wf = Wavefront([1 + 2.j])
+ wf = Wavefront([1 + 2.0j])
assert isinstance(wf, torch.Tensor)
data = torch.tensor([1, 2, 3])
@@ -20,15 +30,19 @@ def test_creation():
@pytest.mark.parametrize(
- ('a', 'b'), [
- (1., 2.),
- (1., 1.,),
- (-1., 1.3)
- ]
+ ("a", "b"),
+ [
+ (1.0, 2.0),
+ (
+ 1.0,
+ 1.0,
+ ),
+ (-1.0, 1.3),
+ ],
)
def test_intensity(a: float, b: float):
"""Test intensity calculations"""
- wf = Wavefront([a + 1j*b])
+ wf = Wavefront([a + 1j * b])
real_intensity = torch.tensor([a**2 + b**2])
torch.testing.assert_close(wf.intensity, real_intensity)
@@ -43,25 +57,40 @@ def test_intensity(a: float, b: float):
@pytest.mark.parametrize(
- ('r', 'phi'), [
- (1., 0.),
- (1., [1.]),
- (10., [1., 2., 3.])
- ]
+ ("r", "phi"), [(1.0, 0.0), (1.0, [1.0]), (10.0, [1.0, 2.0, 3.0])]
)
def test_phase(r, phi):
+ """
+ Tests that the wavefront phase is correctly initialized.
+
+ Args:
+ r: The radius of the wavefront.
+ phi: The initial phase values.
+
+ Returns:
+ None: This function asserts a condition and does not return a value.
+ """
wf = Wavefront(r * torch.exp(1j * torch.tensor(phi)))
torch.testing.assert_close(wf.phase, torch.tensor(phi))
-@pytest.mark.parametrize('waist_radius', (1, 0.5, 0.2))
+@pytest.mark.parametrize("waist_radius", (1, 0.5, 0.2))
def test_fwhm(waist_radius):
+ """
+ Tests the full width at half maximum (FWHM) calculation for a Gaussian beam.
+
+ Args:
+ waist_radius: The waist radius of the Gaussian beam.
+
+ Returns:
+ None: This function asserts properties of the FWHM and does not return a value.
+ """
sim_params = SimulationParameters(
{
- 'W': torch.linspace(-1, 1, 1000),
- 'H': torch.linspace(-1, 1, 1000),
- 'wavelength': 1
+ "W": torch.linspace(-1, 1, 1000),
+ "H": torch.linspace(-1, 1, 1000),
+ "wavelength": 1,
}
)
@@ -73,21 +102,32 @@ def test_fwhm(waist_radius):
assert wf.fwhm(sim_params)[0] == wf.fwhm(sim_params)[1]
torch.testing.assert_close(
torch.tensor(wf.fwhm(sim_params)[0]),
- torch.sqrt(2*torch.log(torch.tensor(2.))) * waist_radius,
+ torch.sqrt(2 * torch.log(torch.tensor(2.0))) * waist_radius,
rtol=0.001,
atol=0.01,
)
-@pytest.mark.parametrize('distance', (1, 1.23, 1e-4, 1e4))
-@pytest.mark.parametrize('wavelength', (1.0, torch.tensor([1.23, 20])))
-@pytest.mark.parametrize('initial_phase', (1.0, 123, 2e-4))
+@pytest.mark.parametrize("distance", (1, 1.23, 1e-4, 1e4))
+@pytest.mark.parametrize("wavelength", (1.0, torch.tensor([1.23, 20])))
+@pytest.mark.parametrize("initial_phase", (1.0, 123, 2e-4))
def test_plane_wave(distance, wavelength, initial_phase):
+ """
+ Tests the plane_wave method of the Wavefront class.
+
+ Args:
+ distance: The distance to propagate the plane wave.
+ wavelength: The wavelength of the plane wave.
+ initial_phase: The initial phase of the plane wave.
+
+ Returns:
+ None. This function asserts properties of the generated Wavefront object.
+ """
sim_params = SimulationParameters(
{
- 'W': torch.linspace(-0.1, 2, 10),
- 'H': torch.linspace(-1, 5, 20),
- 'wavelength': wavelength
+ "W": torch.linspace(-0.1, 2, 10),
+ "H": torch.linspace(-1, 5, 20),
+ "wavelength": wavelength,
}
)
k = 2 * torch.pi / sim_params.axes.wavelength
@@ -99,11 +139,9 @@ def test_plane_wave(distance, wavelength, initial_phase):
assert isinstance(wf, Wavefront)
torch.allclose(
wf.angle(),
- torch.exp(1j * (k * distance + initial_phase)[..., None, None]).angle()
- )
- torch.allclose(
- wf.abs(), torch.tensor(1.)
+ torch.exp(1j * (k * distance + initial_phase)[..., None, None]).angle(),
)
+ torch.allclose(wf.abs(), torch.tensor(1.0))
# x,y propagation
dir_x = 0.1312234
@@ -113,18 +151,18 @@ def test_plane_wave(distance, wavelength, initial_phase):
x = sim_params.axes.W[None, :]
y = sim_params.axes.H[:, None]
wf = Wavefront.plane_wave(
- sim_params, distance=distance, wave_direction=[dir_x, dir_y, 0],
- initial_phase=initial_phase
+ sim_params,
+ distance=distance,
+ wave_direction=[dir_x, dir_y, 0],
+ initial_phase=initial_phase,
)
torch.allclose(
wf.angle(),
- torch.exp(1j * (
- kx[..., None, None] * x + ky[..., None, None] * y + initial_phase
- )).angle()
- )
- torch.allclose(
- wf.abs(), torch.tensor(1.)
+ torch.exp(
+ 1j * (kx[..., None, None] * x + ky[..., None, None] * y + initial_phase)
+ ).angle(),
)
+ torch.allclose(wf.abs(), torch.tensor(1.0))
# Test wrong wave direction
with pytest.raises(ValueError):
@@ -134,22 +172,31 @@ def test_plane_wave(distance, wavelength, initial_phase):
# TODO: Test Gaussian beam against precomputed values
-@pytest.mark.parametrize('distance', (1, 1.23, 1e-4, 1e4))
-@pytest.mark.parametrize('waist_radius', (1, 1.23, 1e-4, 1e4))
-@pytest.mark.parametrize('dx', (1.0, 123, 2e-4))
-@pytest.mark.parametrize('dy', (1.0, 123, 2e-4))
-@pytest.mark.parametrize(
- 'wavelength', (
- 1.0,
- torch.tensor([1.23, 20])
- )
-)
+@pytest.mark.parametrize("distance", (1, 1.23, 1e-4, 1e4))
+@pytest.mark.parametrize("waist_radius", (1, 1.23, 1e-4, 1e4))
+@pytest.mark.parametrize("dx", (1.0, 123, 2e-4))
+@pytest.mark.parametrize("dy", (1.0, 123, 2e-4))
+@pytest.mark.parametrize("wavelength", (1.0, torch.tensor([1.23, 20])))
def test_gaussian_beam(distance, waist_radius, dx, dy, wavelength):
+ """
+ Tests the gaussian_beam method with various parameters.
+
+ Args:
+ distance: The distance to propagate the beam.
+ waist_radius: The radius of the Gaussian beam at its waist.
+ dx: Offset in x direction.
+ dy: Offset in y direction.
+ wavelength: The wavelength of the light. Can be a float or a torch tensor.
+
+ Returns:
+ None: This test does not return any value; it asserts that the
+ gaussian_beam method runs without errors for given parameters.
+ """
sim_params = SimulationParameters(
{
- 'W': torch.linspace(-0.1, 2, 10),
- 'H': torch.linspace(-1, 5, 20),
- 'wavelength': wavelength
+ "W": torch.linspace(-0.1, 2, 10),
+ "H": torch.linspace(-1, 5, 20),
+ "wavelength": wavelength,
}
)
# Stupid test
@@ -159,22 +206,31 @@ def test_gaussian_beam(distance, waist_radius, dx, dy, wavelength):
# TODO: Test spherical wave against precomputed values
-@pytest.mark.parametrize('distance', (1, 1.23, 1e-4, 1e4))
-@pytest.mark.parametrize('initial_phase', (1, 1.23, 1e-4, 1e4))
-@pytest.mark.parametrize('dx', (1.0, 123, 2e-4))
-@pytest.mark.parametrize('dy', (1.0, 123, 2e-4))
-@pytest.mark.parametrize(
- 'wavelength', (
- 1.0,
- torch.tensor([1.23, 20])
- )
-)
+@pytest.mark.parametrize("distance", (1, 1.23, 1e-4, 1e4))
+@pytest.mark.parametrize("initial_phase", (1, 1.23, 1e-4, 1e4))
+@pytest.mark.parametrize("dx", (1.0, 123, 2e-4))
+@pytest.mark.parametrize("dy", (1.0, 123, 2e-4))
+@pytest.mark.parametrize("wavelength", (1.0, torch.tensor([1.23, 20])))
def test_spherical_wave(distance, initial_phase, dx, dy, wavelength):
+ """
+ Tests the spherical wave function with various parameters.
+
+ Args:
+ distance: The distance from the source of the spherical wave.
+ initial_phase: The initial phase of the wave.
+ dx: The x-coordinate offset.
+ dy: The y-coordinate offset.
+ wavelength: The wavelength of the wave.
+
+ Returns:
+ None: This function does not return a value; it asserts that the
+ spherical_wave function runs without errors for given parameters.
+ """
sim_params = SimulationParameters(
{
- 'W': torch.linspace(-0.1, 2, 10),
- 'H': torch.linspace(-1, 5, 20),
- 'wavelength': wavelength
+ "W": torch.linspace(-0.1, 2, 10),
+ "H": torch.linspace(-1, 5, 20),
+ "wavelength": wavelength,
}
)
# Stupid test
@@ -184,6 +240,20 @@ def test_spherical_wave(distance, initial_phase, dx, dy, wavelength):
def test_wavefront_as_a_tensor():
+ """
+ Tests that Wavefront operations with tensors return a Wavefront object.
+
+ This method creates a Wavefront object from a random tensor and then performs
+ various arithmetic operations (addition, multiplication, division) between the
+ Wavefront object and the original tensor. It asserts that the result of each
+ operation is also a Wavefront object.
+
+ Args:
+ None
+
+ Returns:
+ None
+ """
tensor = torch.rand((2, 10, 20))
wf = Wavefront(tensor)
diff --git a/visualization.ipynb b/visualization.ipynb
index 7dc7868..5118778 100644
--- a/visualization.ipynb
+++ b/visualization.ipynb
@@ -1,1071 +1,1071 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "%%html\n",
- ""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "import svetlanna\n",
- "import torch"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "simulation_parameters = svetlanna.SimulationParameters(\n",
- " {\n",
- " 'W': torch.linspace(-1, 1, 100),\n",
- " 'H': torch.linspace(-1, 1, 100),\n",
- " 'wavelength': 2e-1\n",
- " }\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "
\n",
- " mask\n",
- "
\n",
- " \n",
- "
\n",
- "
Tensor of size (100x100)\n",
- "
\n",
- "
\n",
- " \n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- " \n",
- "
\n",
- " mask_norm\n",
- "
\n",
- " \n",
- "
\n",
- "
6.283185307179586\n",
- "
\n",
- "
\n",
- "
"
- ],
- "text/plain": [
- "DiffractiveLayer()"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "svetlanna.elements.DiffractiveLayer(simulation_parameters=simulation_parameters, mask=torch.rand((100, 100)))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 34,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/Users/vigos/Documents/GitHub/SVETlANNa/svetlanna/elements/free_space.py:151: UserWarning: The paraxial (near-axis) optics condition required for the Fresnel method is not satisfied. Consider increasing the distance or decreasing the screen size.\n",
- " warn(\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "06fb527359f844baaf74b453f21a3425",
- "version_major": 2,
- "version_minor": 1
- },
- "text/plain": [
- "LinearOpticalSetupWidget(elements=[{'index': 0, 'type': 'ThinLens', 'specs_html': '\n",
- "
\n",
- " feedback_gain\n",
- "
\n",
- " \n",
- "
\n",
- "
0.1\n",
- "
\n",
- "
\n",
- " \n",
- "
\n",
- " input_gain\n",
- "
\n",
- " \n",
- "
\n",
- "
0.2\n",
- "
\n",
- "
\n",
- " \n",
- "
\n",
- " delay\n",
- "
\n",
- " \n",
- "
\n",
- "
10\n",
- "
\n",
- "
\n",
- "
[Nonlinear element] LinearOpticalSetup
[0] ThinLens
\n",
- "
\n",
- " focal_length\n",
- "
\n",
- " \n",
- "
\n",
- "
1\n",
- "
\n",
- "
\n",
- " \n",
- "
\n",
- " radius\n",
- "
\n",
- " \n",
- "
\n",
- "
inf\n",
- "
\n",
- "
\n",
- "
[1] FreeSpace
\n",
- "
\n",
- " distance\n",
- "
\n",
- " \n",
- "
\n",
- "
1\n",
- "
\n",
- "
\n",
- "
[2] ThinLens
\n",
- "
\n",
- " focal_length\n",
- "
\n",
- " \n",
- "
\n",
- "
1\n",
- "
\n",
- "
\n",
- " \n",
- "
\n",
- " radius\n",
- "
\n",
- " \n",
- "
\n",
- "
inf\n",
- "
\n",
- "
\n",
- "
[Delay element] FreeSpace
\n",
- "
\n",
- " distance\n",
- "
\n",
- " \n",
- "
\n",
- "
1\n",
- "
\n",
- "
\n",
- "
"
- ],
- "text/plain": [
- "SimpleReservoir(\n",
- " (delay_element): FreeSpace()\n",
- ")"
- ]
- },
- "execution_count": 39,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "system1 = svetlanna.LinearOpticalSetup([\n",
- " svetlanna.elements.ThinLens(simulation_parameters=simulation_parameters, focal_length=1),\n",
- " svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),\n",
- " svetlanna.elements.ThinLens(simulation_parameters=simulation_parameters, focal_length=1),\n",
- "])\n",
- "\n",
- "reservoir = svetlanna.elements.reservoir.SimpleReservoir(\n",
- " simulation_parameters,\n",
- " system1,\n",
- " # system1,\n",
- " svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),\n",
- " 0.1,\n",
- " 0.2,\n",
- " 10\n",
- ")\n",
- "reservoir"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 40,
- "metadata": {},
- "outputs": [],
- "source": [
- "system2 = svetlanna.LinearOpticalSetup([\n",
- " svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),\n",
- " reservoir,\n",
- " svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),\n",
- "])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 41,
- "metadata": {},
- "outputs": [],
- "source": [
- "from svetlanna.specs.specs_writer import _ElementsIterator, _ElementInTree\n",
- "from svetlanna.specs import Specsable\n",
- "from dataclasses import dataclass\n",
- "\n",
- "from IPython.core.display import display_html\n",
- "from jinja2 import Environment, FileSystemLoader, select_autoescape\n",
- "\n",
- "jinja_env = Environment(\n",
- " loader=FileSystemLoader(\"templates\"),\n",
- " autoescape=select_autoescape()\n",
- ")\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": 46,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "
\n",
- " (0) LinearOpticalSetup\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
(1) FreeSpace\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- "\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- "\n",
- "
\n",
- " β\n",
- "
\n",
- "\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
(2) SimpleReservoir\n",
- "
\n",
- "
\n",
- " ββββββββ¨ Delay el. β βββββββ\n",
- " ββ ββ\n",
- "ββββββββ¨ Nonlinear el. β βββββ΄ββ\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- "\n",
- "
\n",
- " Nonlinear element\n",
- "
\n",
- "
\n",
- "
\n",
- " (3) LinearOpticalSetup\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
(4) ThinLens\n",
- "
\n",
- "
\n",
- "ββββ\n",
- "ββββ \n",
- "ββββ\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- "\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- "\n",
- "
\n",
- " β\n",
- "
\n",
- "\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
(5) FreeSpace\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- "\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- "\n",
- "
\n",
- " β\n",
- "
\n",
- "\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
(6) ThinLens\n",
- "
\n",
- "
\n",
- "ββββ\n",
- "ββββ \n",
- "ββββ\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- "\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- "\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- "
\n",
- " Delay element\n",
- "
\n",
- "
\n",
- "
\n",
- "
(7) FreeSpace\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- "\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- "\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- "\n",
- "
\n",
- " β\n",
- "
\n",
- "\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
(8) FreeSpace\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- "\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "\n",
- "\n",
- "
\n",
- "
\n",
- "
"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "\n",
- "\n",
- "@dataclass(frozen=True, slots=True)\n",
- "class ElementHTML:\n",
- " element_type: str | None\n",
- " html: str\n",
- "\n",
- "\n",
- "def _widget_html_(\n",
- " index: int,\n",
- " name: str,\n",
- " element_type: str | None,\n",
- " subelements: list[ElementHTML]\n",
- ") -> str:\n",
- " return jinja_env.get_template('default_widget.html.jinja').render(\n",
- " index=index, name=name, subelements=subelements\n",
- " )\n",
- "\n",
- "\n",
- "def _ls_widget_html_(\n",
- " index: int,\n",
- " name: str,\n",
- " element_type: str | None,\n",
- " subelements: list[ElementHTML]\n",
- ") -> str:\n",
- " return jinja_env.get_template('linear_setup_widget.html.jinja').render(\n",
- " index=index, name=name, subelements=subelements\n",
- " )\n",
- "\n",
- "\n",
- "def _fs_widget_html_(\n",
- " index: int,\n",
- " name: str,\n",
- " element_type: str | None,\n",
- " subelements: list[ElementHTML]\n",
- ") -> str:\n",
- " return jinja_env.get_template('free_space_widget.html.jinja').render(\n",
- " index=index, name=name, subelements=subelements\n",
- " )\n",
- "\n",
- "\n",
- "def _rs_widget_html_(\n",
- " index: int,\n",
- " name: str,\n",
- " element_type: str | None,\n",
- " subelements: list[ElementHTML]\n",
- ") -> str:\n",
- " return jinja_env.get_template('reservoir_widget.html.jinja').render(\n",
- " index=index, name=name, subelements=subelements\n",
- " )\n",
- "\n",
- "\n",
- "def _l_widget_html_(\n",
- " index: int,\n",
- " name: str,\n",
- " element_type: str | None,\n",
- " subelements: list[ElementHTML]\n",
- ") -> str:\n",
- " return jinja_env.get_template('lens_widget.html.jinja').render(\n",
- " index=index, name=name, subelements=subelements\n",
- " )\n",
- "\n",
- "\n",
- "def _get_widget_html_method(element: Specsable):\n",
- " if hasattr(element, '_widget_html_'):\n",
- " widget_html_method = getattr(element, '_widget_html_')\n",
- " else:\n",
- " widget_html_method = _widget_html_\n",
- "\n",
- " if isinstance(element, svetlanna.LinearOpticalSetup):\n",
- " widget_html_method = _ls_widget_html_\n",
- "\n",
- " if isinstance(element, svetlanna.elements.FreeSpace):\n",
- " widget_html_method = _fs_widget_html_\n",
- " \n",
- " if isinstance(element, svetlanna.elements.SimpleReservoir):\n",
- " widget_html_method = _rs_widget_html_\n",
- " \n",
- " if isinstance(element, svetlanna.elements.ThinLens):\n",
- " widget_html_method = _l_widget_html_\n",
- "\n",
- " return widget_html_method\n",
- "\n",
- "\n",
- "def _subelements_html(subelements: list[_ElementInTree]) -> list[ElementHTML]:\n",
- " res = []\n",
- "\n",
- " for subelement in subelements:\n",
- " widget_html_method = _get_widget_html_method(subelement.element)\n",
- " try:\n",
- " res.append(\n",
- " ElementHTML(\n",
- " subelement.subelement_type,\n",
- " html=widget_html_method(\n",
- " index=subelement.element_index,\n",
- " name=subelement.element.__class__.__name__,\n",
- " element_type=subelement.subelement_type,\n",
- " subelements=_subelements_html(subelement.children)\n",
- " )\n",
- " )\n",
- " )\n",
- " except Exception as e:\n",
- " pass\n",
- "\n",
- " return res\n",
- "\n",
- "\n",
- "elements = _ElementsIterator(system2, directory='')\n",
- "\n",
- "for _, _, i in elements:\n",
- " for _ in i:\n",
- " pass\n",
- "\n",
- "res = _subelements_html(elements.tree)\n",
- "\n",
- "\n",
- "containered_html = f'{res[0].html}
'\n",
- "display_html(containered_html, raw=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [],
- "source": [
- "e = svetlanna.specs.specs_writer.write_specs(system2, filename='specs.md')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[_ElementInTree(element=, element_index=3, children=[_ElementInTree(element=ThinLens(), element_index=4, children=[], subelement_name='0'), _ElementInTree(element=FreeSpace(), element_index=5, children=[], subelement_name='1'), _ElementInTree(element=ThinLens(), element_index=6, children=[], subelement_name='2')], subelement_name='Nonlinear element'),\n",
- " _ElementInTree(element=, element_index=3, children=[_ElementInTree(element=ThinLens(), element_index=4, children=[], subelement_name='0'), _ElementInTree(element=FreeSpace(), element_index=5, children=[], subelement_name='1'), _ElementInTree(element=ThinLens(), element_index=6, children=[], subelement_name='2')], subelement_name='Delay element')]"
- ]
- },
- "execution_count": 11,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "e.tree[0].children[1].children"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [
- {
- "ename": "Exception",
- "evalue": "",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mException\u001b[0m Traceback (most recent call last)",
- "Cell \u001b[0;32mIn[12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m\n",
- "\u001b[0;31mException\u001b[0m: "
- ]
- }
- ],
- "source": [
- "raise Exception"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "'Nonlinear element'"
- ]
- },
- "execution_count": 22,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "e._tree[0].children[1].children[0].element_name"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "ElementInTree(element=, element_index=0, children=[ElementInTree(element=FreeSpace(), element_index=1, children=[]), ElementInTree(element=SimpleReservoir(), element_index=2, children=[ElementInTree(element=, element_index=3, children=[ElementInTree(element=ThinLens(), element_index=4, children=[]), ElementInTree(element=FreeSpace(), element_index=5, children=[]), ElementInTree(element=ThinLens(), element_index=6, children=[])]), ElementInTree(element=, element_index=7, children=[ElementInTree(element=ThinLens(), element_index=8, children=[]), ElementInTree(element=FreeSpace(), element_index=9, children=[]), ElementInTree(element=ThinLens(), element_index=10, children=[])])]), ElementInTree(element=FreeSpace(), element_index=11, children=[])])\n"
- ]
- }
- ],
- "source": [
- "print('\\n'.join([str(i) for i in e._tree]))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# torch.set_default_dtype(torch.float32)\n",
- "# Image.fromarray(torch.tensor(a).to(torch.float64).numpy(), mode='L').show()\n",
- "# Image.fromarray(np.uint8(255*torch.tensor(a).numpy()), mode='L').show() # <- works"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import torch"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "False"
- ]
- },
- "execution_count": 97,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "torch.tensor([[1, 1,], [1, 2]]).size() < torch.tensor([1,]).size()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import svetlanna\n",
- "import svetlanna.elements\n",
- "\n",
- "\n",
- "class A(svetlanna.elements.Element):\n",
- " def __init__(self, simulation_parameters: svetlanna.SimulationParameters) -> None:\n",
- " super().__init__(simulation_parameters)\n",
- "\n",
- " self.a = self.make_buffer('a', self.simulation_parameters.axes.W)\n",
- "\n",
- " def forward(self, input_field: svetlanna.Wavefront) -> svetlanna.Wavefront:\n",
- " pass"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "simulation_parameters = svetlanna.SimulationParameters(\n",
- " {\n",
- " 'W': torch.linspace(-1, 1, 10),\n",
- " 'H': torch.linspace(-1, 1, 10),\n",
- " 'wavelength': 10\n",
- " }\n",
- ")\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "a = A(simulation_parameters=simulation_parameters)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- ""
- ],
- "text/plain": [
- "A()"
- ]
- },
- "execution_count": 108,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "a.to('mps')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "device(type='cpu')"
- ]
- },
- "execution_count": 114,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "simulation_parameters.axes.W.device"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "device(type='mps', index=0)"
- ]
- },
- "execution_count": 113,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "a.a.device"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.11.9"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "%%html\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import svetlanna\n",
+ "import torch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "simulation_parameters = svetlanna.SimulationParameters(\n",
+ " {\n",
+ " 'W': torch.linspace(-1, 1, 100),\n",
+ " 'H': torch.linspace(-1, 1, 100),\n",
+ " 'wavelength': 2e-1\n",
+ " }\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n",
+ " mask\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
Tensor of size (100x100)\n",
+ "
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ " mask_norm\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
6.283185307179586\n",
+ "
\n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ "DiffractiveLayer()"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "svetlanna.elements.DiffractiveLayer(simulation_parameters=simulation_parameters, mask=torch.rand((100, 100)))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/vigos/Documents/GitHub/SVETlANNa/svetlanna/elements/free_space.py:151: UserWarning: The paraxial (near-axis) optics condition required for the Fresnel method is not satisfied. Consider increasing the distance or decreasing the screen size.\n",
+ " warn(\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "06fb527359f844baaf74b453f21a3425",
+ "version_major": 2,
+ "version_minor": 1
+ },
+ "text/plain": [
+ "LinearOpticalSetupWidget(elements=[{'index': 0, 'type': 'ThinLens', 'specs_html': '\n",
+ "
\n",
+ " feedback_gain\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
0.1\n",
+ "
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ " input_gain\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
0.2\n",
+ "
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ " delay\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
10\n",
+ "
\n",
+ "
\n",
+ "
[Nonlinear element] LinearOpticalSetup
[0] ThinLens
\n",
+ "
\n",
+ " focal_length\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
1\n",
+ "
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ " radius\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
inf\n",
+ "
\n",
+ "
\n",
+ "
[1] FreeSpace
\n",
+ "
\n",
+ " distance\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
1\n",
+ "
\n",
+ "
\n",
+ "
[2] ThinLens
\n",
+ "
\n",
+ " focal_length\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
1\n",
+ "
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ " radius\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
inf\n",
+ "
\n",
+ "
\n",
+ "
[Delay element] FreeSpace
\n",
+ "
\n",
+ " distance\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
1\n",
+ "
\n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ "SimpleReservoir(\n",
+ " (delay_element): FreeSpace()\n",
+ ")"
+ ]
+ },
+ "execution_count": 39,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "system1 = svetlanna.LinearOpticalSetup([\n",
+ " svetlanna.elements.ThinLens(simulation_parameters=simulation_parameters, focal_length=1),\n",
+ " svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),\n",
+ " svetlanna.elements.ThinLens(simulation_parameters=simulation_parameters, focal_length=1),\n",
+ "])\n",
+ "\n",
+ "reservoir = svetlanna.elements.reservoir.SimpleReservoir(\n",
+ " simulation_parameters,\n",
+ " system1,\n",
+ " # system1,\n",
+ " svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),\n",
+ " 0.1,\n",
+ " 0.2,\n",
+ " 10\n",
+ ")\n",
+ "reservoir"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "system2 = svetlanna.LinearOpticalSetup([\n",
+ " svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),\n",
+ " reservoir,\n",
+ " svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),\n",
+ "])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from svetlanna.specs.specs_writer import _ElementsIterator, _ElementInTree\n",
+ "from svetlanna.specs import Specsable\n",
+ "from dataclasses import dataclass\n",
+ "\n",
+ "from IPython.core.display import display_html\n",
+ "from jinja2 import Environment, FileSystemLoader, select_autoescape\n",
+ "\n",
+ "jinja_env = Environment(\n",
+ " loader=FileSystemLoader(\"templates\"),\n",
+ " autoescape=select_autoescape()\n",
+ ")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n",
+ " (0) LinearOpticalSetup\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
(1) FreeSpace\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "\n",
+ "
\n",
+ " β\n",
+ "
\n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
(2) SimpleReservoir\n",
+ "
\n",
+ "
\n",
+ " ββββββββ¨ Delay el. β βββββββ\n",
+ " ββ ββ\n",
+ "ββββββββ¨ Nonlinear el. β βββββ΄ββ\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "\n",
+ "
\n",
+ " Nonlinear element\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " (3) LinearOpticalSetup\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
(4) ThinLens\n",
+ "
\n",
+ "
\n",
+ "ββββ\n",
+ "ββββ \n",
+ "ββββ\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "\n",
+ "
\n",
+ " β\n",
+ "
\n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
(5) FreeSpace\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "\n",
+ "
\n",
+ " β\n",
+ "
\n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
(6) ThinLens\n",
+ "
\n",
+ "
\n",
+ "ββββ\n",
+ "ββββ \n",
+ "ββββ\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "
\n",
+ " Delay element\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
(7) FreeSpace\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "\n",
+ "
\n",
+ " β\n",
+ "
\n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
(8) FreeSpace\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "\n",
+ "\n",
+ "@dataclass(frozen=True, slots=True)\n",
+ "class ElementHTML:\n",
+ " element_type: str | None\n",
+ " html: str\n",
+ "\n",
+ "\n",
+ "def _widget_html_(\n",
+ " index: int,\n",
+ " name: str,\n",
+ " element_type: str | None,\n",
+ " subelements: list[ElementHTML]\n",
+ ") -> str:\n",
+ " return jinja_env.get_template('default_widget.html.jinja').render(\n",
+ " index=index, name=name, subelements=subelements\n",
+ " )\n",
+ "\n",
+ "\n",
+ "def _ls_widget_html_(\n",
+ " index: int,\n",
+ " name: str,\n",
+ " element_type: str | None,\n",
+ " subelements: list[ElementHTML]\n",
+ ") -> str:\n",
+ " return jinja_env.get_template('linear_setup_widget.html.jinja').render(\n",
+ " index=index, name=name, subelements=subelements\n",
+ " )\n",
+ "\n",
+ "\n",
+ "def _fs_widget_html_(\n",
+ " index: int,\n",
+ " name: str,\n",
+ " element_type: str | None,\n",
+ " subelements: list[ElementHTML]\n",
+ ") -> str:\n",
+ " return jinja_env.get_template('free_space_widget.html.jinja').render(\n",
+ " index=index, name=name, subelements=subelements\n",
+ " )\n",
+ "\n",
+ "\n",
+ "def _rs_widget_html_(\n",
+ " index: int,\n",
+ " name: str,\n",
+ " element_type: str | None,\n",
+ " subelements: list[ElementHTML]\n",
+ ") -> str:\n",
+ " return jinja_env.get_template('reservoir_widget.html.jinja').render(\n",
+ " index=index, name=name, subelements=subelements\n",
+ " )\n",
+ "\n",
+ "\n",
+ "def _l_widget_html_(\n",
+ " index: int,\n",
+ " name: str,\n",
+ " element_type: str | None,\n",
+ " subelements: list[ElementHTML]\n",
+ ") -> str:\n",
+ " return jinja_env.get_template('lens_widget.html.jinja').render(\n",
+ " index=index, name=name, subelements=subelements\n",
+ " )\n",
+ "\n",
+ "\n",
+ "def _get_widget_html_method(element: Specsable):\n",
+ " if hasattr(element, '_widget_html_'):\n",
+ " widget_html_method = getattr(element, '_widget_html_')\n",
+ " else:\n",
+ " widget_html_method = _widget_html_\n",
+ "\n",
+ " if isinstance(element, svetlanna.LinearOpticalSetup):\n",
+ " widget_html_method = _ls_widget_html_\n",
+ "\n",
+ " if isinstance(element, svetlanna.elements.FreeSpace):\n",
+ " widget_html_method = _fs_widget_html_\n",
+ " \n",
+ " if isinstance(element, svetlanna.elements.SimpleReservoir):\n",
+ " widget_html_method = _rs_widget_html_\n",
+ " \n",
+ " if isinstance(element, svetlanna.elements.ThinLens):\n",
+ " widget_html_method = _l_widget_html_\n",
+ "\n",
+ " return widget_html_method\n",
+ "\n",
+ "\n",
+ "def _subelements_html(subelements: list[_ElementInTree]) -> list[ElementHTML]:\n",
+ " res = []\n",
+ "\n",
+ " for subelement in subelements:\n",
+ " widget_html_method = _get_widget_html_method(subelement.element)\n",
+ " try:\n",
+ " res.append(\n",
+ " ElementHTML(\n",
+ " subelement.subelement_type,\n",
+ " html=widget_html_method(\n",
+ " index=subelement.element_index,\n",
+ " name=subelement.element.__class__.__name__,\n",
+ " element_type=subelement.subelement_type,\n",
+ " subelements=_subelements_html(subelement.children)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " except Exception as e:\n",
+ " pass\n",
+ "\n",
+ " return res\n",
+ "\n",
+ "\n",
+ "elements = _ElementsIterator(system2, directory='')\n",
+ "\n",
+ "for _, _, i in elements:\n",
+ " for _ in i:\n",
+ " pass\n",
+ "\n",
+ "res = _subelements_html(elements.tree)\n",
+ "\n",
+ "\n",
+ "containered_html = f'{res[0].html}
'\n",
+ "display_html(containered_html, raw=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "e = svetlanna.specs.specs_writer.write_specs(system2, filename='specs.md')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[_ElementInTree(element=, element_index=3, children=[_ElementInTree(element=ThinLens(), element_index=4, children=[], subelement_name='0'), _ElementInTree(element=FreeSpace(), element_index=5, children=[], subelement_name='1'), _ElementInTree(element=ThinLens(), element_index=6, children=[], subelement_name='2')], subelement_name='Nonlinear element'),\n",
+ " _ElementInTree(element=, element_index=3, children=[_ElementInTree(element=ThinLens(), element_index=4, children=[], subelement_name='0'), _ElementInTree(element=FreeSpace(), element_index=5, children=[], subelement_name='1'), _ElementInTree(element=ThinLens(), element_index=6, children=[], subelement_name='2')], subelement_name='Delay element')]"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "e.tree[0].children[1].children"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "Exception",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mException\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m\n",
+ "\u001b[0;31mException\u001b[0m: "
+ ]
+ }
+ ],
+ "source": [
+ "raise Exception"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'Nonlinear element'"
+ ]
+ },
+ "execution_count": 22,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "e._tree[0].children[1].children[0].element_name"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "ElementInTree(element=, element_index=0, children=[ElementInTree(element=FreeSpace(), element_index=1, children=[]), ElementInTree(element=SimpleReservoir(), element_index=2, children=[ElementInTree(element=, element_index=3, children=[ElementInTree(element=ThinLens(), element_index=4, children=[]), ElementInTree(element=FreeSpace(), element_index=5, children=[]), ElementInTree(element=ThinLens(), element_index=6, children=[])]), ElementInTree(element=, element_index=7, children=[ElementInTree(element=ThinLens(), element_index=8, children=[]), ElementInTree(element=FreeSpace(), element_index=9, children=[]), ElementInTree(element=ThinLens(), element_index=10, children=[])])]), ElementInTree(element=FreeSpace(), element_index=11, children=[])])\n"
+ ]
+ }
+ ],
+ "source": [
+ "print('\\n'.join([str(i) for i in e._tree]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# torch.set_default_dtype(torch.float32)\n",
+ "# Image.fromarray(torch.tensor(a).to(torch.float64).numpy(), mode='L').show()\n",
+ "# Image.fromarray(np.uint8(255*torch.tensor(a).numpy()), mode='L').show() # <- works"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "False"
+ ]
+ },
+ "execution_count": 97,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "torch.tensor([[1, 1,], [1, 2]]).size() < torch.tensor([1,]).size()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import svetlanna\n",
+ "import svetlanna.elements\n",
+ "\n",
+ "\n",
+ "class A(svetlanna.elements.Element):\n",
+ " def __init__(self, simulation_parameters: svetlanna.SimulationParameters) -> None:\n",
+ " super().__init__(simulation_parameters)\n",
+ "\n",
+ " self.a = self.make_buffer('a', self.simulation_parameters.axes.W)\n",
+ "\n",
+ " def forward(self, input_field: svetlanna.Wavefront) -> svetlanna.Wavefront:\n",
+ " pass"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "simulation_parameters = svetlanna.SimulationParameters(\n",
+ " {\n",
+ " 'W': torch.linspace(-1, 1, 10),\n",
+ " 'H': torch.linspace(-1, 1, 10),\n",
+ " 'wavelength': 10\n",
+ " }\n",
+ ")\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "a = A(simulation_parameters=simulation_parameters)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ "A()"
+ ]
+ },
+ "execution_count": 108,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "a.to('mps')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "device(type='cpu')"
+ ]
+ },
+ "execution_count": 114,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "simulation_parameters.axes.W.device"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "device(type='mps', index=0)"
+ ]
+ },
+ "execution_count": 113,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "a.a.device"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}