diff --git a/.github/BUG_ISSUE.md b/.github/BUG_ISSUE.md new file mode 100644 index 0000000..1c0ecea --- /dev/null +++ b/.github/BUG_ISSUE.md @@ -0,0 +1,32 @@ +--- +name: Bug report +about: Is something not working as expected? +title: "[Bug]: " +labels: bug +assignees: '' + +--- + +Hi! Thank you for taking the time to report a bug with SVETlANNa. + +Additionally, please note that this platform is meant for bugs in SVETlANNa only. +Issues regarding dependencies and libraries should be reported in their respective repositories. + + + +## Expected Behavior + + +## Current Behavior + + +## Possible Solution + + + +## Steps to Reproduce + + +## Context [OPTIONAL] + + diff --git a/.github/CODE_OF_CONDUCT.md b/.github/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..1c95f34 --- /dev/null +++ b/.github/CODE_OF_CONDUCT.md @@ -0,0 +1,69 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +## Enforcement + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md new file mode 100644 index 0000000..1d48a06 --- /dev/null +++ b/.github/CONTRIBUTING.md @@ -0,0 +1,72 @@ +# Contributing to SVETlANNa + +We welcome you to [check the existing issues](https://github.com/CompPhysLab/SVETlANNa/issues) for bugs or enhancements to work on. +If you have an idea for an extension to SVETlANNa, please [file a new issue](https://github.com/CompPhysLab/SVETlANNa/issues/new) so we can discuss it. + +Make sure to familiarize yourself with the project layout before making any major contributions. + +## How to contribute + +1. Fork the [project repository](https://github.com/CompPhysLab/SVETlANNa/): click on the 'Fork' button near the top of the page. This creates a copy of the code under your account on the GitHub server. + +2. Clone this copy to your local disk: + + ```bash + git clone git@github.com:YourUsername/SVETlANNa.git + ``` + +3. Create a branch to hold your changes: + + ```bash + git checkout -b my-contribution + ``` + +4. Make sure your local environment is correctly set up for development and that all required project dependencies are installed. + +5. Start making changes on your newly created branch, remembering to + never work on the ``master`` branch! Work on this copy on your + computer using Git to do the version control. + +6. To check that your changes haven’t broken existing tests and that new tests pass, run the tests. + +7. When you're done editing and local testing, run: + + ```bash + git add modified_files + git commit + ``` + + to record your changes in Git, then push them to GitHub with: + + ```bash + git push -u origin my-contribution + ``` + +Finally, go to the web page of your fork of the SVETlANNa repo, and click +'Pull Request' (PR) to send your changes to the maintainers for review. + +When creating your PR, please make sure to enable the "Allow edits from maintainers" option (known as maintainer_can_modify). +This allows the maintainers to make minor changes or improvements to your PR branch if necessary during the review process. + +(If it looks confusing to you, then look up the [Git +documentation](http://git-scm.com/documentation) on the web.) + +## Before submitting your pull request + +Before you submit a pull request for your contribution, please work +through this checklist to make sure that you have done everything +necessary so we can efficiently review and accept your changes. + +If your contribution changes SVETlANNa in any way: + +- Update the [README](https://github.com/CompPhysLab/SVETlANNa/tree/dev/README.md) if anything there has changed. + +If your contribution involves any code changes: + +- Update the [project tests](https://github.com/CompPhysLab/SVETlANNa/tree/dev/tests) to test your code changes. + +- Make sure that your code is properly commented with [docstrings](https://peps.python.org/pep-0257/) and comments explaining your rationale behind non-obvious coding practices. + +## Acknowledgements + +This document guide is based at well-written contributung guide of [TPOT](https://github.com/EpistasisLab/tpot) and [FEDOT](https://github.com/aimclub/FEDOT) frameworks. diff --git a/.github/FEATURE_ISSUE.md b/.github/FEATURE_ISSUE.md new file mode 100644 index 0000000..2273f5d --- /dev/null +++ b/.github/FEATURE_ISSUE.md @@ -0,0 +1,68 @@ +--- +name: Feature request +about: Want us to add any features to SVETlANNa? +title: 'enh: ' +labels: enhancement +assignees: '' + +--- + + + +## Summary + + + +## Motivation + + + +## Guide-level explanation + + + +## Reference-level explanation + + + +## Drawbacks + + + +## Unresolved Questions + + diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..9423423 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,53 @@ + + + + + + + + + + + +--- + +## Checklist + +Please check all that apply (`x` inside `[ ]`): + +- [ ] I've performed a self-review of my code +- [ ] I've run linters and tests locally before submission +- [ ] I've added tests (if it's a bug, feature or enhancement) +- [ ] I've adjusted the documentation (if it's a feature or enhancement) + +## Summary + + + +## Context + + + +## Additional Notes + +Add any additional context for reviewers (questions, implementation details, suggestions): diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml new file mode 100644 index 0000000..2852b29 --- /dev/null +++ b/.github/workflows/black.yml @@ -0,0 +1,17 @@ +name: Black Formatter +'on': +- push +- pull_request +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - name: Checkout repo + uses: actions/checkout@v4 + - name: Run Black + uses: psf/black@stable + with: + options: --check --diff + src: . + jupyter: 'false' diff --git a/.github/workflows/requirements.yml b/.github/workflows/requirements.yml new file mode 100644 index 0000000..4d5ce6e --- /dev/null +++ b/.github/workflows/requirements.yml @@ -0,0 +1,52 @@ +name: Requirements File +on: + - push + - pull_request +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.x' + + - name: Install pipreqs + run: | + pip install pipreqs + + - name: Generate new requirements.txt + run: | + pipreqs $GITHUB_WORKSPACE --force --savepath requirements_new.txt + + - name: Update requirements.txt if necessary + run: | + git config --local user.email "action@github.com" + git config --local user.name "GitHub Action" + + # Check if requirements.txt exists and compare with the new file + if [ -f requirements.txt ]; then + echo "requirements.txt already exists." + # Compare the existing and new requirements.txt + if ! cmp -s requirements.txt requirements_new.txt; then + echo "Updating requirements.txt." + mv requirements_new.txt requirements.txt + git add requirements.txt + git commit -m "Update requirements.txt" + git push + else + echo "No changes to requirements.txt." + rm requirements_new.txt # УдаляСм Π²Ρ€Π΅ΠΌΠ΅Π½Π½Ρ‹ΠΉ Ρ„Π°ΠΉΠ», Ссли ΠΈΠ·ΠΌΠ΅Π½Π΅Π½ΠΈΠΉ Π½Π΅Ρ‚ + fi + else + echo "requirements.txt does not exist. Creating it." + mv requirements_new.txt requirements.txt + git add requirements.txt + git commit -m "Add requirements.txt" + git push + fi + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml new file mode 100644 index 0000000..b1b381c --- /dev/null +++ b/.github/workflows/unit_tests.yml @@ -0,0 +1,30 @@ +name: Unit Tests +'on': +- push +- pull_request +jobs: + test: + name: Run Tests + runs-on: ${{ matrix.os }} + timeout-minutes: 15 + strategy: + matrix: + os: + - ubuntu-latest + python-version: + - '3.11' + steps: + - name: Checkout repo + uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: pip install -r requirements.txt && pip install pytest pytest-cov + - name: Set PYTHONPATH + run: echo "PYTHONPATH=$(pwd)" >> $GITHUB_ENV + - name: Run tests + run: pytest tests/ --cov=. + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7314a74 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +anywidget==0.9.18 +ipython==9.2.0 +Jinja2==3.1.6 +LightPipes==2.1.5 +matplotlib==3.10.3 +numpy==2.2.6 +pandas==2.2.3 +Pillow==11.2.1 +pytest==8.3.5 +scipy==1.15.3 +torch==2.7.0 +tqdm==4.67.1 +traitlets==5.14.3 diff --git a/svetlanna/__init__.py b/svetlanna/__init__.py index e13fe8c..9ab9fdf 100644 --- a/svetlanna/__init__.py +++ b/svetlanna/__init__.py @@ -10,15 +10,15 @@ from . import networks __all__ = [ - 'Parameter', - 'ConstrainedParameter', - 'LinearOpticalSetup', - 'SimulationParameters', - 'Wavefront', - 'set_debug_logging', - 'elements', - 'units', - 'specs', - 'Clerk', - 'networks' + "Parameter", + "ConstrainedParameter", + "LinearOpticalSetup", + "SimulationParameters", + "Wavefront", + "set_debug_logging", + "elements", + "units", + "specs", + "Clerk", + "networks", ] diff --git a/svetlanna/axes_math.py b/svetlanna/axes_math.py index 4960d91..bbdfd4a 100644 --- a/svetlanna/axes_math.py +++ b/svetlanna/axes_math.py @@ -4,10 +4,7 @@ import torch -def _append_slice_generator( - axes_number: int, - new_axes_number: int -): +def _append_slice_generator(axes_number: int, new_axes_number: int): """Yields Ellipsis, then `axes_number` of full slices (`::`), then `new_axes_number - axes_number` of None (new axis). @@ -32,8 +29,7 @@ def _append_slice_generator( @cache def _append_slice( - axes: tuple[str, ...], - new_axes: tuple[str, ...] + axes: tuple[str, ...], new_axes: tuple[str, ...] ) -> tuple[EllipsisType | slice | None, ...]: """ Slice tuple that can be used to add new axes to the end @@ -46,8 +42,7 @@ def _append_slice( @cache def _axes_indices_to_sort( - axes: tuple[str, ...], - new_axes: tuple[str, ...] + axes: tuple[str, ...], new_axes: tuple[str, ...] ) -> tuple[int, ...]: """ Indices of each axis of `axes` in `new_axes`. @@ -77,7 +72,7 @@ def _axes_indices_to_sort( def _swaps_generator( - axes_indices: tuple[int, ...] + axes_indices: tuple[int, ...], ) -> Generator[tuple[int, int], None, None]: """ Generates swaps to sort indices array. @@ -89,19 +84,17 @@ def _swaps_generator( for i in range(L - 1): j_min = i - for j in range(i+1, L): + for j in range(i + 1, L): if indices[j] < indices[j_min]: j_min = j if j_min != i: indices[i], indices[j_min] = indices[j_min], indices[i] - yield -L+i, -L+j_min + yield -L + i, -L + j_min @cache -def _swaps( - axes_indices: tuple[int, ...] -) -> tuple[tuple[int, int], ...]: +def _swaps(axes_indices: tuple[int, ...]) -> tuple[tuple[int, int], ...]: """ Swaps to sort indices array. """ @@ -109,21 +102,16 @@ def _swaps( @cache -def _check_new_axes( - axes: tuple[str, ...], - new_axes: tuple[str, ...] -) -> None: +def _check_new_axes(axes: tuple[str, ...], new_axes: tuple[str, ...]) -> None: """ Check whether `new_axes` contain all names presented in `axes`. """ if not set(new_axes).issuperset(axes): - raise ValueError('new_axes should contain all names in axes') + raise ValueError("new_axes should contain all names in axes") def cast_tensor( - a: torch.Tensor, - axes: tuple[str, ...], - new_axes: tuple[str, ...] + a: torch.Tensor, axes: tuple[str, ...], new_axes: tuple[str, ...] ) -> torch.Tensor: """Cast tensor `a` with axes `(..., a, b, c)` to `(..., *new_axes)`. `new_axes` should contain all axes presented in `axes`. @@ -157,20 +145,15 @@ def cast_tensor( @cache -def _axis_to_tuple( - axis: str | Iterable[str] -) -> tuple[str, ...]: +def _axis_to_tuple(axis: str | Iterable[str]) -> tuple[str, ...]: """Creates tuple of `str` from `str` or `Iterable[str]`.""" if isinstance(axis, str): - axis = (axis, ) + axis = (axis,) return tuple(axis) @cache -def _new_axes( - a_axis: tuple[str, ...], - b_axis: tuple[str, ...] -) -> tuple[str, ...]: +def _new_axes(a_axis: tuple[str, ...], b_axis: tuple[str, ...]) -> tuple[str, ...]: """ Generates tuple with new axes. ``` @@ -208,10 +191,7 @@ def is_scalar(a: torch.Tensor | float) -> bool: return False -def _check_axis( - a: torch.Tensor | float, - a_axis: tuple[str, ...] -): +def _check_axis(a: torch.Tensor | float, a_axis: tuple[str, ...]): """ Check if each axis is unique. Check whether the number of axes not greater than the dimensionality @@ -223,7 +203,9 @@ def _check_axis( if isinstance(a, torch.Tensor) and a.shape: # if a not a scalar if len(a.shape) < len(a_axis): - raise ValueError(f"Number of axes in the tensor ({len(a.shape)}) should be larger than number of provided axes' names ({len(a_axis)})!") + raise ValueError( + f"Number of axes in the tensor ({len(a.shape)}) should be larger than number of provided axes' names ({len(a_axis)})!" + ) def tensor_dot( @@ -231,7 +213,7 @@ def tensor_dot( b: torch.Tensor | float, a_axis: str | Iterable[str], b_axis: str | Iterable[str], - preserve_a_axis: bool = False + preserve_a_axis: bool = False, ) -> tuple[torch.Tensor, tuple[str, ...]]: """Perform tensor dot product. diff --git a/svetlanna/clerk.py b/svetlanna/clerk.py index 27db3f6..3897211 100644 --- a/svetlanna/clerk.py +++ b/svetlanna/clerk.py @@ -42,9 +42,7 @@ class ClerkMode(StrEnum): CHECKPOINT_FILENAME_SUFFIX = ".pt" -CHECKPOINT_FILENAME_PATTERN = re.compile( - f"^\\d+\\{CHECKPOINT_FILENAME_SUFFIX}$" -) +CHECKPOINT_FILENAME_PATTERN = re.compile(f"^\\d+\\{CHECKPOINT_FILENAME_SUFFIX}$") CHECKPOINT_BACKUP_FILENAME_PATTERN = re.compile( f"^backup_\\d{{4}}-\\d{{2}}-\\d{{2}}_\\d{{2}}-\\d{{2}}-\\d{{2}}\\.\\d{{6}}\\{CHECKPOINT_FILENAME_SUFFIX}$" ) @@ -55,6 +53,14 @@ class ClerkMode(StrEnum): class Clerk(Generic[ConditionsType]): + """ + A lightweight alternative to TensorBoard and other logging frameworks + for tracking the training process, storing experiment metadata, + and handling checkpoints. + + The Clerk is not a new concept but a minimal implementation included + in the framework to start training models without any dependencies.""" + def __init__( self, experiment_directory: str, @@ -358,7 +364,7 @@ def load_checkpoint( self, index: str | int, targets: dict[str, StatefulTorchClass] | None = None, - weights_only: bool = True + weights_only: bool = True, ) -> object | None: """Load the checkpoint with a specific index and apply state dicts to checkpoint targets. If the targets are not provided, the checkpoint @@ -471,14 +477,26 @@ def clean_checkpoints(self): file.unlink() def clean_backup_checkpoints(self): - """Remove checkpoints that are matches backup checkpoints name pattern. - """ + """Remove checkpoints that are matches backup checkpoints name pattern.""" for file in self.experiment_directory.iterdir(): filename = file.name if CHECKPOINT_BACKUP_FILENAME_PATTERN.match(filename): file.unlink() def __enter__(self): + """ + Enters the context and prepares the clerk for experiment execution. + + This method checks if the clerk is already in use, creates the experiment + directory, and potentially resumes from a previous checkpoint based on the + configured mode. + + Args: + None + + Returns: + self: The Clerk instance itself, allowing usage with `with` statements. + """ # Check if the clerk is not in use in other context if self._in_use: raise RuntimeError("The clerk is already is used in some other context!") @@ -546,7 +564,7 @@ def __exit__(self, exc_type, exc_value, traceback): "description": "Backup checkpoint", "time": time, } - time_str = time.replace(' ', '_').replace(':', '-') + time_str = time.replace(" ", "_").replace(":", "-") index = f"backup_{time_str}{CHECKPOINT_FILENAME_SUFFIX}" try: diff --git a/svetlanna/detector.py b/svetlanna/detector.py index d95508a..5bd1e46 100644 --- a/svetlanna/detector.py +++ b/svetlanna/detector.py @@ -12,11 +12,8 @@ class Detector(Element): transforms incident field to intensities for further image analysis (2) ... """ - def __init__( - self, - simulation_parameters: SimulationParameters, - func='intensity' - ): + + def __init__(self, simulation_parameters: SimulationParameters, func="intensity"): """ Parameters ---------- @@ -51,7 +48,7 @@ def forward(self, input_field: Wavefront) -> torch.Tensor: """ detector_output = None - if self.func == 'intensity': + if self.func == "intensity": # TODO: add some normalization for intensities? what is with units? detector_output = torch.Tensor( input_field.abs().pow(2) @@ -65,15 +62,16 @@ class DetectorProcessorClf(nn.Module): The necessary layer to solve a classification task. Must be placed after a detector. This layer process an image from the detector and calculates probabilities of belonging to classes. """ + def __init__( - self, - num_classes: int, - simulation_parameters: SimulationParameters, - segmented_detector: torch.Tensor | None = None, - segments_weights: torch.Tensor | None = None, - segments_zone_size: torch.Size | None = None, - segmentation_type: str = 'strips', - device: str | torch.device = torch.get_default_device(), + self, + num_classes: int, + simulation_parameters: SimulationParameters, + segmented_detector: torch.Tensor | None = None, + segments_weights: torch.Tensor | None = None, + segments_zone_size: torch.Size | None = None, + segmentation_type: str = "strips", + device: str | torch.device = torch.get_default_device(), ): """ Parameters @@ -99,7 +97,9 @@ def __init__( """ super().__init__() self.num_classes = num_classes - self.simulation_parameters = simulation_parameters # only to get sizes - devices mustn't match + self.simulation_parameters = ( + simulation_parameters # only to get sizes - devices mustn't match + ) self.__device = device @@ -107,11 +107,15 @@ def __init__( self.segmentation_type = segmentation_type if segmented_detector is not None: # if a detector segmentation is not defined - self.segmented_detector = segmented_detector.int() # markup of a detector by classes zones + self.segmented_detector = ( + segmented_detector.int() + ) # markup of a detector by classes zones else: # detector is not segmented self.segmentation_type = segmentation_type if segments_zone_size is None: - sim_params_size = self.simulation_parameters.axes_size(axs=('H', 'W')) # [H, W] + sim_params_size = self.simulation_parameters.axes_size( + axs=("H", "W") + ) # [H, W] # make a detector segmentation according to self.segmentation_type self.segmented_detector = self.detector_segmentation(sim_params_size) else: @@ -153,12 +157,16 @@ def detector_segmentation(self, detector_shape: torch.Size) -> torch.Tensor: detector_y, detector_x = detector_shape detector_markup = (-1) * torch.ones(size=detector_shape, dtype=torch.int32) - if self.segmentation_type == 'strips': + if self.segmentation_type == "strips": # segments are vertical strips, symmetrically arranged relative to the detector center! # TODO: gaps between strips? check if possible etc. if self.num_classes % 2 == 0: # even number of classes - central_class = 0 # no central class, classes are symmetrically arranged - if detector_x % 2 == 0: # even number of detector "pixels" in x-direction + central_class = ( + 0 # no central class, classes are symmetrically arranged + ) + if ( + detector_x % 2 == 0 + ): # even number of detector "pixels" in x-direction # Strips: |..111222|333444..| x_center_left_ind = int(detector_x // 2) x_center_right_ind = x_center_left_ind @@ -172,24 +180,38 @@ def detector_segmentation(self, detector_shape: torch.Size) -> torch.Tensor: else: # odd number of classes central_class = 1 # there is a central strip strip_width = int(detector_x // self.num_classes) - if detector_x % 2 == 0: # even number of detector "pixels" in x-direction - if strip_width % 2 == 0: # can symmetrically arrange a central class strip + if ( + detector_x % 2 == 0 + ): # even number of detector "pixels" in x-direction + if ( + strip_width % 2 == 0 + ): # can symmetrically arrange a central class strip # Strips: |..111122|223333..| x_center_left_ind = int(detector_x // 2 - strip_width // 2) x_center_right_ind = int(detector_x // 2 + strip_width // 2) else: # should make a center strip of even width # Strips: |.11122|22333.| center_strip_width = strip_width + 1 # becomes even! - x_center_left_ind = int(detector_x // 2 - center_strip_width // 2) - x_center_right_ind = int(detector_x // 2 + center_strip_width // 2) + x_center_left_ind = int( + detector_x // 2 - center_strip_width // 2 + ) + x_center_right_ind = int( + detector_x // 2 + center_strip_width // 2 + ) # update width for other strips except the center one strip_width = int(x_center_left_ind // (self.num_classes // 2)) else: # odd number of detector "pixels" in x-direction - if strip_width % 2 == 0: # should make a center strip of odd width for symmetry + if ( + strip_width % 2 == 0 + ): # should make a center strip of odd width for symmetry # Strips: |11112|2|23333| center_strip_width = strip_width + 1 # becomes odd! - x_center_left_ind = int(detector_x // 2 - center_strip_width // 2) - x_center_right_ind = int(detector_x // 2 + 1 + center_strip_width // 2) + x_center_left_ind = int( + detector_x // 2 - center_strip_width // 2 + ) + x_center_right_ind = int( + detector_x // 2 + 1 + center_strip_width // 2 + ) # update width for other strips except the center one strip_width = int(x_center_left_ind // (self.num_classes // 2)) else: # can symmetrically arrange a central class strip @@ -198,7 +220,9 @@ def detector_segmentation(self, detector_shape: torch.Size) -> torch.Tensor: x_center_right_ind = int(detector_x // 2 + 1 + strip_width // 2) # mask for the central class ind_central_class = int(self.num_classes // 2) - detector_markup[:, x_center_left_ind:x_center_right_ind] = ind_central_class + detector_markup[:, x_center_left_ind:x_center_right_ind] = ( + ind_central_class + ) # fill masks from the detector center (like apertures for each class) # from the center to left @@ -206,18 +230,22 @@ def detector_segmentation(self, detector_shape: torch.Size) -> torch.Tensor: ind_class = int(self.num_classes // 2 - 1 - ind) ind_left_border = x_center_left_ind - strip_width * (ind + 1) ind_right_border = x_center_left_ind - strip_width * ind - assert torch.all(-1 == detector_markup[:, ind_left_border:ind_right_border]).item() + assert torch.all( + -1 == detector_markup[:, ind_left_border:ind_right_border] + ).item() detector_markup[:, ind_left_border:ind_right_border] = ind_class # from the center to right for ind in range(self.num_classes // 2): # right half of the detector ind_class = int(ind + self.num_classes // 2 + central_class) ind_left_border = x_center_right_ind + strip_width * ind ind_right_border = x_center_right_ind + strip_width * (ind + 1) - assert torch.all(-1 == detector_markup[:, ind_left_border:ind_right_border]).item() + assert torch.all( + -1 == detector_markup[:, ind_left_border:ind_right_border] + ).item() detector_markup[:, ind_left_border:ind_right_border] = ind_class # add padding to match simulation parameters Wavefront shape - sim_params_size = self.simulation_parameters.axes_size(axs=('H', 'W')) + sim_params_size = self.simulation_parameters.axes_size(axs=("H", "W")) if not sim_params_size == detector_shape: y_nodes, x_nodes = sim_params_size # goal size y_mask, x_mask = detector_shape # current size @@ -230,8 +258,8 @@ def detector_segmentation(self, detector_shape: torch.Size) -> torch.Tensor: detector_markup = nn.functional.pad( input=detector_markup, pad=(pad_left, pad_right, pad_top, pad_bottom), - mode='constant', - value=-1 + mode="constant", + value=-1, ) # if the detector size matches with sim params @@ -254,7 +282,9 @@ def weight_segments(self) -> torch.Tensor: # TODO: solve the problem with dimensions... classes_areas = torch.zeros(size=(1, self.num_classes)) for ind_class in range(self.num_classes): - classes_areas[0, ind_class] = torch.where(ind_class == self.segmented_detector, 1, 0).sum().item() + classes_areas[0, ind_class] = ( + torch.where(ind_class == self.segmented_detector, 1, 0).sum().item() + ) min_class_area = classes_areas.min().item() return min_class_area / classes_areas @@ -280,14 +310,16 @@ def forward(self, detector_data: torch.Tensor) -> torch.Tensor: # `mask_class` will be on the same device as `self.segmented_detector`! mask_class = torch.where(ind_class == self.segmented_detector, 1, 0) integrals_by_classes[0, ind_class] = ( - detector_data * mask_class - ).sum().item() + (detector_data * mask_class).sum().item() + ) integrals_by_classes = integrals_by_classes * self.segments_weights # TODO: maybe some function like SoftMax? but integrals can be large! return integrals_by_classes / integrals_by_classes.sum().item() - def batch_zone_integral(self, batch_detector_data: torch.Tensor, ind_class: int) -> torch.Tensor: + def batch_zone_integral( + self, batch_detector_data: torch.Tensor, ind_class: int + ) -> torch.Tensor: """ Returns an integral (sum) of a detector data over a selected zone (`ind_class`). ... @@ -314,7 +346,9 @@ def batch_zone_integral(self, batch_detector_data: torch.Tensor, ind_class: int) # TODO: how to process other dimensions? user must define by himself? return class_integral.sum( dim=tuple( - range(1, len(class_integral.size())) # all dimensions except batch_size dimension + range( + 1, len(class_integral.size()) + ) # all dimensions except batch_size dimension ) ) # return.size() = [batch_size] else: # no other dimensions except ['W', 'H'] for each item in the batch @@ -345,16 +379,20 @@ def batch_forward(self, batch_detector_data: torch.Tensor) -> torch.Tensor: batch_size = batch_detector_data.size()[0] # batch size is a 0'th dimension! - integrals_by_classes = torch.zeros(size=(batch_size, self.num_classes)).to(self.__device) + integrals_by_classes = torch.zeros(size=(batch_size, self.num_classes)).to( + self.__device + ) for ind_class in range(self.num_classes): integrals_by_classes[:, ind_class] = ( - self.batch_zone_integral(batch_detector_data, ind_class) * - self.segments_weights[0, ind_class] + self.batch_zone_integral(batch_detector_data, ind_class) + * self.segments_weights[0, ind_class] ) - return integrals_by_classes / torch.unsqueeze(integrals_by_classes.sum(dim=1), 1) + return integrals_by_classes / torch.unsqueeze( + integrals_by_classes.sum(dim=1), 1 + ) - def to(self, device: str | torch.device | int) -> 'DetectorProcessorClf': + def to(self, device: str | torch.device | int) -> "DetectorProcessorClf": if self.__device == torch.device(device): return self @@ -362,9 +400,18 @@ def to(self, device: str | torch.device | int) -> 'DetectorProcessorClf': num_classes=self.num_classes, simulation_parameters=self.simulation_parameters, segmented_detector=self.segmented_detector, - device=device + device=device, ) @property def device(self) -> str | torch.device | int: + """ + Returns the device on which tensors are allocated. + + Args: + None + + Returns: + The device string, torch.device object or integer representing the device. + """ return self.__device diff --git a/svetlanna/elements/__init__.py b/svetlanna/elements/__init__.py index 72ba9a8..a5fc0aa 100644 --- a/svetlanna/elements/__init__.py +++ b/svetlanna/elements/__init__.py @@ -8,15 +8,15 @@ from .reservoir import SimpleReservoir __all__ = [ - 'Element', - 'FreeSpace', - 'Aperture', - 'RoundAperture', - 'RectangularAperture', - 'ThinLens', - 'SpatialLightModulator', - 'DiffractiveLayer', - 'NonlinearElement', - 'FunctionModule', - 'SimpleReservoir' + "Element", + "FreeSpace", + "Aperture", + "RoundAperture", + "RectangularAperture", + "ThinLens", + "SpatialLightModulator", + "DiffractiveLayer", + "NonlinearElement", + "FunctionModule", + "SimpleReservoir", ] diff --git a/svetlanna/elements/aperture.py b/svetlanna/elements/aperture.py index 643e87c..44b1ebc 100644 --- a/svetlanna/elements/aperture.py +++ b/svetlanna/elements/aperture.py @@ -57,17 +57,26 @@ def forward(self, incident_wavefront: Wavefront) -> Wavefront: incident_wavefront, self.get_transmission_function(), self.transmission_function_axes, - self.simulation_parameters + self.simulation_parameters, ) @staticmethod def _widget_html_( - index: int, - name: str, - element_type: str | None, - subelements: list[ElementHTML] + index: int, name: str, element_type: str | None, subelements: list[ElementHTML] ) -> str: - return jinja_env.get_template('widget_aperture.html.jinja').render( + """ + Renders the HTML for a widget using a Jinja2 template. + + Args: + index: The index of the widget. + name: The name of the widget. + element_type: The type of element (optional). + subelements: A list of sub-elements to include in the widget. + + Returns: + str: The rendered HTML string for the widget. + """ + return jinja_env.get_template("widget_aperture.html.jinja").render( index=index, name=name, subelements=subelements ) @@ -79,9 +88,7 @@ class Aperture(MulElement): """ def __init__( - self, - simulation_parameters: SimulationParameters, - mask: OptimizableTensor + self, simulation_parameters: SimulationParameters, mask: OptimizableTensor ): """Aperture of the optical element defined by mask tensor. @@ -94,27 +101,55 @@ def __init__( Each element must be either 0 (blocks light) or 1 (allows light). """ - super().__init__( - simulation_parameters=simulation_parameters - ) + super().__init__(simulation_parameters=simulation_parameters) - self.mask = self.process_parameter('mask', mask) - self._calc_axes = ('H', 'W') + self.mask = self.process_parameter("mask", mask) + self._calc_axes = ("H", "W") @property def transmission_function_axes(self) -> tuple[str, ...]: + """ + Returns the axes used for the transmission function calculation. + + Args: + None + + Returns: + tuple[str, ...]: A tuple of strings representing the axis labels + used in calculating the transmission function. + """ return self._calc_axes def get_transmission_function(self) -> torch.Tensor: + """ + Returns the transmission function (mask). + + Args: + None + + Returns: + torch.Tensor: The mask representing the transmission function. + """ return self.mask def to_specs(self) -> Iterable[ParameterSpecs]: + """ + Returns parameter specifications for the mask. + + Args: + None + + Returns: + Iterable[ParameterSpecs]: An iterable of ParameterSpecs objects, + containing representations of the mask. + """ return [ ParameterSpecs( - 'mask', [ + "mask", + [ PrettyReprRepr(self.mask), ImageRepr(self.mask.numpy(force=True)), - ] + ], ) ] @@ -122,14 +157,11 @@ def to_specs(self) -> Iterable[ParameterSpecs]: # TODO" check docstring class RectangularAperture(MulElement): """A rectangle-shaped aperture with a transmission function taking either - a value of 0 or 1 + a value of 0 or 1 """ def __init__( - self, - simulation_parameters: SimulationParameters, - height: float, - width: float + self, simulation_parameters: SimulationParameters, height: float, width: float ): """Constructor method @@ -142,61 +174,72 @@ def __init__( width : float aperture width """ - super().__init__( - simulation_parameters=simulation_parameters - ) + super().__init__(simulation_parameters=simulation_parameters) - self.height = self.process_parameter('height', height) - self.width = self.process_parameter('width', width) + self.height = self.process_parameter("height", height) + self.width = self.process_parameter("width", width) - _x_grid, _y_grid = self.simulation_parameters.meshgrid( - x_axis='W', y_axis='H' - ) + _x_grid, _y_grid = self.simulation_parameters.meshgrid(x_axis="W", y_axis="H") - self._calc_axes = ('H', 'W') + self._calc_axes = ("H", "W") self._mask = self.make_buffer( - '_mask', + "_mask", ( - ( - torch.abs(_x_grid) <= self.width/2 - ) * ( - torch.abs(_y_grid) <= self.height/2 - ) - ).to(dtype=torch.get_default_dtype()) + (torch.abs(_x_grid) <= self.width / 2) + * (torch.abs(_y_grid) <= self.height / 2) + ).to(dtype=torch.get_default_dtype()), ) @property def transmission_function_axes(self) -> tuple[str, ...]: + """ + Returns the axes used for the transmission function calculation. + + Args: + None + + Returns: + tuple[str, ...]: A tuple of strings representing the axis labels + used in calculating the transmission function. + """ return self._calc_axes def get_transmission_function(self) -> torch.Tensor: + """ + Returns the transmission function (mask). + + Args: + None + + Returns: + torch.Tensor: The mask representing the transmission function. + """ return self._mask def to_specs(self) -> Iterable[ParameterSpecs]: + """ + Returns parameter specifications for height and width. + + Args: + None + + Returns: + Iterable[ParameterSpecs]: An iterable of ParameterSpecs objects, + containing specs for 'height' and 'width'. + """ return [ - ParameterSpecs( - 'height', [ - PrettyReprRepr(self.height) - ] - ), - ParameterSpecs( - 'width', [ - PrettyReprRepr(self.width) - ] - ) + ParameterSpecs("height", [PrettyReprRepr(self.height)]), + ParameterSpecs("width", [PrettyReprRepr(self.width)]), ] # TODO: check docstrings class RoundAperture(MulElement): """A round-shaped aperture with a transmission function taking either - a value of 0 or 1 + a value of 0 or 1 """ - def __init__( - self, - simulation_parameters: SimulationParameters, - radius: float - ): + + def __init__(self, simulation_parameters: SimulationParameters, radius: float): """Constructor method Parameters @@ -206,36 +249,55 @@ def __init__( radius : float Radius of the round-shaped aperture """ - super().__init__( - simulation_parameters=simulation_parameters - ) + super().__init__(simulation_parameters=simulation_parameters) - self.radius = self.process_parameter('radius', radius) + self.radius = self.process_parameter("radius", radius) - _x_grid, _y_grid = self.simulation_parameters.meshgrid( - x_axis='W', y_axis='H' - ) + _x_grid, _y_grid = self.simulation_parameters.meshgrid(x_axis="W", y_axis="H") - self._calc_axes = ('H', 'W') + self._calc_axes = ("H", "W") self._mask = self.make_buffer( - '_mask', - ( - _x_grid**2 + _y_grid**2 <= self.radius**2 - ).to(dtype=torch.get_default_dtype()) + "_mask", + (_x_grid**2 + _y_grid**2 <= self.radius**2).to( + dtype=torch.get_default_dtype() + ), ) @property def transmission_function_axes(self) -> tuple[str, ...]: + """ + Returns the axes used for the transmission function calculation. + + Args: + None + + Returns: + tuple[str, ...]: A tuple of strings representing the axis labels + used in calculating the transmission function. + """ return self._calc_axes def get_transmission_function(self) -> torch.Tensor: + """ + Returns the transmission function (mask). + + Args: + None + + Returns: + torch.Tensor: The mask representing the transmission function. + """ return self._mask def to_specs(self) -> Iterable[ParameterSpecs]: - return [ - ParameterSpecs( - 'radius', [ - PrettyReprRepr(self.radius) - ] - ) - ] + """ + Returns parameter specifications for the radius. + + Args: + None + + Returns: + Iterable[ParameterSpecs]: A list of ParameterSpecs, currently containing only + the specification for the 'radius' parameter. + """ + return [ParameterSpecs("radius", [PrettyReprRepr(self.radius)])] diff --git a/svetlanna/elements/diffractive_layer.py b/svetlanna/elements/diffractive_layer.py index f37d32c..45ae622 100644 --- a/svetlanna/elements/diffractive_layer.py +++ b/svetlanna/elements/diffractive_layer.py @@ -17,7 +17,7 @@ def __init__( self, simulation_parameters: SimulationParameters, mask: OptimizableTensor, - mask_norm: float = 2 * torch.pi + mask_norm: float = 2 * torch.pi, ): """Constructor method @@ -35,14 +35,21 @@ def __init__( super().__init__(simulation_parameters) - self.mask = self.process_parameter('mask', mask) - self.mask_norm = self.process_parameter('mask_norm', mask_norm) + self.mask = self.process_parameter("mask", mask) + self.mask_norm = self.process_parameter("mask_norm", mask_norm) @property def transmission_function(self) -> torch.Tensor: - return torch.exp( - (2j * torch.pi / self.mask_norm) * self.mask - ) + """ + Calculates the transmission function based on the mask and its norm. + + Args: + None + + Returns: + torch.Tensor: The calculated transmission function as a complex-valued tensor. + """ + return torch.exp((2j * torch.pi / self.mask_norm) * self.mask) def forward(self, incident_wavefront: Wavefront) -> Wavefront: """Method that calculates the field after propagating through the SLM @@ -60,8 +67,8 @@ def forward(self, incident_wavefront: Wavefront) -> Wavefront: return mul( incident_wavefront, self.transmission_function, - ('H', 'W'), - self.simulation_parameters + ("H", "W"), + self.simulation_parameters, ) def reverse(self, transmission_wavefront: Wavefront) -> Wavefront: @@ -83,36 +90,57 @@ def reverse(self, transmission_wavefront: Wavefront) -> Wavefront: return mul( transmission_wavefront, torch.conj(self.transmission_function), - ('H', 'W'), - self.simulation_parameters + ("H", "W"), + self.simulation_parameters, ) def to_specs(self) -> Iterable[ParameterSpecs]: + """ + Returns parameter specifications for the mask and normalized mask. + + Args: + None + + Returns: + Iterable[ParameterSpecs]: A list of ParameterSpecs objects, one for the + mask (with both a pretty representation and an image representation) and + one for the normalized mask (with only a pretty representation). + """ mask = self.mask.numpy(force=True) mask_min = mask.min() mask_max = mask.max() return [ ParameterSpecs( - 'mask', [ + "mask", + [ PrettyReprRepr(self.mask), - ImageRepr((255 * (mask - mask_min) / (mask_max - mask_min)).astype('uint8')), - ] + ImageRepr( + (255 * (mask - mask_min) / (mask_max - mask_min)).astype( + "uint8" + ) + ), + ], ), - ParameterSpecs( - 'mask_norm', [ - PrettyReprRepr(self.mask_norm) - ] - ) + ParameterSpecs("mask_norm", [PrettyReprRepr(self.mask_norm)]), ] @staticmethod def _widget_html_( - index: int, - name: str, - element_type: str | None, - subelements: list[ElementHTML] + index: int, name: str, element_type: str | None, subelements: list[ElementHTML] ) -> str: - return jinja_env.get_template('widget_diffractive_layer.html.jinja').render( + """ + Renders the HTML for a widget using a Jinja2 template. + + Args: + index: The index of the widget. + name: The name of the widget. + element_type: The type of element (optional). + subelements: A list of sub-elements to include in the widget. + + Returns: + str: The rendered HTML string for the widget. + """ + return jinja_env.get_template("widget_diffractive_layer.html.jinja").render( index=index, name=name, subelements=subelements ) diff --git a/svetlanna/elements/element.py b/svetlanna/elements/element.py index 86651f4..ba2d2ca 100644 --- a/svetlanna/elements/element.py +++ b/svetlanna/elements/element.py @@ -11,10 +11,10 @@ from warnings import warn -INNER_PARAMETER_SUFFIX = '_svtlnn_inner_parameter' +INNER_PARAMETER_SUFFIX = "_svtlnn_inner_parameter" -_T = TypeVar('_T', Tensor, None) -_V = TypeVar('_V') +_T = TypeVar("_T", Tensor, None) +_V = TypeVar("_V") class _BufferedValueContainer(tuple): @@ -25,9 +25,20 @@ class _BufferedValueContainer(tuple): Inheriting from tuple is used for performance reasons, so the `__slots__`. This approach was identified by GPT as the fastest one. """ + __slots__ = () def __new__(cls, obj: Tensor | None): + """ + Creates a new instance of the class. + + Args: + obj: The object to be wrapped. Can be None. + + Returns: + A new instance of the class initialized with the given object as its + single element tuple. + """ return super().__new__(cls, (obj,)) @@ -35,10 +46,7 @@ def __new__(cls, obj: Tensor | None): class Element(nn.Module, metaclass=ABCMeta): """A class that describes each element of the system""" - def __init__( - self, - simulation_parameters: SimulationParameters - ) -> None: + def __init__(self, simulation_parameters: SimulationParameters) -> None: """A class that describes each element of the system Parameters @@ -54,25 +62,34 @@ def __init__( # TODO: check doctrings @abstractmethod def forward(self, incident_wavefront: Wavefront) -> Wavefront: - """Forward propagation through the optical element""" def to_specs(self) -> Iterable[ParameterSpecs | SubelementSpecs]: - """Create specs""" - for (name, parameter) in self.named_parameters(): + for name, parameter in self.named_parameters(): yield ParameterSpecs( - parameter_name=name, - representations=(PrettyReprRepr(value=parameter),) + parameter_name=name, representations=(PrettyReprRepr(value=parameter),) ) def __setattr__( - self, - name: str, - value: Tensor | nn.Module | _BufferedValueContainer + self, name: str, value: Tensor | nn.Module | _BufferedValueContainer ) -> None: + """ + Sets an attribute on the module. + + Handles special cases for BufferedValueContainers and Parameters to ensure + correct storage and behavior within the module's state. + + Args: + name: The name of the attribute to set. + value: The value to assign to the attribute. Can be a Tensor, nn.Module, + or an instance of _BufferedValueContainer. + + Returns: + None + """ if isinstance(value, _BufferedValueContainer): # In the case of pattern `self.x = self.make_buffer('x', x_value)` @@ -91,22 +108,29 @@ def __setattr__( # BoundedParameter and Parameter are handled by pointing # auxiliary attribute on them with a name plus INNER_PARAMETER_SUFFIX if isinstance(value, (ConstrainedParameter, Parameter)): - super().__setattr__( - name + INNER_PARAMETER_SUFFIX, value.inner_storage - ) + super().__setattr__(name + INNER_PARAMETER_SUFFIX, value.inner_storage) return super().__setattr__(name, value) def _repr_html_(self) -> str: - stream = StringIO('') + """ + Generates an HTML representation of the object's specifications. + + This method recursively writes the details of the element and its subelements + to a string stream, formatting them as collapsible HTML details tags. + + Args: + None + + Returns: + str: An HTML string representing the object's specifications. + """ + stream = StringIO("") def write_element_details(element: Specsable): subelements: list[SubelementSpecs] = [] writer_context_generator = context_generator( - element=element, - element_index=0, - directory='', - subelements=subelements + element=element, element_index=0, directory="", subelements=subelements ) # Write element's parameter specs to the stream write_specs_to_html(element, 0, writer_context_generator, stream) @@ -121,32 +145,21 @@ def write_element_details(element: Specsable): ) element_name = subelement.subelement.__class__.__name__ # Write the element's name to the summary tag - stream.write( - f'[{subelement.subelement_type}] {element_name}' - ) + stream.write(f"[{subelement.subelement_type}] {element_name}") # Close summary tag and open a new div for the subelement stream.write( - '' - '
' + "" '
' ) # Repeat the process for the subelement write_element_details(subelement.subelement) # Close the div and the details tags - stream.write( - '
' - '' - ) + stream.write("
" "") write_element_details(self) return stream.getvalue() - def make_buffer( - self, - name: str, - value: _T, - persistent: bool = False - ) -> _T: + def make_buffer(self, name: str, value: _T, persistent: bool = False) -> _T: """Make buffer for internal use. Use case: @@ -179,19 +192,13 @@ def make_buffer( "the simulation parameters device." ) - self.register_buffer( - name, value, persistent=persistent - ) + self.register_buffer(name, value, persistent=persistent) # The instance of _BufferedValueContainer is returned # to support `self.x = self.make_buffer('x', x_value)` pattern return _BufferedValueContainer(self.__getattr__(name)) # type: ignore - def process_parameter( - self, - name: str, - value: _V - ) -> _V: + def process_parameter(self, name: str, value: _V) -> _V: """Process element parameter passed by user. Automatically registers buffer for non-parametric tensors. @@ -216,8 +223,7 @@ def process_parameter( if isinstance(value, Tensor): if value.device != self.simulation_parameters.device: raise ValueError( - f"Parameter {name} must be on " - "the simulation parameters device." + f"Parameter {name} must be on " "the simulation parameters device." ) if isinstance(value, (nn.Parameter, Parameter)): return value @@ -228,5 +234,5 @@ def process_parameter( # === methods below are added for typing only === if TYPE_CHECKING: - def __call__(self, incident_wavefront: Wavefront) -> Wavefront: - ... + + def __call__(self, incident_wavefront: Wavefront) -> Wavefront: ... diff --git a/svetlanna/elements/free_space.py b/svetlanna/elements/free_space.py index e768825..cc4b707 100644 --- a/svetlanna/elements/free_space.py +++ b/svetlanna/elements/free_space.py @@ -20,7 +20,7 @@ def __init__( self, simulation_parameters: SimulationParameters, distance: OptimizableFloat, - method: Literal['fresnel', 'AS'] + method: Literal["fresnel", "AS"], ): """Free space element. @@ -37,14 +37,14 @@ def __init__( """ super().__init__(simulation_parameters) - self.distance = self.process_parameter('distance', distance) - self.method = self.process_parameter('method', method) + self.distance = self.process_parameter("distance", distance) + self.method = self.process_parameter("method", method) # params extracted from SimulationParameters device = self.simulation_parameters.device - self._w_index = self.simulation_parameters.axes.index('W') - self._h_index = self.simulation_parameters.axes.index('H') + self._w_index = self.simulation_parameters.axes.index("W") + self._h_index = self.simulation_parameters.axes.index("H") x_linear = self.simulation_parameters.axes.W y_linear = self.simulation_parameters.axes.H @@ -53,16 +53,12 @@ def __init__( y_nodes = y_linear.shape[0] # Compute spatial grid spacing - dx = (x_linear[1] - x_linear[0]) if x_nodes > 1 else 1. - dy = (y_linear[1] - y_linear[0]) if y_nodes > 1 else 1. + dx = (x_linear[1] - x_linear[0]) if x_nodes > 1 else 1.0 + dy = (y_linear[1] - y_linear[0]) if y_nodes > 1 else 1.0 # Compute wave vectors - kx_linear = 2 * torch.pi * torch.fft.fftfreq( - x_nodes, dx, device=device - ) - ky_linear = 2 * torch.pi * torch.fft.fftfreq( - y_nodes, dy, device=device - ) + kx_linear = 2 * torch.pi * torch.fft.fftfreq(x_nodes, dx, device=device) + ky_linear = 2 * torch.pi * torch.fft.fftfreq(y_nodes, dy, device=device) # Compute wave vectors grids kx_grid = kx_linear[None, :] # shape: (1, 'W') @@ -73,14 +69,11 @@ def __init__( k = 2 * torch.pi / self.simulation_parameters.axes.wavelength # 2) Calculate (kx^2+ky^2) tensor - kx2ky2 = kx_grid ** 2 + ky_grid ** 2 # shape: ('H', 'W') + kx2ky2 = kx_grid**2 + ky_grid**2 # shape: ('H', 'W') # 3) Calculate (kx^2+ky^2) / k^2 relation, relation_axes = tensor_dot( - a=1 / (k ** 2), - b=kx2ky2, - a_axis='wavelength', - b_axis=('H', 'W') + a=1 / (k**2), b=kx2ky2, a_axis="wavelength", b_axis=("H", "W") ) # shape: ('wavelength', 'H', 'W') or ('H', 'W') depending on k shape # TODO: Remove legacy filter @@ -90,23 +83,21 @@ def __init__( # The filter removes contribution of evanescent waves if use_legacy_filter: # TODO: Shouldn't the 88'th string be here? - condition = (relation <= 1) # calculate the low pass filter condition # noqa + condition = relation <= 1 # calculate the low pass filter condition # noqa condition = condition.to(kx_grid) # cast bool to float # Registering Buffer for _low_pass_filter - self._low_pass_filter = self.make_buffer( - '_low_pass_filter', condition - ) + self._low_pass_filter = self.make_buffer("_low_pass_filter", condition) else: self._low_pass_filter = 1 # Reshape wave vector for further calculations - wave_number = k[..., None, None] # shape: ('wavelength', 1, 1) or (1, 1) # noqa + wave_number = k[ + ..., None, None + ] # shape: ('wavelength', 1, 1) or (1, 1) # noqa # Registering Buffer for _wave_number - self._wave_number = self.make_buffer( - '_wave_number', wave_number - ) + self._wave_number = self.make_buffer("_wave_number", wave_number) self._calc_axes = relation_axes # axes tuple used during calculations @@ -116,25 +107,23 @@ def __init__( # or # kz = |k| otherwise wave_number_z = torch.sqrt( - self._wave_number ** 2 - self._low_pass_filter * kx2ky2 + self._wave_number**2 - self._low_pass_filter * kx2ky2 ) else: # kz = sqrt(k^2 - (kx^2 + ky^2)) wave_number_z = torch.sqrt( - self._wave_number ** 2 - kx2ky2 + 0j + self._wave_number**2 - kx2ky2 + 0j ) # 0j is required to convert argument to complex # Registering Buffer for _wave_number_z - self._wave_number_z = self.make_buffer( - '_wave_number_z', wave_number_z - ) + self._wave_number_z = self.make_buffer("_wave_number_z", wave_number_z) # Calculate kz taylored, used by Fresnel approximation - wave_number_z_eff_fresnel = - 0.5 * kx2ky2 / self._wave_number + wave_number_z_eff_fresnel = -0.5 * kx2ky2 / self._wave_number # Registering Buffer for _wave_number_z_eff_fresnel self._wave_number_z_eff_fresnel = self.make_buffer( - '_wave_number_z_eff_fresnel', wave_number_z_eff_fresnel + "_wave_number_z_eff_fresnel", wave_number_z_eff_fresnel ) # Warnings for fulfilling the method criteria @@ -143,35 +132,35 @@ def __init__( # by Kedar Khare, Mansi Butola and Sunaina Rajor Lx = torch.abs(x_linear[-1] - x_linear[0]) Ly = torch.abs(y_linear[-1] - y_linear[0]) - if method == 'AS': + if method == "AS": kx_max = torch.max(torch.abs(kx_linear)) ky_max = torch.max(torch.abs(ky_linear)) - x_condition = kx_max >= k / torch.sqrt(1 + (2*distance / Lx)**2) - y_condition = ky_max >= k / torch.sqrt(1 + (2*distance / Ly)**2) + x_condition = kx_max >= k / torch.sqrt(1 + (2 * distance / Lx) ** 2) + y_condition = ky_max >= k / torch.sqrt(1 + (2 * distance / Ly) ** 2) if not torch.all(x_condition): warn( - 'Aliasing problems may occur in the AS method. ' - 'Consider reducing the distance ' - 'or increasing the Nx*dx product.' + "Aliasing problems may occur in the AS method. " + "Consider reducing the distance " + "or increasing the Nx*dx product." ) if not torch.all(y_condition): warn( - 'Aliasing problems may occur in the AS method. ' - 'Consider reducing the distance ' - 'or increasing the Ny*dy product.' + "Aliasing problems may occur in the AS method. " + "Consider reducing the distance " + "or increasing the Ny*dy product." ) - if method == 'fresnel': + if method == "fresnel": diagonal_squared = Lx**2 + Ly**2 - condition = distance**3 > k / 8 * (diagonal_squared)**2 + condition = distance**3 > k / 8 * (diagonal_squared) ** 2 if not torch.all(condition): warn( - 'The paraxial (near-axis) optics condition ' - 'required for the Fresnel method is not satisfied. ' - 'Consider increasing the distance ' - 'or decreasing the screen size.' + "The paraxial (near-axis) optics condition " + "required for the Fresnel method is not satisfied. " + "Consider increasing the distance " + "or decreasing the screen size." ) def impulse_response_angular_spectrum(self) -> torch.Tensor: @@ -201,10 +190,10 @@ def impulse_response_fresnel(self) -> torch.Tensor: # Fourier image of impulse response function # 0 if k^2 < (kx^2 + ky^2) [if use_legacy_filter] - return self._low_pass_filter * torch.exp( - (1j * self.distance) * self._wave_number_z_eff_fresnel - ) * torch.exp( - (1j * self.distance) * self._wave_number + return ( + self._low_pass_filter + * torch.exp((1j * self.distance) * self._wave_number_z_eff_fresnel) + * torch.exp((1j * self.distance) * self._wave_number) ) def _impulse_response(self) -> torch.Tensor: @@ -216,19 +205,16 @@ def _impulse_response(self) -> torch.Tensor: The impulse response function """ - if self.method == 'AS': + if self.method == "AS": return self.impulse_response_angular_spectrum() - elif self.method == 'fresnel': + elif self.method == "fresnel": return self.impulse_response_fresnel() raise ValueError("Unknown forward propagation method") # TODO: ask for tol parameter, maybe move it to init? - def forward( - self, - incident_wavefront: Wavefront - ) -> Wavefront: + def forward(self, incident_wavefront: Wavefront) -> Wavefront: """Calculates the field after propagating in the free space Parameters @@ -248,8 +234,7 @@ def forward( """ input_field_fft = torch.fft.fft2( - incident_wavefront, - dim=(self._h_index, self._w_index) + incident_wavefront, dim=(self._h_index, self._w_index) ) impulse_response_fft = self._impulse_response() @@ -260,12 +245,11 @@ def forward( b=impulse_response_fft, # example shape: ('wavelength', 'H', 'W') a_axis=self.simulation_parameters.axes.names, b_axis=self._calc_axes, - preserve_a_axis=True # check that the output has the input shape + preserve_a_axis=True, # check that the output has the input shape ) # example output shape: (5, 'wavelength', 1, 'H', 'W') output_field = torch.fft.ifft2( - output_field_fft, - dim=(self._h_index, self._w_index) + output_field_fft, dim=(self._h_index, self._w_index) ) return output_field @@ -287,8 +271,7 @@ def reverse(self, transmission_wavefront: Wavefront) -> Wavefront: """ transmission_field_fft = torch.fft.fft2( - transmission_wavefront, - dim=(self._h_index, self._w_index) + transmission_wavefront, dim=(self._h_index, self._w_index) ) impulse_response_fft = self._impulse_response().conj() @@ -299,32 +282,51 @@ def reverse(self, transmission_wavefront: Wavefront) -> Wavefront: b=impulse_response_fft, # example shape: ('wavelength', 'H', 'W') a_axis=self.simulation_parameters.axes.names, b_axis=self._calc_axes, - preserve_a_axis=True # check that the output has the first input shape # noqa + preserve_a_axis=True, # check that the output has the first input shape # noqa ) # example output shape: (5, 'wavelength', 1, 'H', 'W') incident_field = torch.fft.ifft2( - incident_field_fft, - dim=(self._h_index, self._w_index) + incident_field_fft, dim=(self._h_index, self._w_index) ) return incident_field def to_specs(self) -> Iterable[ParameterSpecs]: + """ + Returns parameter specifications for the distance. + + Args: + None + + Returns: + Iterable[ParameterSpecs]: An iterable of ParameterSpecs, + containing a specification for the 'distance' parameter. + """ return [ ParameterSpecs( - 'distance', [ + "distance", + [ PrettyReprRepr(self.distance), - ] + ], ) ] @staticmethod def _widget_html_( - index: int, - name: str, - element_type: str | None, - subelements: list[ElementHTML] + index: int, name: str, element_type: str | None, subelements: list[ElementHTML] ) -> str: - return jinja_env.get_template('widget_free_space.html.jinja').render( + """ + Renders the HTML for a free space widget using a Jinja template. + + Args: + index: The index of the widget. + name: The name of the widget. + element_type: The type of element (optional). + subelements: A list of sub-elements to include in the widget. + + Returns: + str: The rendered HTML string for the free space widget. + """ + return jinja_env.get_template("widget_free_space.html.jinja").render( index=index, name=name, subelements=subelements ) diff --git a/svetlanna/elements/lens.py b/svetlanna/elements/lens.py index 1883b6f..dad373f 100644 --- a/svetlanna/elements/lens.py +++ b/svetlanna/elements/lens.py @@ -19,7 +19,7 @@ def __init__( self, simulation_parameters: SimulationParameters, focal_length: OptimizableFloat, - radius: float = torch.inf + radius: float = torch.inf, ): """Thin lens element. @@ -36,26 +36,19 @@ def __init__( super().__init__(simulation_parameters) - self.focal_length = self.process_parameter( - 'focal_length', focal_length - ) - self.radius = self.process_parameter( - 'radius', radius - ) + self.focal_length = self.process_parameter("focal_length", focal_length) + self.radius = self.process_parameter("radius", radius) # Compute wave_number as a tensor wave_number, axes = tensor_dot( 2 * torch.pi / self.simulation_parameters.axes.wavelength, torch.tensor([[1]], device=self.simulation_parameters.device), - 'wavelength', - ('H', 'W') + "wavelength", + ("H", "W"), ) # shape: ('wavelength', 1, 1) or (1, 1) # Registering Buffer for _wave_number - self._wave_number = self.make_buffer( - '_wave_number', - wave_number - ) + self._wave_number = self.make_buffer("_wave_number", wave_number) self._calc_axes = axes # axes tuple used during calculations @@ -67,8 +60,7 @@ def __init__( # Registering Buffer for _radius_squared self._radius_squared = self.make_buffer( - '_radius_squared', - x_grid**2 + y_grid**2 + "_radius_squared", x_grid**2 + y_grid**2 ) # Create a mask that acts as an aperture: @@ -78,18 +70,30 @@ def __init__( self._radius_mask = 1.0 else: self._radius_mask = self.make_buffer( - '_radius_mask', + "_radius_mask", (self._radius_squared <= self.radius**2).to( dtype=torch.get_default_dtype() # cast bool to float - ) + ), ) @property def transmission_function(self) -> torch.Tensor: + """ + Calculates the transmission function of the optical system. + + This function computes the complex-valued transmission function based on + the radius mask, radius squared, wave number, and focal length. It uses + an exponential function to represent the phase shift introduced by the + optical element. + + Returns: + torch.Tensor: The calculated transmission function as a PyTorch tensor. + """ return torch.exp( - - 1j * self._radius_mask * self._radius_squared * ( - self._wave_number / (2 * self.focal_length) - ) + -1j + * self._radius_mask + * self._radius_squared + * (self._wave_number / (2 * self.focal_length)) ) def get_transmission_function(self) -> torch.Tensor: @@ -121,7 +125,7 @@ def forward(self, incident_wavefront: Wavefront) -> Wavefront: incident_wavefront, self.transmission_function, self._calc_axes, - self.simulation_parameters + self.simulation_parameters, ) def reverse(self, transmission_wavefront: Wavefront) -> Wavefront: @@ -144,30 +148,47 @@ def reverse(self, transmission_wavefront: Wavefront) -> Wavefront: transmission_wavefront, torch.conj(self.transmission_function), self._calc_axes, - self.simulation_parameters + self.simulation_parameters, ) def to_specs(self) -> Iterable[ParameterSpecs]: + """ + Returns a list of ParameterSpecs for the object's parameters. + + Args: + None + + Returns: + Iterable[ParameterSpecs]: An iterable containing ParameterSpecs objects, + representing the focal length and radius + of the object. + """ return [ ParameterSpecs( - 'focal_length', [ + "focal_length", + [ PrettyReprRepr(self.focal_length), - ] + ], ), - ParameterSpecs( - 'radius', [ - PrettyReprRepr(self.radius) - ] - ) + ParameterSpecs("radius", [PrettyReprRepr(self.radius)]), ] @staticmethod def _widget_html_( - index: int, - name: str, - element_type: str | None, - subelements: list[ElementHTML] + index: int, name: str, element_type: str | None, subelements: list[ElementHTML] ) -> str: - return jinja_env.get_template('widget_lens.html.jinja').render( + """ + Renders the HTML for a widget using a Jinja2 template. + + Args: + index: The index of the widget. + name: The name of the widget. + element_type: The type of element (optional). + subelements: A list of sub-elements to include in the widget. + + Returns: + str: The rendered HTML string for the widget. + """ + return jinja_env.get_template("widget_lens.html.jinja").render( index=index, name=name, subelements=subelements ) diff --git a/svetlanna/elements/nonlinear_element.py b/svetlanna/elements/nonlinear_element.py index 538f59f..e5170e2 100644 --- a/svetlanna/elements/nonlinear_element.py +++ b/svetlanna/elements/nonlinear_element.py @@ -10,9 +10,11 @@ class FunctionModule(torch.nn.Module): """A class for transforming an arbitrary function with multiple parameters. Allows training function parameters """ + def __init__( - self, function: Callable[[torch.Tensor], torch.Tensor], - function_parameters: Dict | None + self, + function: Callable[[torch.Tensor], torch.Tensor], + function_parameters: Dict | None, ) -> None: """Constructor method @@ -35,10 +37,7 @@ def __init__( elif isinstance(value, torch.Tensor): self.register_buffer(name, value) - def forward( - self, - function_argument: torch.Tensor - ) -> torch.Tensor: + def forward(self, function_argument: torch.Tensor) -> torch.Tensor: """forward method for a class inherited from torch.nn.Module Parameters @@ -57,11 +56,8 @@ def forward( return self.function(function_argument) if TYPE_CHECKING: - def __call__( - self, - function_argument: torch.Tensor - ) -> torch.Tensor: - ... + + def __call__(self, function_argument: torch.Tensor) -> torch.Tensor: ... class NonlinearElement(Element): @@ -74,7 +70,7 @@ def __init__( self, simulation_parameters: SimulationParameters, response_function: Callable[[torch.Tensor], torch.Tensor], - response_parameters: Dict | None = None + response_parameters: Dict | None = None, ): """Constructor method @@ -91,10 +87,7 @@ def __init__( super().__init__(simulation_parameters) - self.response_function = FunctionModule( - response_function, - response_parameters - ) + self.response_function = FunctionModule(response_function, response_parameters) def forward(self, incident_wavefront: Wavefront) -> Wavefront: """Method calculating the wavefront after passing a nonlinear optical @@ -110,12 +103,8 @@ def forward(self, incident_wavefront: Wavefront) -> Wavefront: Wavefront Wavefront passing through a nonlinear optical element """ - transformed_amplitude = self.response_function( - torch.abs(incident_wavefront) - ) + transformed_amplitude = self.response_function(torch.abs(incident_wavefront)) # preserve the phase of the incident wavefront # phase = incident_wavefront / torch.abs(incident_wavefront) phase = torch.exp(1j * incident_wavefront.phase) - return Wavefront( - transformed_amplitude * phase - ) + return Wavefront(transformed_amplitude * phase) diff --git a/svetlanna/elements/reservoir.py b/svetlanna/elements/reservoir.py index 3738ba2..383e4fd 100644 --- a/svetlanna/elements/reservoir.py +++ b/svetlanna/elements/reservoir.py @@ -13,14 +13,15 @@ class SimpleReservoir(Element): """Reservoir element.""" + def __init__( self, simulation_parameters: SimulationParameters, - nonlinear_element: Union[Element, 'LinearOpticalSetup'], - delay_element: Union[Element, 'LinearOpticalSetup'], + nonlinear_element: Union[Element, "LinearOpticalSetup"], + delay_element: Union[Element, "LinearOpticalSetup"], feedback_gain: OptimizableFloat, input_gain: OptimizableFloat, - delay: int + delay: int, ) -> None: """Reservoir element. The main idea is explained in https://doi.org/10.1364/OE.20.022783. @@ -55,15 +56,9 @@ def __init__( self.nonlinear_element = nonlinear_element self.delay_element = delay_element - self.feedback_gain = self.process_parameter( - 'feedback_gain', feedback_gain - ) - self.input_gain = self.process_parameter( - 'input_gain', input_gain - ) - self.delay = self.process_parameter( - 'delay', delay - ) + self.feedback_gain = self.process_parameter("feedback_gain", feedback_gain) + self.input_gain = self.process_parameter("input_gain", input_gain) + self.delay = self.process_parameter("delay", delay) # create FIFI queue for delay line self.feedback_queue: deque[Wavefront] = deque(maxlen=self.delay) @@ -93,11 +88,25 @@ def pop_feedback_queue(self) -> None | Wavefront: return self.feedback_queue.popleft() def drop_feedback_queue(self) -> None: - """Clear all elements from the feedback queue. - """ + """Clear all elements from the feedback queue.""" self.feedback_queue.clear() def forward(self, incident_wavefront: Wavefront) -> Wavefront: + """ + Processes an incoming wavefront through a feedback loop. + + This method simulates the propagation of a wavefront through a system + with a delay line and nonlinear elements. It retrieves a delayed element + from a queue, combines it with the current input, processes the result + through a nonlinear function, and adds the output back to the delay line. + + Args: + incident_wavefront: The incoming wavefront signal. + + Returns: + Wavefront: The processed output wavefront after applying feedback and + nonlinear transformations. + """ # get an element from feedback line queue delayed = self.pop_feedback_queue() @@ -108,36 +117,49 @@ def forward(self, incident_wavefront: Wavefront) -> Wavefront: ) else: # if the delay line is empty - output = self.nonlinear_element( - incident_wavefront * self.input_gain - ) + output = self.nonlinear_element(incident_wavefront * self.input_gain) # add output to the delay line self.append_feedback_queue(output) return output def to_specs(self) -> Iterable[ParameterSpecs | SubelementSpecs]: + """ + Returns an iterable of parameter and subelement specifications. + + Args: + None + + Returns: + Iterable[ParameterSpecs | SubelementSpecs]: An iterable containing + ParameterSpecs for 'feedback_gain', 'input_gain', and 'delay', as well + as SubelementSpecs for 'nonlinear_element' and 'delay_element'. + """ return ( - ParameterSpecs('feedback_gain', ( - PrettyReprRepr(self.feedback_gain), - )), - ParameterSpecs('input_gain', ( - PrettyReprRepr(self.input_gain), - )), - ParameterSpecs('delay', ( - PrettyReprRepr(self.delay), - )), - SubelementSpecs('Nonlinear element', self.nonlinear_element), - SubelementSpecs('Delay element', self.delay_element) + ParameterSpecs("feedback_gain", (PrettyReprRepr(self.feedback_gain),)), + ParameterSpecs("input_gain", (PrettyReprRepr(self.input_gain),)), + ParameterSpecs("delay", (PrettyReprRepr(self.delay),)), + SubelementSpecs("Nonlinear element", self.nonlinear_element), + SubelementSpecs("Delay element", self.delay_element), ) @staticmethod def _widget_html_( - index: int, - name: str, - element_type: str | None, - subelements: list[ElementHTML] + index: int, name: str, element_type: str | None, subelements: list[ElementHTML] ) -> str: - return jinja_env.get_template('widget_reservoir.html.jinja').render( + """ + Renders the HTML for a widget using a Jinja2 template. + + Args: + index: The index of the widget. + name: The name of the widget. + element_type: The type of element (e.g., 'input', 'select'). Can be None. + subelements: A list of ElementHTML objects representing the sub-elements + of the widget. + + Returns: + str: The rendered HTML string for the widget. + """ + return jinja_env.get_template("widget_reservoir.html.jinja").render( index=index, name=name, subelements=subelements ) diff --git a/svetlanna/elements/slm.py b/svetlanna/elements/slm.py index 6495fa8..f43c85f 100644 --- a/svetlanna/elements/slm.py +++ b/svetlanna/elements/slm.py @@ -24,17 +24,29 @@ def __init__( mask: torch.Tensor, height: OptimizableFloat, width: OptimizableFloat, - location: Tuple = (0., 0.), + location: Tuple = (0.0, 0.0), number_of_levels: int = 256, step_function: Callable[[torch.Tensor], torch.Tensor] = relu, mode: Literal[ - 'nearest', - 'bilinear', - 'bicubic', - 'area', - 'nearest-exact' - ] = 'nearest' + "nearest", "bilinear", "bicubic", "area", "nearest-exact" + ] = "nearest", ): + """ + Initializes a new instance of the class. + + Args: + simulation_parameters: The simulation parameters object. + mask: The mask tensor. + height: The height parameter. + width: The width parameter. + location: The location tuple (x, y). Defaults to (0., 0.). + number_of_levels: The number of levels. Defaults to 256. + step_function: The step function. Defaults to relu. + mode: The interpolation mode. Defaults to 'nearest'. + + Returns: + None + """ super().__init__(simulation_parameters) @@ -48,24 +60,21 @@ def __init__( self.mode = self.process_parameter("mode", mode) self.number_of_levels = self.process_parameter( - "number_of_levels", - number_of_levels + "number_of_levels", number_of_levels ) self.height_resolution, self.width_resolution = self.mask.shape self._device = self.simulation_parameters.device - self._w_index = self.simulation_parameters.axes.index('W') - self._h_index = self.simulation_parameters.axes.index('H') + self._w_index = self.simulation_parameters.axes.index("W") + self._h_index = self.simulation_parameters.axes.index("H") self._x_linear = self.make_buffer( - "_x_linear", - self.simulation_parameters.axes.W + "_x_linear", self.simulation_parameters.axes.W ) self._y_linear = self.make_buffer( - "_y_linear", - self.simulation_parameters.axes.H + "_y_linear", self.simulation_parameters.axes.H ) self._x_grid = self._x_linear[None, :] @@ -73,38 +82,72 @@ def __init__( @property def get_aperture(self) -> torch.Tensor: + """ + Returns the aperture mask as a boolean tensor. - self.aperture = ((torch.abs( - self._x_grid - self.x) <= self.width/2) * (torch.abs( - self._y_grid - self.y) <= self.height/2)).to( - dtype=torch.get_default_dtype() - ) + The aperture is defined by a rectangular region centered at (x, y) + with width and height specified by the object's attributes. + + Args: + None + + Returns: + torch.Tensor: A boolean tensor representing the aperture mask. True values indicate pixels within the aperture, False otherwise. + """ + + self.aperture = ( + (torch.abs(self._x_grid - self.x) <= self.width / 2) + * (torch.abs(self._y_grid - self.y) <= self.height / 2) + ).to(dtype=torch.get_default_dtype()) return self.aperture @property def resized_mask(self) -> torch.Tensor: + """ + Resizes the mask to match the simulation parameters. + + Calculates boundaries for resizing based on the aperture and linear coordinates, + then interpolates the mask to the new dimensions. Includes a warning if the + resized mask is smaller than the original. + + Args: + None + + Returns: + torch.Tensor: The resized mask as a torch tensor. + """ _y_indices, _x_indices = torch.where(self.aperture == 1) - _y_indices, _x_indices = torch.unique(_y_indices), torch.unique(_x_indices) # noqa: E501 + _y_indices, _x_indices = torch.unique(_y_indices), torch.unique( + _x_indices + ) # noqa: E501 self.left_boundary = _x_indices[ torch.argmin( - torch.abs(self._x_linear[_x_indices] - (self.x - self.width / 2)) # noqa: E501 + torch.abs( + self._x_linear[_x_indices] - (self.x - self.width / 2) + ) # noqa: E501 ).item() ] self.right_boundary = _x_indices[ torch.argmin( - torch.abs(self._x_linear[_x_indices] - (self.x + self.width / 2)) # noqa: E501 + torch.abs( + self._x_linear[_x_indices] - (self.x + self.width / 2) + ) # noqa: E501 ).item() ] self.top_boundary = _y_indices[ torch.argmin( - torch.abs(self._y_linear[_y_indices] - (self.y + self.height / 2)) # noqa: E501 + torch.abs( + self._y_linear[_y_indices] - (self.y + self.height / 2) + ) # noqa: E501 ).item() ] self.bottom_boundary = _y_indices[ torch.argmin( - torch.abs(self._y_linear[_y_indices] - (self.y - self.height / 2)) # noqa: E501 + torch.abs( + self._y_linear[_y_indices] - (self.y - self.height / 2) + ) # noqa: E501 ).item() ] @@ -117,26 +160,41 @@ def resized_mask(self) -> torch.Tensor: _resized_mask = interpolate( _resized_mask, size=(y_nodes_interpolate, x_nodes_interpolate), - mode=self.mode + mode=self.mode, ) # delete added dimensions resized_mask = _resized_mask.squeeze(0).squeeze(0) if resized_mask.size() < self.mask.size(): - warnings.warn(f"New mask size {resized_mask.size()} is smaller than the original one {self.mask.size()}! ") + warnings.warn( + f"New mask size {resized_mask.size()} is smaller than the original one {self.mask.size()}! " + ) return resized_mask @property def transmission_function(self) -> torch.Tensor: + """ + Calculates the transmission function of the hologram. + + This method generates a phase mask based on the resized mask and then + computes the transmission function as the exponential of the imaginary + phase mask. + + Args: + None + + Returns: + torch.Tensor: The calculated transmission function. + """ _aperture = self.get_aperture _resized_mask = self.resized_mask indices = ( slice(self.bottom_boundary, self.top_boundary + 1), - slice(self.left_boundary, self.right_boundary + 1) + slice(self.left_boundary, self.right_boundary + 1), ) _phase_mask = _aperture.clone() @@ -155,26 +213,43 @@ def transmission_function(self) -> torch.Tensor: _phase_mask[indices] = quantized_mask - transmission_function = torch.exp( - 1j * _phase_mask - ) + transmission_function = torch.exp(1j * _phase_mask) return transmission_function def forward(self, incident_wavefront: Wavefront) -> Wavefront: + """ + Applies the transmission function to an incoming wavefront. + + Args: + incident_wavefront: The input wavefront representing the incident wave. + + Returns: + Wavefront: The resulting wavefront after applying the transmission + function, modified according to the simulation parameters. + """ return mul( - incident_wavefront, - self.transmission_function, - ('H', 'W'), - self.simulation_parameters - ) + incident_wavefront, + self.transmission_function, + ("H", "W"), + self.simulation_parameters, + ) def reverse(self, transmission_wavefront: Wavefront) -> Wavefront: + """ + Reverses the effect of the transmission function on a wavefront. + + Args: + transmission_wavefront: The wavefront to be reversed. + + Returns: + Wavefront: The reversed wavefront. + """ return mul( - transmission_wavefront, - torch.conj(self.transmission_function), - ('H', 'W'), - self.simulation_parameters - ) + transmission_wavefront, + torch.conj(self.transmission_function), + ("H", "W"), + self.simulation_parameters, + ) diff --git a/svetlanna/logging.py b/svetlanna/logging.py index e22f1c6..d8f7f8b 100644 --- a/svetlanna/logging.py +++ b/svetlanna/logging.py @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) __handles: None | tuple[RemovableHandle, ...] = None -__logging_type: Literal['logging', 'print'] = 'print' +__logging_type: Literal["logging", "print"] = "print" def agr_short_description(arg: Any) -> str: @@ -30,15 +30,24 @@ def agr_short_description(arg: Any) -> str: description """ if isinstance(arg, Tensor): - return f'{type(arg)} shape={arg.shape}, dtype={arg.dtype}, device={arg.device}' + return f"{type(arg)} shape={arg.shape}, dtype={arg.dtype}, device={arg.device}" else: - return f'{type(arg)}' + return f"{type(arg)}" def log_message(message: str): - if __logging_type == 'logging': + """ + Logs a debug message using the configured logging method. + + Args: + message: The message to be logged. + + Returns: + None + """ + if __logging_type == "logging": logger.debug(message) - elif __logging_type == 'print': + elif __logging_type == "print": print(message) @@ -47,41 +56,50 @@ def forward_logging_hook(module, input, output) -> None: if not isinstance(module, Element): return - args_info = '' + args_info = "" # cast inputs and outputs to tuples input = (input,) if not isinstance(input, tuple) else input output = (output,) if not isinstance(output, tuple) else output for i, _input in enumerate(input): - args_info += f'\n input {i}: {agr_short_description(_input)}' + args_info += f"\n input {i}: {agr_short_description(_input)}" for i, _output in enumerate(output): - args_info += f'\n output {i}: {agr_short_description(_output)}' + args_info += f"\n output {i}: {agr_short_description(_output)}" - log_message( - f'The forward method of {module._get_name()} was computed{args_info}' - ) + log_message(f"The forward method of {module._get_name()} was computed{args_info}") def register_logging_hook( - module, name, value, - type: Literal['Parameter', 'Buffer', 'Module'] + module, name, value, type: Literal["Parameter", "Buffer", "Module"] ) -> None: + """ + Registers a logging hook for a given module attribute. + + This method logs information about the registered attribute (parameter, buffer, or module) + to provide visibility into the model's configuration and state. It only operates on modules that are instances of Element. + + Args: + module: The module to which the attribute belongs. + name: The name of the attribute being registered. + value: The value of the attribute. + type: The type of the attribute ('Parameter', 'Buffer', or 'Module'). + + Returns: + None + """ if not isinstance(module, Element): return - value_info = f'\n {agr_short_description(value)}' + value_info = f"\n {agr_short_description(value)}" log_message( - f'{type} of {module._get_name()} was registered with name {name}:{value_info}' + f"{type} of {module._get_name()} was registered with name {name}:{value_info}" ) -def set_debug_logging( - mode: bool, - type: Literal['logging', 'print'] = 'print' -): +def set_debug_logging(mode: bool, type: Literal["logging", "print"] = "print"): """Enables and disables debug logging. If type is `'print'`, then messages are printed using `print`, if type is `'logging'` the messages are written in the logger @@ -97,27 +115,23 @@ def set_debug_logging( global __handles global __logging_type - if type not in ('logging', 'print'): - raise ValueError( - f"Logging type should be 'logging' or 'print, not {type}" - ) + if type not in ("logging", "print"): + raise ValueError(f"Logging type should be 'logging' or 'print, not {type}") __logging_type = type if mode: if __handles is None: __handles = ( - register_module_forward_hook( - forward_logging_hook - ), + register_module_forward_hook(forward_logging_hook), register_module_parameter_registration_hook( - partial(register_logging_hook, type='Parameter') + partial(register_logging_hook, type="Parameter") ), register_module_buffer_registration_hook( - partial(register_logging_hook, type='Buffer') + partial(register_logging_hook, type="Buffer") ), register_module_module_registration_hook( - partial(register_logging_hook, type='Module') - ) + partial(register_logging_hook, type="Module") + ), ) else: if __handles is not None: diff --git a/svetlanna/networks/autoencoder.py b/svetlanna/networks/autoencoder.py index 16604f0..dde4779 100644 --- a/svetlanna/networks/autoencoder.py +++ b/svetlanna/networks/autoencoder.py @@ -3,6 +3,7 @@ from torch import nn from svetlanna import Wavefront, SimulationParameters from svetlanna.elements import Element + # for visualisation: from svetlanna import LinearOpticalSetup from svetlanna.specs import ParameterSpecs, SubelementSpecs @@ -19,12 +20,12 @@ class LinearAutoencoder(nn.Module): """ def __init__( - self, - sim_params: SimulationParameters, - encoder_elements_list: list[Element] | Iterable[Element], - decoder_elements_list: list[Element] | Iterable[Element], - to_return: Literal['wf', 'amps'] = 'wf', - device: str | torch.device = torch.get_default_device(), + self, + sim_params: SimulationParameters, + encoder_elements_list: list[Element] | Iterable[Element], + decoder_elements_list: list[Element] | Iterable[Element], + to_return: Literal["wf", "amps"] = "wf", + device: str | torch.device = torch.get_default_device(), ): """ Parameters @@ -46,7 +47,7 @@ def __init__( self.sim_params = sim_params self.h, self.w = self.sim_params.axes_size( - axs=('H', 'W') + axs=("H", "W") ) # height and width for a Wavefronts self.__device = torch.device(device) @@ -68,9 +69,9 @@ def encode(self, wavefront_in): wavefront_encoded : Wavefront An encoded input wavefront. """ - if self.to_return == 'wf': + if self.to_return == "wf": return self.encoder(wavefront_in) - if self.to_return == 'amps': + if self.to_return == "amps": return self.encoder(wavefront_in).abs() + 0j def decode(self, wavefront_encoded): @@ -82,9 +83,9 @@ def decode(self, wavefront_encoded): wavefront_decoded : Wavefront A decoded wavefront. """ - if self.to_return == 'wf': + if self.to_return == "wf": return self.decoder(wavefront_encoded) - if self.to_return == 'amps': + if self.to_return == "amps": return self.decoder(wavefront_encoded).abs() + 0j def forward(self, wavefront_in): @@ -106,18 +107,23 @@ def forward(self, wavefront_in): return wavefront_encoded, wavefront_decoded def to_specs(self) -> Iterable[ParameterSpecs | SubelementSpecs]: + """ + Returns the encoder and decoder specifications. + + Args: + None + + Returns: + Iterable[ParameterSpecs | SubelementSpecs]: An iterable containing + SubelementSpecs for the encoder and decoder, each holding a + LinearOpticalSetup representing their respective elements. + """ return ( - SubelementSpecs( - 'Encoder', - LinearOpticalSetup(self.encoder_elements) - ), - SubelementSpecs( - 'Decoder', - LinearOpticalSetup(self.decoder_elements) - ), + SubelementSpecs("Encoder", LinearOpticalSetup(self.encoder_elements)), + SubelementSpecs("Decoder", LinearOpticalSetup(self.decoder_elements)), ) - def to(self, device: str | torch.device | int) -> 'LinearAutoencoder': + def to(self, device: str | torch.device | int) -> "LinearAutoencoder": if self.__device == torch.device(device): return self @@ -131,4 +137,13 @@ def to(self, device: str | torch.device | int) -> 'LinearAutoencoder': @property def device(self) -> str | torch.device | int: + """ + Returns the device on which the model is located. + + Args: + None + + Returns: + The device as a string, torch.device object, or integer. + """ return self.__device diff --git a/svetlanna/networks/diffractive_conv.py b/svetlanna/networks/diffractive_conv.py index aa5db5c..9381e1c 100644 --- a/svetlanna/networks/diffractive_conv.py +++ b/svetlanna/networks/diffractive_conv.py @@ -4,24 +4,27 @@ from torch import nn from svetlanna import Wavefront, SimulationParameters, ConstrainedParameter from svetlanna import elements + # for visualisation: from svetlanna import LinearOpticalSetup from svetlanna.specs import ParameterSpecs, SubelementSpecs + class ConvLayer4F(nn.Module): """ Diffractive convolutional layer based on a 4f system. """ + # TODO: Add a custom aperture (defined by a mask) before a DiffractiveLayer? def __init__( - self, - sim_params: SimulationParameters, - focal_length: float, - conv_diffractive_mask: torch.Tensor, - learnable_mask: bool = False, - max_phase: float = 2 * torch.pi, - fs_method: Literal['fresnel', 'AS'] = 'AS', + self, + sim_params: SimulationParameters, + focal_length: float, + conv_diffractive_mask: torch.Tensor, + learnable_mask: bool = False, + max_phase: float = 2 * torch.pi, + fs_method: Literal["fresnel", "AS"] = "AS", ): """ Parameters @@ -62,7 +65,7 @@ def get_free_space(self): return elements.FreeSpace( simulation_parameters=self.sim_params, distance=self.focal_length, # distance is not learnable! - method=self.fs_method + method=self.fs_method, ) def get_thin_lens(self): @@ -83,9 +86,7 @@ def get_diffractive_layer(self): diff_layer = elements.DiffractiveLayer( simulation_parameters=self.sim_params, mask=ConstrainedParameter( - self.conv_diffractive_mask, - min_value=0, - max_value=self.max_phase + self.conv_diffractive_mask, min_value=0, max_value=self.max_phase ), ) else: @@ -97,13 +98,23 @@ def get_diffractive_layer(self): return diff_layer def get_conv_layer_4f(self): + """ + Constructs a 4f system using sequential layers. + + This method creates a sequence of optical elements – free spaces and thin lenses – + arranged in a 4f configuration, including a diffractive layer for convolution + in the Fourier plane. + + Returns: + nn.Sequential: A PyTorch Sequential model representing the 4f system. + """ system_elements = [ self.get_free_space(), # <-- F - self.get_thin_lens(), # <-- ThinLens + self.get_thin_lens(), # <-- ThinLens self.get_free_space(), # <-- F self.get_diffractive_layer(), # <-- convolution in a Fourier plane self.get_free_space(), # <-- F - self.get_thin_lens(), # <-- ThinLens + self.get_thin_lens(), # <-- ThinLens self.get_free_space(), # <-- F ] return nn.Sequential(*system_elements) @@ -132,15 +143,15 @@ class ConvDiffNetwork4F(nn.Module): """ def __init__( - self, - sim_params: SimulationParameters, - network_elements_list: list, - focal_length: float, - conv_phase_mask: torch.Tensor, - learnable_mask: bool = False, - max_phase: float = 2 * torch.pi, - fs_method: Literal['fresnel', 'AS'] = 'AS', - device: str | torch.device = torch.get_default_device(), + self, + sim_params: SimulationParameters, + network_elements_list: list, + focal_length: float, + conv_phase_mask: torch.Tensor, + learnable_mask: bool = False, + max_phase: float = 2 * torch.pi, + fs_method: Literal["fresnel", "AS"] = "AS", + device: str | torch.device = torch.get_default_device(), ): """ Parameters @@ -183,12 +194,14 @@ def __init__( conv_diffractive_mask=self.conv_phase_mask, learnable_mask=self.learnable_mask, max_phase=self.max_phase, - fs_method=self.fs_method + fs_method=self.fs_method, ).to(self.__device) # PART OF THE NETWORK AFTER A 4F CONVOLUTION self.network_elements_list = network_elements_list - self.net_after_conv = nn.Sequential(*self.network_elements_list).to(self.__device) + self.net_after_conv = nn.Sequential(*self.network_elements_list).to( + self.__device + ) def forward(self, wavefront_in): """ @@ -210,18 +223,28 @@ def forward(self, wavefront_in): return result def to_specs(self) -> Iterable[ParameterSpecs | SubelementSpecs]: + """ + Returns the specs for the 4F convolution system and linear setup. + + Args: + None + + Returns: + Iterable[ParameterSpecs | SubelementSpecs]: An iterable containing + SubelementSpecs objects representing the 4F Convolution System and + Linear Setup, each with its corresponding LinearOpticalSetup. + """ return ( SubelementSpecs( - '4F Convolution System', - LinearOpticalSetup(list(self.conv_layer.conv_layer_4f)) + "4F Convolution System", + LinearOpticalSetup(list(self.conv_layer.conv_layer_4f)), ), SubelementSpecs( - 'Linear Setup', - LinearOpticalSetup(list(self.net_after_conv)) + "Linear Setup", LinearOpticalSetup(list(self.net_after_conv)) ), ) - def to(self, device: str | torch.device | int) -> 'ConvDiffNetwork4F': + def to(self, device: str | torch.device | int) -> "ConvDiffNetwork4F": if self.__device == torch.device(device): return self @@ -238,4 +261,13 @@ def to(self, device: str | torch.device | int) -> 'ConvDiffNetwork4F': @property def device(self) -> str | torch.device | int: + """ + Returns the device on which the model is located. + + Args: + None + + Returns: + The device as a string, torch.device object, or integer. + """ return self.__device diff --git a/svetlanna/networks/diffractive_rnn.py b/svetlanna/networks/diffractive_rnn.py index aa697f0..c1b07ac 100644 --- a/svetlanna/networks/diffractive_rnn.py +++ b/svetlanna/networks/diffractive_rnn.py @@ -3,6 +3,7 @@ import torch from torch import nn from svetlanna import Wavefront, SimulationParameters + # for visualisation: from svetlanna import LinearOpticalSetup from svetlanna.specs import ParameterSpecs, SubelementSpecs @@ -16,16 +17,16 @@ class DiffractiveRNN(nn.Module): """ def __init__( - self, - sim_params: SimulationParameters, - sequence_len: int, - fusing_coeff: float, - read_in_layer: nn.Sequential, - memory_layer: nn.Sequential, - hidden_forward_layer: nn.Sequential, - read_out_layer: nn.Sequential, - detector_layer: nn.Sequential, - device: str | torch.device = torch.get_default_device(), + self, + sim_params: SimulationParameters, + sequence_len: int, + fusing_coeff: float, + read_in_layer: nn.Sequential, + memory_layer: nn.Sequential, + hidden_forward_layer: nn.Sequential, + read_out_layer: nn.Sequential, + detector_layer: nn.Sequential, + device: str | torch.device = torch.get_default_device(), ): """ sim_params: SimulationParameters @@ -53,7 +54,7 @@ def __init__( self.sim_params = sim_params self.h, self.w = self.sim_params.axes_size( - axs=('H', 'W') + axs=("H", "W") ) # height and width for a wavefronts self.__device = torch.device(device) @@ -79,18 +80,14 @@ def forward(self, subsequence_wf: Wavefront): if len(subsequence_wf.shape) > 3: # if a batch is an input batch_flag = True bs = subsequence_wf.shape[0] - h_prev = Wavefront( - torch.zeros( - size=(bs, self.h, self.w) - ) - ).to(self.__device) # h_{t - 1} - reset hidden for the first input + h_prev = Wavefront(torch.zeros(size=(bs, self.h, self.w))).to( + self.__device + ) # h_{t - 1} - reset hidden for the first input else: batch_flag = False - h_prev = Wavefront( - torch.zeros( - size=(self.h, self.w) - ) - ).to(self.__device) # h_{t - 1} - reset hidden for the first input + h_prev = Wavefront(torch.zeros(size=(self.h, self.w))).to( + self.__device + ) # h_{t - 1} - reset hidden for the first input for frame_ind in range(self.sequence_len): if batch_flag: @@ -110,44 +107,61 @@ def forward(self, subsequence_wf: Wavefront): return out def to_specs(self) -> Iterable[ParameterSpecs | SubelementSpecs]: + """ + Returns a collection of parameter and subelement specifications. + + Args: + None + + Returns: + Iterable[ParameterSpecs | SubelementSpecs]: An iterable yielding + ParameterSpecs for parameters like sequence length and fusing coefficient, + and SubelementSpecs for the different layers within the optical setup + (read-in, memory, hidden forward, and read-out). + """ return ( - ParameterSpecs('sequence_len', ( - PrettyReprRepr(self.sequence_len), - )), - ParameterSpecs('fusing_coeff', ( - PrettyReprRepr(self.fusing_coeff), - )), + ParameterSpecs("sequence_len", (PrettyReprRepr(self.sequence_len),)), + ParameterSpecs("fusing_coeff", (PrettyReprRepr(self.fusing_coeff),)), SubelementSpecs( - 'Read-in Layer', - LinearOpticalSetup(list(self.read_in_layer)) + "Read-in Layer", LinearOpticalSetup(list(self.read_in_layer)) ), SubelementSpecs( - 'Memory Layer', - LinearOpticalSetup(list(self.memory_layer)) + "Memory Layer", LinearOpticalSetup(list(self.memory_layer)) ), SubelementSpecs( - 'Hidden Forward Layer', - LinearOpticalSetup(list(self.hidden_forward_layer)) + "Hidden Forward Layer", + LinearOpticalSetup(list(self.hidden_forward_layer)), ), SubelementSpecs( - 'Read-out Layer', - LinearOpticalSetup(list(self.read_out_layer)) + "Read-out Layer", LinearOpticalSetup(list(self.read_out_layer)) ), ) - def to(self, device: str | torch.device | int) -> 'DiffractiveRNN': + def to(self, device: str | torch.device | int) -> "DiffractiveRNN": if self.__device == torch.device(device): return self return DiffractiveRNN( sim_params=self.sim_params, - sequence_len=self.sequence_len, fusing_coeff=self.fusing_coeff, - read_in_layer=self.read_in_layer, memory_layer=self.memory_layer, + sequence_len=self.sequence_len, + fusing_coeff=self.fusing_coeff, + read_in_layer=self.read_in_layer, + memory_layer=self.memory_layer, hidden_forward_layer=self.hidden_forward_layer, - read_out_layer=self.read_out_layer, detector_layer=self.detector_layer, + read_out_layer=self.read_out_layer, + detector_layer=self.detector_layer, device=device, ) @property def device(self) -> str | torch.device | int: + """ + Returns the device on which tensors are allocated. + + Args: + None + + Returns: + The device string, torch.device object or integer representing the device. + """ return self.__device diff --git a/svetlanna/parameters.py b/svetlanna/parameters.py index e72e810..6d2f147 100644 --- a/svetlanna/parameters.py +++ b/svetlanna/parameters.py @@ -5,18 +5,25 @@ class InnerParameterStorageModule(torch.nn.Module): - def __init__( - self, - params_to_store: dict[str, torch.Tensor | torch.nn.Parameter] - ): + """ + Stores parameters (PyTorch tensors or nn.Parameters) for later use.""" + + def __init__(self, params_to_store: dict[str, torch.Tensor | torch.nn.Parameter]): + """ + Initializes the module with parameters to store. + + Args: + params_to_store: A dictionary where keys are parameter names and values + are PyTorch tensors or nn.Parameters to be stored. + + Returns: + None + """ super().__init__() self.params_to_store = {} self.expand(params_to_store) - def expand( - self, - params_to_store: dict[str, torch.Tensor | torch.nn.Parameter] - ): + def expand(self, params_to_store: dict[str, torch.Tensor | torch.nn.Parameter]): """Add more parameters to the storage Parameters @@ -31,9 +38,9 @@ def expand( self.register_buffer(name, value) else: raise TypeError( - 'Parameters should be instances of either torch.Tensor ' - 'or torch.nn.Parameter. ' - 'The type {type(value)} of {name} is not compatible.' + "Parameters should be instances of either torch.Tensor " + "or torch.nn.Parameter. " + "The type {type(value)} of {name} is not compatible." ) self.params_to_store[name] = value @@ -42,16 +49,26 @@ class Parameter(torch.Tensor): """`torch.Parameter` replacement. Added for further feature enrichment. """ + @staticmethod def __new__(cls, *args, **kwargs): + """ + Creates a new instance of the class. + + This method overrides the default object creation process to ensure proper + initialization, leveraging the superclass's __new__ method for Parameter objects. + + Args: + *args: Positional arguments passed to the constructor. + **kwargs: Keyword arguments passed to the constructor. + + Returns: + Parameter: A new instance of the class. + """ # see https://github.com/albanD/subclass_zoo/blob/ec47458346c2a1cfcd5e676926a4bbc6709ff62e/base_tensor.py # noqa: E501 return super(cls, Parameter).__new__(cls) - def __init__( - self, - data: Any, - requires_grad: bool = True - ): + def __init__(self, data: Any, requires_grad: bool = True): """ Parameters ---------- @@ -67,17 +84,30 @@ def __init__( # real parameter that should be optimized self.inner_parameter = torch.nn.Parameter( - data=data, - requires_grad=requires_grad + data=data, requires_grad=requires_grad ) self.inner_storage = InnerParameterStorageModule( - { - 'inner_parameter': self.inner_parameter - } + {"inner_parameter": self.inner_parameter} ) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): + """ + Calls a function with inner parameters if applicable. + + This method is used to ensure that when calling functions within the PyTorch + ecosystem, instances of this class are replaced with their `inner_parameter` + attribute for calculations. This is part of extending the Torch Python API. + + Args: + func: The function to call. + types: The types associated with the function (not directly used in implementation). + args: Positional arguments to pass to the function. + kwargs: Keyword arguments to pass to the function. + + Returns: + The result of calling `func` with the modified arguments. + """ # see https://pytorch.org/docs/stable/notes/extending.html#extending-torch-python-api # noqa: E501 # real parameter should be used for any calculations, @@ -86,12 +116,23 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} kwargs = { - k: v.inner_parameter if isinstance(v, cls) else v for k, v in kwargs.items() # noqa: E501 + k: v.inner_parameter if isinstance(v, cls) else v + for k, v in kwargs.items() # noqa: E501 } args = (a.inner_parameter if isinstance(a, cls) else a for a in args) return func(*args, **kwargs) def __repr__(self, *args, **kwargs) -> str: + """ + Returns a string representation of the object. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + str: A string representation of the inner_parameter attribute. + """ return repr(self.inner_parameter) @@ -108,14 +149,28 @@ def sigmoid_inv(x: torch.Tensor) -> torch.Tensor: torch.Tensor the output tensor """ - return torch.log(x/(1-x)) + return torch.log(x / (1 - x)) class ConstrainedParameter(Parameter): - """Constrained parameter - """ + """Constrained parameter""" + @staticmethod def __new__(cls, *args, **kwargs): + """ + Creates a new instance of the class. + + This method overrides the default object creation process to ensure proper + initialization within the ConstrainedParameter context, leveraging the + torch.Tensor base class's instantiation logic. + + Args: + *args: Variable length argument list. Passed to super(). + **kwargs: Arbitrary keyword arguments. Passed to super(). + + Returns: + ConstrainedParameter: A new instance of the ConstrainedParameter class. + """ return super(torch.Tensor, ConstrainedParameter).__new__(cls) def __init__( @@ -125,7 +180,7 @@ def __init__( max_value: Any, bound_func: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, inv_bound_func: Callable[[torch.Tensor], torch.Tensor] = sigmoid_inv, - requires_grad: bool = True + requires_grad: bool = True, ): r""" Parameters @@ -159,10 +214,7 @@ def __init__( b = min_value # m initial_value = inv_bound_func((data - b) / a) - super().__init__( - data=initial_value, - requires_grad=requires_grad - ) + super().__init__(data=initial_value, requires_grad=requires_grad) self.min_value = min_value self.max_value = max_value @@ -171,8 +223,8 @@ def __init__( self.inner_storage.expand( { - 'a': a, - 'b': b, + "a": a, + "b": b, } ) @@ -187,21 +239,49 @@ def value(self) -> torch.Tensor: """ # for inner parameter value y: # x = (M-m) * bound_function( y ) + m = a * bound_function( y ) + b - return self.inner_storage.a * self.bound_func(self.inner_parameter) + self.inner_storage.b + return ( + self.inner_storage.a * self.bound_func(self.inner_parameter) + + self.inner_storage.b + ) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): + """ + Applies a function to the underlying values of instances. + + This method is used to ensure that operations are performed on the + actual data stored within instances of this class, rather than the + instances themselves, especially when dealing with nested structures. + + Args: + func: The function to apply. + types: A tuple representing the expected types. (Not directly used in implementation but likely relevant context) + args: Positional arguments to pass to the function. + kwargs: Keyword arguments to pass to the function. + + Returns: + The result of applying the function to the underlying values. + """ # the same as for Parameter class, `instance.value` should be used if kwargs is None: kwargs = {} - kwargs = { - k: v.value if isinstance(v, cls) else v for k, v in kwargs.items() - } + kwargs = {k: v.value if isinstance(v, cls) else v for k, v in kwargs.items()} args = (a.value if isinstance(a, cls) else a for a in args) return func(*args, **kwargs) def __repr__(self) -> str: - return f'Bounded parameter containing:\n{repr(self.value)}' + """ + Returns a string representation of the Bounded parameter. + + Args: + None + + Returns: + str: A string containing the value of the bounded parameter. + The format is 'Bounded parameter containing: + {repr(self.value)}'. + """ + return f"Bounded parameter containing:\n{repr(self.value)}" OptimizableFloat: TypeAlias = float | torch.Tensor | torch.nn.Parameter | Parameter diff --git a/svetlanna/phase_retrieval_problem/__init__.py b/svetlanna/phase_retrieval_problem/__init__.py index 354dccd..0e4d0bd 100644 --- a/svetlanna/phase_retrieval_problem/__init__.py +++ b/svetlanna/phase_retrieval_problem/__init__.py @@ -6,9 +6,9 @@ __all__ = [ - 'retrieve_phase', - 'gerchberg_saxton_algorithm', - 'hybrid_input_output', - 'PhaseRetrievalResult', - 'SetupLike' + "retrieve_phase", + "gerchberg_saxton_algorithm", + "hybrid_input_output", + "PhaseRetrievalResult", + "SetupLike", ] diff --git a/svetlanna/phase_retrieval_problem/algorithms.py b/svetlanna/phase_retrieval_problem/algorithms.py index 32573c4..7c12b3a 100644 --- a/svetlanna/phase_retrieval_problem/algorithms.py +++ b/svetlanna/phase_retrieval_problem/algorithms.py @@ -55,9 +55,7 @@ def gerchberg_saxton_algorithm( source_amplitude = torch.sqrt(source_intensity) target_amplitude = torch.sqrt(target_intensity) - input_field = source_amplitude * torch.exp( - 1j * initial_approximation - ) + input_field = source_amplitude * torch.exp(1j * initial_approximation) number_of_iterations = 0 @@ -65,13 +63,12 @@ def gerchberg_saxton_algorithm( # calculate the output field output_field = forward(input_field) output_field_phase = torch.angle(output_field) - output_intensity = output_field.abs()**2 + output_intensity = output_field.abs() ** 2 # calculate an error error = ( - torch.mean( - torch.abs(output_intensity - target_intensity) - ) / torch.max(target_intensity) + torch.mean(torch.abs(output_intensity - target_intensity)) + / torch.max(target_intensity) ).item() cost_func_evolution.append(error) @@ -82,25 +79,20 @@ def gerchberg_saxton_algorithm( if (target_phase is not None) and (target_region is not None): - updated_phase = target_phase * target_region + ( - 1. - target_region - ) * output_field_phase - - updated_output_field = target_amplitude * torch.exp( - 1j * updated_phase + updated_phase = ( + target_phase * target_region + + (1.0 - target_region) * output_field_phase ) + updated_output_field = target_amplitude * torch.exp(1j * updated_phase) + else: - updated_output_field = target_amplitude * torch.exp( - 1j * output_field_phase - ) + updated_output_field = target_amplitude * torch.exp(1j * output_field_phase) updated_input_field = reverse(updated_output_field) updated_input_field_phase = torch.angle(updated_input_field) - input_field = source_amplitude * torch.exp( - 1j * updated_input_field_phase - ) + input_field = source_amplitude * torch.exp(1j * updated_input_field_phase) number_of_iterations += 1 @@ -123,7 +115,7 @@ def hybrid_input_output( maxiter: int, target_phase: torch.Tensor | None = None, target_region: torch.Tensor | None = None, - constant_factor: float = 0.9 + constant_factor: float = 0.9, ) -> prr.PhaseRetrievalResult: """Hybrid Input-Output(HIO) algorithm for for solving the phase retrieval problem @@ -166,28 +158,23 @@ def hybrid_input_output( source_amplitude = torch.sqrt(source_intensity) target_amplitude = torch.sqrt(target_intensity) - input_field = source_amplitude * torch.exp( - 1j * initial_approximation - ) + input_field = source_amplitude * torch.exp(1j * initial_approximation) number_of_iterations = 0 - support_constrain = ( - target_amplitude > torch.max(target_amplitude) / 10 - ).float() + support_constrain = (target_amplitude > torch.max(target_amplitude) / 10).float() while True: # calculate the output field output_field = forward(input_field) output_field_phase = torch.angle(output_field) - output_intensity = output_field.abs()**2 + output_intensity = output_field.abs() ** 2 # calculate an error error = ( - torch.mean( - torch.abs(output_intensity - target_intensity) - ) / torch.max(target_intensity) + torch.mean(torch.abs(output_intensity - target_intensity)) + / torch.max(target_intensity) ).item() cost_func_evolution.append(error) @@ -198,26 +185,27 @@ def hybrid_input_output( if (target_phase is not None) and (target_region is not None): - updated_phase = target_phase * target_region + ( - 1. - target_region - ) * output_field_phase + updated_phase = ( + target_phase * target_region + + (1.0 - target_region) * output_field_phase + ) - updated_output_field = target_amplitude * torch.exp( - 1j * updated_phase - ) - (1. - support_constrain) * constant_factor * output_field + updated_output_field = ( + target_amplitude * torch.exp(1j * updated_phase) + - (1.0 - support_constrain) * constant_factor * output_field + ) else: - updated_output_field = target_amplitude * torch.exp( - 1j * output_field_phase - ) - (1. - support_constrain) * constant_factor * output_field + updated_output_field = ( + target_amplitude * torch.exp(1j * output_field_phase) + - (1.0 - support_constrain) * constant_factor * output_field + ) updated_input_field = reverse(updated_output_field) updated_input_field_phase = torch.angle(updated_input_field) - input_field = source_amplitude * torch.exp( - 1j * updated_input_field_phase - ) + input_field = source_amplitude * torch.exp(1j * updated_input_field_phase) number_of_iterations += 1 diff --git a/svetlanna/phase_retrieval_problem/phase_retrieval.py b/svetlanna/phase_retrieval_problem/phase_retrieval.py index 62d73ca..378e2fe 100644 --- a/svetlanna/phase_retrieval_problem/phase_retrieval.py +++ b/svetlanna/phase_retrieval_problem/phase_retrieval.py @@ -17,21 +17,33 @@ class SetupLike(Protocol): Protocol : _type_ _description_ """ - def forward(self, input_field: torch.Tensor) -> torch.Tensor: - ... - def reverse(self, transmission_field: torch.Tensor) -> torch.Tensor: - ... + def forward(self, input_field: torch.Tensor) -> torch.Tensor: ... + + def reverse(self, transmission_field: torch.Tensor) -> torch.Tensor: ... class AlgorithmOptions(TypedDict, total=False): + """ + Options for controlling the behavior of an algorithm. + + This class encapsulates various parameters that can be used to tune + the performance and accuracy of an optimization or iterative algorithm. + + Attributes: + tol: Tolerance for convergence criteria. Algorithm stops when change is less than this value. + maxiter: Maximum number of iterations allowed. + constant_factor: A constant factor used in calculations within the algorithm. + disp: Whether to display detailed information during execution. + """ + tol: float maxiter: int constant_factor: float disp: bool -Method: TypeAlias = Literal['GS', 'HIO'] +Method: TypeAlias = Literal["GS", "HIO"] @overload @@ -41,10 +53,9 @@ def retrieve_phase( target_intensity: torch.Tensor, *, initial_phase: torch.Tensor | None = None, - method: Method = 'GS', + method: Method = "GS", options: AlgorithmOptions | None = None -) -> prr.PhaseRetrievalResult: - ... +) -> prr.PhaseRetrievalResult: ... @overload @@ -56,10 +67,9 @@ def retrieve_phase( target_region: torch.Tensor, *, initial_phase: torch.Tensor | None = None, - method: Method = 'GS', + method: Method = "GS", options: AlgorithmOptions | None = None -) -> prr.PhaseRetrievalResult: - ... +) -> prr.PhaseRetrievalResult: ... def retrieve_phase( @@ -70,7 +80,7 @@ def retrieve_phase( target_region: torch.Tensor | None = None, *, initial_phase: torch.Tensor | None = None, - method: Method = 'GS', + method: Method = "GS", options: AlgorithmOptions | None = None ) -> prr.PhaseRetrievalResult: """Function for solving phase retrieval problem: generating target @@ -123,12 +133,12 @@ def retrieve_phase( # read options if options is None: options = {} - tol = options.get('tol', 1e-16) - maxiter = options.get('maxiter', 100) - constant_factor = options.get('constant_factor', 0.9) - disp = options.get('disp', False) + tol = options.get("tol", 1e-16) + maxiter = options.get("maxiter", 100) + constant_factor = options.get("constant_factor", 0.9) + disp = options.get("disp", False) - if method == 'GS': + if method == "GS": result = algorithms.gerchberg_saxton_algorithm( target_intensity=target_intensity, @@ -139,9 +149,9 @@ def retrieve_phase( tol=tol, maxiter=maxiter, target_phase=target_phase, - target_region=target_region + target_region=target_region, ) - elif method == 'HIO': + elif method == "HIO": result = algorithms.hybrid_input_output( target_intensity=target_intensity, source_intensity=source_intensity, @@ -152,18 +162,18 @@ def retrieve_phase( maxiter=maxiter, target_phase=target_phase, target_region=target_region, - constant_factor=constant_factor + constant_factor=constant_factor, ) else: - raise ValueError('Unknown optimization method') + raise ValueError("Unknown optimization method") if disp: if (target_phase is not None) & (target_region is not None): - print('Type of problem: phase reconstruction') + print("Type of problem: phase reconstruction") else: - print('Type of problem: generate intensity profile') - print('Method:' + str(method)) - print('Current cost function value:' + str(result.cost_func)) - print('Number of iteration:' + str(result.number_of_iterations)) + print("Type of problem: generate intensity profile") + print("Method:" + str(method)) + print("Current cost function value:" + str(result.cost_func)) + print("Number of iteration:" + str(result.number_of_iterations)) return result diff --git a/svetlanna/phase_retrieval_problem/phase_retrieval_result.py b/svetlanna/phase_retrieval_problem/phase_retrieval_result.py index c1d94d8..ad0bd76 100644 --- a/svetlanna/phase_retrieval_problem/phase_retrieval_result.py +++ b/svetlanna/phase_retrieval_problem/phase_retrieval_result.py @@ -5,8 +5,7 @@ # TODO: ask for message and status code @dataclass(frozen=True, slots=True) class PhaseRetrievalResult: - """Represents the phase retrieval result - """ + """Represents the phase retrieval result""" solution: torch.Tensor cost_func: float diff --git a/svetlanna/setup.py b/svetlanna/setup.py index 6bd861b..285ba6c 100644 --- a/svetlanna/setup.py +++ b/svetlanna/setup.py @@ -11,6 +11,7 @@ class LinearOpticalSetup(nn.Module): """ A linear optical network composed of Element's """ + def __init__(self, elements: Iterable[Element]) -> None: """ Parameters @@ -37,7 +38,7 @@ def check_sim_params(element: Element) -> bool: "the same SimulationParameters instance." ) - if all((hasattr(el, 'reverse') for el in self.elements)): + if all((hasattr(el, "reverse") for el in self.elements)): class ReverseNet(nn.Module): def forward(self, Ein: Tensor) -> Tensor: @@ -88,16 +89,16 @@ def stepwise_forward(self, input_wavefront: Tensor): # list of wavefronts while propagation of an initial wavefront through the system steps_wavefront = [this_wavefront] # input wavefront is a zeroth step - optical_scheme = '' # string that represents a linear optical setup (schematic) + optical_scheme = "" # string that represents a linear optical setup (schematic) self.net.eval() for ind_element, element in enumerate(self.net): # for visualization in a console element_name = type(element).__name__ - optical_scheme += f'-({ind_element})-> [{ind_element + 1}. {element_name}] ' + optical_scheme += f"-({ind_element})-> [{ind_element + 1}. {element_name}] " # TODO: Replace len(...) with something for Iterable? if ind_element == len(self.net) - 1: - optical_scheme += f'-({ind_element + 1})->' + optical_scheme += f"-({ind_element + 1})->" # element forward this_wavefront = element.forward(this_wavefront) steps_wavefront.append(this_wavefront) # add a wavefront to list of steps @@ -105,25 +106,57 @@ def stepwise_forward(self, input_wavefront: Tensor): return optical_scheme, steps_wavefront def reverse(self, Ein: Tensor) -> Tensor: + """ + Reverses the input tensor using a pre-defined reverse network. + + Args: + Ein: The input tensor to be reversed. + + Returns: + Tensor: The reversed tensor if a reverse network is available. + + Raises: + TypeError: If no reverse network is defined, indicating that + reverse propagation is not possible for this element. + """ if self._reverse_net is not None: return self._reverse_net(Ein) raise TypeError( - 'Reverse propagation is impossible. ' - 'All elements should have reverse method.' + "Reverse propagation is impossible. " + "All elements should have reverse method." ) def to_specs(self) -> Iterable[ParameterSpecs | SubelementSpecs]: + """ + Converts the elements into a sequence of subelement specifications. + + Args: + None + + Returns: + Iterable[ParameterSpecs | SubelementSpecs]: An iterable yielding + SubelementSpecs for each element in the 'elements' list, with an index. + """ return ( SubelementSpecs(str(i), element) for i, element in enumerate(self.elements) ) @staticmethod def _widget_html_( - index: int, - name: str, - element_type: str | None, - subelements: list[ElementHTML] + index: int, name: str, element_type: str | None, subelements: list[ElementHTML] ) -> str: - return jinja_env.get_template('widget_linear_setup.html.jinja').render( + """ + Renders the HTML for a widget using a Jinja template. + + Args: + index: The index of the widget. + name: The name of the widget. + element_type: The type of element (optional). + subelements: A list of sub-elements to include in the widget. + + Returns: + str: The rendered HTML string for the widget. + """ + return jinja_env.get_template("widget_linear_setup.html.jinja").render( index=index, name=name, subelements=subelements ) diff --git a/svetlanna/simulation_parameters.py b/svetlanna/simulation_parameters.py index 2b3aa51..d794057 100644 --- a/svetlanna/simulation_parameters.py +++ b/svetlanna/simulation_parameters.py @@ -4,30 +4,43 @@ class AxisNotFound(Exception): + """ + Raised when an axis is requested that does not exist.""" + pass _AXES_INNER_ATTRS = tuple( - f'_Axes{i}' for i in ('__axes_dict', '__names', '__names_inversed') + f"_Axes{i}" for i in ("__axes_dict", "__names", "__names_inversed") ) class Axes: """Axes storage""" + def __init__(self, axes: dict[str, torch.Tensor]) -> None: + """ + Initializes the AxisInfo object with a dictionary of axes. + + Args: + axes: A dictionary where keys are axis names (e.g., 'W', 'H') and + values are PyTorch tensors representing the axis values. Must contain + 'W', 'H', and 'wavelength'. + + Returns: + None + """ # TODO: set default values for the new axis if needed (ex. pol = 0) # check if required axes are presented - required_axes = ( - 'W', 'H', 'wavelength' - ) + required_axes = ("W", "H", "wavelength") if not all(name in axes.keys() for name in required_axes): raise ValueError("Axes 'W', 'H', and 'wavelength' are required!") # check if W and H axes are 1-d - if not len(axes['W'].shape) == 1: + if not len(axes["W"].shape) == 1: raise ValueError("'W' axis should be 1-dimensional") - if not len(axes['H'].shape) == 1: + if not len(axes["H"].shape) == 1: raise ValueError("'H' axis should be 1-dimensional") # check if axes are 0- or 1-dimensional @@ -74,9 +87,24 @@ def index(self, name: str) -> int: """ if name in self.__names: return -self.__names_inversed.index(name) - 1 - raise AxisNotFound(f'Axis with name {name} does not exist.') + raise AxisNotFound(f"Axis with name {name} does not exist.") def __getattribute__(self, name: str) -> Any: + """ + Retrieves an attribute from the object. + + This method intercepts attribute access and first checks for inner attributes + defined in _AXES_INNER_ATTRS. If not found there, it looks within the internal + axes dictionary (__axes_dict). If still not found, it falls back to the + default attribute retrieval behavior of the superclass. + + Args: + name: The name of the attribute to retrieve. + + Returns: + The value of the attribute if found, otherwise the result of the + superclass's __getattribute__ method. + """ if name in _AXES_INNER_ATTRS: return super().__getattribute__(name) @@ -89,26 +117,76 @@ def __getattribute__(self, name: str) -> Any: return super().__getattribute__(name) def __setattr__(self, name: str, value: Any) -> None: + """ + Sets an attribute on the object. + + This method intercepts attribute assignments to handle special cases for + inner attributes and axes that already exist. It issues a warning if an + existing axis is being reassigned without modification. + + Args: + name: The name of the attribute to set. + value: The value to assign to the attribute. + + Returns: + None + """ if name in _AXES_INNER_ATTRS: return super().__setattr__(name, value) if name in self.__axes_dict: - warnings.warn(f'Axis {name} has not been changed') + warnings.warn(f"Axis {name} has not been changed") return super().__setattr__(name, value) def __getitem__(self, name: str) -> Any: + """ + Retrieves an axis by its name. + + Args: + name: The name of the axis to retrieve. + + Returns: + The axis object associated with the given name. + + Raises: + AxisNotFound: If no axis with the specified name exists. + """ axes = self.__axes_dict if name in axes: return axes[name] - raise AxisNotFound(f'Axis with name {name} does not exist.') + raise AxisNotFound(f"Axis with name {name} does not exist.") def __setitem__(self, name: str, value: Any) -> None: - raise RuntimeError('Axis can not be changed') + """ + Sets an item in the axis. + + This method is overridden to prevent modification of the axis after creation. + + Args: + name: The name of the item to set. + value: The value to assign to the item. + + Returns: + None + + Raises: + RuntimeError: Always raised, as axis items cannot be changed. + """ + raise RuntimeError("Axis can not be changed") def __dir__(self) -> Iterable[str]: + """ + Returns the names of the axes in the coordinate system. + + Args: + None + + Returns: + Iterable[str]: An iterable of strings representing the axis names. + """ return self.__axes_dict.keys() @@ -116,10 +194,19 @@ class SimulationParameters: """ A class which describes characteristic parameters of the system """ - def __init__( - self, - axes: dict[str, torch.Tensor | float] - ) -> None: + + def __init__(self, axes: dict[str, torch.Tensor | float]) -> None: + """ + Initializes the object with a dictionary of axes. + + Args: + axes: A dictionary where keys are axis names (strings) and values are either + PyTorch tensors or floats representing the axis values. All tensor values + must be on the same device. + + Returns: + None + """ device = None def value_to_tensor(x): @@ -128,7 +215,7 @@ def value_to_tensor(x): if device is None: device = x.device if x.device != device: - raise ValueError('All axes should be on the same device') + raise ValueError("All axes should be on the same device") return x return torch.tensor(x) @@ -146,6 +233,15 @@ def value_to_tensor(x): self.axes = Axes(self.__axes_dict) def __getitem__(self, axis: str) -> torch.Tensor: + """ + Returns the tensor associated with a given axis. + + Args: + axis: The name of the axis to retrieve. + + Returns: + torch.Tensor: The tensor corresponding to the specified axis. + """ return self.axes[axis] def meshgrid(self, x_axis: str, y_axis: str): @@ -167,10 +263,7 @@ def meshgrid(self, x_axis: str, y_axis: str): of the second axis (`y_axis`) and the second dimension corresponds to the cardinality of the first axis (`x_axis`). """ - a, b = torch.meshgrid( - self.axes[x_axis], self.axes[y_axis], - indexing='xy' - ) + a, b = torch.meshgrid(self.axes[x_axis], self.axes[y_axis], indexing="xy") return a.to(self.__device), b.to(self.__device) def axes_size(self, axs: Iterable[str]) -> torch.Size: @@ -205,7 +298,7 @@ def axes_size(self, axs: Iterable[str]) -> torch.Size: return torch.Size(sizes) - def to(self, device: str | torch.device | int) -> 'SimulationParameters': + def to(self, device: str | torch.device | int) -> "SimulationParameters": if self.__device == torch.device(device): return self @@ -216,6 +309,15 @@ def to(self, device: str | torch.device | int) -> 'SimulationParameters': @property def device(self) -> str | torch.device | int: + """ + Returns the device on which tensors are allocated. + + Args: + None + + Returns: + The device string, torch.device object or integer representing the device. + """ return self.__device # def check_wf(self, wf: 'Wavefront'): diff --git a/svetlanna/specs/__init__.py b/svetlanna/specs/__init__.py index 393fc5a..c3d4856 100644 --- a/svetlanna/specs/__init__.py +++ b/svetlanna/specs/__init__.py @@ -6,16 +6,16 @@ __all__ = [ - 'Representation', - 'StrRepresentation', - 'MarkdownRepresentation', - 'HTMLRepresentation', - 'ReprRepr', - 'ImageRepr', - 'NpyFileRepr', - 'PrettyReprRepr', - 'ParameterSpecs', - 'ParameterSaveContext', - 'Specsable', - 'SubelementSpecs' + "Representation", + "StrRepresentation", + "MarkdownRepresentation", + "HTMLRepresentation", + "ReprRepr", + "ImageRepr", + "NpyFileRepr", + "PrettyReprRepr", + "ParameterSpecs", + "ParameterSaveContext", + "Specsable", + "SubelementSpecs", ] diff --git a/svetlanna/specs/specs.py b/svetlanna/specs/specs.py index b65b595..4ef6510 100644 --- a/svetlanna/specs/specs.py +++ b/svetlanna/specs/specs.py @@ -16,11 +16,8 @@ class ParameterSaveContext: """Generates different context managers that can be used to write a parameter data to output stream or file. """ - def __init__( - self, - parameter_name: str, - directory: Path - ): + + def __init__(self, parameter_name: str, directory: Path): """ Parameters ---------- @@ -50,23 +47,18 @@ def get_new_filepath(self, extension: str) -> Path: Path relative path to the file """ - suffix = '.' + extension + suffix = "." + extension # calculate total number of created files with the same extension total_files = len( - list( - filter( - lambda f: f.suffix == suffix, - self._generated_files - ) - ) - ) + list(filter(lambda f: f.suffix == suffix, self._generated_files)) + ) # name of the new file ending with `_.` - file_name = self.parameter_name + f'_{total_files}' + file_name = self.parameter_name + f"_{total_files}" # save filepath of the file - filepath = Path(self._directory, file_name).with_suffix(suffix) + filepath = Path(self._directory, file_name).with_suffix(suffix) self._generated_files.append(filepath) # create a new folder for the file if there is none @@ -103,26 +95,22 @@ def file(self, filepath: Path) -> Generator[BufferedWriter, Any, None]: Generator[BufferedWriter, Any, None] Buffer """ - with open(filepath, mode='wb') as file: + with open(filepath, mode="wb") as file: yield file -ParameterSaveContext_ = TypeVar( - 'ParameterSaveContext_', - bound=ParameterSaveContext -) +ParameterSaveContext_ = TypeVar("ParameterSaveContext_", bound=ParameterSaveContext) class Representation(Generic[ParameterSaveContext_]): """Base class for a parameter representation""" + ... -class MarkdownRepresentation( - Representation[ParameterSaveContext_], - metaclass=ABCMeta -): +class MarkdownRepresentation(Representation[ParameterSaveContext_], metaclass=ABCMeta): """Representation that can be exported to markdown file""" + @abstractmethod def to_markdown(self, out: TextIO, context: ParameterSaveContext_) -> None: """Write the parameter related data to be shown in a markdown file. @@ -137,11 +125,9 @@ def to_markdown(self, out: TextIO, context: ParameterSaveContext_) -> None: """ -class StrRepresentation( - Representation[ParameterSaveContext_], - metaclass=ABCMeta -): +class StrRepresentation(Representation[ParameterSaveContext_], metaclass=ABCMeta): """Representation that can be exported in the text format""" + @abstractmethod def to_str(self, out: TextIO, context: ParameterSaveContext_) -> None: """Write the parameter related data to be shown as a plain text. @@ -156,11 +142,9 @@ def to_str(self, out: TextIO, context: ParameterSaveContext_) -> None: """ -class HTMLRepresentation( - Representation[ParameterSaveContext_], - metaclass=ABCMeta -): +class HTMLRepresentation(Representation[ParameterSaveContext_], metaclass=ABCMeta): """Representation that can be exported to the HTML""" + @abstractmethod def to_html(self, out: TextIO, context: ParameterSaveContext_) -> None: """Write the parameter related data to be shown in a HTML file. @@ -179,12 +163,13 @@ class ImageRepr(StrRepresentation, MarkdownRepresentation, HTMLRepresentation): """Representation of the parameter as an image. Image generation is based on the `pillow` package. """ + def __init__( self, value: Any, - mode: Literal['1', 'L', 'LA', 'I', 'P', 'RGB', 'RGBA'] = 'L', - format: str = 'png', - show_image: bool = True + mode: Literal["1", "L", "LA", "I", "P", "RGB", "RGBA"] = "L", + format: str = "png", + show_image: bool = True, ): """ Parameters @@ -203,11 +188,7 @@ def __init__( self.mode = mode self.show_image = show_image - def draw_image( - self, - context: ParameterSaveContext, - filepath: Path - ) -> Image.Image: + def draw_image(self, context: ParameterSaveContext, filepath: Path) -> Image.Image: """Draw image into the file, using `pillow` package. Parameters @@ -225,23 +206,53 @@ def draw_image( return image def to_str(self, out: TextIO, context: ParameterSaveContext): + """ + Saves the image to a file and writes a message to the output stream. + + Args: + out: The output stream to write to. + context: The parameter save context containing filepath information. + + Returns: + None + """ filepath = context.get_new_filepath(extension=self.format) self.draw_image(context=context, filepath=filepath) - out.write(f'The image is saved to {context.rel_filepath(filepath)}\n') + out.write(f"The image is saved to {context.rel_filepath(filepath)}\n") def to_markdown(self, out: TextIO, context: ParameterSaveContext): + """ + Saves the image and writes a Markdown string to the output. + + Args: + out: The output stream to write to. + context: The parameter save context containing file path information. + + Returns: + None + """ filepath = context.get_new_filepath(extension=self.format) rel_filepath = context.rel_filepath(filepath) self.draw_image(context=context, filepath=filepath) - out.write(f'The image is saved to `{rel_filepath}`\n') + out.write(f"The image is saved to `{rel_filepath}`\n") if self.show_image: - out.write(f'\n![{context.parameter_name}]({rel_filepath})\n\n') + out.write(f"\n![{context.parameter_name}]({rel_filepath})\n\n") def to_html(self, out: TextIO, context: ParameterSaveContext): + """ + Renders the image as an HTML tag. + + Args: + out: The output stream to write the HTML to. + context: The parameter save context. + + Returns: + None + """ image = Image.fromarray(self.value, mode=self.mode) @@ -263,6 +274,7 @@ class ReprRepr(StrRepresentation, MarkdownRepresentation, HTMLRepresentation): """Representation of the parameter as a plain text. The `__repr__` method is used to generate the text. """ + def __init__(self, value: Any): """ Parameters @@ -275,18 +287,51 @@ def __init__(self, value: Any): self.value = value def to_str(self, out: TextIO, context: ParameterSaveContext): - out.write(f'{repr(self.value)}\n') + """ + Writes the string representation of the value to the output stream. + + Args: + out: The output stream to write to. + context: The parameter save context. + + Returns: + None + """ + out.write(f"{repr(self.value)}\n") def to_markdown(self, out: TextIO, context: ParameterSaveContext): - out.write(f'```\n{repr(self.value)}\n```\n') + """ + Writes the value as a code block in Markdown format. + + Args: + out: The output stream to write to. + context: The parameter save context (unused). + + Returns: + None + """ + out.write(f"```\n{repr(self.value)}\n```\n") def to_html(self, out: TextIO, context: Any): + """ + Renders the object to an HTML representation. + + This method simply calls `to_str` to generate a string representation + and writes it to the provided output stream. + + Args: + out: The output stream to write the HTML to. + context: Additional context that might be needed during rendering. + + Returns: + None + """ self.to_str(out, context) class NpyFileRepr(StrRepresentation, MarkdownRepresentation): - """Representation of the parameter as a `.npy` file. - """ + """Representation of the parameter as a `.npy` file.""" + def __init__(self, value: ArrayLike): """ Parameters @@ -311,30 +356,47 @@ def save_to_file(self, context: ParameterSaveContext, filepath: Path): np.save(f, self.value) def to_str(self, out: TextIO, context: ParameterSaveContext): - filepath = context.get_new_filepath(extension='npy') + """ + Saves the numpy array to a file and writes a message to the output stream. + + Args: + out: The output stream to write to. + context: The parameter save context containing filepath information. + + Returns: + None + """ + filepath = context.get_new_filepath(extension="npy") rel_filepath = context.rel_filepath(filepath) self.save_to_file(context, filepath) - out.write(f'The numpy array is saved to {rel_filepath}\n') + out.write(f"The numpy array is saved to {rel_filepath}\n") def to_markdown(self, out: TextIO, context: ParameterSaveContext): - filepath = context.get_new_filepath(extension='npy') + """ + Saves the numpy array to a file and writes a markdown message indicating the filepath. + + Args: + out: The output stream to write the markdown message to. + context: The parameter save context providing filepath information. + + Returns: + None + """ + filepath = context.get_new_filepath(extension="npy") rel_filepath = context.rel_filepath(filepath) self.save_to_file(context, filepath) - out.write(f'The numpy array is saved to `{rel_filepath}`\n') + out.write(f"The numpy array is saved to `{rel_filepath}`\n") class PrettyReprRepr(ReprRepr, HTMLRepresentation): """Same as ReprRepr but with better handling of Parameters and BoundedParameter""" - def __init__( - self, - value: Any, - units: str | None = None - ): + + def __init__(self, value: Any, units: str | None = None): """Representation of the parameter as a plain text. The `__repr__` method is used to generate the text if the `value` is not `torch.Tensor` or `Parameter`. @@ -350,7 +412,18 @@ def __init__( self.units = units def _repr(self) -> str: - units_suffix = '' if self.units is None else f' [{self.units}]' + """ + Returns a string representation of the object. + + Args: + None + + Returns: + str: A string representing the object's value, including its class name, + units (if any), and shape (for tensors). For constrained parameters, + it includes min/max values. + """ + units_suffix = "" if self.units is None else f" [{self.units}]" class_name: str = self.value.__class__.__name__ if isinstance(self.value, torch.Tensor): @@ -368,43 +441,71 @@ def _repr(self) -> str: min_val = self.value.min_value.item() max_val = self.value.max_value.item() - s = f'{class_name}\n' - s += f' ┏ min value {min_val}{units_suffix}\n' - s += f' β”— max value {max_val}{units_suffix}\n' - return s + f'{self.value.item()}{units_suffix}' + s = f"{class_name}\n" + s += f" ┏ min value {min_val}{units_suffix}\n" + s += f" β”— max value {max_val}{units_suffix}\n" + return s + f"{self.value.item()}{units_suffix}" except ImportError: pass - return f'{class_name}\n{self.value.item()}{units_suffix}' + return f"{class_name}\n{self.value.item()}{units_suffix}" # Print shape of the tensor shape_str = "x".join(map(str, shape)) - return f'{class_name} of size ({shape_str}){units_suffix}' + return f"{class_name} of size ({shape_str}){units_suffix}" # If the value is number, it can be directly printed out if isinstance(self.value, numbers.Number): - return f'{self.value}{units_suffix}' + return f"{self.value}{units_suffix}" return repr(self.value) def to_str(self, out: TextIO, context: ParameterSaveContext): - out.write(f'{self._repr()}\n') + """ + Writes the string representation of the object to the output stream. + + Args: + out: The output stream to write to. + context: The parameter save context. + + Returns: + None + """ + out.write(f"{self._repr()}\n") def to_markdown(self, out: TextIO, context: ParameterSaveContext): - out.write(f'```\n{self._repr()}\n```\n') + """ + Writes a markdown representation of the object to the output stream. + + Args: + out: The output stream to write to. + context: The parameter save context (unused). + + Returns: + None + """ + out.write(f"```\n{self._repr()}\n```\n") def to_html(self, out: TextIO, context: ParameterSaveContext): + """ + Writes the object to an HTML file. + + Args: + out: The output stream to write to. + context: The parameter save context. + + Returns: + None + """ self.to_str(out, context) class ParameterSpecs: - """Container with all representations for the parameter. - """ + """Container with all representations for the parameter.""" + def __init__( - self, - parameter_name: str, - representations: Iterable[Representation] + self, parameter_name: str, representations: Iterable[Representation] ) -> None: """ Parameters @@ -419,13 +520,9 @@ def __init__( class SubelementSpecs: - """Container for named subelement - """ - def __init__( - self, - subelement_type: str, - subelement: 'Specsable' - ): + """Container for named subelement""" + + def __init__(self, subelement_type: str, subelement: "Specsable"): """ Parameters ---------- @@ -440,5 +537,5 @@ def __init__( class Specsable(Protocol): """Represents any specsable object""" - def to_specs(self) -> Iterable[ParameterSpecs | SubelementSpecs]: - ... + + def to_specs(self) -> Iterable[ParameterSpecs | SubelementSpecs]: ... diff --git a/svetlanna/specs/specs_writer.py b/svetlanna/specs/specs_writer.py index a10eeff..7d47490 100644 --- a/svetlanna/specs/specs_writer.py +++ b/svetlanna/specs/specs_writer.py @@ -9,11 +9,23 @@ from io import StringIO -_T = TypeVar('_T') +_T = TypeVar("_T") @dataclass(frozen=True, slots=True) class _IndexedObject(Generic[_T]): + """ + Represents an object with a value and its corresponding index. + + This class is designed to hold data paired with an index, useful for + situations where tracking the original position of an item is important, + such as during sorting or reordering operations. + + Attributes: + value: The data associated with this object. + index: The original index or position of the value. + """ + value: _T index: int @@ -21,6 +33,7 @@ class _IndexedObject(Generic[_T]): @dataclass class _WriterContext: """Storage for additional info within ParameterSaveContext""" + parameter_name: _IndexedObject[str] representation: _IndexedObject[Representation] context: ParameterSaveContext @@ -33,7 +46,7 @@ def context_generator( element: Specsable, element_index: int, directory: str | Path, - subelements: list[SubelementSpecs] + subelements: list[SubelementSpecs], ) -> _WriterContextGenerator: """Generate _WriterContext for the element @@ -52,9 +65,7 @@ def context_generator( _WriterContext context """ - specs_directory = Path( - directory, f'{element_index}_{element.__class__.__name__}' - ) + specs_directory = Path(directory, f"{element_index}_{element.__class__.__name__}") # sort all iterators based on parameter name repr_iterators: dict[str, list[Iterable[Representation]]] = {} @@ -78,23 +89,20 @@ def context_generator( name: itertools.chain(*iters) for name, iters in repr_iterators.items() } - for parameter_index, (parameter_name, representations) in enumerate(parameter_representations.items()): + for parameter_index, (parameter_name, representations) in enumerate( + parameter_representations.items() + ): # create context for parameter context = ParameterSaveContext( - parameter_name=parameter_name, - directory=specs_directory + parameter_name=parameter_name, directory=specs_directory ) for representation_index, representation in enumerate(representations): yield _WriterContext( - parameter_name=_IndexedObject( - parameter_name, parameter_index - ), - representation=_IndexedObject( - representation, representation_index - ), - context=context + parameter_name=_IndexedObject(parameter_name, parameter_index), + representation=_IndexedObject(representation, representation_index), + context=context, ) @@ -104,17 +112,29 @@ def write_specs_to_str( writer_context_generator: _WriterContextGenerator, stream: TextIO, ): + """ + Writes the specifications of an element to a stream. + + Args: + element: The element whose specifications are to be written. + element_index: The index of the element in a list of elements. + writer_context_generator: A generator that yields writer contexts for different representations. + stream: The output stream to write the specifications to. + + Returns: + None + """ # create and write header for the element element_name = element.__class__.__name__ indexed_name = f"({element_index}) {element_name}" if element_index == 0: - element_header = '' + element_header = "" else: - element_header = '\n' + element_header = "\n" - element_header += f'{indexed_name}\n' + element_header += f"{indexed_name}\n" stream.write(element_header) # loop over representations @@ -122,10 +142,10 @@ def write_specs_to_str( # create header for parameter specs if writer_context.parameter_name.index == 0: - specs_header = '' + specs_header = "" else: - specs_header = '\n' - specs_header += f' {writer_context.parameter_name.value}\n' + specs_header = "\n" + specs_header += f" {writer_context.parameter_name.value}\n" # write header for parameter only in the beginning of representations if writer_context.representation.index == 0: @@ -136,22 +156,17 @@ def write_specs_to_str( if isinstance(representation, StrRepresentation): # write separator between two representations if writer_context.representation.index != 0: - stream.write('\n') + stream.write("\n") - _stream = StringIO('') + _stream = StringIO("") - representation.to_str( - out=_stream, - context=writer_context.context - ) + representation.to_str(out=_stream, context=writer_context.context) s = _stream.getvalue() # add spaces at the beginning of each line - new_line_prefix = ' ' * 8 + new_line_prefix = " " * 8 stream.write( - new_line_prefix + new_line_prefix.join( - s.splitlines(keepends=True) - ) + new_line_prefix + new_line_prefix.join(s.splitlines(keepends=True)) ) @@ -161,12 +176,25 @@ def write_specs_to_markdown( writer_context_generator: _WriterContextGenerator, stream: TextIO, ): + """ + Writes specifications for an element to a markdown stream. + + Args: + element: The element whose specs are being written. + element_index: The index of the element in a list of elements. + writer_context_generator: A generator that yields writer contexts + containing parameter and representation information. + stream: The output stream to write markdown to. + + Returns: + None + """ # create and write header for the element element_name = element.__class__.__name__ indexed_name = f"({element_index}) {element_name}" - element_header = '' if element_index == 0 else '\n' + element_header = "" if element_index == 0 else "\n" element_header += f"# {indexed_name}\n" stream.write(element_header) @@ -175,10 +203,10 @@ def write_specs_to_markdown( # create header for parameter specs if writer_context.parameter_name.index == 0: - specs_header = '' + specs_header = "" else: - specs_header = '\n' - specs_header += f'**{writer_context.parameter_name.value}**\n' + specs_header = "\n" + specs_header += f"**{writer_context.parameter_name.value}**\n" # write header for parameter only in the beginning of representations if writer_context.representation.index == 0: @@ -189,12 +217,9 @@ def write_specs_to_markdown( if isinstance(representation, MarkdownRepresentation): # write separator between two representations if writer_context.representation.index != 0: - stream.write('\n') + stream.write("\n") - representation.to_markdown( - out=stream, - context=writer_context.context - ) + representation.to_markdown(out=stream, context=writer_context.context) def write_specs_to_html( @@ -203,6 +228,21 @@ def write_specs_to_html( writer_context_generator: _WriterContextGenerator, stream: TextIO, ): + """ + Writes specifications to an HTML stream. + + Iterates through writer contexts and writes HTML representations of specs, + including parameter headers where appropriate. + + Args: + element: The element whose specs are being written. + element_index: The index of the element. + writer_context_generator: A generator yielding writer contexts for each representation. + stream: The output stream to write the HTML to. + + Returns: + None + """ s = '
' @@ -221,12 +261,9 @@ def write_specs_to_html( representation = writer_context.representation.value if isinstance(representation, HTMLRepresentation): - _stream = StringIO('') + _stream = StringIO("") - representation.to_html( - out=_stream, - context=writer_context.context - ) + representation.to_html(out=_stream, context=writer_context.context) s += f"""
@@ -239,36 +276,82 @@ def write_specs_to_html( @dataclass(frozen=True, slots=True) class _ElementInTree: + """ + Represents an element within a tree structure. + + This class holds information about an element, its index, children, and + potential subelement type. It provides functionality to create copies of the element. + + Attributes: + element: The actual element being represented. + element_index: The index of the element in its parent's list of children. + children: A list of child _ElementInTree objects. + subelement_type: The type of subelement, if applicable. + + Methods: + create_copy(subelement_type): Creates a copy of the current element. + """ + element: Specsable element_index: int - children: list['_ElementInTree'] = field(default_factory=list) + children: list["_ElementInTree"] = field(default_factory=list) subelement_type: str | None = None - def create_copy( - self, subelement_type: str | None - ) -> '_ElementInTree': + def create_copy(self, subelement_type: str | None) -> "_ElementInTree": return _ElementInTree( element=self.element, element_index=self.element_index, children=self.children, - subelement_type=subelement_type + subelement_type=subelement_type, ) class _ElementsIterator: + """ + Iterates through a collection of iterables, yielding elements with their index and writer context. + + This class provides an iterator that handles multiple iterables, subelements, and ensures + each element is written only once while building a tree structure of the iterated elements. + """ + def __init__(self, *iterables: Specsable, directory: str | Path) -> None: + """ + Initializes a new instance of the class. + + Args: + *iterables: The iterables to process. Can be multiple. + directory: The directory associated with the iterables. + + Returns: + None + """ self.iterables = tuple(iterables) self.directory = directory self._iterated: dict[int, _ElementInTree] = {} self._tree: list[_ElementInTree] | None = None def __iter__( - self + self, ) -> Generator[tuple[int, Specsable, _WriterContextGenerator], None, None]: + """ + Iterates through the specsables and yields them with their index and writer context. + + This method traverses a collection of specsables, handling subelements and ensuring + each element is written only once. It maintains an internal state to track iterated + elements and builds a tree structure as it iterates. + + Args: + None + + Returns: + Generator[tuple[int, Specsable, _WriterContextGenerator]]: A generator that yields tuples + containing the index of the element, the element itself (Specsable or SubelementSpecs), + and a writer context generator for the element. + """ def f( specsables: Iterable[Specsable | SubelementSpecs], - parent_children: list[_ElementInTree] + parent_children: list[_ElementInTree], ): for element in specsables: @@ -298,9 +381,7 @@ def f( # Create a new tree element element_in_tree = _ElementInTree( - element, - index, - subelement_type=element_name + element, index, subelement_type=element_name ) self._iterated[id(element)] = element_in_tree parent_children.append(element_in_tree) @@ -334,16 +415,26 @@ def write_elements_tree_to_str( tree: list[_ElementInTree], stream: TextIO, ): - stream.write('\n\nTree:\n') + """ + Writes the elements tree to a string stream. + + Args: + tree: The list of root elements in the tree. + stream: The output stream to write to. + + Returns: + None + """ + stream.write("\n\nTree:\n") def _write_element(tree_level: int, element: _ElementInTree): - stream.write(' ' * (8 * tree_level)) + stream.write(" " * (8 * tree_level)) element_name = element.element.__class__.__name__ - indexed_name = f'({element.element_index}) {element_name}' + indexed_name = f"({element.element_index}) {element_name}" if element.subelement_type is not None: - stream.write(f'[{element.subelement_type}] ') - stream.write(f'{indexed_name}\n') + stream.write(f"[{element.subelement_type}] ") + stream.write(f"{indexed_name}\n") for subelement in element.children: _write_element(tree_level + 1, subelement) @@ -356,16 +447,26 @@ def write_elements_tree_to_markdown( tree: list[_ElementInTree], stream: TextIO, ): - stream.write('\n\n# Tree:\n') + """ + Writes the elements tree to a markdown formatted string. + + Args: + tree: The list of root elements in the tree. + stream: A file-like object to write the markdown output to. + + Returns: + None + """ + stream.write("\n\n# Tree:\n") def _write_element(tree_level: int, element: _ElementInTree): - stream.write(' ' * (4 * tree_level) + '* ') + stream.write(" " * (4 * tree_level) + "* ") element_name = element.element.__class__.__name__ - indexed_name = f'`({element.element_index}) {element_name}`' + indexed_name = f"`({element.element_index}) {element_name}`" if element.subelement_type is not None: - stream.write(f'[{element.subelement_type}] ') - stream.write(f'{indexed_name}\n') + stream.write(f"[{element.subelement_type}] ") + stream.write(f"{indexed_name}\n") for subelement in element.children: _write_element(tree_level + 1, subelement) @@ -376,39 +477,47 @@ def _write_element(tree_level: int, element: _ElementInTree): def write_specs( *iterables: Specsable, - filename: str = 'specs.txt', - directory: str | Path = 'specs', + filename: str = "specs.txt", + directory: str | Path = "specs", ): + """ + Writes specifications from iterables to a file. + + Creates a directory if it doesn't exist and writes the specs to either a + text or markdown file based on the filename extension. + + Args: + *iterables: One or more iterable objects containing specification data. + filename: The name of the output file (e.g., 'specs.txt' or 'specs.md'). + directory: The directory to write the file to. Defaults to 'specs'. + + Returns: + _ElementsIterator: An iterator object representing the written elements and their tree structure. + """ Path.mkdir(Path(directory), parents=True, exist_ok=True) path = Path(directory, filename) elements = _ElementsIterator(*iterables, directory=directory) - with open(path, 'w') as file: - if filename.endswith('.txt'): + with open(path, "w") as file: + if filename.endswith(".txt"): for elemennt_index, element, writer_context_generator in elements: write_specs_to_str( element=element, element_index=elemennt_index, writer_context_generator=writer_context_generator, - stream=file + stream=file, ) - write_elements_tree_to_str( - tree=elements.tree, - stream=file - ) - elif filename.endswith('.md'): + write_elements_tree_to_str(tree=elements.tree, stream=file) + elif filename.endswith(".md"): for elemennt_index, element, writer_context_generator in elements: write_specs_to_markdown( element=element, element_index=elemennt_index, writer_context_generator=writer_context_generator, - stream=file + stream=file, ) - write_elements_tree_to_markdown( - tree=elements.tree, - stream=file - ) + write_elements_tree_to_markdown(tree=elements.tree, stream=file) else: raise ValueError( "Unknown file extension. ' \ diff --git a/svetlanna/transforms.py b/svetlanna/transforms.py index 804fd39..03af916 100644 --- a/svetlanna/transforms.py +++ b/svetlanna/transforms.py @@ -17,6 +17,7 @@ class ToWavefront(nn.Module): (3) modulation_type='amp&phase' (any other str) tensor values transforms to amplitude and phase simultaneously """ + def __init__(self, modulation_type=None): """ Parameters @@ -47,26 +48,32 @@ def forward(self, img_tensor: torch.Tensor) -> Wavefront: # creation of a wavefront based on an image if img_tensor.size()[0] == 1: # only one channel # squeeze 0th channel dimension of image tensor - normalized_tensor = torch.squeeze(img_tensor, 0) # values from 0 to 1, shape=[H, W] + normalized_tensor = torch.squeeze( + img_tensor, 0 + ) # values from 0 to 1, shape=[H, W] else: # more than 1 color channels normalized_tensor = img_tensor # values from 0 to 1, shape=[C, H, W] # TODO: check that in simulation parameters we have the same number of wavelengths? - if self.modulation_type == 'amp': # amplitude modulation + if self.modulation_type == "amp": # amplitude modulation amplitudes = normalized_tensor phases = torch.zeros(size=normalized_tensor.size()) else: # image -> phases from -pi + eps to pi - eps normalized_tensor_fix = normalized_tensor - normalized_tensor_fix[normalized_tensor_fix == 1.] -= self.eps # maximal values - eps - normalized_tensor_fix[normalized_tensor_fix == 0.] += self.eps # 0 + eps + normalized_tensor_fix[ + normalized_tensor_fix == 1.0 + ] -= self.eps # maximal values - eps + normalized_tensor_fix[normalized_tensor_fix == 0.0] += self.eps # 0 + eps # [0, 1] --> [-pi + eps, pi - eps] phases = normalized_tensor_fix * 2 * torch.pi - torch.pi - if self.modulation_type == 'phase': # phase modulation + if self.modulation_type == "phase": # phase modulation # TODO: What is with an amplitude? - amplitudes = torch.ones(size=normalized_tensor.size()) # constant amplitude + amplitudes = torch.ones( + size=normalized_tensor.size() + ) # constant amplitude else: # phase AND amplitude modulation 'amp&phase' amplitudes = normalized_tensor @@ -80,11 +87,9 @@ class GaussModulation(nn.Module): """ Multiplies an amplitude of a Wavefront on a gaussian. """ + def __init__( - self, - sim_params: SimulationParameters, - fwhm_x, fwhm_y, - peak_x=0., peak_y=0. + self, sim_params: SimulationParameters, fwhm_x, fwhm_y, peak_x=0.0, peak_y=0.0 ): """ Parameters @@ -113,12 +118,13 @@ def get_gauss(self): gauss_2d : torch.Tensor A gaussian distribution in a 2D plane. """ - x_grid, y_grid = self.sim_params.meshgrid(x_axis='W', y_axis='H') + x_grid, y_grid = self.sim_params.meshgrid(x_axis="W", y_axis="H") gauss_2d = 1 * torch.exp( - -1 * ( - (x_grid - self.peak_x) ** 2 / 2 / self.sigma_x ** 2 + - (y_grid - self.peak_y) ** 2 / 2 / self.sigma_y ** 2 + -1 + * ( + (x_grid - self.peak_x) ** 2 / 2 / self.sigma_x**2 + + (y_grid - self.peak_y) ** 2 / 2 / self.sigma_y**2 ) ) return gauss_2d @@ -138,11 +144,11 @@ def forward(self, wf: Wavefront) -> Wavefront: wf_gauss : Wavefront A gaussian distribution in a 2D plane. """ - sim_nodes_shape = self.sim_params.axes_size(axs=('H', 'W')) # [H, W] + sim_nodes_shape = self.sim_params.axes_size(axs=("H", "W")) # [H, W] if not wf.size()[-2:] == sim_nodes_shape: warnings.warn( - message='A shape of an input Wavefront does not match with SimulationParameters! Gauss was not applied!' + message="A shape of an input Wavefront does not match with SimulationParameters! Gauss was not applied!" ) wf_gauss = wf else: diff --git a/svetlanna/units.py b/svetlanna/units.py index 31e25e8..15ff192 100644 --- a/svetlanna/units.py +++ b/svetlanna/units.py @@ -22,6 +22,7 @@ class ureg(Enum): var = 10 assert var * ureg.mm == 10*1e-2 """ + Gm = _G Mm = _M km = _k @@ -57,24 +58,83 @@ class ureg(Enum): pHz = _p def __mul__(self, other): + """ + Multiplies the value of this object by another number. + + Args: + other: The number to multiply this object's value by. + + Returns: + float: The result of multiplying the object's value by the other number. + """ return self.value * other def __rmul__(self, other): + """ + Returns the result of multiplying 'other' by the value. + + This method enables multiplication with this object on either side + (e.g., `2 * MyObject` or `MyObject * 2`). It leverages Python's + multiplication operator to achieve this. + + Args: + other: The value to multiply by the object's value. + + Returns: + The result of the multiplication. + """ return other * self.value def __truediv__(self, other): + """ + Divides the value of this object by another. + + Args: + other: The number to divide this object's value by. + + Returns: + float: The result of dividing this object's value by the given number. + """ return self.value / other def __rtruediv__(self, other): + """ + Divides another number by the value of this instance. + + Args: + other: The number to be divided by the instance's value. + + Returns: + float: The result of dividing `other` by the instance's `value`. + """ return other / self.value def __pow__(self, other): - return self.value ** other + """ + Calculates the power of this value. + + Args: + other: The exponent to raise the value to. + + Returns: + float: The result of raising the value to the power of 'other'. + """ + return self.value**other def __array__(self, dtype=None, copy=None): + """ + Returns an array representation of the value. + + Args: + dtype: The desired data type of the returned array. + copy: Whether to allocate a copy of the underlying data. + + Returns: + numpy.ndarray: A NumPy array containing the values. A copy is always created, + so attempting to set `copy=False` will raise a ValueError. + """ import numpy + if copy is False: - raise ValueError( - "`copy=False` isn't supported. A copy is always created." - ) + raise ValueError("`copy=False` isn't supported. A copy is always created.") return numpy.array(self.value, dtype=dtype) diff --git a/svetlanna/visualization/__init__.py b/svetlanna/visualization/__init__.py index f688bf3..80d1c29 100644 --- a/svetlanna/visualization/__init__.py +++ b/svetlanna/visualization/__init__.py @@ -2,9 +2,9 @@ from .widgets import jinja_env, ElementHTML __all__ = [ - 'show_specs', - 'show_structure', - 'show_stepwise_forward', - 'jinja_env', - 'ElementHTML' + "show_specs", + "show_structure", + "show_stepwise_forward", + "jinja_env", + "ElementHTML", ] diff --git a/svetlanna/visualization/widgets.py b/svetlanna/visualization/widgets.py index 75d85f8..c9e547f 100644 --- a/svetlanna/visualization/widgets.py +++ b/svetlanna/visualization/widgets.py @@ -15,52 +15,51 @@ import base64 -STATIC_FOLDER = pathlib.Path(__file__).parent / 'static' -TEMPLATES_FOLDER = pathlib.Path(__file__).parent / 'templates' +STATIC_FOLDER = pathlib.Path(__file__).parent / "static" +TEMPLATES_FOLDER = pathlib.Path(__file__).parent / "templates" jinja_env = Environment( - loader=FileSystemLoader(TEMPLATES_FOLDER), - autoescape=select_autoescape() + loader=FileSystemLoader(TEMPLATES_FOLDER), autoescape=select_autoescape() ) StepwisePlotTypes = Union[ - Literal['A'], - Literal['I'], - Literal['phase'], - Literal['Re'], - Literal['Im'] + Literal["A"], Literal["I"], Literal["phase"], Literal["Re"], Literal["Im"] ] class StepwiseForwardWidget(anywidget.AnyWidget): - _esm = STATIC_FOLDER / 'stepwise_forward_widget.js' - _css = STATIC_FOLDER / 'setup_widget.css' + """ + A widget for stepwise forward selection visualization.""" + + _esm = STATIC_FOLDER / "stepwise_forward_widget.js" + _css = STATIC_FOLDER / "setup_widget.css" elements = traitlets.List([]).tag(sync=True) - structure_html = traitlets.Unicode('').tag(sync=True) + structure_html = traitlets.Unicode("").tag(sync=True) class SpecsWidget(anywidget.AnyWidget): - _esm = STATIC_FOLDER / 'specs_widget.js' - _css = STATIC_FOLDER / 'setup_widget.css' + """ + A widget for displaying and interacting with specifications.""" + + _esm = STATIC_FOLDER / "specs_widget.js" + _css = STATIC_FOLDER / "setup_widget.css" elements = traitlets.List([]).tag(sync=True) - structure_html = traitlets.Unicode('').tag(sync=True) + structure_html = traitlets.Unicode("").tag(sync=True) @dataclass(frozen=True, slots=True) class ElementHTML: """Representation of an element in HTML format.""" + element_type: str | None html: str def default_widget_html_method( - index: int, - name: str, - element_type: str | None, - subelements: list[ElementHTML] + index: int, name: str, element_type: str | None, subelements: list[ElementHTML] ) -> str: """Default `_widget_html_` method used for rendering `Specsable` elements. @@ -81,14 +80,12 @@ def default_widget_html_method( str rendered HTML """ - return jinja_env.get_template('widget_default.html.jinja').render( + return jinja_env.get_template("widget_default.html.jinja").render( index=index, name=name, subelements=subelements ) -def _get_widget_html_method( - element: Specsable -) -> Callable[..., str]: +def _get_widget_html_method(element: Specsable) -> Callable[..., str]: """Returns `_widget_html_` method based on type of element. Parameters @@ -101,8 +98,8 @@ def _get_widget_html_method( Any `_widget_html_` method """ - if hasattr(element, '_widget_html_'): - return getattr(element, '_widget_html_') + if hasattr(element, "_widget_html_"): + return getattr(element, "_widget_html_") return default_widget_html_method @@ -129,15 +126,10 @@ def _subelements_html(subelements: list[_ElementInTree]) -> list[ElementHTML]: index=subelement.element_index, name=subelement.element.__class__.__name__, element_type=subelement.subelement_type, - subelements=_subelements_html(subelement.children) + subelements=_subelements_html(subelement.children), ) - res.append( - ElementHTML( - subelement.subelement_type, - html=raw_subelement_html - ) - ) + res.append(ElementHTML(subelement.subelement_type, html=raw_subelement_html)) return res @@ -158,9 +150,9 @@ def generate_structure_html(subelements: list[_ElementInTree]) -> str: elements_html = _subelements_html(subelements) - return jinja_env.get_template( - 'widget_structure_container.html.jinja' - ).render(elements_html=elements_html) + return jinja_env.get_template("widget_structure_container.html.jinja").render( + elements_html=elements_html + ) def show_structure(*specsable: Specsable): @@ -171,7 +163,7 @@ def show_structure(*specsable: Specsable): from IPython.display import HTML, display # Generate HTML - elements = _ElementsIterator(*specsable, directory='') + elements = _ElementsIterator(*specsable, directory="") structure_html = generate_structure_html(elements.tree) # Display HTML @@ -190,22 +182,20 @@ def show_specs(*specsable: Specsable) -> SpecsWidget: The widget """ - elements = _ElementsIterator(*specsable, directory='') + elements = _ElementsIterator(*specsable, directory="") # Prepare elements data for widget elements_json = [] for element_index, element, writer_context_generator in elements: - stream = StringIO('') + stream = StringIO("") # Write element's parameter specs to the stream - write_specs_to_html( - element, element_index, writer_context_generator, stream - ) + write_specs_to_html(element, element_index, writer_context_generator, stream) elements_json.append( { - 'index': element_index, - 'name': element.__class__.__name__, - 'specs_html': stream.getvalue() + "index": element_index, + "name": element.__class__.__name__, + "specs_html": stream.getvalue(), } ) @@ -213,10 +203,7 @@ def show_specs(*specsable: Specsable) -> SpecsWidget: structure_html = generate_structure_html(elements.tree) # Create a widget - widget = SpecsWidget( - structure_html=structure_html, - elements=elements_json - ) + widget = SpecsWidget(structure_html=structure_html, elements=elements_json) return widget @@ -224,7 +211,7 @@ def show_specs(*specsable: Specsable) -> SpecsWidget: def draw_wavefront( wavefront: torch.Tensor, simulation_parameters: SimulationParameters, - types_to_plot: tuple[StepwisePlotTypes, ...] = ('I', 'phase') + types_to_plot: tuple[StepwisePlotTypes, ...] = ("I", "phase"), ) -> bytes: """Show field propagation in the setup via widget. Currently only wavefronts of shape `(W, H)` are supported. @@ -253,16 +240,10 @@ def draw_wavefront( n_plots = len(types_to_plot) - width_to_height = ( - width.max() - width.min() - ) / ( - height.max() - height.min() - ) + width_to_height = (width.max() - width.min()) / (height.max() - height.min()) figure, ax = plt.subplots( - 1, n_plots, - figsize=(2+3*n_plots*width_to_height, 3), - dpi=120 + 1, n_plots, figsize=(2 + 3 * n_plots * width_to_height, 3), dpi=120 ) for i, plot_type in enumerate(types_to_plot): @@ -272,25 +253,17 @@ def draw_wavefront( axes = ax[i] axes = cast(Axes, axes) - if plot_type == 'A': + if plot_type == "A": # Plot the wavefront amplitude - axes.pcolorfast( - width, - height, - wavefront.abs().numpy(force=True) - ) - axes.set_title('Amplitude') + axes.pcolorfast(width, height, wavefront.abs().numpy(force=True)) + axes.set_title("Amplitude") - elif plot_type == 'I': + elif plot_type == "I": # Plot the wavefront intensity - axes.pcolorfast( - width, - height, - (wavefront.abs()**2).numpy(force=True) - ) - axes.set_title('Intensity') + axes.pcolorfast(width, height, (wavefront.abs() ** 2).numpy(force=True)) + axes.set_title("Intensity") - elif plot_type == 'phase': + elif plot_type == "phase": # Plot the wavefront phase axes.pcolorfast( width, @@ -299,27 +272,27 @@ def draw_wavefront( vmin=-torch.pi, vmax=torch.pi, ) - axes.set_title('Phase') + axes.set_title("Phase") - elif plot_type == 'Re': + elif plot_type == "Re": # Plot the wavefront real part axes.pcolorfast( width, height, wavefront.real.numpy(force=True), ) - axes.set_title('Real part') + axes.set_title("Real part") - elif plot_type == 'Im': + elif plot_type == "Im": # Plot the wavefront imaginary part axes.pcolorfast( width, height, wavefront.imag.numpy(force=True), ) - axes.set_title('Imaginary part') + axes.set_title("Imaginary part") - axes.set_aspect('equal') + axes.set_aspect("equal") plt.tight_layout() figure.savefig(stream) @@ -332,7 +305,7 @@ def show_stepwise_forward( *specsable: Specsable, input: torch.Tensor, simulation_parameters: SimulationParameters, - types_to_plot: tuple[StepwisePlotTypes, ...] = ('I', 'phase') + types_to_plot: tuple[StepwisePlotTypes, ...] = ("I", "phase"), ) -> StepwiseForwardWidget: """Display the wavefront propagation through a setup structure using a widget interface. Currently only wavefronts @@ -354,7 +327,7 @@ def show_stepwise_forward( """ elements_to_call = tuple(s for s in specsable) - elements = _ElementsIterator(*elements_to_call, directory='') + elements = _ElementsIterator(*elements_to_call, directory="") outputs = {} @@ -397,19 +370,19 @@ def capture_output_hook(module, args, output): draw_wavefront( wavefront=outputs[element], simulation_parameters=simulation_parameters, - types_to_plot=types_to_plot + types_to_plot=types_to_plot, ) ).decode() except Exception as e: - output_image = f'\n{e}' + output_image = f"\n{e}" else: output_image = None elements_json.append( { - 'index': element_index, - 'name': element.__class__.__name__, - 'output_image': output_image + "index": element_index, + "name": element.__class__.__name__, + "output_image": output_image, } ) @@ -418,8 +391,7 @@ def capture_output_hook(module, args, output): # Create a widget widget = StepwiseForwardWidget( - structure_html=structure_html, - elements=elements_json + structure_html=structure_html, elements=elements_json ) return widget diff --git a/svetlanna/wavefront.py b/svetlanna/wavefront.py index bb0272d..9afea54 100644 --- a/svetlanna/wavefront.py +++ b/svetlanna/wavefront.py @@ -6,8 +6,21 @@ class Wavefront(torch.Tensor): """Class that represents wavefront""" + @staticmethod def __new__(cls, data, *args, **kwargs): + """ + Creates a new Wavefront object from the given data. + + Args: + data: The input data to be converted into a tensor. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + Wavefront: A new instance of the Wavefront class with the data + converted to a PyTorch tensor. + """ # see https://github.com/albanD/subclass_zoo/blob/ec47458346c2a1cfcd5e676926a4bbc6709ff62e/base_tensor.py # noqa: E501 data = torch.as_tensor(data) return super(cls, Wavefront).__new__(cls, data) @@ -47,10 +60,7 @@ def phase(self) -> torch.Tensor: res = torch.angle(torch.Tensor(self) + 0.0) return res - def fwhm( - self, - simulation_parameters: SimulationParameters - ) -> tuple[float, float]: + def fwhm(self, simulation_parameters: SimulationParameters) -> tuple[float, float]: """Calculates full width at half maximum of the wavefront Returns @@ -79,9 +89,9 @@ def fwhm( def plane_wave( cls, simulation_parameters: SimulationParameters, - distance: float = 0., + distance: float = 0.0, wave_direction: Any = None, - initial_phase: float = 0. + initial_phase: float = 0.0, ) -> Self: """Generate wavefront of the plane wave @@ -105,25 +115,21 @@ def plane_wave( """ # by default the wave propagates along z direction if wave_direction is None: - wave_direction = [0., 0., 1.] + wave_direction = [0.0, 0.0, 1.0] wave_direction = torch.tensor( - wave_direction, - dtype=torch.float32, - device=simulation_parameters.device + wave_direction, dtype=torch.float32, device=simulation_parameters.device ) if wave_direction.shape != torch.Size([3]): - raise ValueError( - "wave_direction should contain exactly three components" - ) + raise ValueError("wave_direction should contain exactly three components") wave_direction = wave_direction / torch.norm(wave_direction) wave_number = 2 * torch.pi / simulation_parameters.axes.wavelength x = simulation_parameters.axes.W[None, :] y = simulation_parameters.axes.H[:, None] - kxx, axes = tensor_dot(wave_number, x, 'wavelength', ('H', 'W')) - kyy, _ = tensor_dot(wave_number, y, 'wavelength', ('H', 'W')) + kxx, axes = tensor_dot(wave_number, x, "wavelength", ("H", "W")) + kyy, _ = tensor_dot(wave_number, y, "wavelength", ("H", "W")) kzz = wave_number[..., None, None] * distance field = torch.exp(1j * wave_direction[0] * kxx) @@ -137,9 +143,9 @@ def gaussian_beam( cls, simulation_parameters: SimulationParameters, waist_radius: float, - distance: float = 0., - dx: float = 0., - dy: float = 0., + distance: float = 0.0, + dx: float = 0.0, + dy: float = 0.0, ) -> Self: """Generates the Gaussian beam. @@ -164,15 +170,21 @@ def gaussian_beam( wave_number = 2 * torch.pi / simulation_parameters.axes.wavelength - rayleigh_range = torch.pi * (waist_radius**2) / simulation_parameters.axes.wavelength # noqa: E501 + rayleigh_range = ( + torch.pi * (waist_radius**2) / simulation_parameters.axes.wavelength + ) # noqa: E501 x = simulation_parameters.axes.W[None, :] - dx y = simulation_parameters.axes.H[:, None] - dy radial_distance_squared = x**2 + y**2 - hyperbolic_relation = waist_radius * (1 + (distance / rayleigh_range)**2)**(1/2) # noqa: E501 + hyperbolic_relation = waist_radius * (1 + (distance / rayleigh_range) ** 2) ** ( + 1 / 2 + ) # noqa: E501 - inverse_radius_of_curvature = distance / (distance**2 + rayleigh_range**2) # noqa: E501 + inverse_radius_of_curvature = distance / ( + distance**2 + rayleigh_range**2 + ) # noqa: E501 # Gouy phase gouy_phase = torch.arctan(distance / rayleigh_range) @@ -180,54 +192,56 @@ def gaussian_beam( phase1, axes1 = tensor_dot( a=1j * wave_number * inverse_radius_of_curvature / 2, b=radial_distance_squared, - a_axis='wavelength', - b_axis=('H', 'W') + a_axis="wavelength", + b_axis=("H", "W"), ) field = torch.exp(phase1) field, _ = tensor_dot( a=field, b=torch.exp(1j * wave_number * distance), - a_axis=axes1, b_axis='wavelength', preserve_a_axis=True + a_axis=axes1, + b_axis="wavelength", + preserve_a_axis=True, ) field, _ = tensor_dot( a=field, b=torch.exp(-1j * gouy_phase), - a_axis=axes1, b_axis='wavelength', preserve_a_axis=True + a_axis=axes1, + b_axis="wavelength", + preserve_a_axis=True, ) phase2, axes2 = tensor_dot( - a=-1/(hyperbolic_relation)**2, + a=-1 / (hyperbolic_relation) ** 2, b=radial_distance_squared, - a_axis='wavelength', - b_axis=('H', 'W') + a_axis="wavelength", + b_axis=("H", "W"), ) field, axes = tensor_dot( a=field, b=torch.exp(phase2), a_axis=axes1, b_axis=axes2, - preserve_a_axis=True + preserve_a_axis=True, ) field, _ = tensor_dot( a=field, b=waist_radius / hyperbolic_relation, a_axis=axes, - b_axis='wavelength', - preserve_a_axis=True + b_axis="wavelength", + preserve_a_axis=True, ) - return cls( - cast_tensor(field, axes, simulation_parameters.axes.names) - ) + return cls(cast_tensor(field, axes, simulation_parameters.axes.names)) @classmethod def spherical_wave( cls, simulation_parameters: SimulationParameters, distance: float, - initial_phase: float = 0., - dx: float = 0., - dy: float = 0., + initial_phase: float = 0.0, + dx: float = 0.0, + dy: float = 0.0, ) -> Self: """Generate wavefront of the spherical wave @@ -254,22 +268,17 @@ def spherical_wave( x = simulation_parameters.axes.W[None, :] - dx y = simulation_parameters.axes.H[:, None] - dy - radius = torch.sqrt( - (x**2 + y**2) + distance**2 - ) + radius = torch.sqrt((x**2 + y**2) + distance**2) phase, axes = tensor_dot( - a=wave_number, - b=radius, - a_axis='wavelength', - b_axis=('H', 'W') + a=wave_number, b=radius, a_axis="wavelength", b_axis=("H", "W") ) field, _ = tensor_dot( a=torch.exp(1j * (phase + initial_phase)), b=1 / radius, a_axis=axes, - b_axis=('H', 'W'), - preserve_a_axis=True + b_axis=("H", "W"), + preserve_a_axis=True, ) return cls(cast_tensor(field, axes, simulation_parameters.axes.names)) @@ -277,30 +286,25 @@ def spherical_wave( # === methods below are added for typing only === if TYPE_CHECKING: - def __mul__(self, other: Any) -> Self: - ... - def __rmul__(self, other: Any) -> Self: - ... + def __mul__(self, other: Any) -> Self: ... + + def __rmul__(self, other: Any) -> Self: ... - def __add__(self, other: Any) -> Self: - ... + def __add__(self, other: Any) -> Self: ... - def __radd__(self, other: Any) -> Self: - ... + def __radd__(self, other: Any) -> Self: ... - def __truediv__(self, other: Any) -> Self: - ... + def __truediv__(self, other: Any) -> Self: ... - def __rtruediv__(self, other: Any) -> Self: - ... + def __rtruediv__(self, other: Any) -> Self: ... DEFAULT_LAST_AXES_NAMES = ( # 'pol', # 'wavelength', - 'H', - 'W' + "H", + "W", ) @@ -308,7 +312,7 @@ def mul( wf: Wavefront, b: Any, b_axis: str | Iterable[str], - sim_params: SimulationParameters | None = None + sim_params: SimulationParameters | None = None, ) -> Wavefront: """Multiplication of the wavefront and tensor. diff --git a/tests/analytical_solutions.py b/tests/analytical_solutions.py index 1900900..0436447 100644 --- a/tests/analytical_solutions.py +++ b/tests/analytical_solutions.py @@ -8,8 +8,9 @@ class RectangleFresnel: """A class describing the analytical solution for the problem of free - propagation after planar wave passes a rectangular aperture + propagation after planar wave passes a rectangular aperture """ + def __init__( self, distance: float, @@ -19,7 +20,7 @@ def __init__( y_nodes: int, width: float, height: float, - wavelength: torch.Tensor | float + wavelength: torch.Tensor | float, ): """Constructor method @@ -71,24 +72,29 @@ def field(self) -> np.ndarray: x_grid = x_grid[None, :] y_grid = y_grid[None, :] - psi1 = -np.sqrt(wave_number/(np.pi*self.distance))*(self.width/2 - + x_grid) - psi2 = np.sqrt(wave_number/(np.pi*self.distance))*(self.width / 2 - - x_grid) - eta1 = -np.sqrt(wave_number/(np.pi*self.distance))*(self.height / 2 - + y_grid) - eta2 = np.sqrt(wave_number/(np.pi*self.distance))*(self.height / 2 - - y_grid) + psi1 = -np.sqrt(wave_number / (np.pi * self.distance)) * ( + self.width / 2 + x_grid + ) + psi2 = np.sqrt(wave_number / (np.pi * self.distance)) * ( + self.width / 2 - x_grid + ) + eta1 = -np.sqrt(wave_number / (np.pi * self.distance)) * ( + self.height / 2 + y_grid + ) + eta2 = np.sqrt(wave_number / (np.pi * self.distance)) * ( + self.height / 2 - y_grid + ) s_psi1, c_psi1 = sp.special.fresnel(psi1) s_psi2, c_psi2 = sp.special.fresnel(psi2) s_eta1, c_eta1 = sp.special.fresnel(eta1) s_eta2, c_eta2 = sp.special.fresnel(eta2) - field = np.exp(1j * wave_number * self.distance) * (1 / 2j) * ( - (c_psi2 - c_psi1) + 1j * (s_psi2 - s_psi1) - ) * ( - (c_eta2 - c_eta1) + 1j * (s_eta2 - s_eta1) + field = ( + np.exp(1j * wave_number * self.distance) + * (1 / 2j) + * ((c_psi2 - c_psi1) + 1j * (s_psi2 - s_psi1)) + * ((c_eta2 - c_eta1) + 1j * (s_eta2 - s_eta1)) ) # intensity = (1/4)*(np.power((c_psi2 - c_psi1), 2) + @@ -99,13 +105,26 @@ def field(self) -> np.ndarray: return field def intensity(self) -> np.ndarray: + """ + Calculates the intensity of the field. + + Returns the squared magnitude of the electric or magnetic field. + + Args: + None + + Returns: + np.ndarray: The intensity, which is the absolute value of the field + squared. + """ return np.abs(self.field()) ** 2 class CircleFresnel: """A class describing the analytical solution for the problem of free - propagation after planar wave passes a circular aperture aperture + propagation after planar wave passes a circular aperture aperture """ + def __init__( self, distance: float, @@ -115,8 +134,24 @@ def __init__( y_nodes: int, radius: float, wavelength: torch.Tensor | float, - summation_number: int = 50 + summation_number: int = 50, ): + """ + Initializes a new instance of the class. + + Args: + distance: The distance parameter. + x_size: The x size parameter. + y_size: The y size parameter. + x_nodes: The number of nodes in the x dimension. + y_nodes: The number of nodes in the y dimension. + radius: The radius parameter. + wavelength: The wavelength parameter (can be a torch.Tensor or float). + summation_number: The number of summation terms to use, defaults to 50. + + Returns: + None + """ self.distance = distance self.x_size = x_size @@ -128,6 +163,15 @@ def __init__( self.wavelength = wavelength def field(self) -> np.ndarray: + """ + Calculates the complex-valued electromagnetic field distribution. + + Args: + None + + Returns: + np.ndarray: A 2D NumPy array representing the calculated field. + """ x_linear = np.linspace(-self.x_size / 2, self.x_size / 2, self.x_nodes) y_linear = np.linspace(-self.y_size / 2, self.y_size / 2, self.y_nodes) @@ -145,23 +189,34 @@ def field(self) -> np.ndarray: series = np.zeros_like(x_grid, dtype=np.complex128) for n in tqdm(range(self.summation_number)): - series += (( - -1j * radius / (self.radius) - ) ** n) * jv( - n, 2 * np.pi * self.radius * radius / (self.wavelength * self.distance) # noqa: E501 + series += ((-1j * radius / (self.radius)) ** n) * jv( + n, + 2 + * np.pi + * self.radius + * radius + / (self.wavelength * self.distance), # noqa: E501 ) self.field = np.exp(1j * wave_number * self.distance) * ( - 1 - np.exp( - 1j * np.pi * radius**2 / (self.wavelength * self.distance) - ) * np.exp( - 1j * np.pi * self.radius**2 / (self.wavelength * self.distance) - ) * series + 1 + - np.exp(1j * np.pi * radius**2 / (self.wavelength * self.distance)) + * np.exp(1j * np.pi * self.radius**2 / (self.wavelength * self.distance)) + * series ) return self.field def intensity(self) -> np.ndarray: + """ + Calculates the intensity pattern of a circular aperture. + + Args: + None + + Returns: + np.ndarray: A 2D NumPy array representing the calculated intensity pattern. + """ x_linear = np.linspace(-self.x_size / 2, self.x_size / 2, self.x_nodes) y_linear = np.linspace(-self.y_size / 2, self.y_size / 2, self.y_nodes) x_grid, y_grid = np.meshgrid(x_linear, y_linear) @@ -175,9 +230,33 @@ def intensity(self) -> np.ndarray: radius = np.sqrt(x_grid**2 + y_grid**2) - intensity = 1 / (1 + np.exp((radius / self.radius)**2))**2 * ( - 1 + jv(0, 2 * np.pi * self.radius * radius / (self.wavelength * self.distance))**2 - 2*np.cos( - np.pi * self.radius**2/(self.wavelength * self.distance) + np.pi * radius**2 / (self.distance*self.wavelength) - ) * jv(0, 2 * np.pi * self.radius * radius / (self.wavelength * self.distance)) + intensity = ( + 1 + / (1 + np.exp((radius / self.radius) ** 2)) ** 2 + * ( + 1 + + jv( + 0, + 2 + * np.pi + * self.radius + * radius + / (self.wavelength * self.distance), + ) + ** 2 + - 2 + * np.cos( + np.pi * self.radius**2 / (self.wavelength * self.distance) + + np.pi * radius**2 / (self.distance * self.wavelength) + ) + * jv( + 0, + 2 + * np.pi + * self.radius + * radius + / (self.wavelength * self.distance), + ) + ) ) return intensity diff --git a/tests/test_analytic.py b/tests/test_analytic.py index c565411..ce3f6e0 100644 --- a/tests/test_analytic.py +++ b/tests/test_analytic.py @@ -19,7 +19,7 @@ "width_test", "height_test", "expected_error", - "error_energy" + "error_energy", ] @@ -29,52 +29,58 @@ ( 8, # ox_size, mm 8, # oy_size, mm - 1200, # ox_nodes - 1300, # oy_nodes + 1200, # ox_nodes + 1300, # oy_nodes 540 * 1e-6, # wavelength_test, mm - 600, # distance_test, mm + 600, # distance_test, mm 4, # width_test, mm 2, # height_test, mm 0.075, # expected error - 0.05 # error_energy + 0.05, # error_energy ), ( 10, # ox_size, mm 10, # oy_size, mm - 1400, # ox_nodes - 1300, # oy_nodes - torch.linspace(330 * 1e-6, 660 * 1e-6, 5), # wavelength_test tensor, mm # noqa: E501 - 150, # distance_test, mm - 3, # width_test, mm - 3, # height_test, mm + 1400, # ox_nodes + 1300, # oy_nodes + torch.linspace( + 330 * 1e-6, 660 * 1e-6, 5 + ), # wavelength_test tensor, mm # noqa: E501 + 150, # distance_test, mm + 3, # width_test, mm + 3, # height_test, mm 0.065, # expected error - 0.05 # error_energy + 0.05, # error_energy ), ( 8, # ox_size, mm 8, # oy_size, mm - 1200, # ox_nodes - 1300, # oy_nodes - torch.linspace(330 * 1e-6, 660 * 1e-6, 5, dtype=torch.float64), # wavelength_test tensor, mm # noqa: E501 - 600, # distance_test, mm + 1200, # ox_nodes + 1300, # oy_nodes + torch.linspace( + 330 * 1e-6, 660 * 1e-6, 5, dtype=torch.float64 + ), # wavelength_test tensor, mm # noqa: E501 + 600, # distance_test, mm 2, # width_test, mm 2, # height_test, mm 0.075, # expected error - 0.05 # error_energy + 0.05, # error_energy ), ( 8, # ox_size, mm 8, # oy_size, mm - 1200, # ox_nodes - 1300, # oy_nodes - torch.linspace(330 * 1e-6, 660 * 1e-6, 5, dtype=torch.float64), # wavelength_test tensor, mm # noqa: E501 - 600, # distance_test, mm + 1200, # ox_nodes + 1300, # oy_nodes + torch.linspace( + 330 * 1e-6, 660 * 1e-6, 5, dtype=torch.float64 + ), # wavelength_test tensor, mm # noqa: E501 + 600, # distance_test, mm 4, # width_test, mm 2, # height_test, mm 0.075, # expected std - 0.05 # error_energy - ) - ] + 0.05, # error_energy + ), + ], ) def test_rectangle_fresnel( ox_size: float, @@ -86,7 +92,7 @@ def test_rectangle_fresnel( width_test: float, height_test: float, expected_error: float, - error_energy: float + error_energy: float, ): """Test for the free propagation problem on the example of diffraction of the plane wave on the rectangular aperture @@ -117,13 +123,13 @@ def test_rectangle_fresnel( params = SimulationParameters( { - 'W': torch.linspace( - -ox_size/2, ox_size/2, ox_nodes, dtype=torch.float64 + "W": torch.linspace( + -ox_size / 2, ox_size / 2, ox_nodes, dtype=torch.float64 ), - 'H': torch.linspace( - -oy_size/2, oy_size/2, oy_nodes, dtype=torch.float64 + "H": torch.linspace( + -oy_size / 2, oy_size / 2, oy_nodes, dtype=torch.float64 ), - 'wavelength': wavelength_test + "wavelength": wavelength_test, } ) @@ -131,30 +137,22 @@ def test_rectangle_fresnel( dy = oy_size / oy_nodes incident_field = Wavefront.plane_wave( - simulation_parameters=params, - distance=distance_test, - wave_direction=[0, 0, 1] + simulation_parameters=params, distance=distance_test, wave_direction=[0, 0, 1] ) # field after the square aperture transmission_field = elements.RectangularAperture( - simulation_parameters=params, - height=height_test, - width=width_test + simulation_parameters=params, height=height_test, width=width_test )(incident_field) # field on the screen by using Fresnel propagation method output_field_fresnel = elements.FreeSpace( - simulation_parameters=params, - distance=distance_test, - method='fresnel' - )(transmission_field) + simulation_parameters=params, distance=distance_test, method="fresnel" + )(transmission_field) # field on the screen by using Angular Spectrum method output_field_as = elements.FreeSpace( - simulation_parameters=params, - distance=distance_test, - method='AS' - )(transmission_field) + simulation_parameters=params, distance=distance_test, method="AS" + )(transmission_field) # intensity distribution on the screen by using Fresnel propagation method intensity_output_fresnel = output_field_fresnel.intensity @@ -170,40 +168,36 @@ def test_rectangle_fresnel( y_nodes=oy_nodes, width=width_test, height=height_test, - wavelength=wavelength_test + wavelength=wavelength_test, ).intensity() if isinstance(intensity_analytic, np.ndarray): intensity_analytic = torch.from_numpy(intensity_analytic) energy_analytic = torch.sum(intensity_analytic, dim=(-2, -1)) * dx * dy - energy_numeric_fresnel = torch.sum( - intensity_output_fresnel, dim=(-2, -1) - ) * dx * dy + energy_numeric_fresnel = torch.sum(intensity_output_fresnel, dim=(-2, -1)) * dx * dy energy_numeric_as = torch.sum(intensity_output_as, dim=(-2, -1)) * dx * dy intensity_difference_fresnel = torch.abs( intensity_analytic - intensity_output_fresnel ) / (ox_nodes * oy_nodes) - intensity_difference_as = torch.abs( - intensity_analytic - intensity_output_as - ) / (ox_nodes * oy_nodes) + intensity_difference_as = torch.abs(intensity_analytic - intensity_output_as) / ( + ox_nodes * oy_nodes + ) error_fresnel, _ = intensity_difference_fresnel.view( intensity_difference_fresnel.size(0), -1 ).max(dim=1) - error_as, _ = intensity_difference_as.view( - intensity_difference_as.size(0), -1 - ).max(dim=1) + error_as, _ = intensity_difference_as.view(intensity_difference_as.size(0), -1).max( + dim=1 + ) energy_error_fresnel = torch.abs( (energy_analytic - energy_numeric_fresnel) / energy_analytic ) - energy_error_as = torch.abs( - (energy_analytic - energy_numeric_as) / energy_analytic - ) + energy_error_as = torch.abs((energy_analytic - energy_numeric_as) / energy_analytic) assert (error_fresnel < expected_error).all() assert (error_as < expected_error).all() diff --git a/tests/test_apertures.py b/tests/test_apertures.py index 1fe6e73..36a07fa 100644 --- a/tests/test_apertures.py +++ b/tests/test_apertures.py @@ -13,14 +13,16 @@ "wavelength_test", "height_test", "width_test", - "expected_std" + "expected_std", ] @pytest.mark.parametrize( rectangle_parameters, - [(10, 10, 1000, 1200, 1064 * 1e-6, 4, 10, 1e-5), - (4, 4, 1300, 1000, 1064 * 1e-6, 3, 1, 1e-5)] + [ + (10, 10, 1000, 1200, 1064 * 1e-6, 4, 10, 1e-5), + (4, 4, 1300, 1000, 1064 * 1e-6, 3, 1, 1e-5), + ], ) def test_rectangle_aperture( ox_size: float, @@ -55,27 +57,27 @@ def test_rectangle_aperture( """ params = SimulationParameters( { - 'W': torch.linspace(-ox_size/2, ox_size/2, ox_nodes), - 'H': torch.linspace(-oy_size/2, oy_size/2, oy_nodes), - 'wavelength': wavelength_test + "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes), + "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes), + "wavelength": wavelength_test, } ) # transmission function of the rectangular aperture as a class method aperture = elements.RectangularAperture( - simulation_parameters=params, - height=height_test, - width=width_test + simulation_parameters=params, height=height_test, width=width_test ) transmission_function = aperture.get_transmission_function() x_linear = torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes) y_linear = torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes) - x_grid, y_grid = torch.meshgrid(x_linear, y_linear, indexing='xy') + x_grid, y_grid = torch.meshgrid(x_linear, y_linear, indexing="xy") - transmission_function_analytic = 1 * ( - torch.abs(x_grid) <= width_test / 2 - ) * (torch.abs(y_grid) <= height_test / 2) + transmission_function_analytic = ( + 1 + * (torch.abs(x_grid) <= width_test / 2) + * (torch.abs(y_grid) <= height_test / 2) + ) standard_deviation = torch.std( transmission_function - transmission_function_analytic @@ -85,9 +87,7 @@ def test_rectangle_aperture( # test forward calculations wavefront = svetlanna.Wavefront.plane_wave(params) - torch.testing.assert_close( - aperture(wavefront), transmission_function * wavefront - ) + torch.testing.assert_close(aperture(wavefront), transmission_function * wavefront) round_parameters = [ @@ -97,14 +97,16 @@ def test_rectangle_aperture( "oy_nodes", "wavelength_test", "radius_test", - "expected_std" + "expected_std", ] @pytest.mark.parametrize( round_parameters, - [(10, 15, 1200, 1000, 1064 * 1e-6, 4, 1e-5), - (8, 4, 1000, 1300, 1064 * 1e-6, 2.5, 1e-5)] + [ + (10, 15, 1200, 1000, 1064 * 1e-6, 4, 1e-5), + (8, 4, 1000, 1300, 1064 * 1e-6, 2.5, 1e-5), + ], ) def test_round_aperture( ox_size: float, @@ -137,22 +139,19 @@ def test_round_aperture( params = SimulationParameters( { - 'W': torch.linspace(-ox_size/2, ox_size/2, ox_nodes), - 'H': torch.linspace(-oy_size/2, oy_size/2, oy_nodes), - 'wavelength': wavelength_test + "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes), + "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes), + "wavelength": wavelength_test, } ) # transmission function of the round aperture as a class method - aperture = elements.RoundAperture( - simulation_parameters=params, - radius=radius_test - ) + aperture = elements.RoundAperture(simulation_parameters=params, radius=radius_test) transmission_function = aperture.get_transmission_function() x_linear = torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes) y_linear = torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes) - x_grid, y_grid = torch.meshgrid(x_linear, y_linear, indexing='xy') + x_grid, y_grid = torch.meshgrid(x_linear, y_linear, indexing="xy") transmission_function_analytic = 1 * ( torch.pow(x_grid, 2) + torch.pow(y_grid, 2) <= radius_test**2 @@ -166,9 +165,7 @@ def test_round_aperture( # test forward calculations wavefront = svetlanna.Wavefront.plane_wave(params) - torch.testing.assert_close( - aperture(wavefront), transmission_function * wavefront - ) + torch.testing.assert_close(aperture(wavefront), transmission_function * wavefront) arbitrary_parameters = [ @@ -178,14 +175,16 @@ def test_round_aperture( "oy_nodes", "wavelength_test", "mask_test", - "expected_std" + "expected_std", ] @pytest.mark.parametrize( arbitrary_parameters, - [(10, 15, 1200, 1000, 1064 * 1e-6, torch.rand(1000, 1200), 1e-5), - (8, 4, 1100, 1000, 1064 * 1e-6, torch.rand(1000, 1100), 1e-5)] + [ + (10, 15, 1200, 1000, 1064 * 1e-6, torch.rand(1000, 1200), 1e-5), + (8, 4, 1100, 1000, 1064 * 1e-6, torch.rand(1000, 1100), 1e-5), + ], ) def test_aperture( ox_size: float, @@ -218,17 +217,14 @@ def test_aperture( params = SimulationParameters( { - 'W': torch.linspace(-ox_size/2, ox_size/2, ox_nodes), - 'H': torch.linspace(-oy_size/2, oy_size/2, oy_nodes), - 'wavelength': wavelength_test + "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes), + "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes), + "wavelength": wavelength_test, } ) # transmission function for the aperture with arbitrary shape as a # class method - aperture = elements.Aperture( - simulation_parameters=params, - mask=mask_test - ) + aperture = elements.Aperture(simulation_parameters=params, mask=mask_test) transmission_function = aperture.get_transmission_function() transmission_function_analytic = mask_test @@ -241,6 +237,4 @@ def test_aperture( # test forward calculations wavefront = svetlanna.Wavefront.plane_wave(params) - torch.testing.assert_close( - aperture(wavefront), transmission_function * wavefront - ) + torch.testing.assert_close(aperture(wavefront), transmission_function * wavefront) diff --git a/tests/test_autoencoder.py b/tests/test_autoencoder.py index 9d1468d..08a973d 100644 --- a/tests/test_autoencoder.py +++ b/tests/test_autoencoder.py @@ -14,24 +14,26 @@ def empty_encoder_or_decoder(zero_free_space): @pytest.mark.parametrize( - "wf_real, wf_imag", [ + "wf_real, wf_imag", + [ (1.00, 0.00), (0.00, 1.00), (2.50, 1.25), - ] + ], ) def test_autoencoder_forward( - sim_params, empty_encoder_or_decoder, # fixtures - wf_real, wf_imag + sim_params, empty_encoder_or_decoder, wf_real, wf_imag # fixtures ): """Test forward function for a single Wavefront sequence.""" - h, w = sim_params.axes_size(axs=('H', 'W')) # size of a wavefront according to SimulationParameters + h, w = sim_params.axes_size( + axs=("H", "W") + ) # size of a wavefront according to SimulationParameters test_wavefront = Wavefront( - torch.ones(size=(h, w), dtype=torch.float64) * wf_real + - torch.ones(size=(h, w), dtype=torch.float64) * wf_imag * 1j + torch.ones(size=(h, w), dtype=torch.float64) * wf_real + + torch.ones(size=(h, w), dtype=torch.float64) * wf_imag * 1j ) - for to_return in ['wf', 'amps']: + for to_return in ["wf", "amps"]: autoencoder = LinearAutoencoder( sim_params, encoder_elements_list=empty_encoder_or_decoder, @@ -45,9 +47,9 @@ def test_autoencoder_forward( for wf in [wf_encoded, wf_decoded]: assert isinstance(wf, Wavefront) - if to_return == 'wf': + if to_return == "wf": assert torch.allclose(wf, test_wavefront) - if to_return == 'amps': + if to_return == "amps": assert torch.allclose(wf, test_wavefront.abs() + 0j) @@ -58,13 +60,15 @@ def test_autoencoder_device(sim_params, empty_encoder_or_decoder): sim_params, encoder_elements_list=empty_encoder_or_decoder, decoder_elements_list=empty_encoder_or_decoder, - device='cpu', + device="cpu", ) - assert autoencoder.device == torch.device('cpu') + assert autoencoder.device == torch.device("cpu") - new_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - if new_device == torch.device('cpu'): # if cuda is not available - check if `mps` is - new_device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu') + new_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if new_device == torch.device( + "cpu" + ): # if cuda is not available - check if `mps` is + new_device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") new_autoencoder = autoencoder.to(new_device) diff --git a/tests/test_axes_math.py b/tests/test_axes_math.py index a8e0379..be0b481 100644 --- a/tests/test_axes_math.py +++ b/tests/test_axes_math.py @@ -10,8 +10,8 @@ def test_append_slice(): """Test that append slice""" - axes = ('a',) - new_axes = ('a', 'b') + axes = ("a",) + new_axes = ("a", "b") full_slice = slice(None, None, None) # no additional axes @@ -21,16 +21,16 @@ def test_append_slice(): assert _append_slice(axes, new_axes) == (..., full_slice, None) # two additional axis should be at the end - for new_axes in permutations(('a', 'b', 'c')): + for new_axes in permutations(("a", "b", "c")): assert _append_slice(axes, new_axes) == (..., full_slice, None, None) def test_axes_indices_to_sort(): """Test for `_axes_indices_to_sort` function""" - axes = ('a', 'b') - new_axes = ('b', 'd', 'a', 'c') + axes = ("a", "b") + new_axes = ("b", "d", "a", "c") # axes of the tensor expanded with _append_slice - appended_tensor_axes = ('a', 'b', 'd', 'c') + appended_tensor_axes = ("a", "b", "d", "c") assert _axes_indices_to_sort(axes, new_axes) == tuple( new_axes.index(axis) for axis in appended_tensor_axes @@ -45,51 +45,75 @@ def test_swaps(): # elements swap for i, j in _swaps(new_axes): - new_axes_list[i], new_axes_list[j] \ - = new_axes_list[j], new_axes_list[i] + new_axes_list[i], new_axes_list[j] = new_axes_list[j], new_axes_list[i] # test if new_axes_list is sorted after swapping assert sorted(axes) == new_axes_list def test_cast_tensor(): + """ + Tests the cast_tensor function with various axis configurations. + + This tests checks that the `cast_tensor` function correctly adds, maintains, and swaps axes + of a given tensor while raising ValueErrors when invalid configurations are provided. + + Args: + None + + Returns: + None + """ a = torch.tensor([[1, 2], [3, 4]]) # additional axes - b = cast_tensor(a=a, axes=('a',), new_axes=('a', 'b', 'c')) + b = cast_tensor(a=a, axes=("a",), new_axes=("a", "b", "c")) assert len(b.shape) == 4 assert b.shape[-1] == b.shape[-2] == 1 - b = cast_tensor(a=a, axes=('a', 'b'), new_axes=('a', 'b', 'c')) + b = cast_tensor(a=a, axes=("a", "b"), new_axes=("a", "b", "c")) assert len(b.shape) == 3 assert b.shape[-1] == 1 # same axes test - b = cast_tensor(a=a, axes=('a', 'b'), new_axes=('a', 'b')) + b = cast_tensor(a=a, axes=("a", "b"), new_axes=("a", "b")) assert len(b.shape) == 2 # swap axes test - b = cast_tensor(a=a, axes=('a', 'b'), new_axes=('b', 'a')) + b = cast_tensor(a=a, axes=("a", "b"), new_axes=("b", "a")) assert torch.allclose(a, b.T) with pytest.raises(ValueError): - b = cast_tensor(a=a, axes=('a', 'b'), new_axes=('a', 'c')) + b = cast_tensor(a=a, axes=("a", "b"), new_axes=("a", "c")) def test_axis_to_tuple(): + """ + Tests the _axis_to_tuple function for various inputs and caching behavior. + + This test verifies that _axis_to_tuple correctly converts different input types + (empty tuple, string, tuple of strings) into tuples and also confirms that it + caches results for identical inputs. + + Parameters: + None + + Returns: + None + """ a = _axis_to_tuple(()) - b = _axis_to_tuple('a') - c = _axis_to_tuple(('a', 'b')) + b = _axis_to_tuple("a") + c = _axis_to_tuple(("a", "b")) # test for values assert a == () - assert b == ('a',) - assert c == ('a', 'b') + assert b == ("a",) + assert c == ("a", "b") # check for cache assert a is _axis_to_tuple(()) - assert b is _axis_to_tuple('a') - assert c is _axis_to_tuple(('a', 'b')) + assert b is _axis_to_tuple("a") + assert c is _axis_to_tuple(("a", "b")) def test_new_axes(): @@ -101,38 +125,48 @@ def test_new_axes(): ``` """ - assert _new_axes(('a', 'b'), ('a',)) == ('a', 'b') + assert _new_axes(("a", "b"), ("a",)) == ("a", "b") - assert _new_axes(('a', 'b'), ('c',)) == ('a', 'b', 'c') - assert _new_axes(('a', 'b'), ('c', 'd')) == ('a', 'b', 'c', 'd') + assert _new_axes(("a", "b"), ("c",)) == ("a", "b", "c") + assert _new_axes(("a", "b"), ("c", "d")) == ("a", "b", "c", "d") - assert _new_axes(('a', 'b'), ('a', 'c')) == ('a', 'b', 'c') - assert _new_axes(('a', 'b'), ('c', 'a')) == ('a', 'b', 'c') - assert _new_axes(('a', 'b'), ('b', 'c')) == ('a', 'b', 'c') - assert _new_axes(('a', 'b'), ('c', 'b')) == ('a', 'b', 'c') - assert _new_axes(('a', 'b'), ('c', 'd', 'b', 'e')) \ - == ('a', 'b', 'c', 'd', 'e') + assert _new_axes(("a", "b"), ("a", "c")) == ("a", "b", "c") + assert _new_axes(("a", "b"), ("c", "a")) == ("a", "b", "c") + assert _new_axes(("a", "b"), ("b", "c")) == ("a", "b", "c") + assert _new_axes(("a", "b"), ("c", "b")) == ("a", "b", "c") + assert _new_axes(("a", "b"), ("c", "d", "b", "e")) == ("a", "b", "c", "d", "e") def test_is_scalar(): - assert is_scalar(123.) - assert is_scalar(torch.tensor(123.)) - assert not is_scalar(torch.tensor([123.])) - assert not is_scalar(torch.tensor([123., 123])) - assert not is_scalar(torch.tensor([[123., 123]])) + """ + Tests the is_scalar function with various inputs. + + This tests that single number values and tensors containing a single value are correctly identified as scalar, + while tensors containing multiple values are not. + + Returns: + None + """ + assert is_scalar(123.0) + assert is_scalar(torch.tensor(123.0)) + assert not is_scalar(torch.tensor([123.0])) + assert not is_scalar(torch.tensor([123.0, 123])) + assert not is_scalar(torch.tensor([[123.0, 123]])) def test_check_axis(): + """ + Tests the _check_axis function for various error conditions.""" # test for unique with pytest.raises(ValueError): - _check_axis(torch.tensor([[[123]]]), ('a', 'a', 'b')) + _check_axis(torch.tensor([[[123]]]), ("a", "a", "b")) # test for number of axes in tensor with pytest.raises(ValueError): - _check_axis(torch.tensor([123]), ('a', 'b')) + _check_axis(torch.tensor([123]), ("a", "b")) # and for number of axes in float - assert _check_axis(123, ('a', 'b')) is None + assert _check_axis(123, ("a", "b")) is None def test_tensor_dot(): @@ -140,17 +174,17 @@ def test_tensor_dot(): e = 123 d = 321 # product of a scalar and a scalar - c, c_axis = tensor_dot(d, e, ('a', 'b'), ('c')) + c, c_axis = tensor_dot(d, e, ("a", "b"), ("c")) assert 123 * d == c assert c_axis == () - c, c_axis = tensor_dot(d, e, ('a', 'b'), ('c'), preserve_a_axis=True) + c, c_axis = tensor_dot(d, e, ("a", "b"), ("c"), preserve_a_axis=True) assert e * d == c - assert c_axis == ('a', 'b') + assert c_axis == ("a", "b") # product of a tensor and a scalar - a = torch.tensor([1.]) - b = torch.tensor([[1., 2], [3., 4.]]) + a = torch.tensor([1.0]) + b = torch.tensor([[1.0, 2], [3.0, 4.0]]) c, c_axis = tensor_dot(a, e, (), ()) assert e * a == c @@ -160,25 +194,25 @@ def test_tensor_dot(): assert torch.allclose(e * b, c) assert c_axis == () - c, c_axis = tensor_dot(a, e, ('a',), ()) + c, c_axis = tensor_dot(a, e, ("a",), ()) assert e * a == c - assert c_axis == ('a',) + assert c_axis == ("a",) - c, c_axis = tensor_dot(b, e, ('a',), ()) + c, c_axis = tensor_dot(b, e, ("a",), ()) assert torch.allclose(e * b, c) - assert c_axis == ('a',) + assert c_axis == ("a",) - c, c_axis = tensor_dot(a, e, ('a',), ('b', 'c')) + c, c_axis = tensor_dot(a, e, ("a",), ("b", "c")) assert e * a == c - assert c_axis == ('a',) + assert c_axis == ("a",) - c, c_axis = tensor_dot(b, e, ('a',), ('b', 'c')) + c, c_axis = tensor_dot(b, e, ("a",), ("b", "c")) assert torch.allclose(e * b, c) - assert c_axis == ('a',) + assert c_axis == ("a",) - c, c_axis = tensor_dot(b, e, ('a', 'd'), ('b', 'c')) + c, c_axis = tensor_dot(b, e, ("a", "d"), ("b", "c")) assert torch.allclose(e * b, c) - assert c_axis == ('a', 'd') + assert c_axis == ("a", "d") # product of a scalar and a tensor c, c_axis = tensor_dot(e, a, (), ()) @@ -189,96 +223,98 @@ def test_tensor_dot(): assert torch.allclose(e * b, c) assert c_axis == () - c, c_axis = tensor_dot(e, a, (), ('a')) + c, c_axis = tensor_dot(e, a, (), ("a")) assert e * a == c - assert c_axis == ('a',) + assert c_axis == ("a",) - c, c_axis = tensor_dot(e, b, (), ('a')) + c, c_axis = tensor_dot(e, b, (), ("a")) assert torch.allclose(e * b, c) - assert c_axis == ('a',) + assert c_axis == ("a",) - c, c_axis = tensor_dot(e, a, ('a',), ('a')) + c, c_axis = tensor_dot(e, a, ("a",), ("a")) assert e * a == c - assert c_axis == ('a',) + assert c_axis == ("a",) - c, c_axis = tensor_dot(e, a, ('a', 'c'), ('a')) + c, c_axis = tensor_dot(e, a, ("a", "c"), ("a")) assert e * a == c - assert c_axis == ('a',) + assert c_axis == ("a",) - c, c_axis = tensor_dot(e, a, ('a', 'c'), ('a'), preserve_a_axis=True) + c, c_axis = tensor_dot(e, a, ("a", "c"), ("a"), preserve_a_axis=True) assert e * a == c - assert c_axis == ('a', 'c') + assert c_axis == ("a", "c") - c, c_axis = tensor_dot(e, b, ('a', 'c'), ('a'), preserve_a_axis=True) + c, c_axis = tensor_dot(e, b, ("a", "c"), ("a"), preserve_a_axis=True) assert torch.allclose((e * b)[..., None], c) - assert c_axis == ('a', 'c') + assert c_axis == ("a", "c") # product of a tensor and a tensor c, c_axis = tensor_dot(a, b, (), ()) assert torch.allclose((a * b), c) assert c_axis == () - c, c_axis = tensor_dot(a, b, ('a'), ('a', 'b')) + c, c_axis = tensor_dot(a, b, ("a"), ("a", "b")) d = b.clone() d[:] *= a[:] assert torch.allclose(c, d) - assert c_axis == ('a', 'b') + assert c_axis == ("a", "b") - c, c_axis = tensor_dot(a, b, ('a'), ('a')) + c, c_axis = tensor_dot(a, b, ("a"), ("a")) d = b.clone() d[..., :] *= a[:] assert torch.allclose(c, d) - assert c_axis == ('a',) + assert c_axis == ("a",) - c, c_axis = tensor_dot(b, a, ('a', 'b'), ('a')) + c, c_axis = tensor_dot(b, a, ("a", "b"), ("a")) d = b.clone() d[:, ...] *= a[:] assert torch.allclose(c, d) - assert c_axis == ('a', 'b') + assert c_axis == ("a", "b") - c, c_axis = tensor_dot(b, a, ('a', 'b'), ('c')) + c, c_axis = tensor_dot(b, a, ("a", "b"), ("c")) d = b.clone()[..., None] d[..., :] *= a[:] assert torch.allclose(c, d) - assert c_axis == ('a', 'b', 'c') + assert c_axis == ("a", "b", "c") def test_mul(): - wf = Wavefront([[1.+1j]]) + """ + Tests the mul function with various Wavefront objects and tensors.""" + wf = Wavefront([[1.0 + 1j]]) # test wf and non-tensor product assert mul(wf, 123, ()) == wf * 123 # test default axes - wf = Wavefront([[1., 2], [3, 4]]) + wf = Wavefront([[1.0, 2], [3, 4]]) a = torch.tensor([10, 20]) - assert torch.allclose(mul(wf, a, ('H')), wf * a[:, None]) - assert torch.allclose(mul(wf, a, ('W')), wf * a[None, :]) + assert torch.allclose(mul(wf, a, ("H")), wf * a[:, None]) + assert torch.allclose(mul(wf, a, ("W")), wf * a[None, :]) with pytest.raises(AssertionError): - mul(wf, torch.tensor([123]), ('s')) + mul(wf, torch.tensor([123]), ("s")) # test non default axes sim_params1 = SimulationParameters( axes={ - 'H': torch.linspace(-1, 1, 2), - 'W': torch.linspace(-1, 1, 2), - 'wavelength': torch.tensor([1]), + "H": torch.linspace(-1, 1, 2), + "W": torch.linspace(-1, 1, 2), + "wavelength": torch.tensor([1]), } ) - wf1 = Wavefront([[[1., 2], [3, 4]]]) - assert torch.allclose(mul(wf1, 123, 'wavelength', sim_params1), 123 * wf1) + wf1 = Wavefront([[[1.0, 2], [3, 4]]]) + assert torch.allclose(mul(wf1, 123, "wavelength", sim_params1), 123 * wf1) r = wf1 * a[None, :] - assert torch.allclose(mul(wf1, a, 'H', sim_params1), r) + assert torch.allclose(mul(wf1, a, "H", sim_params1), r) # test the same product but with other simulation parameters sim_params2 = SimulationParameters( axes={ - 'wavelength': torch.tensor([1]), - 'W': torch.linspace(-1, 1, 2), - 'H': torch.linspace(-1, 1, 2), + "wavelength": torch.tensor([1]), + "W": torch.linspace(-1, 1, 2), + "H": torch.linspace(-1, 1, 2), } ) wf2 = Wavefront(wf1.swapaxes(0, 2)) - assert torch.allclose(mul(wf2, a, 'H', sim_params2), r.swapaxes(0, 2)) + assert torch.allclose(mul(wf2, a, "H", sim_params2), r.swapaxes(0, 2)) diff --git a/tests/test_clerk.py b/tests/test_clerk.py index fc598d4..5108456 100644 --- a/tests/test_clerk.py +++ b/tests/test_clerk.py @@ -5,14 +5,23 @@ def test_init(tmp_path): + """ + Tests the Clerk initialization and experiment directory validation. + + Args: + tmp_path: A temporary path to be used for testing. + + Returns: + None + """ # Test the experiment directory clerk = Clerk(tmp_path) assert clerk.experiment_directory == tmp_path # Test if the experiment directory is not a directory case - new_path = tmp_path / 'test' + new_path = tmp_path / "test" assert not new_path.exists() - with open(new_path, 'w'): + with open(new_path, "w"): pass with pytest.raises(ValueError): @@ -20,7 +29,16 @@ def test_init(tmp_path): def test_make_experiment_dir(tmp_path): - new_path = tmp_path / 'test' + """ + Tests the _make_experiment_dir method to ensure it creates a directory. + + Args: + tmp_path: A temporary path for testing purposes. + + Returns: + None + """ + new_path = tmp_path / "test" clerk = Clerk(new_path) assert not new_path.exists() @@ -29,12 +47,14 @@ def test_make_experiment_dir(tmp_path): def test_get_log_stream(tmp_path): + """ + Tests the _get_log_stream method.""" clerk = Clerk(tmp_path) - tag = '123' + tag = "123" with clerk._get_log_stream(tag) as stream: # Test if the file was created - assert (tmp_path / (tag + '.jsonl')).exists() + assert (tmp_path / (tag + ".jsonl")).exists() # Test if the stream is not closed after the context is closed assert not stream.closed @@ -44,23 +64,32 @@ def test_get_log_stream(tmp_path): assert stream is stream2 # Test if the same stream is not used for the different tag - other_tag = '312' + other_tag = "312" assert tag != other_tag with clerk._get_log_stream(other_tag) as stream3: assert stream is not stream3 def test_get_log_stream_mode(tmp_path): + """ + Tests that the log stream mode is 'w' for new runs and 'a' for resumed runs. + + Args: + tmp_path: A temporary path to use for the Clerk instance. + + Returns: + None + """ clerk = Clerk(tmp_path) - tag = '123' + tag = "123" # Test if the stream mode is 'w' for 'new_run' mode # By default 'new_run' mode is used with clerk: with clerk._get_log_stream(tag) as stream: assert clerk._mode == ClerkMode.new_run - assert stream.mode == 'w' + assert stream.mode == "w" # Test if the stream mode is 'a' for 'resume' mode # The clerk.begin() should be used to set 'resume' mode @@ -68,13 +97,27 @@ def test_get_log_stream_mode(tmp_path): with clerk._get_log_stream(tag) as stream: assert clerk._mode == ClerkMode.resume - assert stream.mode == 'a' + assert stream.mode == "a" def test_get_log_stream_flushed(tmp_path): + """ + Tests the behavior of log stream flushing within a context manager. + + This test verifies that the flush method is called on the underlying stream + only when explicitly requested via the `flush` parameter to + `_get_log_stream`. It also checks that flush isn't called if the context + manager exits normally without requesting a flush. + + Args: + tmp_path: A temporary path for Clerk initialization. + + Returns: + None + """ # TODO: refactoring clerk = Clerk(tmp_path) - tag = '123' + tag = "123" with clerk._get_log_stream(tag) as stream: pass @@ -99,26 +142,33 @@ def monkey_flush(): def test_conditions(tmp_path): - experiment_dir = tmp_path / 'experiment' + """ + Tests the saving and loading of experiment conditions. + + This method creates a Clerk instance, saves a dictionary of conditions to + a JSON file within an experiment directory, and then verifies that the + directory and file are created correctly. It also loads the conditions + from the saved file and asserts that they match the original conditions. + + Args: + tmp_path: A temporary path for creating the experiment directory. + + Returns: + None + """ + experiment_dir = tmp_path / "experiment" clerk = Clerk(experiment_dir) conditions = { - 'test1': 123, - 'test2': [ - 123, - 10., - 'a' - ], - 'test3': { - 't': 'e', - 's': 't' - } + "test1": 123, + "test2": [123, 10.0, "a"], + "test3": {"t": "e", "s": "t"}, } clerk.save_conditions(conditions) # Test if the folder and the file are created assert experiment_dir.exists() - assert (experiment_dir / 'conditions.json').exists() + assert (experiment_dir / "conditions.json").exists() # Test if when loaded, the conditions are the same new_clerk = Clerk(experiment_dir) @@ -129,13 +179,23 @@ def test_conditions(tmp_path): def test_logs(tmp_path): + """ + Tests the log writing and loading functionality of the Clerk. + + This tests checks that logs cannot be written before a context is active, + that files are created/appended to correctly, and that loaded messages match + the original messages in both regular and resume modes. + + Args: + tmp_path: A temporary path for creating log files. + + Returns: + None + """ clerk = Clerk(tmp_path) - tag = 'test' - messages = [ - {'a': 123, 'b': 321.}, - {'a': 321, 'b': 5423} - ] + tag = "test" + messages = [{"a": 123, "b": 321.0}, {"a": 321, "b": 5423}] # Test if log can't be written before the clerk is used in any context with pytest.raises(RuntimeError): @@ -143,7 +203,7 @@ def test_logs(tmp_path): clerk.write_log(tag, message) # Test if log file does not exist - assert not (tmp_path / (tag + '.jsonl')).exists() + assert not (tmp_path / (tag + ".jsonl")).exists() # Write the logs with clerk: @@ -151,7 +211,7 @@ def test_logs(tmp_path): clerk.write_log(tag, message) # Test if log file is created - assert (tmp_path / (tag + '.jsonl')).exists() + assert (tmp_path / (tag + ".jsonl")).exists() # Test if when loaded, the messages are the same loaded_messages = list(clerk.load_logs(tag)) @@ -159,14 +219,14 @@ def test_logs(tmp_path): assert loaded_messages == messages # Test if in resume mode the logs are appended in existing file - tag2 = 'test2' - assert not (tmp_path / (tag2 + '.jsonl')).exists() + tag2 = "test2" + assert not (tmp_path / (tag2 + ".jsonl")).exists() with clerk.begin(resume=True): for message in messages: clerk.write_log(tag2, message) - assert (tmp_path / (tag2 + '.jsonl')).exists() + assert (tmp_path / (tag2 + ".jsonl")).exists() loaded_messages = clerk.load_logs(tag2) for i, message in enumerate(loaded_messages): @@ -183,13 +243,23 @@ def test_logs(tmp_path): def test_logs_pandas(tmp_path): + """ + Tests loading logs to a Pandas DataFrame. + + This test writes log messages with a specific tag, then loads them into a + Pandas DataFrame and verifies that the loaded data matches the original + messages. + + Args: + tmp_path: A temporary path for storing log files. + + Returns: + None + """ clerk = Clerk(tmp_path) - tag = 'test' - messages = [ - {'a': 123, 'b': 321.}, - {'a': 321, 'b': 5423} - ] + tag = "test" + messages = [{"a": 123, "b": 321.0}, {"a": 321, "b": 5423}] with clerk: for message in messages: @@ -201,8 +271,20 @@ def test_logs_pandas(tmp_path): def test_checkpoints(tmp_path): + """ + Tests the Clerk's checkpointing functionality. + + This tests various scenarios including writing checkpoints, loading them, + cleaning up old checkpoints, and handling metadata and targets. + + Args: + tmp_path: A temporary path for storing checkpoint files. + + Returns: + None + """ clerk = Clerk(tmp_path) - checkpoints_filepath = tmp_path / 'checkpoints.txt' + checkpoints_filepath = tmp_path / "checkpoints.txt" # Test if checkpoint can't be written before # the clerk is used in any context @@ -217,9 +299,7 @@ def test_checkpoints(tmp_path): # Write checkpoint with metadata and no targets with clerk: for i in range(11): - clerk.write_checkpoint(metadata={ - 'i': i - }) + clerk.write_checkpoint(metadata={"i": i}) assert checkpoints_filepath.exists() @@ -228,15 +308,15 @@ def test_checkpoints(tmp_path): first_run_checkpoint_filenames: list[str] = [] with open(checkpoints_filepath) as file: for i, line in enumerate(file.readlines()): - checkpoint_filename = f'{i}.pt' + checkpoint_filename = f"{i}.pt" first_run_checkpoint_filenames.append(checkpoint_filename) - assert line == checkpoint_filename + '\n' + assert line == checkpoint_filename + "\n" assert (tmp_path / checkpoint_filename).exists() metadata = clerk.load_checkpoint(i) - assert metadata == {'i': i} + assert metadata == {"i": i} same_metadata = clerk.load_checkpoint(checkpoint_filename) assert same_metadata == metadata @@ -254,7 +334,7 @@ def test_checkpoints(tmp_path): with open(checkpoints_filepath) as file: assert len(file.readlines()) == 3 for i in range(3): - second_run_checkpoint_filenames.append(f'{i}.pt') + second_run_checkpoint_filenames.append(f"{i}.pt") # Test clean_checkpoints clerk.clean_checkpoints() @@ -274,15 +354,12 @@ class ObjectWithState(torch.nn.Module): def __init__(self) -> None: super().__init__() self.test_parameter: torch.Tensor - self.register_buffer('test_parameter', torch.tensor(0.0)) + self.register_buffer("test_parameter", torch.tensor(0.0)) clerk = Clerk(tmp_path) object1 = ObjectWithState() object2 = ObjectWithState() - clerk.set_checkpoint_targets({ - '1': object1, - '2': object2 - }) + clerk.set_checkpoint_targets({"1": object1, "2": object2}) # Write checkpoint with targets with clerk: @@ -294,16 +371,14 @@ def __init__(self) -> None: for i in range(6): clerk.load_checkpoint(i) assert object1.test_parameter.item() == 0 - assert object2.test_parameter.item() == 2. + 2. * i + assert object2.test_parameter.item() == 2.0 + 2.0 * i # Test load_checkpoint for specific target object1.test_parameter = torch.tensor(123) object2.test_parameter = torch.tensor(321) object3 = ObjectWithState() for i in range(6): - clerk.load_checkpoint(i, targets={ - '2': object3 - }) + clerk.load_checkpoint(i, targets={"2": object3}) # Test if object1 and object2 does not change assert object1.test_parameter.item() == 123 assert object2.test_parameter.item() == 321 @@ -312,10 +387,7 @@ def __init__(self) -> None: # Test if more checkpoints has been saved when resume mode clerk = Clerk(tmp_path) - clerk.set_checkpoint_targets({ - '1': object1, - '2': object2 - }) + clerk.set_checkpoint_targets({"1": object1, "2": object2}) assert object1.test_parameter.item() != 0 assert object2.test_parameter.item() != 16 @@ -333,10 +405,7 @@ def __init__(self) -> None: # Test if resume_load_last_checkpoint can be turned off clerk = Clerk(tmp_path) - clerk.set_checkpoint_targets({ - '1': object1, - '2': object2 - }) + clerk.set_checkpoint_targets({"1": object1, "2": object2}) object1.test_parameter = torch.tensor(123) object2.test_parameter = torch.tensor(321) @@ -351,7 +420,19 @@ def __init__(self) -> None: def test_context(tmp_path): - new_path = tmp_path / 'test' + """ + Tests the Clerk context manager functionality. + + This tests nested context usage, directory creation, and automatic stream closing. + It also verifies exception handling when detaching a stream within a clerk context. + + Args: + tmp_path: A temporary path for testing purposes. + + Returns: + None + """ + new_path = tmp_path / "test" clerk = Clerk(new_path) assert not new_path.exists() @@ -365,7 +446,7 @@ def test_context(tmp_path): assert new_path.exists() # Test if all streams are closed automatically - with clerk._get_log_stream('test', flush=False) as stream: + with clerk._get_log_stream("test", flush=False) as stream: assert not stream.closed with clerk: @@ -374,14 +455,28 @@ def test_context(tmp_path): assert stream.closed with pytest.raises(ExceptionGroup): - with clerk._get_log_stream('test', flush=False) as stream: + with clerk._get_log_stream("test", flush=False) as stream: with clerk: stream.detach() def test_backup_checkpoint(tmp_path): + """ + Tests the backup checkpoint functionality of the Clerk class. + + This test verifies that a backup checkpoint is created when an exception + occurs during a clerk session with autosave enabled, and that it can be + loaded and cleaned up correctly. It also tests the handling of exceptions + during the preparation of checkpoint data. + + Args: + tmp_path: A temporary path for creating test files. + + Returns: + None + """ clerk = Clerk(tmp_path) - checkpoints_filepath = tmp_path / 'checkpoints.txt' + checkpoints_filepath = tmp_path / "checkpoints.txt" class SpecificException(Exception): pass @@ -395,12 +490,12 @@ class SpecificException(Exception): # Test if the backup checkpoint is not in 'checkpoints.txt' with open(checkpoints_filepath) as file: - assert file.readlines() == ['0.pt\n'] + assert file.readlines() == ["0.pt\n"] backup_checkpoints: list[str] = [] # Find backup checkpoint files for file in tmp_path.iterdir(): - if file.name.endswith('.pt'): + if file.name.endswith(".pt"): if not CHECKPOINT_FILENAME_PATTERN.match(file.name): backup_checkpoints.append(file.name) @@ -409,8 +504,8 @@ class SpecificException(Exception): # Test metadata metadata = clerk.load_checkpoint(backup_checkpoints[0]) assert isinstance(metadata, dict) - assert 'time' in metadata - assert 'description' in metadata + assert "time" in metadata + assert "description" in metadata # Test clean_backup_checkpoints method assert (tmp_path / backup_checkpoints[0]).exists() diff --git a/tests/test_conv4f_net.py b/tests/test_conv4f_net.py index d7b0f6a..c79fde5 100644 --- a/tests/test_conv4f_net.py +++ b/tests/test_conv4f_net.py @@ -12,16 +12,17 @@ @pytest.fixture() def some_elements_list(sim_params): """Returns list with a zero distance FreeSpace, i.e. empty network""" - h, w = sim_params.axes_size(axs=('H', 'W')) + h, w = sim_params.axes_size(axs=("H", "W")) elements_list = [ elements.DiffractiveLayer( simulation_parameters=sim_params, - mask=torch.rand(h, w) * 2 * torch.pi, # mask is not changing during the training! + mask=torch.rand(h, w) + * 2 + * torch.pi, # mask is not changing during the training! ), elements.FreeSpace( - simulation_parameters=sim_params, - distance=3.00 * 1e-2, method='AS' + simulation_parameters=sim_params, distance=3.00 * 1e-2, method="AS" ), ] @@ -29,24 +30,28 @@ def some_elements_list(sim_params): @pytest.mark.parametrize( - "wf_real, wf_imag, focal_length", [ + "wf_real, wf_imag, focal_length", + [ (1.00, 0.00, 1.00 * 1e-2), (0.00, 1.00, 2.00 * 1e-2), (2.50, 1.25, 3.00 * 1e-2), - ] + ], ) def test_conv4f_net_forward( - sim_params, some_elements_list, # fixtures - wf_real, wf_imag, focal_length + sim_params, some_elements_list, wf_real, wf_imag, focal_length # fixtures ): """Test forward function for a single Wavefront sequence.""" - h, w = sim_params.axes_size(axs=('H', 'W')) # size of a wavefront according to SimulationParameters + h, w = sim_params.axes_size( + axs=("H", "W") + ) # size of a wavefront according to SimulationParameters test_wavefront = Wavefront( - torch.ones(size=(h, w), dtype=torch.float64) * wf_real + - torch.ones(size=(h, w), dtype=torch.float64) * wf_imag * 1j + torch.ones(size=(h, w), dtype=torch.float64) * wf_real + + torch.ones(size=(h, w), dtype=torch.float64) * wf_imag * 1j ) - random_diffractive_mask = torch.rand(h, w) * 2 * torch.pi # random mask for a convolution + random_diffractive_mask = ( + torch.rand(h, w) * 2 * torch.pi + ) # random mask for a convolution # NETWORK conv4f_net = ConvDiffNetwork4F( @@ -77,7 +82,9 @@ def test_conv4f_net_forward( def test_conv4f_net_device(sim_params, some_elements_list): """Test .to(device) function for a Convolutional Network.""" - h, w = sim_params.axes_size(axs=('H', 'W')) # size of a wavefront according to SimulationParameters + h, w = sim_params.axes_size( + axs=("H", "W") + ) # size of a wavefront according to SimulationParameters random_diffractive_mask = torch.rand(h, w) # random mask for a convolution # NETWORK @@ -86,14 +93,16 @@ def test_conv4f_net_device(sim_params, some_elements_list): network_elements_list=some_elements_list, focal_length=1.00 * 1e-2, conv_phase_mask=random_diffractive_mask, - device='cpu', + device="cpu", ) - assert conv4f_net.device == torch.device('cpu') + assert conv4f_net.device == torch.device("cpu") - new_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - if new_device == torch.device('cpu'): # if cuda is not available - check if `mps` is - new_device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu') + new_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if new_device == torch.device( + "cpu" + ): # if cuda is not available - check if `mps` is + new_device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") new_conv4f_net = conv4f_net.to(new_device) diff --git a/tests/test_detector.py b/tests/test_detector.py index a61ca40..b70b55e 100644 --- a/tests/test_detector.py +++ b/tests/test_detector.py @@ -13,9 +13,9 @@ def test_detector_types(): detector = Detector( SimulationParameters( { - 'W': torch.linspace(-1e-2/2, 1e-2/2, 5), - 'H': torch.linspace(-1e-2/2, 1e-2/2, 5), - 'wavelength': 1e-6 + "W": torch.linspace(-1e-2 / 2, 1e-2 / 2, 5), + "H": torch.linspace(-1e-2 / 2, 1e-2 / 2, 5), + "wavelength": 1e-6, } ) ) @@ -24,10 +24,11 @@ def test_detector_types(): @pytest.mark.parametrize( - "x_size, y_size, x_nodes, y_nodes, wavelength", [ + "x_size, y_size, x_nodes, y_nodes, wavelength", + [ (10e-2, 10e-2, 10, 10, 1e-6), (15e-2, 20e-2, 15, 20, 1e-6), - ] + ], ) def test_detector_intensity(x_size, y_size, x_nodes, y_nodes, wavelength): """ @@ -42,12 +43,12 @@ def test_detector_intensity(x_size, y_size, x_nodes, y_nodes, wavelength): detector = Detector( SimulationParameters( { - 'W': torch.linspace(-x_size/2, x_size/2, x_nodes), - 'H': torch.linspace(-y_size/2, y_size/2, y_nodes), - 'wavelength': wavelength + "W": torch.linspace(-x_size / 2, x_size / 2, x_nodes), + "H": torch.linspace(-y_size / 2, y_size / 2, y_nodes), + "wavelength": wavelength, } ), - func='intensity' + func="intensity", ) input_field = torch.rand(size=[y_nodes, x_nodes]) @@ -57,17 +58,18 @@ def test_detector_intensity(x_size, y_size, x_nodes, y_nodes, wavelength): @pytest.mark.parametrize( - "num_classes, detector_x, expected_mask", [ - (4, 8, [[0, 0, 1, 1, 2, 2, 3, 3]]), # num_classes - even, detector_x - even + "num_classes, detector_x, expected_mask", + [ + (4, 8, [[0, 0, 1, 1, 2, 2, 3, 3]]), # num_classes - even, detector_x - even (2, 4, [[0, 0, 1, 1]]), - (2, 7, [[0, 0, 0, -1, 1, 1, 1]]), # num_classes - even, detector_x - odd + (2, 7, [[0, 0, 0, -1, 1, 1, 1]]), # num_classes - even, detector_x - odd (4, 7, [[-1, 0, 1, -1, 2, 3, -1]]), - (3, 8, [[-1, 0, 0, 1, 1, 2, 2, -1]]), # num_classes - odd, detector_x - even + (3, 8, [[-1, 0, 0, 1, 1, 2, 2, -1]]), # num_classes - odd, detector_x - even (3, 10, [[0, 0, 0, 1, 1, 1, 1, 2, 2, 2]]), - (3, 7, [[0, 0, 1, 1, 1, 2, 2]]), # num_classes - odd, detector_x - odd + (3, 7, [[0, 0, 1, 1, 1, 2, 2]]), # num_classes - odd, detector_x - odd (5, 7, [[-1, 0, 1, 2, 3, 4, -1]]), (5, 11, [[0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4]]), - ] + ], ) def test_detector_segmentation_strips(num_classes, detector_x, expected_mask): """ @@ -87,11 +89,11 @@ def test_detector_segmentation_strips(num_classes, detector_x, expected_mask): num_classes=num_classes, simulation_parameters=SimulationParameters( { - 'W': torch.linspace(-1e-2/2, 1e-2/2, detector_x), - 'H': torch.linspace(0, 0, 1), - 'wavelength': 500e-6 + "W": torch.linspace(-1e-2 / 2, 1e-2 / 2, detector_x), + "H": torch.linspace(0, 0, 1), + "wavelength": 500e-6, } - ) + ), ) assert isinstance(processor, torch.nn.Module) @@ -100,19 +102,22 @@ def test_detector_segmentation_strips(num_classes, detector_x, expected_mask): for ind_class in range(num_classes): # check if all classes zones are marked assert ind_class in processor.segmented_detector - assert torch.allclose(processor.segmented_detector, torch.tensor(expected_mask, dtype=torch.int32)) + assert torch.allclose( + processor.segmented_detector, torch.tensor(expected_mask, dtype=torch.int32) + ) @pytest.mark.parametrize( - "num_classes, segmented_detector, expected_weights", [ + "num_classes, segmented_detector, expected_weights", + [ (2, [[0, 0, 1, 1, 0, 1, 0, 0]], [[3 / 5, 1.0]]), (3, [[-1, -1, 0, 0, 0, 1, 2, 0, 1, 2, 2, 2, -1, -1]], [[0.5, 1.0, 0.5]]), - (4, - [[-1, -1, 1, 1, -1, -1, 3, 3], - [0, 0, -1, -1, 2, 2, -1, -1]], - [[1.0, 1.0, 1.0, 1.0]] - ), - ] + ( + 4, + [[-1, -1, 1, 1, -1, -1, 3, 3], [0, 0, -1, -1, 2, 2, -1, -1]], + [[1.0, 1.0, 1.0, 1.0]], + ), + ], ) def test_detector_weight_segments(num_classes, segmented_detector, expected_weights): """ @@ -135,9 +140,9 @@ def test_detector_weight_segments(num_classes, segmented_detector, expected_weig num_classes=num_classes, simulation_parameters=SimulationParameters( { - 'W': torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[1]), - 'H': torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[0]), - 'wavelength': 500e-6 + "W": torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[1]), + "H": torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[0]), + "wavelength": 500e-6, } ), segmented_detector=segmented_detector_tensor, @@ -150,24 +155,23 @@ def test_detector_weight_segments(num_classes, segmented_detector, expected_weig @pytest.mark.parametrize( - "num_classes, segmented_detector, batch_detector_data, expected_probas", [ + "num_classes, segmented_detector, batch_detector_data, expected_probas", + [ ( - 2, - [[0, 0], [1, 1]], - [ - [[0.00, 0.00], [1.00, 0.00]], - [[0.10, 0.00], [0.00, 0.90]], - [[0.22, 0.00], [0.30, 0.48]] - ], - [ - [0.00, 1.00], - [0.10, 0.90], - [0.22, 0.78] - ] - ), - ] + 2, + [[0, 0], [1, 1]], + [ + [[0.00, 0.00], [1.00, 0.00]], + [[0.10, 0.00], [0.00, 0.90]], + [[0.22, 0.00], [0.30, 0.48]], + ], + [[0.00, 1.00], [0.10, 0.90], [0.22, 0.78]], + ), + ], ) -def test_detector_batch_forward(num_classes, segmented_detector, batch_detector_data, expected_probas): +def test_detector_batch_forward( + num_classes, segmented_detector, batch_detector_data, expected_probas +): """ Test of a method of DetectorProcessorClf for calculating probabilities for a batch. ... @@ -192,9 +196,9 @@ def test_detector_batch_forward(num_classes, segmented_detector, batch_detector_ num_classes=num_classes, simulation_parameters=SimulationParameters( { - 'W': torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[1]), - 'H': torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[0]), - 'wavelength': 500e-6 + "W": torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[1]), + "H": torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[0]), + "wavelength": 500e-6, } ), segmented_detector=segmented_detector_tensor, @@ -207,20 +211,18 @@ def test_detector_batch_forward(num_classes, segmented_detector, batch_detector_ @pytest.mark.parametrize( - "num_classes, segments_zone_size, segmented_detector", [ + "num_classes, segments_zone_size, segmented_detector", + [ ( - 2, - [2, 2], - [ - [-1, -1, -1, -1], - [-1, 0, 1, -1], - [-1, 0, 1, -1], - [-1, -1, -1, -1] - ] - ), - ] + 2, + [2, 2], + [[-1, -1, -1, -1], [-1, 0, 1, -1], [-1, 0, 1, -1], [-1, -1, -1, -1]], + ), + ], ) -def test_detector_segmentation_for_aperture(num_classes, segments_zone_size, segmented_detector): +def test_detector_segmentation_for_aperture( + num_classes, segments_zone_size, segmented_detector +): """ Test of a feature of DetectorProcessorClf that extends segmented detector sizes to match SimulationParameters. ... @@ -242,9 +244,9 @@ def test_detector_segmentation_for_aperture(num_classes, segments_zone_size, seg num_classes=num_classes, simulation_parameters=SimulationParameters( { - 'W': torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[1]), - 'H': torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[0]), - 'wavelength': 500e-6 + "W": torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[1]), + "H": torch.linspace(-1e-2 / 2, 1e-2 / 2, all_detector_size[0]), + "wavelength": 500e-6, } ), segments_zone_size=torch.Size(segments_zone_size), @@ -262,20 +264,20 @@ def test_detector_device(): num_classes=2, simulation_parameters=SimulationParameters( { - 'W': torch.linspace(-1e-2 / 2, 1e-2 / 2, 5), - 'H': torch.linspace(-1e-2 / 2, 1e-2 / 2, 5), - 'wavelength': 500e-6 + "W": torch.linspace(-1e-2 / 2, 1e-2 / 2, 5), + "H": torch.linspace(-1e-2 / 2, 1e-2 / 2, 5), + "wavelength": 500e-6, } ), ) - processor_2 = processor.to('cpu') + processor_2 = processor.to("cpu") assert isinstance(processor_2, DetectorProcessorClf) # available device? - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - if device == torch.device('cpu'): - device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device == torch.device("cpu"): + device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") processor_3 = processor.to(device) device_3 = processor_3.device diff --git a/tests/test_device.py b/tests/test_device.py index f4b57e9..e624526 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -11,10 +11,10 @@ parameters = "device_type" -@pytest.mark.parametrize(parameters, [ - torch.device("cuda" if torch.cuda.is_available() else "cpu"), - torch.device("cpu") -]) +@pytest.mark.parametrize( + parameters, + [torch.device("cuda" if torch.cuda.is_available() else "cpu"), torch.device("cpu")], +) def test_devices(device_type: torch.device): """A test that checks that all elements belong to the same device @@ -24,26 +24,30 @@ def test_devices(device_type: torch.device): device for objects """ - ox_size = 15. - oy_size = 8. + ox_size = 15.0 + oy_size = 8.0 ox_nodes = 1200 oy_nodes = 1100 - wavelength = torch.linspace(330*1e-6, 660*1e-6, 5) - waist_radius = 2. - distance = 100. - focal_length = 100. - radius = 10. - height = 4. - width = 3. + wavelength = torch.linspace(330 * 1e-6, 660 * 1e-6, 5) + waist_radius = 2.0 + distance = 100.0 + focal_length = 100.0 + radius = 10.0 + height = 4.0 + width = 3.0 tensors = [] params = SimulationParameters( axes={ - 'W': torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes).to(device_type), # noqa: E501 - 'H': torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes).to(device_type), # noqa: E501 - 'wavelength': wavelength.to(device_type) - } + "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes).to( + device_type + ), # noqa: E501 + "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes).to( + device_type + ), # noqa: E501 + "wavelength": wavelength.to(device_type), + } ).to(device=device_type) x_linear = params.axes.W @@ -53,34 +57,24 @@ def test_devices(device_type: torch.device): wavelength = params.axes.wavelength tensors.append(wavelength) - x_grid, y_grid = params.meshgrid(x_axis='W', y_axis='H') + x_grid, y_grid = params.meshgrid(x_axis="W", y_axis="H") tensors.extend([x_grid, y_grid]) gaussian_beam = w.gaussian_beam( - simulation_parameters=params, - waist_radius=waist_radius, - distance=distance + simulation_parameters=params, waist_radius=waist_radius, distance=distance ) tensors.append(gaussian_beam) - plane_wave = w.plane_wave( - simulation_parameters=params, - distance=distance - ) + plane_wave = w.plane_wave(simulation_parameters=params, distance=distance) tensors.append(plane_wave) - spherical_wave = w.spherical_wave( - simulation_parameters=params, - distance=distance - ) + spherical_wave = w.spherical_wave(simulation_parameters=params, distance=distance) tensors.append(spherical_wave) lens = elements.ThinLens( - simulation_parameters=params, - focal_length=focal_length, - radius=radius + simulation_parameters=params, focal_length=focal_length, radius=radius ) tensors.append(lens.get_transmission_function()) @@ -88,26 +82,20 @@ def test_devices(device_type: torch.device): tensors.append(lens.reverse(gaussian_beam)) aperture = elements.Aperture( - simulation_parameters=params, - mask=torch.zeros(x_grid.shape).to(device_type) + simulation_parameters=params, mask=torch.zeros(x_grid.shape).to(device_type) ) tensors.append(aperture.get_transmission_function()) tensors.append(aperture.forward(gaussian_beam)) rectangular_aperture = elements.RectangularAperture( - simulation_parameters=params, - height=height, - width=width + simulation_parameters=params, height=height, width=width ) tensors.append(rectangular_aperture.get_transmission_function()) tensors.append(rectangular_aperture.forward(gaussian_beam)) - round_aperture = elements.RoundAperture( - simulation_parameters=params, - radius=radius - ) + round_aperture = elements.RoundAperture(simulation_parameters=params, radius=radius) tensors.append(round_aperture.get_transmission_function()) tensors.append(round_aperture.forward(gaussian_beam)) @@ -116,7 +104,7 @@ def test_devices(device_type: torch.device): simulation_parameters=params, height=height, width=width, - mask=torch.ones_like(x_grid) + mask=torch.ones_like(x_grid), ) tensors.append(slm.transmission_function) @@ -124,8 +112,7 @@ def test_devices(device_type: torch.device): tensors.append(slm.reverse(gaussian_beam)) layer = elements.DiffractiveLayer( - simulation_parameters=params, - mask=torch.zeros_like(x_grid) + simulation_parameters=params, mask=torch.zeros_like(x_grid) ) tensors.append(layer.transmission_function) @@ -133,29 +120,25 @@ def test_devices(device_type: torch.device): tensors.append(layer.reverse(gaussian_beam)) free_space_as = elements.FreeSpace( - simulation_parameters=params, - distance=distance, method='AS' + simulation_parameters=params, distance=distance, method="AS" ) tensors.append(free_space_as.forward(gaussian_beam)) free_space_fresnel = elements.FreeSpace( - simulation_parameters=params, - distance=distance, method='fresnel' + simulation_parameters=params, distance=distance, method="fresnel" ) tensors.append(free_space_fresnel.forward(gaussian_beam)) free_space_reverse = elements.FreeSpace( - simulation_parameters=params, - distance=distance, method='fresnel' + simulation_parameters=params, distance=distance, method="fresnel" ) tensors.append(free_space_reverse.reverse(gaussian_beam)) nl = elements.NonlinearElement( - simulation_parameters=params, - response_function=lambda x: x**2 + simulation_parameters=params, response_function=lambda x: x**2 ) tensors.append(nl.forward(gaussian_beam)) @@ -163,33 +146,42 @@ def test_devices(device_type: torch.device): assert all(tensor.device.type == device_type.type for tensor in tensors) -@pytest.mark.parametrize(parameters, [ - torch.device("cuda" if torch.cuda.is_available() else "cpu"), - torch.device("cpu") -]) +@pytest.mark.parametrize( + parameters, + [torch.device("cuda" if torch.cuda.is_available() else "cpu"), torch.device("cpu")], +) def test_device_setup(device_type: torch.device): + """ + Tests that the optical setup and parameters are correctly moved to the specified device. + + Args: + device_type: The torch.device to move the setup and parameters to (CPU or CUDA). + + Returns: + None + """ - ox_size = 15. - oy_size = 8. + ox_size = 15.0 + oy_size = 8.0 ox_nodes = 1200 oy_nodes = 1100 - wavelength = torch.linspace(330*1e-6, 660*1e-6, 5) + wavelength = torch.linspace(330 * 1e-6, 660 * 1e-6, 5) # waist_radius = 2. - distance = 50. - focal_length = 100. - radius = 10. - height = 4. - width = 3. + distance = 50.0 + focal_length = 100.0 + radius = 10.0 + height = 4.0 + width = 3.0 params = SimulationParameters( axes={ - 'W': torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes), - 'H': torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes), - 'wavelength': wavelength - } + "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes), + "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes), + "wavelength": wavelength, + } ) - x_grid, _ = params.meshgrid(x_axis='W', y_axis='H') + x_grid, _ = params.meshgrid(x_axis="W", y_axis="H") # gaussian_beam = w.gaussian_beam( # simulation_parameters=params, @@ -198,48 +190,36 @@ def test_device_setup(device_type: torch.device): # ) free_space = elements.FreeSpace( - simulation_parameters=params, - distance=distance, - method="AS" + simulation_parameters=params, distance=distance, method="AS" ) - circle = elements.RoundAperture( - simulation_parameters=params, - radius=radius - ) + circle = elements.RoundAperture(simulation_parameters=params, radius=radius) rectangle = elements.RectangularAperture( - simulation_parameters=params, - height=height, - width=width + simulation_parameters=params, height=height, width=width ) aperture = elements.Aperture( - simulation_parameters=params, - mask=torch.ones_like(x_grid) + simulation_parameters=params, mask=torch.ones_like(x_grid) ) lens = elements.ThinLens( - simulation_parameters=params, - focal_length=distance, - radius=focal_length + simulation_parameters=params, focal_length=distance, radius=focal_length ) slm = elements.SpatialLightModulator( simulation_parameters=params, - mask=torch.tensor([[1., 1.], [1., 1.]]), + mask=torch.tensor([[1.0, 1.0], [1.0, 1.0]]), height=height, - width=width + width=width, ) nl = elements.NonlinearElement( - simulation_parameters=params, - response_function=lambda x: x**2 + simulation_parameters=params, response_function=lambda x: x**2 ) dl = elements.DiffractiveLayer( - simulation_parameters=params, - mask=torch.zeros_like(x_grid) + simulation_parameters=params, mask=torch.zeros_like(x_grid) ) det = detector.Detector(simulation_parameters=params) @@ -260,8 +240,7 @@ def test_device_setup(device_type: torch.device): free_space, dl, free_space, - det - + det, ] ) diff --git a/tests/test_diffraction_peaks.py b/tests/test_diffraction_peaks.py index 821386b..8112880 100644 --- a/tests/test_diffraction_peaks.py +++ b/tests/test_diffraction_peaks.py @@ -29,48 +29,48 @@ ( 500, # ox_size 500, # oy_size - 1000000, # ox_nodes - 10, # oy_nodes + 1000000, # ox_nodes + 10, # oy_nodes 1064 * 1e-6, # wavelength_test tensor, mm # noqa: E501 - 1500, # distance, mm - 0.1, # width, mm + 1500, # distance, mm + 0.1, # width, mm 5, # max diffraction order to check - 0.02 # expected_error + 0.02, # expected_error ), ( 500, # ox_size 500, # oy_size - 1000000, # ox_nodes - 10, # oy_nodes + 1000000, # ox_nodes + 10, # oy_nodes 660 * 1e-6, # wavelength_test tensor, mm # noqa: E501 - 1500, # distance, mm - 0.1, # width, mm + 1500, # distance, mm + 0.1, # width, mm 6, # max diffraction order to check - 0.02 # expected_error + 0.02, # expected_error ), ( 500, # ox_size 500, # oy_size - 1000000, # ox_nodes - 10, # oy_nodes + 1000000, # ox_nodes + 10, # oy_nodes 540 * 1e-6, # wavelength_test tensor, mm # noqa: E501 - 1500, # distance, mm - 0.1, # width, mm + 1500, # distance, mm + 0.1, # width, mm 4, # max diffraction order to check - 0.02 # expected_error + 0.02, # expected_error ), ( 500, # ox_size 500, # oy_size - 1000000, # ox_nodes - 10, # oy_nodes + 1000000, # ox_nodes + 10, # oy_nodes 990 * 1e-6, # wavelength_test tensor, mm # noqa: E501 - 1500, # distance, mm - 0.1, # width, mm + 1500, # distance, mm + 0.1, # width, mm 8, # max diffraction order to check - 0.02 # expected_error + 0.02, # expected_error ), - ] + ], ) def test_diffraction_peaks( ox_size: float, @@ -81,7 +81,7 @@ def test_diffraction_peaks( distance: float, width: float, diffraction_order: int, - expected_error: float + expected_error: float, ): """Test checking the coincidence of diffraction maxima at diffraction on a thin slit @@ -114,46 +114,39 @@ def test_diffraction_peaks( y_length = torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes) params = SimulationParameters( - axes={ - 'W': x_length, - 'H': y_length, - 'wavelength': wavelength_test - }) + axes={"W": x_length, "H": y_length, "wavelength": wavelength_test} + ) beam = Wavefront.gaussian_beam( - simulation_parameters=params, - waist_radius=2., - distance=distance + simulation_parameters=params, waist_radius=2.0, distance=distance ) # create rectangular aperture rectangular_aperture = elements.RectangularAperture( - simulation_parameters=params, - height=height, - width=width + simulation_parameters=params, height=height, width=width ) field_after_aperture = rectangular_aperture(beam) fs = elements.FreeSpace( - simulation_parameters=params, distance=distance, method='AS' + simulation_parameters=params, distance=distance, method="AS" ) output_field = fs.forward(field_after_aperture) intensity_output = output_field.intensity - amplitude_1d = np.sqrt(intensity_output.detach().numpy())[int(oy_nodes/2)] + amplitude_1d = np.sqrt(intensity_output.detach().numpy())[int(oy_nodes / 2)] def intensity_analytic(coordinates: torch.Tensor) -> np.ndarray: phi = np.arctan(coordinates / distance) u = np.pi / wavelength_test * width * np.sin(phi) - return (np.sin(u) / u)**2 * intensity_output[int(oy_nodes/2), int(ox_nodes/2)] # noqa: E501 + return (np.sin(u) / u) ** 2 * intensity_output[ + int(oy_nodes / 2), int(ox_nodes / 2) + ] # noqa: E501 def find_maximum(start, end): result = minimize_scalar( - lambda x: -intensity_analytic(x), - bounds=(start, end), - method='bounded' + lambda x: -intensity_analytic(x), bounds=(start, end), method="bounded" ) return result.x @@ -168,7 +161,7 @@ def find_maximum(start, end): # define Gaussian function def gaussian(x, amp, cen, wid): - return amp * np.exp(-(x-cen)**2 / (2*wid**2)) + return amp * np.exp(-((x - cen) ** 2) / (2 * wid**2)) x_max_averaged = np.array([]) diff --git a/tests/test_drnn.py b/tests/test_drnn.py index 9fb0a40..fc55d89 100644 --- a/tests/test_drnn.py +++ b/tests/test_drnn.py @@ -14,9 +14,9 @@ def sim_params(): """Returns SimulationParameters object.""" return SimulationParameters( { - 'W': torch.linspace(-1e-2 / 2, 1e-2 / 2, 10), - 'H': torch.linspace(-1e-2 / 2, 1e-2 / 2, 10), - 'wavelength': 800e-6 + "W": torch.linspace(-1e-2 / 2, 1e-2 / 2, 10), + "H": torch.linspace(-1e-2 / 2, 1e-2 / 2, 10), + "wavelength": 800e-6, } ) @@ -24,10 +24,7 @@ def sim_params(): @pytest.fixture() def zero_free_space(sim_params): """Returns FreeSpace with a zero distance!""" - return FreeSpace( - simulation_parameters=sim_params, - distance=0.0, method='AS' - ) + return FreeSpace(simulation_parameters=sim_params, distance=0.0, method="AS") @pytest.fixture() @@ -39,24 +36,29 @@ def empty_layer(zero_free_space): @pytest.fixture() def detector(sim_params): """Returns nn.Sequentional with a Detector for RNN detector_layer""" - return nn.Sequential( - Detector(sim_params, func='intensity') - ) + return nn.Sequential(Detector(sim_params, func="intensity")) @pytest.mark.parametrize( - "sequence_len, fusing_coeff, sequence_amplitudes", [ + "sequence_len, fusing_coeff, sequence_amplitudes", + [ (1, 0.30, [1.00]), (2, 0.25, [0.77, 0.13]), (3, 0.50, [0.80, 1.50, 2.10]), - ] + ], ) def test_drnn_forward( - sim_params, empty_layer, detector, # fixtures - sequence_len, fusing_coeff, sequence_amplitudes + sim_params, + empty_layer, + detector, # fixtures + sequence_len, + fusing_coeff, + sequence_amplitudes, ): """Test forward function for a single Wavefront sequence.""" - h, w = sim_params.axes_size(axs=('H', 'W')) # size of a wavefront according to SimulationParameters + h, w = sim_params.axes_size( + axs=("H", "W") + ) # size of a wavefront according to SimulationParameters wavefront_seq = Wavefront( torch.ones(size=(sequence_len, h, w), dtype=torch.float64) ) @@ -69,46 +71,84 @@ def test_drnn_forward( for step_ind in range(sequence_len): input = sequence_amplitudes[step_ind] hidden = fusing_coeff * hidden + (1 - fusing_coeff) * input - out_expected_val = hidden ** 2 # intensity output + out_expected_val = hidden**2 # intensity output # empty D-RNN drnn = DiffractiveRNN( sim_params, - sequence_len=sequence_len, fusing_coeff=fusing_coeff, - read_in_layer=empty_layer, memory_layer=empty_layer, + sequence_len=sequence_len, + fusing_coeff=fusing_coeff, + read_in_layer=empty_layer, + memory_layer=empty_layer, hidden_forward_layer=empty_layer, - read_out_layer=empty_layer, detector_layer=detector, + read_out_layer=empty_layer, + detector_layer=detector, device=torch.get_default_device(), ) # forward for D-RNN out_drnn = drnn(wavefront_seq) assert torch.allclose( - out_drnn, - torch.ones(size=(h, w), dtype=torch.float64) * out_expected_val + out_drnn, torch.ones(size=(h, w), dtype=torch.float64) * out_expected_val ) @pytest.mark.parametrize( - "batch_size, sequence_len, fusing_coeff, sequence_amplitudes", [ - (3, 1, 0.30, [[1.00], [0.40], [0.55],]), - (3, 2, 0.25, [[0.77, 0.13], [0.10, 1.10], [2.20, 5.00],]), - (2, 3, 0.50, [[0.80, 1.50, 2.10], [1.00, 2.00, 3.00],]), - ] + "batch_size, sequence_len, fusing_coeff, sequence_amplitudes", + [ + ( + 3, + 1, + 0.30, + [ + [1.00], + [0.40], + [0.55], + ], + ), + ( + 3, + 2, + 0.25, + [ + [0.77, 0.13], + [0.10, 1.10], + [2.20, 5.00], + ], + ), + ( + 2, + 3, + 0.50, + [ + [0.80, 1.50, 2.10], + [1.00, 2.00, 3.00], + ], + ), + ], ) def test_drnn_batch_forward( - sim_params, empty_layer, detector, # fixtures - batch_size, sequence_len, fusing_coeff, sequence_amplitudes + sim_params, + empty_layer, + detector, # fixtures + batch_size, + sequence_len, + fusing_coeff, + sequence_amplitudes, ): """Test forward function for a batch of Wavefront sequences.""" - h, w = sim_params.axes_size(axs=('H', 'W')) # size of a wavefront according to SimulationParameters + h, w = sim_params.axes_size( + axs=("H", "W") + ) # size of a wavefront according to SimulationParameters wavefront_seq_batch = Wavefront( torch.ones(size=(batch_size, sequence_len, h, w), dtype=torch.float64) ) for seq_ind in range(batch_size): for step_ind in range(sequence_len): # set amplitudes for a wavefront sequence - wavefront_seq_batch[seq_ind, step_ind, :, :] *= sequence_amplitudes[seq_ind][step_ind] + wavefront_seq_batch[seq_ind, step_ind, :, :] *= sequence_amplitudes[ + seq_ind + ][step_ind] # calculate expected values for an empty D-RNN with specified fusing coefficient out_expected_values = [] @@ -117,16 +157,19 @@ def test_drnn_batch_forward( for step_ind in range(sequence_len): input = sequence_amplitudes[seq_ind][step_ind] hidden = fusing_coeff * hidden + (1 - fusing_coeff) * input - out_expected_val = hidden ** 2 # intensity output + out_expected_val = hidden**2 # intensity output out_expected_values.append(out_expected_val) # empty D-RNN drnn = DiffractiveRNN( sim_params, - sequence_len=sequence_len, fusing_coeff=fusing_coeff, - read_in_layer=empty_layer, memory_layer=empty_layer, + sequence_len=sequence_len, + fusing_coeff=fusing_coeff, + read_in_layer=empty_layer, + memory_layer=empty_layer, hidden_forward_layer=empty_layer, - read_out_layer=empty_layer, detector_layer=detector, + read_out_layer=empty_layer, + detector_layer=detector, device=torch.get_default_device(), ) # forward for D-RNN @@ -135,7 +178,7 @@ def test_drnn_batch_forward( for ind_seq in range(batch_size): assert torch.allclose( out_drnn[ind_seq, :, :], - torch.ones(size=(h, w), dtype=torch.float64) * out_expected_values[ind_seq] + torch.ones(size=(h, w), dtype=torch.float64) * out_expected_values[ind_seq], ) @@ -144,17 +187,22 @@ def test_drnn_device(sim_params, empty_layer, detector): # empty D-RNN drnn = DiffractiveRNN( sim_params, - sequence_len=3, fusing_coeff=0.5, # some values - read_in_layer=empty_layer, memory_layer=empty_layer, + sequence_len=3, + fusing_coeff=0.5, # some values + read_in_layer=empty_layer, + memory_layer=empty_layer, hidden_forward_layer=empty_layer, - read_out_layer=empty_layer, detector_layer=detector, - device='cpu', + read_out_layer=empty_layer, + detector_layer=detector, + device="cpu", ) - assert drnn.device == torch.device('cpu') + assert drnn.device == torch.device("cpu") - new_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - if new_device == torch.device('cpu'): # if cuda is not available - check if `mps` is - new_device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu') + new_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if new_device == torch.device( + "cpu" + ): # if cuda is not available - check if `mps` is + new_device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") new_drnn = drnn.to(new_device) diff --git a/tests/test_element.py b/tests/test_element.py index 7c2c181..e2e2aab 100644 --- a/tests/test_element.py +++ b/tests/test_element.py @@ -8,44 +8,77 @@ class ElementToTest(svetlanna.elements.Element): + """ + Represents an element for testing within a simulation. + + This class serves as a basic building block for testing purposes, + primarily passing data through without modification. + """ + def __init__( self, simulation_parameters: svetlanna.SimulationParameters, test_parameter, test_buffer, ) -> None: + """ + Initializes the class with simulation parameters and test data. + + Args: + simulation_parameters: The simulation parameters object. + test_parameter: The parameter to be processed. + test_buffer: The buffer to be created. + + Returns: + None + """ super().__init__(simulation_parameters) - self.test_parameter = self.process_parameter( - 'test_parameter', test_parameter - ) - self.test_buffer = self.make_buffer( - 'test_buffer', test_buffer - ) - - def forward( - self, - incident_wavefront: svetlanna.Wavefront - ) -> svetlanna.Wavefront: + self.test_parameter = self.process_parameter("test_parameter", test_parameter) + self.test_buffer = self.make_buffer("test_buffer", test_buffer) + + def forward(self, incident_wavefront: svetlanna.Wavefront) -> svetlanna.Wavefront: + """ + Passes the incident wavefront through the layer. + + This method simply calls the `forward` method of the parent class, + effectively passing the input wavefront unchanged. + + Args: + incident_wavefront: The incoming wavefront. + + Returns: + svetlanna.Wavefront: The transmitted wavefront (identical to the input). + """ return super().forward(incident_wavefront) def test_setattr(): + """ + Tests that setattr correctly saves inner storage of parameters. + + This test creates a simulation and an element with a parameter, then + asserts that the inner storage of the parameter is saved as expected + and accessible through a specific attribute name. It also verifies that + the inner parameter is present in the element's parameters dictionary. + + Args: + None + + Returns: + None + """ sim_params = svetlanna.SimulationParameters( { - 'W': torch.linspace(-10, 10, 100), - 'H': torch.linspace(-10, 10, 100), - 'wavelength': 1., + "W": torch.linspace(-10, 10, 100), + "H": torch.linspace(-10, 10, 100), + "wavelength": 1.0, } ) - test_parameter = svetlanna.Parameter(10.) - element = ElementToTest( - sim_params, - test_parameter=test_parameter, - test_buffer=None - ) + test_parameter = svetlanna.Parameter(10.0) + element = ElementToTest(sim_params, test_parameter=test_parameter, test_buffer=None) # check if inner storage of the parameter has been saved - parameter_name = 'test_parameter' + INNER_PARAMETER_SUFFIX + parameter_name = "test_parameter" + INNER_PARAMETER_SUFFIX assert getattr(element, parameter_name) is test_parameter.inner_storage assert element.test_parameter.inner_parameter in element.parameters() @@ -53,146 +86,154 @@ def test_setattr(): @pytest.mark.parametrize( ("device",), [ + pytest.param("cpu"), pytest.param( - 'cpu' - ), - pytest.param( - 'cuda', + "cuda", marks=pytest.mark.skipif( - not torch.cuda.is_available(), - reason="cuda is not available" - ) + not torch.cuda.is_available(), reason="cuda is not available" + ), ), pytest.param( - 'mps', + "mps", marks=pytest.mark.skipif( - not torch.backends.mps.is_available(), - reason="mps is not available" - ) - ) - ] + not torch.backends.mps.is_available(), reason="mps is not available" + ), + ), + ], ) def test_make_buffer(device): + """ + Tests the registration and device placement of a buffer. + + Args: + device: The device to move the element to ('cpu', 'cuda', or 'mps'). + + Returns: + None + """ sim_params = svetlanna.SimulationParameters( { - 'W': torch.linspace(-10, 10, 100), - 'H': torch.linspace(-10, 10, 100), - 'wavelength': 1., + "W": torch.linspace(-10, 10, 100), + "H": torch.linspace(-10, 10, 100), + "wavelength": 1.0, } ) - test_buffer = torch.tensor(123.) - element = ElementToTest( - sim_params, - test_parameter=None, - test_buffer=test_buffer - ) + test_buffer = torch.tensor(123.0) + element = ElementToTest(sim_params, test_parameter=None, test_buffer=test_buffer) # check if buffer has been registered - assert hasattr(element, 'test_buffer') - assert getattr(element, 'test_buffer') in element.buffers() + assert hasattr(element, "test_buffer") + assert getattr(element, "test_buffer") in element.buffers() # check if buffer is automatically transferred to device element.to(device) - assert getattr(element, 'test_buffer').device.type == device + assert getattr(element, "test_buffer").device.type == device # test if a buffer cannot be registered with a tensor on a device # distinct from the simulation parameters' device - if device != 'cpu': + if device != "cpu": with pytest.raises(ValueError): element = ElementToTest( - sim_params, - test_parameter=None, - test_buffer=test_buffer.to(device) + sim_params, test_parameter=None, test_buffer=test_buffer.to(device) ) @pytest.mark.parametrize( ("device",), [ + pytest.param("cpu"), pytest.param( - 'cpu' - ), - pytest.param( - 'cuda', + "cuda", marks=pytest.mark.skipif( - not torch.cuda.is_available(), - reason="cuda is not available" - ) + not torch.cuda.is_available(), reason="cuda is not available" + ), ), pytest.param( - 'mps', + "mps", marks=pytest.mark.skipif( - not torch.backends.mps.is_available(), - reason="mps is not available" - ) - ) - ] + not torch.backends.mps.is_available(), reason="mps is not available" + ), + ), + ], ) def test_process_parameter(device): + """ + Tests the processing of parameters within the ElementToTest class. + + This test verifies that parameters are correctly registered, transferred to the specified device, + and handled appropriately when provided as tensors or directly as nn.Parameters. It also checks + for ValueErrors when attempting to register a parameter tensor on a different device than the simulation. + + Args: + device: The device (e.g., 'cpu', 'cuda', 'mps') to test with. + + Returns: + None + """ sim_params = svetlanna.SimulationParameters( { - 'W': torch.linspace(-10, 10, 100), - 'H': torch.linspace(-10, 10, 100), - 'wavelength': 1., + "W": torch.linspace(-10, 10, 100), + "H": torch.linspace(-10, 10, 100), + "wavelength": 1.0, } ) - test_parameter = torch.nn.Parameter(torch.tensor(123.)) - element = ElementToTest( - sim_params, - test_parameter=test_parameter, - test_buffer=None - ) + test_parameter = torch.nn.Parameter(torch.tensor(123.0)) + element = ElementToTest(sim_params, test_parameter=test_parameter, test_buffer=None) # check if parameter has been registered - assert hasattr(element, 'test_parameter') - assert getattr(element, 'test_parameter') in element.parameters() + assert hasattr(element, "test_parameter") + assert getattr(element, "test_parameter") in element.parameters() # check if parameter is automatically transferred to device element.to(device) - assert getattr(element, 'test_parameter').device.type == device + assert getattr(element, "test_parameter").device.type == device # test tensor as a parameter - test_parameter = torch.tensor(123.) - element = ElementToTest( - sim_params, - test_parameter=test_parameter, - test_buffer=None - ) + test_parameter = torch.tensor(123.0) + element = ElementToTest(sim_params, test_parameter=test_parameter, test_buffer=None) # check if test_parameter has been registered as a buffer - assert hasattr(element, 'test_parameter') - assert getattr(element, 'test_parameter') not in element.parameters() - assert getattr(element, 'test_parameter') in element.buffers() + assert hasattr(element, "test_parameter") + assert getattr(element, "test_parameter") not in element.parameters() + assert getattr(element, "test_parameter") in element.buffers() # test if a parameter cannot be registered with a tensor on a device # distinct from the simulation parameters' device - if device != 'cpu': + if device != "cpu": with pytest.raises(ValueError): element = ElementToTest( - sim_params, - test_parameter=test_parameter.to(device), - test_buffer=None + sim_params, test_parameter=test_parameter.to(device), test_buffer=None ) def test_to_specs(): + """ + Tests the conversion of an element to specifications. + + This test creates a simulation parameter set and an element with a + test parameter, then asserts that the `to_specs` method generates a list + containing a single specification for the test parameter, and that this + specification contains a representation of type ReprRepr. + + Args: + None + + Returns: + None + """ sim_params = svetlanna.SimulationParameters( { - 'W': torch.linspace(-10, 10, 100), - 'H': torch.linspace(-10, 10, 100), - 'wavelength': 1., + "W": torch.linspace(-10, 10, 100), + "H": torch.linspace(-10, 10, 100), + "wavelength": 1.0, } ) - test_parameter = torch.nn.Parameter(torch.tensor(123.)) - element = ElementToTest( - sim_params, - test_parameter=test_parameter, - test_buffer=None - ) + test_parameter = torch.nn.Parameter(torch.tensor(123.0)) + element = ElementToTest(sim_params, test_parameter=test_parameter, test_buffer=None) specs = list(element.to_specs()) assert len(specs) == 1 - assert specs[0].parameter_name == 'test_parameter' + assert specs[0].parameter_name == "test_parameter" representations = list(specs[0].representations) assert len(representations) == 1 @@ -200,39 +241,56 @@ def test_to_specs(): def test_make_buffer_pattern(): + """ + Tests the creation of a buffer pattern using make_buffer. + + This test instantiates an ElementToTest object with simulation parameters and + asserts that calling make_buffer returns an instance of _BufferedValueContainer. + It also checks for expected warnings when assigning a buffered value to another attribute. + + Args: + None + + Returns: + None + """ sim_params = svetlanna.SimulationParameters( { - 'W': torch.linspace(-10, 10, 100), - 'H': torch.linspace(-10, 10, 100), - 'wavelength': 1., + "W": torch.linspace(-10, 10, 100), + "H": torch.linspace(-10, 10, 100), + "wavelength": 1.0, } ) - element = ElementToTest( - sim_params, - test_parameter=None, - test_buffer=None - ) + element = ElementToTest(sim_params, test_parameter=None, test_buffer=None) - assert isinstance(element.make_buffer('x', None), _BufferedValueContainer) + assert isinstance(element.make_buffer("x", None), _BufferedValueContainer) with pytest.warns( match="You set the attribute y with an object of internal type _BufferedValueContainer. Make sure this is the intended behavior." ): - element.y = element.make_buffer('x', None) + element.y = element.make_buffer("x", None) def test_repr_html(): + """ + Tests the HTML representation of an element. + + This test instantiates a simulation and an ElementToTest object, + then asserts that the _repr_html_() method returns a string. + + Parameters: + None + + Returns: + None + """ sim_params = svetlanna.SimulationParameters( { - 'W': torch.linspace(-10, 10, 100), - 'H': torch.linspace(-10, 10, 100), - 'wavelength': 1., + "W": torch.linspace(-10, 10, 100), + "H": torch.linspace(-10, 10, 100), + "wavelength": 1.0, } ) - element = ElementToTest( - sim_params, - test_parameter=None, - test_buffer=None - ) + element = ElementToTest(sim_params, test_parameter=None, test_buffer=None) assert isinstance(element._repr_html_(), str) diff --git a/tests/test_freespace.py b/tests/test_freespace.py index 78b0d2f..5cba607 100644 --- a/tests/test_freespace.py +++ b/tests/test_freespace.py @@ -16,7 +16,7 @@ "distance_total", "distance_end", "expected_error", - "error_energy" + "error_energy", ] @@ -27,28 +27,30 @@ ( 6, # ox_size 6, # oy_size - 1500, # ox_nodes - 1600, # oy_nodes - torch.linspace(330*1e-6, 660*1e-6, 5), # wavelength_test tensor, mm # noqa: E501 - 2., # waist_radius_test, mm - 300, # distance_total, mm - 200, # distance_end, mm - 0.02, # expected_std - 0.01 # error_energy + 1500, # ox_nodes + 1600, # oy_nodes + torch.linspace( + 330 * 1e-6, 660 * 1e-6, 5 + ), # wavelength_test tensor, mm # noqa: E501 + 2.0, # waist_radius_test, mm + 300, # distance_total, mm + 200, # distance_end, mm + 0.02, # expected_std + 0.01, # error_energy ), ( 6, # ox_size 6, # oy_size - 1500, # ox_nodes - 1600, # oy_nodes + 1500, # ox_nodes + 1600, # oy_nodes 660 * 1e-6, # wavelength_test, mm - 2., # waist_radius_test, mm - 300, # distance_total, mm - 200, # distance_end, mm - 0.02, # expected_std - 0.01 # error_energy - ) - ] + 2.0, # waist_radius_test, mm + 300, # distance_total, mm + 200, # distance_end, mm + 0.02, # expected_std + 0.01, # error_energy + ), + ], ) def test_gaussian_beam_propagation( ox_size: float, @@ -60,7 +62,7 @@ def test_gaussian_beam_propagation( distance_total: float, distance_end: float, expected_error: float, - error_energy: float + error_energy: float, ): """Test for the free field propagation problem: free propagation of the Gaussian beam at the arbitrary distance(distance_total). We calculate the @@ -95,7 +97,7 @@ def test_gaussian_beam_propagation( x_linear = torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes) y_linear = torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes) - x_grid, y_grid = torch.meshgrid(x_linear, y_linear, indexing='xy') + x_grid, y_grid = torch.meshgrid(x_linear, y_linear, indexing="xy") # creating meshgrid x_grid = x_grid[None, :] @@ -103,7 +105,7 @@ def test_gaussian_beam_propagation( # wave_number = 2 * torch.pi / wavelength_test[..., None, None] - amplitude = 1. + amplitude = 1.0 dx = ox_size / ox_nodes dy = oy_size / oy_nodes @@ -112,35 +114,50 @@ def test_gaussian_beam_propagation( wave_number = 2 * torch.pi / wavelength_test rayleigh_range = torch.pi * (waist_radius_test**2) / wavelength_test else: - rayleigh_range = torch.pi * (waist_radius_test**2) / wavelength_test[..., None, None] # noqa: E501 + rayleigh_range = ( + torch.pi * (waist_radius_test**2) / wavelength_test[..., None, None] + ) # noqa: E501 wave_number = 2 * torch.pi / wavelength_test[..., None, None] radial_distance_squared = torch.pow(x_grid, 2) + torch.pow(y_grid, 2) - hyperbolic_relation = waist_radius_test * (1 + ( - distance_total / rayleigh_range)**2)**(1/2) + hyperbolic_relation = waist_radius_test * ( + 1 + (distance_total / rayleigh_range) ** 2 + ) ** (1 / 2) - radius_of_curvature = distance_total * ( - 1 + (rayleigh_range / distance_total)**2 - ) + radius_of_curvature = distance_total * (1 + (rayleigh_range / distance_total) ** 2) # Gouy phase gouy_phase = torch.arctan(torch.tensor(distance_total) / rayleigh_range) # analytical equation for the propagation of the Gaussian beam - field = amplitude * (waist_radius_test / hyperbolic_relation) * ( - torch.exp(-radial_distance_squared / (hyperbolic_relation)**2) * ( - torch.exp(-1j * (wave_number * distance_total + wave_number * ( - radial_distance_squared) / (2 * radius_of_curvature) - ( - gouy_phase))))) + field = ( + amplitude + * (waist_radius_test / hyperbolic_relation) + * ( + torch.exp(-radial_distance_squared / (hyperbolic_relation) ** 2) + * ( + torch.exp( + -1j + * ( + wave_number * distance_total + + wave_number + * (radial_distance_squared) + / (2 * radius_of_curvature) + - (gouy_phase) + ) + ) + ) + ) + ) intensity_analytic = torch.pow(torch.abs(field), 2) params = SimulationParameters( { - 'W': torch.linspace(-ox_size/2, ox_size/2, ox_nodes), - 'H': torch.linspace(-oy_size/2, oy_size/2, oy_nodes), - 'wavelength': wavelength_test + "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes), + "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes), + "wavelength": wavelength_test, } ) @@ -149,53 +166,45 @@ def test_gaussian_beam_propagation( field_gb_start = Wavefront.gaussian_beam( simulation_parameters=params, distance=distance_start, - waist_radius=waist_radius_test + waist_radius=waist_radius_test, ) # field on the screen by using Fresnel propagation method field_end_fresnel = elements.FreeSpace( - simulation_parameters=params, distance=distance_end, method='fresnel' + simulation_parameters=params, distance=distance_end, method="fresnel" )(field_gb_start) # field on the screen by using angular spectrum method field_end_as = elements.FreeSpace( - simulation_parameters=params, distance=distance_end, method='AS' + simulation_parameters=params, distance=distance_end, method="AS" )(field_gb_start) intensity_output_fresnel = field_end_fresnel.intensity intensity_output_as = field_end_as.intensity - energy_analytic = torch.sum( - intensity_analytic, dim=(-2, -1) - ) * dx * dy - energy_numeric_fresnel = torch.sum( - intensity_output_fresnel, dim=(-2, -1) - ) * dx * dy - energy_numeric_as = torch.sum( - intensity_output_as, dim=(-2, -1) - ) * dx * dy + energy_analytic = torch.sum(intensity_analytic, dim=(-2, -1)) * dx * dy + energy_numeric_fresnel = torch.sum(intensity_output_fresnel, dim=(-2, -1)) * dx * dy + energy_numeric_as = torch.sum(intensity_output_as, dim=(-2, -1)) * dx * dy intensity_difference_fresnel = torch.abs( intensity_analytic - intensity_output_fresnel ) / (ox_nodes * oy_nodes) - intensity_difference_as = torch.abs( - intensity_analytic - intensity_output_as - ) / (ox_nodes * oy_nodes) + intensity_difference_as = torch.abs(intensity_analytic - intensity_output_as) / ( + ox_nodes * oy_nodes + ) error_fresnel, _ = intensity_difference_fresnel.view( intensity_difference_fresnel.size(0), -1 ).max(dim=1) - error_as, _ = intensity_difference_as.view( - intensity_difference_as.size(0), -1 - ).max(dim=1) + error_as, _ = intensity_difference_as.view(intensity_difference_as.size(0), -1).max( + dim=1 + ) energy_error_fresnel = torch.abs( (energy_analytic - energy_numeric_fresnel) / energy_analytic ) - energy_error_as = torch.abs( - (energy_analytic - energy_numeric_as) / energy_analytic - ) + energy_error_as = torch.abs((energy_analytic - energy_numeric_as) / energy_analytic) assert (error_fresnel <= expected_error).all() assert (error_as <= expected_error).all() @@ -211,7 +220,7 @@ def test_gaussian_beam_propagation( "wavelength_test", "waist_radius_test", "distance", - "expected_error" + "expected_error", ] @@ -221,36 +230,34 @@ def test_gaussian_beam_propagation( ( 6, # ox_size 6, # oy_size - 1569, # ox_nodes - 1698, # oy_nodes + 1569, # ox_nodes + 1698, # oy_nodes 660 * 1e-6, # wavelength_test tensor, mm # noqa: E501 - 2., # waist_radius_test, mm - 300, # distance, mm - 0.5 # expected relative error + 2.0, # waist_radius_test, mm + 300, # distance, mm + 0.5, # expected relative error ), - ( 15, # ox_size 8, # oy_size - 1111, # ox_nodes - 14070, # oy_nodes + 1111, # ox_nodes + 14070, # oy_nodes 330 * 1e-6, # wavelength_test tensor, mm # noqa: E501 - 1., # waist_radius_test, mm - 50, # distance, mm - 1.7 # expected relative error + 1.0, # waist_radius_test, mm + 50, # distance, mm + 1.7, # expected relative error ), - ( 20, # ox_size 23, # oy_size - 1800, # ox_nodes - 1032, # oy_nodes + 1800, # ox_nodes + 1032, # oy_nodes 540 * 1e-6, # wavelength_test tensor, mm # noqa: E501 - 4., # waist_radius_test, mm - 500, # distance, mm - 0.5 # expected relative error + 4.0, # waist_radius_test, mm + 500, # distance, mm + 0.5, # expected relative error ), - ] + ], ) def test_gaussian_beam_fwhm( ox_size: float, @@ -262,28 +269,42 @@ def test_gaussian_beam_fwhm( distance: float, expected_error: float, ): + """ + Tests the FWHM calculation for a Gaussian beam using Fresnel and Angular Spectrum methods. + + Args: + ox_size: The size of the x-axis grid. + oy_size: The size of the y-axis grid. + ox_nodes: The number of nodes in the x-axis grid. + oy_nodes: The number of nodes in the y-axis grid. + wavelength_test: The wavelength of the light. + waist_radius_test: The waist radius of the Gaussian beam. + distance: The propagation distance. + expected_error: The expected relative error for the FWHM calculation. + + Returns: + None. Raises an AssertionError if the calculated relative errors exceed the expected error. + """ params = SimulationParameters( { - 'W': torch.linspace(-ox_size/2, ox_size/2, ox_nodes), - 'H': torch.linspace(-oy_size/2, oy_size/2, oy_nodes), - 'wavelength': wavelength_test + "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes), + "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes), + "wavelength": wavelength_test, } ) field_gb_start = Wavefront.gaussian_beam( - simulation_parameters=params, - distance=0., - waist_radius=waist_radius_test + simulation_parameters=params, distance=0.0, waist_radius=waist_radius_test ) # field on the screen by using Fresnel propagation method field_end_fresnel = elements.FreeSpace( - simulation_parameters=params, distance=distance, method='fresnel' + simulation_parameters=params, distance=distance, method="fresnel" )(field_gb_start) # field on the screen by using angular spectrum method field_end_as = elements.FreeSpace( - simulation_parameters=params, distance=distance, method='AS' + simulation_parameters=params, distance=distance, method="AS" )(field_gb_start) fwhm_x_as, fwhm_y_as = field_end_as.fwhm(simulation_parameters=params) @@ -291,27 +312,24 @@ def test_gaussian_beam_fwhm( simulation_parameters=params ) - fwhm_analytical = torch.sqrt( - 2. * torch.log(torch.tensor([2.])) - ) * waist_radius_test * torch.sqrt( - torch.tensor([1.]) + ( - distance / (torch.pi * waist_radius_test**2 / wavelength_test) - )**2 + fwhm_analytical = ( + torch.sqrt(2.0 * torch.log(torch.tensor([2.0]))) + * waist_radius_test + * torch.sqrt( + torch.tensor([1.0]) + + (distance / (torch.pi * waist_radius_test**2 / wavelength_test)) ** 2 + ) ) - relative_error_x_as = torch.abs( - fwhm_x_as - fwhm_analytical - ) / fwhm_analytical * 100 - relative_error_y_as = torch.abs( - fwhm_y_as - fwhm_analytical - ) / fwhm_analytical * 100 + relative_error_x_as = torch.abs(fwhm_x_as - fwhm_analytical) / fwhm_analytical * 100 + relative_error_y_as = torch.abs(fwhm_y_as - fwhm_analytical) / fwhm_analytical * 100 - relative_error_x_fresnel = torch.abs( - fwhm_x_fresnel - fwhm_analytical - ) / fwhm_analytical * 100 - relative_error_y_fresnel = torch.abs( - fwhm_y_fresnel - fwhm_analytical - ) / fwhm_analytical * 100 + relative_error_x_fresnel = ( + torch.abs(fwhm_x_fresnel - fwhm_analytical) / fwhm_analytical * 100 + ) + relative_error_y_fresnel = ( + torch.abs(fwhm_y_fresnel - fwhm_analytical) / fwhm_analytical * 100 + ) assert (relative_error_x_as <= expected_error).all() assert (relative_error_y_as <= expected_error).all() @@ -327,7 +345,7 @@ def test_gaussian_beam_fwhm( "wavelength_test", "waist_radius_test", "distance", - "expected_error" + "expected_error", ] @@ -337,36 +355,34 @@ def test_gaussian_beam_fwhm( ( 6, # ox_size 6, # oy_size - 1569, # ox_nodes - 1698, # oy_nodes + 1569, # ox_nodes + 1698, # oy_nodes 660 * 1e-6, # wavelength_test tensor, mm # noqa: E501 - 2., # waist_radius_test, mm - 300, # distance, mm - 0.5 # expected relative error + 2.0, # waist_radius_test, mm + 300, # distance, mm + 0.5, # expected relative error ), - ( 15, # ox_size 8, # oy_size - 1111, # ox_nodes - 14070, # oy_nodes + 1111, # ox_nodes + 14070, # oy_nodes 330 * 1e-6, # wavelength_test tensor, mm # noqa: E501 - 1., # waist_radius_test, mm - 50, # distance, mm - 1.7 # expected relative error + 1.0, # waist_radius_test, mm + 50, # distance, mm + 1.7, # expected relative error ), - ( 20, # ox_size 23, # oy_size - 1800, # ox_nodes - 1032, # oy_nodes + 1800, # ox_nodes + 1032, # oy_nodes 540 * 1e-6, # wavelength_test tensor, mm # noqa: E501 - 4., # waist_radius_test, mm - 500, # distance, mm - 0.5 # expected relative error + 4.0, # waist_radius_test, mm + 500, # distance, mm + 0.5, # expected relative error ), - ] + ], ) def test_gaussian_beam_phase_profile( ox_size: float, @@ -378,34 +394,51 @@ def test_gaussian_beam_phase_profile( distance: float, expected_error: float, ): + """ + Tests the phase profile of a Gaussian beam propagated using Fresnel and Angular Spectrum methods. + + This test compares the phase profiles obtained from propagating a Gaussian beam + using both Fresnel and Angular Spectrum methods to an analytically calculated + phase profile. It checks if the standard deviation of the difference between + the computed and analytical phases is within a specified tolerance. + + Args: + ox_size: The size of the x-axis grid. + oy_size: The size of the y-axis grid. + ox_nodes: The number of nodes in the x-axis grid. + oy_nodes: The number of nodes in the y-axis grid. + wavelength_test: The wavelength of the light used for simulation. + waist_radius_test: The waist radius of the Gaussian beam. + distance: The propagation distance. + expected_error: The expected maximum standard deviation of the phase difference. + + Returns: + None. Raises an AssertionError if the tests fail. + """ params = SimulationParameters( { - 'W': torch.linspace(-ox_size/2, ox_size/2, ox_nodes), - 'H': torch.linspace(-oy_size/2, oy_size/2, oy_nodes), - 'wavelength': wavelength_test + "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes), + "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes), + "wavelength": wavelength_test, } ) field_gb_start = Wavefront.gaussian_beam( - simulation_parameters=params, - distance=0., - waist_radius=waist_radius_test + simulation_parameters=params, distance=0.0, waist_radius=waist_radius_test ) # field on the screen by using Fresnel propagation method field_end_fresnel = elements.FreeSpace( - simulation_parameters=params, distance=distance, method='fresnel' + simulation_parameters=params, distance=distance, method="fresnel" )(field_gb_start) # field on the screen by using angular spectrum method field_end_as = elements.FreeSpace( - simulation_parameters=params, distance=distance, method='AS' + simulation_parameters=params, distance=distance, method="AS" )(field_gb_start) total_field = Wavefront.gaussian_beam( - simulation_parameters=params, - waist_radius=waist_radius_test, - distance=distance + simulation_parameters=params, waist_radius=waist_radius_test, distance=distance ) intensity_analytic = total_field.intensity @@ -417,9 +450,5 @@ def test_gaussian_beam_phase_profile( output_phase_fresnel = field_end_fresnel.phase * target_region output_phase_analytical = total_field.phase * target_region - assert torch.std( - output_phase_as - output_phase_analytical - ) <= expected_error - assert torch.std( - output_phase_fresnel - output_phase_analytical - ) <= expected_error + assert torch.std(output_phase_as - output_phase_analytical) <= expected_error + assert torch.std(output_phase_fresnel - output_phase_analytical) <= expected_error diff --git a/tests/test_lens.py b/tests/test_lens.py index f3c8613..29f9148 100644 --- a/tests/test_lens.py +++ b/tests/test_lens.py @@ -12,7 +12,7 @@ "wavelength_test", "focal_length_test", "radius_test", - "expected_std" + "expected_std", ] @@ -20,26 +20,30 @@ lens_parameters, [ ( - 8, # ox_size, mm - 12, # oy_size, mm - 1200, # ox_nodes - 1400, # oy_nodes - torch.linspace(330 * 1e-6, 1064 * 1e-6, 20), # wavelength_test, tensor # noqa: E501 - 100, # focal_length_test, mm - 10, # radius_test, mm - 1e-5 # expected_std + 8, # ox_size, mm + 12, # oy_size, mm + 1200, # ox_nodes + 1400, # oy_nodes + torch.linspace( + 330 * 1e-6, 1064 * 1e-6, 20 + ), # wavelength_test, tensor # noqa: E501 + 100, # focal_length_test, mm + 10, # radius_test, mm + 1e-5, # expected_std ), ( 8, # ox_size, mm 4, # oy_size, mm - 1100, # ox_nodes - 1000, # oy_nodes - torch.linspace(660 * 1e-6, 1600 * 1e-6, 20), # wavelength_test, tensor # noqa: E501 - 200, # focal_length_test, mm - 15, # radius_test, mm - 1e-5 # expected_std - ) - ] + 1100, # ox_nodes + 1000, # oy_nodes + torch.linspace( + 660 * 1e-6, 1600 * 1e-6, 20 + ), # wavelength_test, tensor # noqa: E501 + 200, # focal_length_test, mm + 15, # radius_test, mm + 1e-5, # expected_std + ), + ], ) def test_lens( ox_size: float, @@ -75,17 +79,15 @@ def test_lens( params = SimulationParameters( { - 'W': torch.linspace(-ox_size/2, ox_size/2, ox_nodes), - 'H': torch.linspace(-oy_size/2, oy_size/2, oy_nodes), - 'wavelength': wavelength_test + "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes), + "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes), + "wavelength": wavelength_test, } ) # transmission function of the thin lens as a class method transmission_function = elements.ThinLens( - simulation_parameters=params, - focal_length=focal_length_test, - radius=radius_test + simulation_parameters=params, focal_length=focal_length_test, radius=radius_test ).get_transmission_function() x_linear = torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes) @@ -102,16 +104,21 @@ def test_lens( radius_squared = torch.pow(x_grid, 2) + torch.pow(y_grid, 2) transmission_function_analytic = torch.exp( - 1j * (-wave_number / (2 * focal_length_test) * radius_squared * ( - radius_squared <= radius_test**2 - )) + 1j + * ( + -wave_number + / (2 * focal_length_test) + * radius_squared + * (radius_squared <= radius_test**2) + ) ) standard_deviation = torch.std( - torch.real((1 / 1j) * ( - torch.log(transmission_function) - torch.log( - transmission_function_analytic - ) + torch.real( + (1 / 1j) + * ( + torch.log(transmission_function) + - torch.log(transmission_function_analytic) ) ) ) @@ -120,21 +127,29 @@ def test_lens( def test_reverse(): + """ + Tests the reversibility of the ThinLens forward and reverse propagation. + + This test checks if applying the `forward` method followed by the `reverse` + method to a wavefront results in the original wavefront, confirming the + correctness of the inverse propagation implementation. + + Args: + None + + Returns: + None + """ params = SimulationParameters( { - 'W': torch.linspace(-10/2, 10/2, 10), - 'H': torch.linspace(-10/2, 10/2, 10), - 'wavelength': 1 + "W": torch.linspace(-10 / 2, 10 / 2, 10), + "H": torch.linspace(-10 / 2, 10 / 2, 10), + "wavelength": 1, } ) - lens = elements.ThinLens( - simulation_parameters=params, - focal_length=1 - ) + lens = elements.ThinLens(simulation_parameters=params, focal_length=1) # test is reverse(forward(x)) is x, where x is a wavefront wavefront = svetlanna.Wavefront.plane_wave(params) - assert torch.allclose( - lens.reverse(lens.forward(wavefront)), wavefront - ) + assert torch.allclose(lens.reverse(lens.forward(wavefront)), wavefront) diff --git a/tests/test_lightpipes_comparison.py b/tests/test_lightpipes_comparison.py index 6757570..d31acfa 100644 --- a/tests/test_lightpipes_comparison.py +++ b/tests/test_lightpipes_comparison.py @@ -7,14 +7,7 @@ from svetlanna import elements -parameters = [ - "ox_size", - "ox_nodes", - "wavelength", - "radius", - "distance", - "focal_length" -] +parameters = ["ox_size", "ox_nodes", "wavelength", "radius", "distance", "focal_length"] # TODO: fix docstrings @@ -23,29 +16,29 @@ [ ( 25 * lp.mm, # ox_size - 3000, # ox_nodes + 3000, # ox_nodes 1064 * lp.nm, # wavelength, mm - 2 * lp.mm, # radius, mm - 2000 * lp.mm, # distance, mm - 2000 * lp.mm, # focal_length, mm + 2 * lp.mm, # radius, mm + 2000 * lp.mm, # distance, mm + 2000 * lp.mm, # focal_length, mm ), ( 25 * lp.mm, # ox_size - 3000, # ox_nodes + 3000, # ox_nodes 1064 * lp.nm, # wavelength, mm - 2 * lp.mm, # radius, mm - 100 * lp.mm, # distance, mm - 20 * lp.mm, # focal_length, mm + 2 * lp.mm, # radius, mm + 100 * lp.mm, # distance, mm + 20 * lp.mm, # focal_length, mm ), ( 25 * lp.mm, # ox_size - 100, # ox_nodes + 100, # ox_nodes 123 * lp.nm, # wavelength, mm - 2 * lp.mm, # radius, mm - 200 * lp.mm, # distance, mm - 2100 * lp.mm, # focal_length, mm - ) - ] + 2 * lp.mm, # radius, mm + 200 * lp.mm, # distance, mm + 2100 * lp.mm, # focal_length, mm + ), + ], ) def test_circular_aperture( ox_size: float, @@ -53,8 +46,28 @@ def test_circular_aperture( wavelength: float, radius: float, distance: float, - focal_length: float + focal_length: float, ): + """ + Tests the circular aperture propagation using LightPipes and SVETlANNa. + + This test compares the field calculated by LightPipes with the field + calculated by SVETlANNa for a circular aperture, free space propagation, + and a lens. It asserts that the mean absolute difference between the two + fields (normalized by their maximum absolute values) is less than 0.01 + before and after the lens. + + Args: + ox_size: The size of the computational grid in x direction. + ox_nodes: The number of nodes in the computational grid. + wavelength: The wavelength of light. + radius: The radius of the circular aperture. + distance: The distance to propagate before the lens. + focal_length: The focal length of the lens. + + Returns: + None. The function asserts that the difference between LightPipes and SVETlANNa results is within tolerance. + """ # ---------------------------------- # LightPipes fields calculations # ---------------------------------- @@ -71,39 +84,21 @@ def test_circular_aperture( # ---------------------------------- oy_size = ox_size oy_nodes = ox_nodes - x_length = torch.linspace( - -ox_size / 2, ox_size / 2, ox_nodes, dtype=torch.float64 - ) - y_length = torch.linspace( - -oy_size / 2, oy_size / 2, oy_nodes, dtype=torch.float64 - ) + x_length = torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes, dtype=torch.float64) + y_length = torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes, dtype=torch.float64) simulation_parameters = sv.SimulationParameters( axes={ - 'W': x_length, - 'H': y_length, - 'wavelength': torch.tensor(wavelength, dtype=torch.float64) + "W": x_length, + "H": y_length, + "wavelength": torch.tensor(wavelength, dtype=torch.float64), } ) # elements' definitions - aperture = elements.RoundAperture( - simulation_parameters, - radius - ) - fs1 = elements.FreeSpace( - simulation_parameters, - distance, - method='fresnel' - ) - lens = elements.ThinLens( - simulation_parameters, - focal_length - ) - fs2 = elements.FreeSpace( - simulation_parameters, - focal_length, - method='fresnel' - ) + aperture = elements.RoundAperture(simulation_parameters, radius) + fs1 = elements.FreeSpace(simulation_parameters, distance, method="fresnel") + lens = elements.ThinLens(simulation_parameters, focal_length) + fs2 = elements.FreeSpace(simulation_parameters, focal_length, method="fresnel") # field calculations G = sv.Wavefront.plane_wave(simulation_parameters) @@ -121,12 +116,13 @@ def test_circular_aperture( # ---------------------------------- # results testing # ---------------------------------- - assert torch.mean( - torch.abs(field_before_lens_lp - field_before_lens_sv) - ) / before_lens_norm < 0.01 + assert ( + torch.mean(torch.abs(field_before_lens_lp - field_before_lens_sv)) + / before_lens_norm + < 0.01 + ) + + assert torch.mean(torch.abs(field_output_lp - field_output_sv)) / output_norm < 0.01 - assert torch.mean( - torch.abs(field_output_lp - field_output_sv) - ) / output_norm < 0.01 # TODO: ΡΡ€Π°Π²Π½ΠΈΡ‚ΡŒ ΠΏΠΈΠΊΠΎΠ²ΡƒΡŽ ΠΌΠΎΡ‰Π½ΠΎΡΡ‚ΡŒ ΠΈ ΠΏΠΎΠ»ΠΎΠΆΠ΅Π½ΠΈΠ΅ максимумов diff --git a/tests/test_logging.py b/tests/test_logging.py index ea5e5fe..721d837 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -9,17 +9,28 @@ @pytest.mark.parametrize( - 'input', [ - torch.tensor([10., 10.]), - torch.tensor(20), - svetlanna.Parameter(11.), - svetlanna.ConstrainedParameter(11., min_value=-10, max_value=100), - 123, - 123., - None - ] + "input", + [ + torch.tensor([10.0, 10.0]), + torch.tensor(20), + svetlanna.Parameter(11.0), + svetlanna.ConstrainedParameter(11.0, min_value=-10, max_value=100), + 123, + 123.0, + None, + ], ) def test_agr_short_description(input): + """ + Tests the agr_short_description function with various inputs. + + Args: + input: The input to be tested. Can be a torch.Tensor, + svetlanna.Parameter, svetlanna.ConstrainedParameter, number, or None. + + Returns: + None + """ if isinstance(input, torch.Tensor): # test for torch.Tensor assert agr_short_description(input) == ( @@ -28,44 +39,61 @@ def test_agr_short_description(input): ) else: # test for other types - assert agr_short_description(input) == f'{type(input)}' + assert agr_short_description(input) == f"{type(input)}" def test_log_message(capfd, caplog): + """ + Tests the log_message function with both 'print' and 'logging' types. + + Args: + capfd: A pytest fixture for capturing stdout/stderr. + caplog: A pytest fixture for capturing logging messages. + + Returns: + None + """ # test for 'print' type - svetlanna.set_debug_logging(False, type='print') # set 'print' type + svetlanna.set_debug_logging(False, type="print") # set 'print' type # log_message prints the message even if mode set to False! - log_message('test message') # print message + log_message("test message") # print message out, _ = capfd.readouterr() # read stdout - assert out == 'test message\n' + assert out == "test message\n" # test for 'logging' type - svetlanna.set_debug_logging(False, type='logging') # set 'logging' type + svetlanna.set_debug_logging(False, type="logging") # set 'logging' type - logger = logging.getLogger('svetlanna.logging') # get logger + logger = logging.getLogger("svetlanna.logging") # get logger logger.setLevel(logging.DEBUG) # set logging level to DEBUG - log_message('test message') # print message + log_message("test message") # print message assert caplog.record_tuples == [ ("svetlanna.logging", logging.DEBUG, "test message") ] -@pytest.mark.parametrize( - 'input', [ - 1, 1., (1, 2), tuple() - ] -) -@pytest.mark.parametrize( - 'output', [ - 1, 1., (1, 2), tuple() - ] -) +@pytest.mark.parametrize("input", [1, 1.0, (1, 2), tuple()]) +@pytest.mark.parametrize("output", [1, 1.0, (1, 2), tuple()]) def test_forward_logging_hook(input, output, capfd): - svetlanna.set_debug_logging(False, type='print') # set 'print' type + """ + Tests the forward logging hook functionality. + + This test verifies that the forward logging hook does not log anything for + modules that are not instances of svetlanna.elements.Element and logs the + input/output types when called with an Element-like module. + + Args: + input: The input to the forward method. + output: The output from the forward method. + capfd: A pytest fixture for capturing stdout. + + Returns: + None + """ + svetlanna.set_debug_logging(False, type="print") # set 'print' type # test for random element ignorance class NotElement(torch.nn.Module): @@ -74,7 +102,7 @@ class NotElement(torch.nn.Module): forward_logging_hook(NotElement(), input, output) out, _ = capfd.readouterr() # read stdout - assert out == '' + assert out == "" # test for elements class ElementLike(svetlanna.elements.Element): @@ -84,48 +112,51 @@ def forward(self, *args, **kwargs): element = ElementLike( simulation_parameters=svetlanna.SimulationParameters( axes={ - 'H': torch.linspace(-1, 1, 10), - 'W': torch.linspace(-1, 1, 10), - 'wavelength': 1. + "H": torch.linspace(-1, 1, 10), + "W": torch.linspace(-1, 1, 10), + "wavelength": 1.0, } ) ) forward_logging_hook(element, input, output) - expected_out = 'The forward method of ElementLike was computed' + expected_out = "The forward method of ElementLike was computed" input = input if isinstance(input, tuple) else (input,) output = output if isinstance(output, tuple) else (output,) for i, _input in enumerate(input): - expected_out += f'\n input {i}: {type(_input)}' + expected_out += f"\n input {i}: {type(_input)}" for i, _output in enumerate(output): - expected_out += f'\n output {i}: {type(_output)}' + expected_out += f"\n output {i}: {type(_output)}" out, _ = capfd.readouterr() # read stdout - assert out == expected_out + '\n' + assert out == expected_out + "\n" -@pytest.mark.parametrize( - 'input', [ - 1, 1., (1, 2), tuple() - ] -) -@pytest.mark.parametrize( - 'type_', [ - 'Parameter', 'Buffer', 'Module' - ] -) +@pytest.mark.parametrize("input", [1, 1.0, (1, 2), tuple()]) +@pytest.mark.parametrize("type_", ["Parameter", "Buffer", "Module"]) def test_register_logging_hook(input, type_, capfd): - svetlanna.set_debug_logging(False, type='print') # set 'print' type + """ + Tests the registration of logging hooks for different input types and element types. + + Args: + input: The input to be logged. + type_: The type of element being logged (e.g., 'Parameter', 'Buffer', 'Module'). + capfd: A pytest fixture for capturing stdout/stderr. + + Returns: + None + """ + svetlanna.set_debug_logging(False, type="print") # set 'print' type # test for random element ignorance class NotElement(torch.nn.Module): pass - register_logging_hook(NotElement(), 'test_name', input, type_) + register_logging_hook(NotElement(), "test_name", input, type_) out, _ = capfd.readouterr() # read stdout - assert out == '' + assert out == "" # test for elements class ElementLike(svetlanna.elements.Element): @@ -135,49 +166,59 @@ def forward(self, *args, **kwargs): element = ElementLike( simulation_parameters=svetlanna.SimulationParameters( axes={ - 'H': torch.linspace(-1, 1, 10), - 'W': torch.linspace(-1, 1, 10), - 'wavelength': 1. + "H": torch.linspace(-1, 1, 10), + "W": torch.linspace(-1, 1, 10), + "wavelength": 1.0, } ) ) - register_logging_hook(element, 'test_name', input, type_) + register_logging_hook(element, "test_name", input, type_) - expected_out = f'{type_} of {element._get_name()} was registered with name test_name:' - expected_out += f'\n {type(input)}' + expected_out = ( + f"{type_} of {element._get_name()} was registered with name test_name:" + ) + expected_out += f"\n {type(input)}" out, _ = capfd.readouterr() # read stdout - assert out == expected_out + '\n' + assert out == expected_out + "\n" -@pytest.mark.parametrize( - 'input', [ - 1, 1., (1, 2), tuple() - ] -) -@pytest.mark.parametrize( - 'output', [ - 1, 1., (1, 2), tuple() - ] -) +@pytest.mark.parametrize("input", [1, 1.0, (1, 2), tuple()]) +@pytest.mark.parametrize("output", [1, 1.0, (1, 2), tuple()]) def test_set_debug_logging(input, output, capfd, caplog): + """ + Tests the set_debug_logging function with different configurations. + + This test verifies that svetlanna.set_debug_logging correctly handles + different debug logging types ('print' and 'logging') and enables/disables + debugging as expected. It also checks for correct output when debugging is + enabled, ensuring the expected messages are printed or logged. + + Args: + input: Input values to be passed to the ElementLike forward method. + output: Output values returned by the ElementLike forward method. + capfd: Pytest fixture for capturing stdout and stderr. + caplog: Pytest fixture for capturing log messages. + + Returns: + None + """ # test wrong type with pytest.raises(ValueError): - svetlanna.set_debug_logging(False, type='123') # type: ignore + svetlanna.set_debug_logging(False, type="123") # type: ignore input = input if isinstance(input, tuple) else (input,) output = output if isinstance(output, tuple) else (output,) class ElementLike(svetlanna.elements.Element): def __init__( - self, - simulation_parameters: svetlanna.SimulationParameters + self, simulation_parameters: svetlanna.SimulationParameters ) -> None: super().__init__(simulation_parameters) self.a = torch.nn.Module() - self.b = svetlanna.Parameter(123.) - self.register_buffer('c', torch.tensor(123.)) + self.b = svetlanna.Parameter(123.0) + self.register_buffer("c", torch.tensor(123.0)) def forward(self, *args, **kwargs): return output @@ -186,9 +227,9 @@ def run_element(): element = ElementLike( simulation_parameters=svetlanna.SimulationParameters( axes={ - 'H': torch.linspace(-1, 1, 10), - 'W': torch.linspace(-1, 1, 10), - 'wavelength': 1. + "H": torch.linspace(-1, 1, 10), + "W": torch.linspace(-1, 1, 10), + "wavelength": 1.0, } ) ) @@ -206,50 +247,44 @@ def run_element(): "Buffer of ElementLike was registered with name c:\n" " shape=torch.Size([]), dtype=torch.float32, device=cpu" ) - expected_output_4 = ( - "The forward method of ElementLike was computed" - ) + expected_output_4 = "The forward method of ElementLike was computed" for i, _input in enumerate(input): - expected_output_4 += f'\n input {i}: {type(_input)}' + expected_output_4 += f"\n input {i}: {type(_input)}" for i, _output in enumerate(output): - expected_output_4 += f'\n output {i}: {type(_output)}' + expected_output_4 += f"\n output {i}: {type(_output)}" expected_outputs = [ expected_output_1, expected_output_2, expected_output_3, - expected_output_4 + expected_output_4, ] # test for print type - svetlanna.set_debug_logging(True, type='print') + svetlanna.set_debug_logging(True, type="print") run_element() out, _ = capfd.readouterr() # read stdout - assert out == '\n'.join(expected_outputs) + '\n' + assert out == "\n".join(expected_outputs) + "\n" # test for print type, with disabled debug logging - svetlanna.set_debug_logging(False, type='print') + svetlanna.set_debug_logging(False, type="print") run_element() out, _ = capfd.readouterr() # read stdout - assert out == '' + assert out == "" # test for logging type - svetlanna.set_debug_logging(True, type='logging') - logger = logging.getLogger('svetlanna.logging') # get logger + svetlanna.set_debug_logging(True, type="logging") + logger = logging.getLogger("svetlanna.logging") # get logger logger.setLevel(logging.DEBUG) # set logging level to DEBUG run_element() assert caplog.record_tuples == [ - ( - "svetlanna.logging", - logging.DEBUG, - message - ) for message in expected_outputs + ("svetlanna.logging", logging.DEBUG, message) for message in expected_outputs ] caplog.clear() # clear caplog assert caplog.record_tuples == [] # test for logging type, with disabled debug logging - svetlanna.set_debug_logging(False, type='logging') + svetlanna.set_debug_logging(False, type="logging") run_element() assert caplog.record_tuples == [] diff --git a/tests/test_nonlinear_element.py b/tests/test_nonlinear_element.py index 21f94c6..7c0c8da 100644 --- a/tests/test_nonlinear_element.py +++ b/tests/test_nonlinear_element.py @@ -14,12 +14,23 @@ "oy_nodes", "wavelength_test", "response_function", - "response_parameters" + "response_parameters", ] def func(x, a, b): - return a / 1 + torch.exp(-b*x) + """ + Computes a value based on the given inputs. + + Args: + x: The input value. + a: A constant value. + b: Another constant value. + + Returns: + torch.Tensor: The computed result of the formula a / 1 + torch.exp(-b*x). + """ + return a / 1 + torch.exp(-b * x) @pytest.mark.parametrize( @@ -27,9 +38,25 @@ def func(x, a, b): [ (10, 10, 1000, 1200, 1064 * 1e-6, lambda x: x**2, None), (4, 4, 1300, 1000, 1064 * 1e-6, lambda x: torch.sin(x) + x**3, None), - (15, 8, 1319, 917, 1e-6 * torch.tensor([330, 660, 1064]), lambda x: torch.sin(x) + x**3, None), # noqa: E501 - (16, 7, 500, 868, 1e-6 * torch.tensor([330, 660, 1064]), func, {"a": 1., "b": 9.}) # noqa: E501 - ] + ( + 15, + 8, + 1319, + 917, + 1e-6 * torch.tensor([330, 660, 1064]), + lambda x: torch.sin(x) + x**3, + None, + ), # noqa: E501 + ( + 16, + 7, + 500, + 868, + 1e-6 * torch.tensor([330, 660, 1064]), + func, + {"a": 1.0, "b": 9.0}, + ), # noqa: E501 + ], ) def test_nonlinear_element( ox_size: float, @@ -38,13 +65,29 @@ def test_nonlinear_element( oy_nodes: int, wavelength_test: float, response_function: Callable[[torch.Tensor], torch.Tensor], - response_parameters: Dict + response_parameters: Dict, ): + """ + Tests the NonlinearElement class with various parameters. + + Args: + ox_size: The size of the simulation area in the x-direction. + oy_size: The size of the simulation area in the y-direction. + ox_nodes: The number of nodes in the x-direction. + oy_nodes: The number of nodes in the y-direction. + wavelength_test: The wavelength of the incident light. + response_function: The nonlinear response function to use. + response_parameters: A dictionary of parameters for the response function. + + Returns: + None. This method asserts that the output field from the NonlinearElement + matches the analytically calculated output field. + """ params = SimulationParameters( { - 'W': torch.linspace(-ox_size/2, ox_size/2, ox_nodes), - 'H': torch.linspace(-oy_size/2, oy_size/2, oy_nodes), - 'wavelength': wavelength_test + "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes), + "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes), + "wavelength": wavelength_test, } ) @@ -53,7 +96,7 @@ def test_nonlinear_element( nle = elements.NonlinearElement( simulation_parameters=params, response_function=response_function, - response_parameters=response_parameters + response_parameters=response_parameters, ) incident_amplitude = torch.abs(incident_field) @@ -65,7 +108,7 @@ def test_nonlinear_element( output_amplitude = response_function( incident_amplitude, response_parameters[keys[0]], - response_parameters[keys[1]] + response_parameters[keys[1]], ) else: diff --git a/tests/test_parameters.py b/tests/test_parameters.py index e034c0e..7806761 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -5,19 +5,21 @@ def test_inner_parameter_storage(): - torch_parameter = torch.nn.Parameter(torch.tensor(1.)) - torch_tensor = torch.tensor(2.) - sv_parameter = Parameter(torch.tensor(3.)) + """ + Tests the inner parameter storage module.""" + torch_parameter = torch.nn.Parameter(torch.tensor(1.0)) + torch_tensor = torch.tensor(2.0) + sv_parameter = Parameter(torch.tensor(3.0)) sv_bounded_parameter = ConstrainedParameter( - torch.tensor(4.), min_value=0., max_value=2. + torch.tensor(4.0), min_value=0.0, max_value=2.0 ) storage = InnerParameterStorageModule( { - 'value1': torch_parameter, - 'value2': torch_tensor, - 'value3': sv_parameter, - 'value4': sv_bounded_parameter, + "value1": torch_parameter, + "value2": torch_tensor, + "value3": sv_parameter, + "value4": sv_bounded_parameter, } ) @@ -39,18 +41,32 @@ def test_inner_parameter_storage(): with pytest.raises(TypeError): InnerParameterStorageModule( { - 'a': 123, # type: ignore + "a": 123, # type: ignore } ) @pytest.mark.parametrize( - "parameter", [ - Parameter(data=123.), - ConstrainedParameter(data=123., min_value=0, max_value=300) - ] + "parameter", + [ + Parameter(data=123.0), + ConstrainedParameter(data=123.0, min_value=0, max_value=300), + ], ) def test_new(parameter: Parameter | ConstrainedParameter): + """ + Tests the properties of a new parameter object. + + This function verifies that the provided parameter is a PyTorch tensor, + not a `torch.nn.Parameter`, and behaves correctly with basic tensor operations. + It also checks the types of inner attributes. + + Args: + parameter: The Parameter or ConstrainedParameter instance to test. + + Returns: + None + """ # check if parameter is a tensor and not a torch parameter assert isinstance(parameter, torch.Tensor) assert not isinstance(parameter, torch.nn.Parameter) @@ -64,16 +80,30 @@ def test_new(parameter: Parameter | ConstrainedParameter): @pytest.mark.parametrize( - "parameter", [ - Parameter(data=123.), - ConstrainedParameter(data=123., min_value=0, max_value=300) - ] + "parameter", + [ + Parameter(data=123.0), + ConstrainedParameter(data=123.0, min_value=0, max_value=300), + ], ) def test_behavior_as_a_tensor(parameter): - a = 123. + """ + Tests the behavior of the parameter when used as a tensor. + + This tests multiplication and exponentiation operations with a scalar, + both directly and using torch functions to ensure proper handling via + __torch_function__. + + Args: + parameter: The parameter object to test. + + Returns: + None + """ + a = 123.0 b = 10 res_mul = torch.tensor(a * b) # a * b - res_pow = torch.tensor(a ** b) # a + b + res_pow = torch.tensor(a**b) # a + b # test __torch_function__ for args processing torch.testing.assert_close(parameter * b, res_mul) @@ -84,30 +114,40 @@ def test_behavior_as_a_tensor(parameter): def test_bounded_parameter_inner_value(): - data = 2. - min_value = 0. - max_value = 5. + """ + Tests the inner parameter value of ConstrainedParameter with and without custom bound functions. + + This test verifies that the inner parameter correctly maps to the constrained data + value using both the default sigmoid function and a user-defined bound function. + It also checks the `value` property when a custom bound function is provided. + + Args: + None + + Returns: + None + """ + data = 2.0 + min_value = 0.0 + max_value = 5.0 # === default bound_func === parameter = ConstrainedParameter( - data=data, - min_value=min_value, - max_value=max_value + data=data, min_value=min_value, max_value=max_value ) # test inner parameter value torch.testing.assert_close( - (max_value-min_value) * torch.sigmoid(parameter.inner_parameter) - + min_value, - torch.tensor(data) + (max_value - min_value) * torch.sigmoid(parameter.inner_parameter) + min_value, + torch.tensor(data), ) # === custom bound_func === def bound_func(x: torch.Tensor) -> torch.Tensor: if x < 0: - return torch.tensor(0.) + return torch.tensor(0.0) if x > 1: - return torch.tensor(1.) + return torch.tensor(1.0) return x def inv_bound_func(x: torch.Tensor) -> torch.Tensor: @@ -118,7 +158,7 @@ def inv_bound_func(x: torch.Tensor) -> torch.Tensor: min_value=min_value, max_value=max_value, bound_func=bound_func, - inv_bound_func=inv_bound_func + inv_bound_func=inv_bound_func, ) # test `value` property @@ -126,19 +166,28 @@ def inv_bound_func(x: torch.Tensor) -> torch.Tensor: # test inner parameter value torch.testing.assert_close( - (max_value-min_value) * bound_func(parameter.inner_parameter) - + min_value, - torch.tensor(data) + (max_value - min_value) * bound_func(parameter.inner_parameter) + min_value, + torch.tensor(data), ) @pytest.mark.parametrize( - "parameter", [ - Parameter(data=123.), - ConstrainedParameter(data=123., min_value=0, max_value=300) - ] + "parameter", + [ + Parameter(data=123.0), + ConstrainedParameter(data=123.0, min_value=0, max_value=300), + ], ) def test_repr(parameter): + """ + Tests the repr of a parameter. + + Args: + parameter: The parameter to test. + + Returns: + None: This function only asserts that `repr(parameter)` does not raise an exception. + """ assert repr(parameter) @@ -146,35 +195,42 @@ def test_repr(parameter): ("device",), [ pytest.param( - 'cuda', + "cuda", marks=pytest.mark.skipif( - not torch.cuda.is_available(), - reason="cuda is not available" - ) + not torch.cuda.is_available(), reason="cuda is not available" + ), ), pytest.param( - 'mps', + "mps", marks=pytest.mark.skipif( - not torch.backends.mps.is_available(), - reason="mps is not available" - ) - ) - ] + not torch.backends.mps.is_available(), reason="mps is not available" + ), + ), + ], ) def test_storage_to_device(device): - torch_parameter = torch.nn.Parameter(torch.tensor(1.)) - torch_tensor = torch.tensor(2.) - sv_parameter = Parameter(torch.tensor(3.)) + """ + Tests moving an InnerParameterStorageModule to a specified device and back to CPU. + + Args: + device: The device to move the storage to (e.g., 'cuda', 'mps'). + + Returns: + None + """ + torch_parameter = torch.nn.Parameter(torch.tensor(1.0)) + torch_tensor = torch.tensor(2.0) + sv_parameter = Parameter(torch.tensor(3.0)) sv_bounded_parameter = ConstrainedParameter( - torch.tensor(4.), min_value=0., max_value=2. + torch.tensor(4.0), min_value=0.0, max_value=2.0 ) storage = InnerParameterStorageModule( { - 'value1': torch_parameter, - 'value2': torch_tensor, - 'value3': sv_parameter, - 'value4': sv_bounded_parameter, + "value1": torch_parameter, + "value2": torch_tensor, + "value3": sv_parameter, + "value4": sv_bounded_parameter, } ) @@ -185,44 +241,51 @@ def test_storage_to_device(device): assert storage.value3.device.type == device assert storage.value4.device.type == device - storage.to(device='cpu') + storage.to(device="cpu") # test if all values has been transferred to the cpu - assert storage.value1.device.type == 'cpu' - assert storage.value2.device.type == 'cpu' - assert storage.value3.device.type == 'cpu' - assert storage.value4.device.type == 'cpu' + assert storage.value1.device.type == "cpu" + assert storage.value2.device.type == "cpu" + assert storage.value3.device.type == "cpu" + assert storage.value4.device.type == "cpu" @pytest.mark.parametrize( ("device",), [ pytest.param( - 'cuda', + "cuda", marks=pytest.mark.skipif( - not torch.cuda.is_available(), - reason="cuda is not available" - ) + not torch.cuda.is_available(), reason="cuda is not available" + ), ), pytest.param( - 'mps', + "mps", marks=pytest.mark.skipif( - not torch.backends.mps.is_available(), - reason="mps is not available" - ) - ) - ] + not torch.backends.mps.is_available(), reason="mps is not available" + ), + ), + ], ) @pytest.mark.parametrize( - "parameter", [ - Parameter(data=torch.tensor(123., dtype=torch.float32)), + "parameter", + [ + Parameter(data=torch.tensor(123.0, dtype=torch.float32)), ConstrainedParameter( - data=torch.tensor(123., dtype=torch.float32), - min_value=0, - max_value=300 - ) - ] + data=torch.tensor(123.0, dtype=torch.float32), min_value=0, max_value=300 + ), + ], ) def test_parameter_to_device(device, parameter): + """ + Tests that a Parameter or ConstrainedParameter can be moved to the specified device. + + Args: + device: The device to move the parameter to (e.g., 'cuda', 'mps'). + parameter: The Parameter or ConstrainedParameter instance to test. + + Returns: + None + """ # transferred_parameter = parameter.to(device) # assert transferred_parameter.device.type == device # assert transferred_parameter.inner_storage.device.type == device diff --git a/tests/test_phase_retrieval.py b/tests/test_phase_retrieval.py index fcd5238..04a7dd9 100644 --- a/tests/test_phase_retrieval.py +++ b/tests/test_phase_retrieval.py @@ -10,11 +10,13 @@ def test_retrieve_phase_api(capsys): + """ + Tests the retrieve_phase API with different scenarios.""" params = SimulationParameters( { - 'W': torch.linspace(-1, 1, 10), - 'H': torch.linspace(-1, 1, 10), - 'wavelength': 1 + "W": torch.linspace(-1, 1, 10), + "H": torch.linspace(-1, 1, 10), + "wavelength": 1, } ) # no initial_phase, no additional options @@ -30,7 +32,7 @@ def test_retrieve_phase_api(capsys): Wavefront.plane_wave(params).abs(), LinearOpticalSetup([]), Wavefront.plane_wave(params).abs(), - method='abs' # type: ignore + method="abs", # type: ignore ) # Test disp option for the intensity profile problem type @@ -38,12 +40,10 @@ def test_retrieve_phase_api(capsys): Wavefront.plane_wave(params).abs(), LinearOpticalSetup([]), Wavefront.plane_wave(params).abs(), - options={ - 'disp': True - } + options={"disp": True}, ) - captured = capsys.readouterr().out.split('\n')[0] - assert captured == 'Type of problem: generate intensity profile' + captured = capsys.readouterr().out.split("\n")[0] + assert captured == "Type of problem: generate intensity profile" # Test disp option for the phase reconstruction problem type phase_retrieval.retrieve_phase( @@ -52,12 +52,10 @@ def test_retrieve_phase_api(capsys): Wavefront.plane_wave(params).abs(), target_phase=torch.zeros((10, 10)), target_region=torch.zeros((10, 10)), - options={ - 'disp': True - } + options={"disp": True}, ) - captured = capsys.readouterr().out.split('\n')[0] - assert captured == 'Type of problem: phase reconstruction' + captured = capsys.readouterr().out.split("\n")[0] + assert captured == "Type of problem: phase reconstruction" parameters = [ @@ -67,20 +65,20 @@ def test_retrieve_phase_api(capsys): "oy_nodes", "wavelength_test", "waist_radius_test", - "distance_test" + "distance_test", ] @pytest.mark.parametrize( parameters, [ - (10, 10, 200, 200, 0.025, 0.7, 100.), - (7, 8, 200, 200, 0.02, 0.7, 150.), - (15, 8, 300, 200, 0.02, 0.5, 120.), - ] + (10, 10, 200, 200, 0.025, 0.7, 100.0), + (7, 8, 200, 200, 0.02, 0.7, 150.0), + (15, 8, 300, 200, 0.02, 0.5, 120.0), + ], ) -@pytest.mark.parametrize('use_phase_target', [True, False]) -@pytest.mark.parametrize('method', ['HIO', 'GS']) +@pytest.mark.parametrize("use_phase_target", [True, False]) +@pytest.mark.parametrize("method", ["HIO", "GS"]) def test_phase_retrieval( ox_size: float, oy_size: float, @@ -90,7 +88,7 @@ def test_phase_retrieval( waist_radius_test: float, distance_test: float, use_phase_target: bool, - method: phase_retrieval.Method + method: phase_retrieval.Method, ): """Test for phase reconstruction problem and generate target intensity problem using HIO and Gerchberg-Saxton algorithms on the example of a @@ -122,32 +120,27 @@ def test_phase_retrieval( params = SimulationParameters( { - 'W': torch.linspace(-ox_size/2, ox_size/2, ox_nodes), - 'H': torch.linspace(-oy_size/2, oy_size/2, oy_nodes), - 'wavelength': wavelength_test + "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes), + "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes), + "wavelength": wavelength_test, } ) - x_grid, y_grid = params.meshgrid('W', 'H') + x_grid, y_grid = params.meshgrid("W", "H") field_before_lens1 = Wavefront.gaussian_beam( simulation_parameters=params, distance=0.05 * distance_test, - waist_radius=waist_radius_test + waist_radius=waist_radius_test, ) intensity_source = field_before_lens1.intensity - lens1 = elements.ThinLens( - simulation_parameters=params, - focal_length=distance_test - ) + lens1 = elements.ThinLens(simulation_parameters=params, focal_length=distance_test) field_after_lens1 = lens1(field_before_lens1) free_space1 = elements.FreeSpace( - simulation_parameters=params, - distance=0.05 * distance_test, - method='AS' + simulation_parameters=params, distance=0.05 * distance_test, method="AS" ) output_field = free_space1(field_after_lens1) @@ -159,7 +152,7 @@ def test_phase_retrieval( # target phase profile for phase reconstruction problem if use_phase_target: phase_target = torch.angle(output_field) - target_region = (x_grid**2 + y_grid ** 2 <= 0.12).float() + target_region = (x_grid**2 + y_grid**2 <= 0.12).float() result_hio = phase_retrieval.retrieve_phase( source_intensity=intensity_source, @@ -169,10 +162,7 @@ def test_phase_retrieval( target_region=target_region, initial_phase=torch.full_like(intensity_target, 0), method=method, - options={ - 'maxiter': 100, - 'constant_factor': 0.5 - } + options={"maxiter": 100, "constant_factor": 0.5}, ) else: result_hio = phase_retrieval.retrieve_phase( @@ -181,16 +171,13 @@ def test_phase_retrieval( target_intensity=intensity_target, initial_phase=torch.full_like(intensity_target, 0), method=method, - options={ - 'maxiter': 100, - 'constant_factor': 0.5 - } + options={"maxiter": 100, "constant_factor": 0.5}, ) errors = result_hio.cost_func_evolution # test if the error decreases - assert np.sum(np.diff(errors) < 0) > 0.7 * (len(errors)-1) + assert np.sum(np.diff(errors) < 0) > 0.7 * (len(errors) - 1) assert (errors[0] - errors[-1]) / errors[0] > 0.6 @@ -204,7 +191,7 @@ def test_phase_retrieval( "waist_radius_test", "distance_test", "radius_test", - "error_energy" + "error_energy", ] @@ -212,10 +199,10 @@ def test_phase_retrieval( @pytest.mark.parametrize( parameters_4f, [ - (10, 10, 1000, 1000, 660 * 1e-6, 0.5, 100., 10., 1e-4), - (7, 8, 1000, 1000, 1064 * 1e-6, 0.7, 150., 10., 1e-4), - (15, 8, 1500, 1000, 550 * 1e-6, 0.5, 120., 10., 1e-4) - ] + (10, 10, 1000, 1000, 660 * 1e-6, 0.5, 100.0, 10.0, 1e-4), + (7, 8, 1000, 1000, 1064 * 1e-6, 0.7, 150.0, 10.0, 1e-4), + (15, 8, 1500, 1000, 550 * 1e-6, 0.5, 120.0, 10.0, 1e-4), + ], ) def test_4f_system( ox_size: float, @@ -226,7 +213,7 @@ def test_4f_system( waist_radius_test: float, distance_test: float, radius_test: float, - error_energy: float + error_energy: float, ): """Test for phase reconstruction problem using HIO algorithm on the example of a 4f optical setup @@ -254,31 +241,29 @@ def test_4f_system( """ x_linear = torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes) y_linear = torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes) - x_grid, y_grid = torch.meshgrid(x_linear, y_linear, indexing='xy') + x_grid, y_grid = torch.meshgrid(x_linear, y_linear, indexing="xy") dx = ox_size / ox_nodes dy = oy_size / oy_nodes params = SimulationParameters( { - 'W': torch.linspace(-ox_size/2, ox_size/2, ox_nodes), - 'H': torch.linspace(-oy_size/2, oy_size/2, oy_nodes), - 'wavelength': wavelength_test + "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes), + "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes), + "wavelength": wavelength_test, } ) field_before_lens1 = Wavefront.gaussian_beam( simulation_parameters=params, distance=distance_test, - waist_radius=waist_radius_test + waist_radius=waist_radius_test, ) intensity_source = field_before_lens1.intensity.detach().numpy() lens1 = elements.ThinLens( - simulation_parameters=params, - focal_length=distance_test, - radius=radius_test + simulation_parameters=params, focal_length=distance_test, radius=radius_test ) field_after_lens1 = lens1.forward(input_field=field_before_lens1) @@ -286,38 +271,37 @@ def test_4f_system( free_space1 = elements.FreeSpace( simulation_parameters=params, distance=torch.tensor(2 * distance_test), - method='AS' + method="AS", ) field_before_lens2 = free_space1.forward(input_field=field_after_lens1) lens2 = elements.ThinLens( - simulation_parameters=params, - focal_length=distance_test, - radius=radius_test + simulation_parameters=params, focal_length=distance_test, radius=radius_test ) field_after_lens2 = lens2.forward(input_field=field_before_lens2) free_space2 = elements.FreeSpace( - simulation_parameters=params, - distance=torch.tensor(distance_test), - method='AS' + simulation_parameters=params, distance=torch.tensor(distance_test), method="AS" ) output_field = free_space2.forward(input_field=field_after_lens2) phase_target = ( - torch.angle(output_field) + 2 * torch.pi * ( - torch.angle(output_field) < 0. - ).float() - ).detach().numpy() + ( + torch.angle(output_field) + + 2 * torch.pi * (torch.angle(output_field) < 0.0).float() + ) + .detach() + .numpy() + ) intensity_target = output_field.intensity.detach().numpy() optical_setup = LinearOpticalSetup([free_space1, lens2, free_space2]) - goal = (x_grid**2 + y_grid ** 2 <= 2).float() + goal = (x_grid**2 + y_grid**2 <= 2).float() result_hio = phase_retrieval.retrieve_phase( source_intensity=torch.tensor(intensity_source), @@ -326,7 +310,7 @@ def test_4f_system( target_phase=torch.tensor(phase_target), target_region=goal, initial_phase=None, - method='HIO', + method="HIO", ) phase_reconstruction_hio = result_hio.solution @@ -335,14 +319,11 @@ def test_4f_system( mask_reconstruction_hio = phase_reconstruction_hio // step field_after_slm = elements.SpatialLightModulator( - simulation_parameters=params, - mask=mask_reconstruction_hio + simulation_parameters=params, mask=mask_reconstruction_hio ).forward(field_before_lens1) output_field = optical_setup.forward(field_after_slm) - intensity_target_opt = torch.pow( - torch.abs(output_field), 2 - ).detach().numpy() + intensity_target_opt = torch.pow(torch.abs(output_field), 2).detach().numpy() energy_reconstruction_hio = np.sum(intensity_target_opt) * dx * dy energy_true = np.sum(intensity_target) * dx * dy diff --git a/tests/test_reservoir.py b/tests/test_reservoir.py index 36380f9..0c47941 100644 --- a/tests/test_reservoir.py +++ b/tests/test_reservoir.py @@ -4,24 +4,29 @@ def test_queue(): + """ + Tests the functionality of the feedback queue in SimpleReservoir. + + This tests appends to, pops from and drops the feedback queue within a + SimpleReservoir instance, verifying correct behavior with different queue lengths + relative to the specified delay. + + Args: + None + + Returns: + None + """ sim_params = SimulationParameters( - { - 'W': torch.tensor([0]), - 'H': torch.tensor([0]), - 'wavelength': 1. - } + {"W": torch.tensor([0]), "H": torch.tensor([0]), "wavelength": 1.0} ) reservoir = SimpleReservoir( sim_params, - nonlinear_element=DiffractiveLayer( - sim_params, mask=torch.tensor([[0.]]) - ), - delay_element=DiffractiveLayer( - sim_params, mask=torch.tensor([[0.]]) - ), + nonlinear_element=DiffractiveLayer(sim_params, mask=torch.tensor([[0.0]])), + delay_element=DiffractiveLayer(sim_params, mask=torch.tensor([[0.0]])), delay=2, feedback_gain=1, - input_gain=1 + input_gain=1, ) # feedback queue is empty @@ -54,20 +59,25 @@ def test_queue(): def test_forward(): + """ + Tests the forward pass of the SimpleReservoir. + + This test verifies that the reservoir correctly implements feedback and delay, + and that the output matches expectations for both initial iterations and after + the delay line is populated. + + Args: + None + + Returns: + None + """ sim_params = SimulationParameters( - { - 'W': torch.tensor([0]), - 'H': torch.tensor([0]), - 'wavelength': 1. - } + {"W": torch.tensor([0]), "H": torch.tensor([0]), "wavelength": 1.0} ) - nonlinear_element = DiffractiveLayer( - sim_params, mask=torch.tensor([[0.]]) - ) - delay_element = DiffractiveLayer( - sim_params, mask=torch.tensor([[0.]]) - ) + nonlinear_element = DiffractiveLayer(sim_params, mask=torch.tensor([[0.0]])) + delay_element = DiffractiveLayer(sim_params, mask=torch.tensor([[0.0]])) feedback_gain = 0.8 input_gain = 0.6 delay = 5 @@ -78,7 +88,7 @@ def test_forward(): delay_element=delay_element, delay=delay, feedback_gain=feedback_gain, - input_gain=input_gain + input_gain=input_gain, ) wf = Wavefront.plane_wave(sim_params) @@ -102,8 +112,6 @@ def test_forward(): # hard coded very first delay line related contribution wf_out_expected = nonlinear_element( - input_gain * wf + feedback_gain * nonlinear_element( - input_gain * wf - ) + input_gain * wf + feedback_gain * nonlinear_element(input_gain * wf) ) assert torch.allclose(wf_out, wf_out_expected) diff --git a/tests/test_setup.py b/tests/test_setup.py index b020740..e4a2d4b 100644 --- a/tests/test_setup.py +++ b/tests/test_setup.py @@ -9,30 +9,82 @@ class SimpleElement(Element): - def __init__( - self, - a: Any, - simulation_parameters: SimulationParameters - ) -> None: + """ + Represents a simple optical element that scales a wavefront. + + Attributes: + a: The scaling factor for the wavefront. + simulation_parameters: Parameters used for the simulation. + + Methods: + __init__: Initializes the instance with given parameters. + forward: Applies a scaling factor to the input wavefront. + """ + + def __init__(self, a: Any, simulation_parameters: SimulationParameters) -> None: + """ + Initializes the instance with given parameters. + + Args: + a: The value for attribute 'a'. + simulation_parameters: Parameters used for the simulation. + + Returns: + None + """ super().__init__(simulation_parameters) self.a = a def forward(self, incident_wavefront: Wavefront) -> Wavefront: + """ + Applies a scaling factor to the input wavefront. + + Args: + incident_wavefront: The input Wavefront object. + + Returns: + Wavefront: A new Wavefront object representing the scaled wavefront. + """ return incident_wavefront * self.a class ReversableSimpleElement(SimpleElement): + """ + Reverses a wavefront using a scaling factor.""" + def reverse(self, wavefront): + """ + Reverses a wavefront by multiplying it with the scaling factor. + + Args: + wavefront: The wavefront to be reversed. + + Returns: + A numpy array representing the reversed wavefront. + """ return wavefront * self.a def test_init(): + """ + Tests the initialization and forward pass of LinearOpticalSetup. + + This test creates a LinearOpticalSetup with three SimpleElements, + verifies that the internal neural network is a torch.nn.Module, + and checks if the forward pass correctly applies the element's 'a' value. + + Parameters: + None + + Returns: + None + """ sim_params = SimulationParameters( { - 'W': torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10), - 'H': torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10), - 'wavelength': 1 + "W": torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10), + "H": torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10), + "wavelength": 1, } ) @@ -41,9 +93,7 @@ def test_init(): el2 = SimpleElement(a=a, simulation_parameters=sim_params) el3 = SimpleElement(a=a, simulation_parameters=sim_params) - setup = LinearOpticalSetup(elements=[ - el1, el2, el3 - ]) + setup = LinearOpticalSetup(elements=[el1, el2, el3]) assert isinstance(setup.net, torch.nn.Module) @@ -53,18 +103,21 @@ def test_init(): def test_init_warning(): + """ + Tests that a UserWarning is raised when initializing LinearOpticalSetup with identical simulation parameters. + """ sim_params1 = SimulationParameters( { - 'W': torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10), - 'H': torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10), - 'wavelength': 1 + "W": torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10), + "H": torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10), + "wavelength": 1, } ) sim_params2 = SimulationParameters( { - 'W': torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10), - 'H': torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10), - 'wavelength': 1 + "W": torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10), + "H": torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10), + "wavelength": 1, } ) @@ -73,50 +126,48 @@ def test_init_warning(): el2 = SimpleElement(a=a, simulation_parameters=sim_params2) with pytest.warns(UserWarning): - LinearOpticalSetup(elements=[ - el1, el2 - ]) + LinearOpticalSetup(elements=[el1, el2]) @pytest.mark.parametrize( ("device",), [ pytest.param( - 'cuda', + "cuda", marks=pytest.mark.skipif( - not torch.cuda.is_available(), - reason="cuda is not available" - ) + not torch.cuda.is_available(), reason="cuda is not available" + ), ), pytest.param( - 'mps', + "mps", marks=pytest.mark.skipif( - not torch.backends.mps.is_available(), - reason="mps is not available" - ) - ) - ] + not torch.backends.mps.is_available(), reason="mps is not available" + ), + ), + ], ) def test_to_device(device): + """ + Tests that moving the network to a device also moves its parameters. + + Args: + device: The device to move the network to ('cuda' or 'mps'). + + Returns: + None + """ sim_params = SimulationParameters( { - 'W': torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10), - 'H': torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10), - 'wavelength': 1 + "W": torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10), + "H": torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10), + "wavelength": 1, } ) - el1 = SimpleElement( - a=Parameter(2.), - simulation_parameters=sim_params - ) + el1 = SimpleElement(a=Parameter(2.0), simulation_parameters=sim_params) el2 = SimpleElement( - a=ConstrainedParameter( - data=0.5, - min_value=0, - max_value=1 - ), - simulation_parameters=sim_params + a=ConstrainedParameter(data=0.5, min_value=0, max_value=1), + simulation_parameters=sim_params, ) setup = LinearOpticalSetup([el1, el2]) @@ -131,33 +182,38 @@ def test_to_device(device): def test_reverse(): + """ + Tests the reverse method of LinearOpticalSetup. + + Args: + None + + Returns: + None + """ # test empty setup setup = LinearOpticalSetup(elements=[]) - wf = torch.Tensor([2., 3.]) + wf = torch.Tensor([2.0, 3.0]) assert setup.reverse(wf) is wf # test setup sim_params = SimulationParameters( { - 'W': torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10), - 'H': torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10), - 'wavelength': 1 + "W": torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10), + "H": torch.linspace(-5 * ureg.mm, 5 * ureg.mm, 10), + "wavelength": 1, } ) a = torch.tensor(2) # test unreversable element el = SimpleElement(a=a, simulation_parameters=sim_params) - setup = LinearOpticalSetup(elements=[ - el - ]) + setup = LinearOpticalSetup(elements=[el]) with pytest.raises(TypeError): setup.reverse(wf) # test reversable element el = ReversableSimpleElement(a=a, simulation_parameters=sim_params) - setup = LinearOpticalSetup(elements=[ - el - ]) + setup = LinearOpticalSetup(elements=[el]) torch.testing.assert_close(setup.reverse(wf), wf * a) diff --git a/tests/test_simulation_parameters.py b/tests/test_simulation_parameters.py index 648ce01..2736d7c 100644 --- a/tests/test_simulation_parameters.py +++ b/tests/test_simulation_parameters.py @@ -5,66 +5,82 @@ def test_axes(): + """ + Tests the Axes class for correct axis handling and validation.""" # Test required axes are actually required with pytest.raises(ValueError): Axes({}) - SimulationParameters({ - 'W': torch.linspace(-1, 1, 10), - }) + SimulationParameters( + { + "W": torch.linspace(-1, 1, 10), + } + ) with pytest.raises(ValueError): - Axes({ - 'W': torch.linspace(-1, 1, 10), - 'H': torch.linspace(-1, 1, 10), - }) - Axes({ - 'W': torch.linspace(-1, 1, 10), - 'H': torch.linspace(-1, 1, 10), - 'wavelength': torch.tensor(312) - }) + Axes( + { + "W": torch.linspace(-1, 1, 10), + "H": torch.linspace(-1, 1, 10), + } + ) + Axes( + { + "W": torch.linspace(-1, 1, 10), + "H": torch.linspace(-1, 1, 10), + "wavelength": torch.tensor(312), + } + ) # Test with wrong H and W axis shape with pytest.raises(ValueError): - Axes({ - 'W': torch.tensor([[10.]]), # wrong shape - 'H': torch.linspace(-1, 1, 10), - 'wavelength': torch.tensor(312) - }) + Axes( + { + "W": torch.tensor([[10.0]]), # wrong shape + "H": torch.linspace(-1, 1, 10), + "wavelength": torch.tensor(312), + } + ) with pytest.raises(ValueError): - Axes({ - 'W': torch.linspace(-1, 1, 10), - 'H': torch.tensor([[10.]]), # wrong shape - 'wavelength': torch.tensor(312) - }) + Axes( + { + "W": torch.linspace(-1, 1, 10), + "H": torch.tensor([[10.0]]), # wrong shape + "wavelength": torch.tensor(312), + } + ) # Test with wrong additional axes shape with pytest.raises(ValueError): - Axes({ - 'W': torch.linspace(-1, 1, 10), - 'H': torch.linspace(-1, 1, 10), - 'wavelength': torch.tensor(312), - 'pol': torch.tensor([[1.2, 3.4]]), # wrong shape - }) + Axes( + { + "W": torch.linspace(-1, 1, 10), + "H": torch.linspace(-1, 1, 10), + "wavelength": torch.tensor(312), + "pol": torch.tensor([[1.2, 3.4]]), # wrong shape + } + ) w_axis = torch.linspace(-1, 1, 10) - pol_axis = torch.tensor([1., 0.]) - axes = Axes({ - 'W': w_axis, - 'H': torch.linspace(-1, 1, 10), - 'wavelength': torch.tensor(312), - 'pol': pol_axis, - }) + pol_axis = torch.tensor([1.0, 0.0]) + axes = Axes( + { + "W": w_axis, + "H": torch.linspace(-1, 1, 10), + "wavelength": torch.tensor(312), + "pol": pol_axis, + } + ) # Test names of non-scalar axes - assert axes.names == ('pol', 'H', 'W') + assert axes.names == ("pol", "H", "W") # Test indices - assert axes.index('pol') == -3 - assert axes.index('H') == -2 - assert axes.index('W') == -1 + assert axes.index("pol") == -3 + assert axes.index("H") == -2 + assert axes.index("W") == -1 with pytest.raises(AxisNotFound): - axes.index('wavelength') # scalar axis + axes.index("wavelength") # scalar axis with pytest.raises(AxisNotFound): - axes.index('t') # axis does not exists + axes.index("t") # axis does not exists # Test __getattribute__ for named axes assert axes.W is w_axis @@ -76,86 +92,106 @@ def test_axes(): assert axes.W is w_axis # Test __getitem__ - assert axes['W'] is w_axis - assert axes['pol'] is pol_axis - assert axes['wavelength'] == torch.tensor(312) + assert axes["W"] is w_axis + assert axes["pol"] is pol_axis + assert axes["wavelength"] == torch.tensor(312) with pytest.raises(AxisNotFound): - axes['t'] # axis does not exists + axes["t"] # axis does not exists # Test disabled __setitem__ with pytest.raises(RuntimeError): - axes['W'] = w_axis + axes["W"] = w_axis with pytest.raises(RuntimeError): - axes['pol'] = pol_axis + axes["pol"] = pol_axis with pytest.raises(RuntimeError): - axes['t'] = 123 + axes["t"] = 123 # Test __dir__ - assert set(dir(axes)) == {'H', 'W', 'pol', 'wavelength'} + assert set(dir(axes)) == {"H", "W", "pol", "wavelength"} def test_simulation_parameters(): + """ + Tests the SimulationParameters class functionality. + + This tests the __getitem__ method, meshgrid generation, and axes size retrieval + of the SimulationParameters class with various parameters. It also checks for + expected warnings when accessing non-existent axes. + + Args: + None + + Returns: + None + """ w_axis = torch.linspace(-1, 2, 13) h_axis = torch.linspace(-12, -3, 25) - pol_axis = torch.tensor([1., 0.]) - sim_params = SimulationParameters({ - 'W': w_axis, - 'H': h_axis, - 'wavelength': 123., - 'pol': pol_axis, - 't': 0.0 - }) + pol_axis = torch.tensor([1.0, 0.0]) + sim_params = SimulationParameters( + {"W": w_axis, "H": h_axis, "wavelength": 123.0, "pol": pol_axis, "t": 0.0} + ) # Test __getitem__ - assert sim_params['W'] is w_axis - assert sim_params['pol'] is pol_axis - assert sim_params['t'] == 0 - assert sim_params['wavelength'] == 123 + assert sim_params["W"] is w_axis + assert sim_params["pol"] is pol_axis + assert sim_params["t"] == 0 + assert sim_params["wavelength"] == 123 # Test meshgrid - meshgrid_W, meshgrid_H = sim_params.meshgrid('W', 'H') + meshgrid_W, meshgrid_H = sim_params.meshgrid("W", "H") assert torch.allclose(meshgrid_W, w_axis[None, ...]) assert torch.allclose(meshgrid_H, h_axis[..., None]) - meshgrid_W1, meshgrid_W2 = sim_params.meshgrid('W', 'W') + meshgrid_W1, meshgrid_W2 = sim_params.meshgrid("W", "W") assert torch.allclose(meshgrid_W1, w_axis[None, ...]) assert torch.allclose(meshgrid_W2, w_axis[..., None]) - meshgrid_H, meshgrid_wl = sim_params.meshgrid('H', 'wavelength') + meshgrid_H, meshgrid_wl = sim_params.meshgrid("H", "wavelength") assert torch.allclose(meshgrid_H, h_axis[None, ...]) - assert torch.allclose(meshgrid_wl, torch.tensor(123.)[None]) + assert torch.allclose(meshgrid_wl, torch.tensor(123.0)[None]) # Test axes_size - assert sim_params.axes_size(('W',)) == torch.Size((13,)) - assert sim_params.axes_size(('wavelength', 'H')) == torch.Size((1, 25)) - assert sim_params.axes_size(('H',)) == torch.Size((25,)) + assert sim_params.axes_size(("W",)) == torch.Size((13,)) + assert sim_params.axes_size(("wavelength", "H")) == torch.Size((1, 25)) + assert sim_params.axes_size(("H",)) == torch.Size((25,)) with pytest.warns(UserWarning): # non existing axis - assert sim_params.axes_size(('a', 'H')) == torch.Size((0, 25)) + assert sim_params.axes_size(("a", "H")) == torch.Size((0, 25)) @pytest.fixture( - scope='function', + scope="function", params=[ - 'cpu', + "cpu", pytest.param( - 'cuda', + "cuda", marks=pytest.mark.skipif( - not torch.cuda.is_available(), - reason="cuda is not available" - ) + not torch.cuda.is_available(), reason="cuda is not available" + ), ), pytest.param( - 'mps', + "mps", marks=pytest.mark.skipif( - not torch.backends.mps.is_available(), - reason="mps is not available" - ) - ) - ] + not torch.backends.mps.is_available(), reason="mps is not available" + ), + ), + ], ) def default_device(request): + """ + Provides a fixture for setting the default PyTorch device. + + This fixture iterates through 'cpu', 'cuda' (if available), and 'mps' (if available) + as parameters, temporarily setting the default device to each one within the scope of a test function. + It yields the current default device and then restores the original default device after the test completes. + + Args: + request: The pytest request object providing access to fixture parameters. + + Returns: + str: The currently set default PyTorch device (e.g., 'cpu', 'cuda', or 'mps'). + """ # Set the default device old_default_device = torch.get_default_device() torch.set_default_device(request.param) @@ -164,23 +200,40 @@ def default_device(request): def test_device(default_device: torch.device): - w_axis = torch.linspace(-1, 2, 13, device='cpu') + """ + Tests device placement and transfer for SimulationParameters. + + This method verifies that the SimulationParameters class correctly handles + device placement of axis tensors, raises errors when appropriate, and + that the `to()` method functions as expected for transferring data between devices. + + Args: + default_device: The default device to use for testing. + + Returns: + None + """ + w_axis = torch.linspace(-1, 2, 13, device="cpu") h_axis = torch.linspace(-12, -3, 25) - if default_device.type != 'cpu': + if default_device.type != "cpu": with pytest.raises(ValueError): - SimulationParameters({ - 'W': w_axis, - 'H': h_axis.to(default_device), - 'wavelength': 123., - }) + SimulationParameters( + { + "W": w_axis, + "H": h_axis.to(default_device), + "wavelength": 123.0, + } + ) # Test if in the following case the axis tensor is located on the device - sim_params = SimulationParameters({ # type: ignore - 'W': [1., 2., 3.], - 'H': [1., 2., 3.], - 'wavelength': 123. - }) + sim_params = SimulationParameters( + { # type: ignore + "W": [1.0, 2.0, 3.0], + "H": [1.0, 2.0, 3.0], + "wavelength": 123.0, + } + ) assert sim_params.axes.W.device == default_device # Test to() method @@ -188,10 +241,10 @@ def test_device(default_device: torch.device): assert transferred_sim_params is sim_params # Test to('cpu') - transferred_sim_params = sim_params.to('cpu') - assert transferred_sim_params.device.type == 'cpu' # type: ignore + transferred_sim_params = sim_params.to("cpu") + assert transferred_sim_params.device.type == "cpu" # type: ignore for axis_name in sim_params.axes.names: - assert transferred_sim_params.axes[axis_name].device.type == 'cpu' + assert transferred_sim_params.axes[axis_name].device.type == "cpu" # And back transferred_sim_params = transferred_sim_params.to(default_device) assert transferred_sim_params.device == default_device diff --git a/tests/test_slm.py b/tests/test_slm.py index d143434..6df2f5a 100644 --- a/tests/test_slm.py +++ b/tests/test_slm.py @@ -15,32 +15,72 @@ "width", "mode", "mask", - "resized_mask" + "resized_mask", ] @pytest.mark.parametrize( parameters_mask, [ - (10, 10, 4, 4, 10, 10, "nearest", torch.Tensor([[1., 2.], [3., 4.]]), - torch.Tensor([ - [1., 1., 2., 2.,], - [1., 1., 2., 2.,], - [3., 3., 4., 4.,], - [3., 3., 4., 4.,] - ]) - ), - (15, 8, 6, 6, 8, 15, "nearest", torch.Tensor([[2., 3.], [4., 5.]]), - torch.Tensor([ - [2., 2., 2., 3., 3., 3.], - [2., 2., 2., 3., 3., 3.], - [2., 2., 2., 3., 3., 3.], - [4., 4., 4., 5., 5., 5.], - [4., 4., 4., 5., 5., 5.], - [4., 4., 4., 5., 5., 5.] - ]) - ) - ] + ( + 10, + 10, + 4, + 4, + 10, + 10, + "nearest", + torch.Tensor([[1.0, 2.0], [3.0, 4.0]]), + torch.Tensor( + [ + [ + 1.0, + 1.0, + 2.0, + 2.0, + ], + [ + 1.0, + 1.0, + 2.0, + 2.0, + ], + [ + 3.0, + 3.0, + 4.0, + 4.0, + ], + [ + 3.0, + 3.0, + 4.0, + 4.0, + ], + ] + ), + ), + ( + 15, + 8, + 6, + 6, + 8, + 15, + "nearest", + torch.Tensor([[2.0, 3.0], [4.0, 5.0]]), + torch.Tensor( + [ + [2.0, 2.0, 2.0, 3.0, 3.0, 3.0], + [2.0, 2.0, 2.0, 3.0, 3.0, 3.0], + [2.0, 2.0, 2.0, 3.0, 3.0, 3.0], + [4.0, 4.0, 4.0, 5.0, 5.0, 5.0], + [4.0, 4.0, 4.0, 5.0, 5.0, 5.0], + [4.0, 4.0, 4.0, 5.0, 5.0, 5.0], + ] + ), + ), + ], ) def test_slm_mask( ox_size: float, @@ -51,25 +91,38 @@ def test_slm_mask( width: float, mode: str, mask: torch.Tensor, - resized_mask: torch.Tensor + resized_mask: torch.Tensor, ): + """ + Tests the SpatialLightModulator's resized mask functionality. + + Args: + ox_size: The size of the x-axis in simulation units. + oy_size: The size of the y-axis in simulation units. + ox_nodes: The number of nodes along the x-axis. + oy_nodes: The number of nodes along the y-axis. + height: The height of the SLM mask. + width: The width of the SLM mask. + mode: The resizing mode (e.g., "nearest"). + mask: The input mask tensor. + resized_mask: The expected resized mask tensor. + + Returns: + None. Raises an AssertionError if the resized masks do not match. + """ x_length = torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes) y_length = torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes) params = SimulationParameters( axes={ - 'W': x_length, - 'H': y_length, - 'wavelength': 1064 * 1e-6, + "W": x_length, + "H": y_length, + "wavelength": 1064 * 1e-6, } ) slm = elements.SpatialLightModulator( - simulation_parameters=params, - mask=mask, - height=height, - width=width, - mode=mode + simulation_parameters=params, mask=mask, height=height, width=width, mode=mode ) slm.get_aperture resized_mask_slm = slm.resized_mask @@ -77,14 +130,7 @@ def test_slm_mask( assert torch.allclose(resized_mask, resized_mask_slm) -parameters_resize = [ - "ox_size", - "oy_size", - "ox_nodes", - "oy_nodes", - "mode", - "mask" -] +parameters_resize = ["ox_size", "oy_size", "ox_nodes", "oy_nodes", "mode", "mask"] @pytest.mark.parametrize( @@ -95,8 +141,8 @@ def test_slm_mask( (6, 5, 1570, 632, "bicubic", torch.rand(100, 100)), (15.8, 8.61, 109, 120, "area", torch.rand(100, 100)), (19, 7, 1089, 2007, "nearest-exact", torch.rand(100, 100)), - (15, 8, 300, 400, "nearest-exact", torch.rand(1080, 1920)) - ] + (15, 8, 300, 400, "nearest-exact", torch.rand(1080, 1920)), + ], ) def test_slm_resize( ox_size: float, @@ -104,16 +150,30 @@ def test_slm_resize( ox_nodes: int, oy_nodes: int, mode: str, - mask: torch.Tensor + mask: torch.Tensor, ): + """ + Tests the resizing functionality of the SpatialLightModulator. + + Args: + ox_size: The size of the x-axis. + oy_size: The size of the y-axis. + ox_nodes: The number of nodes along the x-axis. + oy_nodes: The number of nodes along the y-axis. + mode: The resizing mode to use (e.g., "nearest", "bilinear"). + mask: The input mask tensor. + + Returns: + None. This function asserts properties of the resized mask and aperture. + """ x_length = torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes) y_length = torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes) params = SimulationParameters( axes={ - 'W': x_length, - 'H': y_length, - 'wavelength': 1064 * 1e-6, + "W": x_length, + "H": y_length, + "wavelength": 1064 * 1e-6, } ) @@ -122,7 +182,7 @@ def test_slm_resize( mask=mask, height=oy_size, width=ox_size, - mode=mode + mode=mode, ) aperture = slm.get_aperture resized_mask = slm.resized_mask @@ -141,51 +201,87 @@ def test_slm_resize( "height", "width", "location", - "aperture" + "aperture", ] @pytest.mark.parametrize( parameters_aperture, [ - (6, 5, 6, 5, 3, 3, (-1.5, -1), - torch.tensor([ - [1., 1., 1., 0., 0., 0.], - [1., 1., 1., 0., 0., 0.], - [1., 1., 1., 0., 0., 0.], - [0., 0., 0., 0., 0., 0.], - [0., 0., 0., 0., 0., 0.] - ]) - ), - (6, 5, 6, 5, 3, 3, (-1.5, 1), - torch.tensor([ - [0., 0., 0., 0., 0., 0.], - [0., 0., 0., 0., 0., 0.], - [1., 1., 1., 0., 0., 0.], - [1., 1., 1., 0., 0., 0.], - [1., 1., 1., 0., 0., 0.] - ]) - ), - (6, 5, 6, 5, 3, 3, (1.5, 1), - torch.tensor([ - [0., 0., 0., 0., 0., 0.], - [0., 0., 0., 0., 0., 0.], - [0., 0., 0., 1., 1., 1.], - [0., 0., 0., 1., 1., 1.], - [0., 0., 0., 1., 1., 1.] - ]) - ), - (6, 5, 6, 5, 3, 3, (1.5, -1), - torch.tensor([ - [0., 0., 0., 1., 1., 1.], - [0., 0., 0., 1., 1., 1.], - [0., 0., 0., 1., 1., 1.], - [0., 0., 0., 0., 0., 0.], - [0., 0., 0., 0., 0., 0.] - ]) - ), - (6, 5, 6, 5, 3, 3, (-100, 100), torch.zeros(5, 6)) - ] + ( + 6, + 5, + 6, + 5, + 3, + 3, + (-1.5, -1), + torch.tensor( + [ + [1.0, 1.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + ), + ( + 6, + 5, + 6, + 5, + 3, + 3, + (-1.5, 1), + torch.tensor( + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0, 0.0, 0.0], + ] + ), + ), + ( + 6, + 5, + 6, + 5, + 3, + 3, + (1.5, 1), + torch.tensor( + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0, 1.0], + ] + ), + ), + ( + 6, + 5, + 6, + 5, + 3, + 3, + (1.5, -1), + torch.tensor( + [ + [0.0, 0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + ), + (6, 5, 6, 5, 3, 3, (-100, 100), torch.zeros(5, 6)), + ], ) def test_slm_aperture( ox_size: float, @@ -195,16 +291,32 @@ def test_slm_aperture( height: float, width: float, location: tuple, - aperture: torch.Tensor + aperture: torch.Tensor, ): + """ + Tests the SpatialLightModulator aperture with different parameters. + + Args: + ox_size: The size of the x-axis. + oy_size: The size of the y-axis. + ox_nodes: The number of nodes in the x-axis. + oy_nodes: The number of nodes in the y-axis. + height: The height of the SLM. + width: The width of the SLM. + location: The location of the SLM. + aperture: The expected aperture tensor. + + Returns: + None. Asserts that the calculated aperture matches the expected value. + """ x_length = torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes) y_length = torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes) params = SimulationParameters( axes={ - 'W': x_length, - 'H': y_length, - 'wavelength': 1064 * 1e-6, + "W": x_length, + "H": y_length, + "wavelength": 1064 * 1e-6, } ) @@ -213,7 +325,7 @@ def test_slm_aperture( mask=torch.zeros(ox_nodes, oy_nodes), height=height, width=width, - location=location + location=location, ) slm.get_aperture @@ -230,29 +342,61 @@ def test_slm_aperture( "location", "mask", "wavelength", - "mode" + "mode", ] @pytest.mark.parametrize( parameters_propagation, [ - (10, 10, 1000, 1200, 3., 4., (0., 0.), - torch.rand(100, 100), torch.linspace(330, 1064, 4) * 1e-6, - "nearest" - ), - (9, 12, 1000, 1200, 3., 4., (-2., 3.), - torch.rand(100, 100), torch.linspace(330, 1064, 4) * 1e-6, - "bilinear" - ), - (15.8, 8.61, 1920, 1080, 2., 2., (2., 0.), - torch.rand(100, 100), torch.linspace(330, 1064, 4) * 1e-6, - "bicubic" - ), - (30, 15, 1920*2, 1080*2, 15.8, 8.61, (-1., 1.), - torch.rand(1080, 1920), torch.linspace(330, 1064, 4) * 1e-6, - "bicubic" - ), + ( + 10, + 10, + 1000, + 1200, + 3.0, + 4.0, + (0.0, 0.0), + torch.rand(100, 100), + torch.linspace(330, 1064, 4) * 1e-6, + "nearest", + ), + ( + 9, + 12, + 1000, + 1200, + 3.0, + 4.0, + (-2.0, 3.0), + torch.rand(100, 100), + torch.linspace(330, 1064, 4) * 1e-6, + "bilinear", + ), + ( + 15.8, + 8.61, + 1920, + 1080, + 2.0, + 2.0, + (2.0, 0.0), + torch.rand(100, 100), + torch.linspace(330, 1064, 4) * 1e-6, + "bicubic", + ), + ( + 30, + 15, + 1920 * 2, + 1080 * 2, + 15.8, + 8.61, + (-1.0, 1.0), + torch.rand(1080, 1920), + torch.linspace(330, 1064, 4) * 1e-6, + "bicubic", + ), # (5, 9, 100, 400, 3., 4., (-1., 8.), # torch.rand(100, 100), torch.linspace(330, 1064, 4) * 1e-6, # "area" @@ -261,7 +405,7 @@ def test_slm_aperture( # torch.rand(100, 100), torch.linspace(330, 1064, 4) * 1e-6, # "nearest-exact" # ) - ] + ], ) def test_slm_propagation( ox_size: float, @@ -273,17 +417,35 @@ def test_slm_propagation( location: tuple, mask: torch.Tensor, wavelength: float, - mode: str + mode: str, ): + """ + Tests the propagation of a wavefront through an SLM. + + Args: + ox_size: The size of the x-axis in meters. + oy_size: The size of the y-axis in meters. + ox_nodes: The number of nodes along the x-axis. + oy_nodes: The number of nodes along the y-axis. + height: The height of the SLM in meters. + width: The width of the SLM in meters. + location: The location of the SLM center in (x, y) coordinates. + mask: A 2D tensor representing the mask applied by the SLM. + wavelength: The wavelength of light in meters. + mode: The interpolation mode used for the SLM ("nearest", "bilinear", etc.). + + Returns: + None. Asserts that the output field has the correct size. + """ x_length = torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes) y_length = torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes) params = SimulationParameters( axes={ - 'W': x_length, - 'H': y_length, - 'wavelength': wavelength, + "W": x_length, + "H": y_length, + "wavelength": wavelength, } ) @@ -293,13 +455,11 @@ def test_slm_propagation( height=height, width=width, location=location, - mode=mode + mode=mode, ) incident_field = w.Wavefront.gaussian_beam( - simulation_parameters=params, - waist_radius=2., - distance=100 + simulation_parameters=params, waist_radius=2.0, distance=100 ) transmitted_field = slm(incident_field) diff --git a/tests/test_specs.py b/tests/test_specs.py index 95f324d..166c2ad 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -17,22 +17,48 @@ def test_save_context_get_new_filepath(tmp_path): + """ + Tests the get_new_filepath method of ParameterSaveContext. + + This tests that subsequent calls to `get_new_filepath` with the same extension + return unique filenames within the specified directory, incrementing a counter. + + Args: + tmp_path: A temporary path for testing purposes. + + Returns: + None + """ context = ParameterSaveContext( - parameter_name='test', + parameter_name="test", directory=tmp_path, ) # test filename path = context.get_new_filepath("testext") - assert Path(tmp_path, 'test_0.testext') == path + assert Path(tmp_path, "test_0.testext") == path path = context.get_new_filepath("testext") - assert Path(tmp_path, 'test_1.testext') == path + assert Path(tmp_path, "test_1.testext") == path def test_save_context_file(tmp_path): + """ + Tests saving a file within the parameter save context. + + This test creates a ParameterSaveContext, writes data to a new file obtained + through the context, and verifies that the data was written correctly and + that subsequent calls for new files generate different filenames while + remaining in the same directory. + + Args: + tmp_path: A temporary path where the test file will be created. + + Returns: + None + """ context = ParameterSaveContext( - parameter_name='test', + parameter_name="test", directory=tmp_path, ) @@ -43,7 +69,7 @@ def test_save_context_file(tmp_path): file.write(text.encode()) # check if the test text is written into the file - with open(path, 'rb') as file: + with open(path, "rb") as file: assert file.readline() == text.encode() # check if the new file will have another name, but same folder @@ -53,8 +79,17 @@ def test_save_context_file(tmp_path): def test_save_context_rel_filepath(tmp_path): + """ + Tests that the relative filepath is correctly computed. + + Args: + tmp_path: A temporary path to use for testing. + + Returns: + None + """ contexts = ParameterSaveContext( - parameter_name='test', + parameter_name="test", directory=tmp_path, ) @@ -69,23 +104,28 @@ def test_save_context_rel_filepath(tmp_path): ############################################################################### -@pytest.mark.usefixtures('tmp_path') -@pytest.mark.parametrize( - 'mode', ('1', 'L', 'LA', 'I', 'P', 'RGB', 'RGBA') -) +@pytest.mark.usefixtures("tmp_path") +@pytest.mark.parametrize("mode", ("1", "L", "LA", "I", "P", "RGB", "RGBA")) def test_image_repr_draw_image(tmp_path, mode): + """ + Tests that drawing an ImageRepr to a file produces the correct image. + + Args: + tmp_path: A temporary path for saving the image. + mode: The color mode of the image (e.g., '1', 'L', 'RGB'). + + Returns: + None + """ context = ParameterSaveContext( - parameter_name='test', + parameter_name="test", directory=tmp_path, ) # TODO: mode-based test image_to_draw = np.array([[1]]) - repr = ImageRepr( - value=image_to_draw, - mode=mode - ) + repr = ImageRepr(value=image_to_draw, mode=mode) # draw image to the path path = context.get_new_filepath("png") @@ -96,14 +136,21 @@ def test_image_repr_draw_image(tmp_path, mode): def test_image_repr_to(tmp_path): + """ + Tests the to_str, to_markdown, and to_html methods of ImageRepr. + + Args: + tmp_path: A temporary path for ParameterSaveContext. + + Returns: + None + """ context = ParameterSaveContext( - parameter_name='test', + parameter_name="test", directory=tmp_path, ) - repr = ImageRepr( - value=np.array([[0.5]]) - ) + repr = ImageRepr(value=np.array([[0.5]])) # test for all possible exports test_out = StringIO() @@ -125,14 +172,21 @@ def test_image_repr_to(tmp_path): def test_repr_repr_to(tmp_path): + """ + Tests the to_str, to_markdown, and to_html methods of ReprRepr. + + Args: + tmp_path: A temporary path for ParameterSaveContext. + + Returns: + None + """ context = ParameterSaveContext( - parameter_name='test', + parameter_name="test", directory=tmp_path, ) - repr = ReprRepr( - value=np.array([[0.5]]) - ) + repr = ReprRepr(value=np.array([[0.5]])) # test for all possible exports test_out = StringIO() @@ -153,26 +207,36 @@ def test_repr_repr_to(tmp_path): ############################################################################### -@pytest.mark.usefixtures('tmp_path') +@pytest.mark.usefixtures("tmp_path") @pytest.mark.parametrize( - 'value', ( + "value", + ( np.random.rand(2, 2), torch.rand(2, 2), torch.tensor(random.random()), random.random(), random.randint(0, 10), - ConstrainedParameter(10, 0, 20) - ) + ConstrainedParameter(10, 0, 20), + ), ) def test_pretty_repr_repr_to(tmp_path, value, monkeypatch): + """ + Tests the conversion of a value to string, markdown, and HTML representations. + + Args: + tmp_path: A temporary path for testing file operations (pytest fixture). + value: The value to be converted. Can be a numpy array, torch tensor, float, int or ConstrainedParameter object. + monkeypatch: A pytest monkeypatch fixture used for mocking imports. + + Returns: + None + """ context = ParameterSaveContext( - parameter_name='test', + parameter_name="test", directory=tmp_path, ) - repr = PrettyReprRepr( - value=value - ) + repr = PrettyReprRepr(value=value) # test for all possible exports test_out = StringIO() @@ -198,13 +262,13 @@ def import_with_no_svetlanna(name, *args, **kwargs): raise ImportError return original_import(name, *args, **kwargs) - monkeypatch.setattr(builtins, '__import__', import_with_no_svetlanna) + monkeypatch.setattr(builtins, "__import__", import_with_no_svetlanna) # Test if default string is written to the buffer test_out = StringIO() repr.to_str(test_out, context) class_name = value.__class__.__name__ - assert test_out.getvalue() == f'{class_name}\n{value.item()}\n' + assert test_out.getvalue() == f"{class_name}\n{value.item()}\n" ############################################################################### @@ -212,23 +276,32 @@ def import_with_no_svetlanna(name, *args, **kwargs): ############################################################################### -@pytest.mark.usefixtures('tmp_path') +@pytest.mark.usefixtures("tmp_path") @pytest.mark.parametrize( - 'value', ( + "value", + ( np.random.rand(10, 10), random.random(), random.randint(0, 10), - ) + ), ) def test_npy_file_repr_save_to_file(tmp_path, value): + """ + Saves a value to a file using NpyFileRepr and verifies it can be loaded correctly. + + Args: + tmp_path: A temporary path for saving the file. + value: The value to save (NumPy array or scalar). + + Returns: + None + """ context = ParameterSaveContext( - parameter_name='test', + parameter_name="test", directory=tmp_path, ) - repr = NpyFileRepr( - value=value - ) + repr = NpyFileRepr(value=value) # save the value to a new file path = context.get_new_filepath("png") @@ -239,14 +312,21 @@ def test_npy_file_repr_save_to_file(tmp_path, value): def test_npy_file_repr_to(tmp_path): + """ + Tests the to_str and to_markdown methods of NpyFileRepr. + + Args: + tmp_path: A temporary path for creating files. + + Returns: + None + """ context = ParameterSaveContext( - parameter_name='test', + parameter_name="test", directory=tmp_path, ) - repr = NpyFileRepr( - value=np.array([[0.5]]) - ) + repr = NpyFileRepr(value=np.array([[0.5]])) # test for all possible exports test_out = StringIO() @@ -264,15 +344,21 @@ def test_npy_file_repr_to(tmp_path): def test_parameter_specs(): + """ + Tests the ParameterSpecs class with a simple example. + + Args: + None + + Returns: + None + """ representations = ( ReprRepr(123), ReprRepr(321), ) - specs = ParameterSpecs( - parameter_name='test', - representations=representations - ) + specs = ParameterSpecs(parameter_name="test", representations=representations) assert specs.representations == representations @@ -283,18 +369,15 @@ def test_parameter_specs(): def test_subelement_specs(): - specs = [ - ParameterSpecs('test', []) - ] + """ + Tests that the SubelementSpecs class correctly stores its subelement.""" + specs = [ParameterSpecs("test", [])] class Subelement: def to_specs(self): return specs subelement = Subelement() - subelement_specs = SubelementSpecs( - 'test_type', - subelement - ) + subelement_specs = SubelementSpecs("test_type", subelement) assert subelement_specs.subelement is subelement diff --git a/tests/test_specs_writer.py b/tests/test_specs_writer.py index ec040a8..e63ded0 100644 --- a/tests/test_specs_writer.py +++ b/tests/test_specs_writer.py @@ -18,60 +18,99 @@ class SpecsTestElement(Element): + """ + Tests a set of parameter specifications or subelement specifications.""" def __init__( self, simulation_parameters: SimulationParameters, - test_specs: Iterable[ParameterSpecs | SubelementSpecs] + test_specs: Iterable[ParameterSpecs | SubelementSpecs], ) -> None: + """ + Initializes a new instance of the class. + + Args: + simulation_parameters: The simulation parameters to use. + test_specs: An iterable of parameter specifications or subelement specifications + to be tested. + + Returns: + None + """ super().__init__(simulation_parameters) self.test_specs = test_specs def forward(self, incident_wavefront: Wavefront) -> Wavefront: + """ + Passes the wavefront to the next layer. + + This method simply calls the `forward` method of the parent class, + effectively passing the incident wavefront along for further processing. + + Args: + incident_wavefront: The input wavefront representing the current state + of the wave propagation. + + Returns: + Wavefront: The output wavefront after being processed by the next layer. + """ return super().forward(incident_wavefront) def to_specs(self) -> Iterable[ParameterSpecs | SubelementSpecs]: + """ + Returns the test specifications. + + Args: + None + + Returns: + Iterable[ParameterSpecs | SubelementSpecs]: An iterable of parameter or subelement specifications. + """ return self.test_specs def test_context_generator(tmp_path): + """ + Tests the context generator with a sample SpecsTestElement. + + This test verifies that the context generator produces contexts with the + correct parameter names, representations, and indices. It also tests + the output of writing specs to string, markdown, and HTML formats. + + Args: + tmp_path: A temporary path for testing purposes. + + Returns: + None + """ simulation_parameters = SimulationParameters( - axes={'W': torch.tensor([0]), 'H': torch.tensor([0]), 'wavelength': 1} + axes={"W": torch.tensor([0]), "H": torch.tensor([0]), "wavelength": 1} ) - repr1 = ReprRepr(1.) - repr2 = ReprRepr(2.) - repr3 = ReprRepr(3.) - repr4 = ReprRepr(4.) + repr1 = ReprRepr(1.0) + repr2 = ReprRepr(2.0) + repr3 = ReprRepr(3.0) + repr4 = ReprRepr(4.0) subelement = SpecsTestElement(simulation_parameters, []) element = SpecsTestElement( simulation_parameters=simulation_parameters, test_specs=[ + ParameterSpecs("test1", [repr1, repr2]), ParameterSpecs( - 'test1', - [ - repr1, - repr2 - ] - ), - ParameterSpecs( - 'test2', + "test2", [ repr3, - ] + ], ), ParameterSpecs( - 'test2', # test for the parameter spec with the same name + "test2", # test for the parameter spec with the same name [ repr4, - ] + ], ), - SubelementSpecs( - 'test_type', - subelement - ) - ] + SubelementSpecs("test_type", subelement), + ], ) subelements: list[SubelementSpecs] = [] @@ -83,10 +122,10 @@ def test_context_generator(tmp_path): assert subelements[0].subelement is subelement # test parameter_name attribute - assert contexts[0].parameter_name.value == 'test1' - assert contexts[1].parameter_name.value == 'test1' - assert contexts[2].parameter_name.value == 'test2' - assert contexts[3].parameter_name.value == 'test2' + assert contexts[0].parameter_name.value == "test1" + assert contexts[1].parameter_name.value == "test1" + assert contexts[2].parameter_name.value == "test2" + assert contexts[3].parameter_name.value == "test2" assert contexts[0].parameter_name.index == 0 assert contexts[1].parameter_name.index == 0 @@ -125,52 +164,81 @@ def test_context_generator(tmp_path): def test_ElementInTree(): + """ + Tests the creation and copying of an ElementInTree object. + + This test verifies that creating a copy of _ElementInTree correctly + shares references to immutable attributes (element, element_index, children) + while creating a new instance for mutable attributes (subelement_type). + + Parameters: + None + + Returns: + None + """ simulation_parameters = SimulationParameters( - axes={'W': torch.tensor([0]), 'H': torch.tensor([0]), 'wavelength': 1} + axes={"W": torch.tensor([0]), "H": torch.tensor([0]), "wavelength": 1} ) element = SpecsTestElement( - simulation_parameters=simulation_parameters, - test_specs=[] + simulation_parameters=simulation_parameters, test_specs=[] ) - tree_element = _ElementInTree(element, 123, [], 'test_1') - tree_element_copy = tree_element.create_copy('test_2') + tree_element = _ElementInTree(element, 123, [], "test_1") + tree_element_copy = tree_element.create_copy("test_2") assert tree_element_copy.element is tree_element.element assert tree_element_copy.element_index is tree_element.element_index assert tree_element_copy.children is tree_element.children assert tree_element_copy.subelement_type != tree_element.subelement_type - assert tree_element_copy.subelement_type == 'test_2' + assert tree_element_copy.subelement_type == "test_2" def test_ElementsIterator(): + """ + Tests the functionality of the ElementsIterator class. + + This test case creates a complex element structure with nested subelements and + verifies that the iterator correctly traverses this structure, yielding each + element in the expected order. It also checks if the tree is saved and rebuilt + correctly during multiple iterations. + + Args: + None + + Returns: + None + """ simulation_parameters = SimulationParameters( - axes={'W': torch.tensor([0]), 'H': torch.tensor([0]), 'wavelength': 1} + axes={"W": torch.tensor([0]), "H": torch.tensor([0]), "wavelength": 1} ) - repr1 = ReprRepr(1.) + repr1 = ReprRepr(1.0) subelement1 = SpecsTestElement(simulation_parameters, []) subelement2 = SpecsTestElement(simulation_parameters, []) - subelement3 = SpecsTestElement(simulation_parameters, [ - SubelementSpecs('subelement1', subelement1), - SubelementSpecs('subelement2', subelement2) - ]) + subelement3 = SpecsTestElement( + simulation_parameters, + [ + SubelementSpecs("subelement1", subelement1), + SubelementSpecs("subelement2", subelement2), + ], + ) element = SpecsTestElement( simulation_parameters=simulation_parameters, test_specs=[ ParameterSpecs( - 'test1', + "test1", [ repr1, - ] + ], ), - SubelementSpecs('subelement1_copy', subelement1), - SubelementSpecs('subelement3', subelement3), - ] + SubelementSpecs("subelement1_copy", subelement1), + SubelementSpecs("subelement3", subelement3), + ], ) - elements = _ElementsIterator(element, directory='') + elements = _ElementsIterator(element, directory="") # Test iterator output iterated_indices = [] @@ -185,9 +253,7 @@ def test_ElementsIterator(): iterated_elements.append(el) assert iterated_indices == list(range(4)) - assert iterated_elements == [ - element, subelement1, subelement3, subelement2 - ] + assert iterated_elements == [element, subelement1, subelement3, subelement2] # Test if the tree is saved in the iterator tree = elements.tree @@ -212,87 +278,109 @@ def test_ElementsIterator(): assert elements.tree == tree # Test if the tree can be generated automatically - new_elements = _ElementsIterator(element, directory='') + new_elements = _ElementsIterator(element, directory="") assert new_elements.tree is not tree assert new_elements.tree == tree assert new_elements.tree is new_elements.tree def test_write_tree(tmp_path): + """ + Tests writing the elements tree to both string and markdown formats. + + This test creates a nested structure of simulation elements and then + verifies that `write_elements_tree_to_str` and `write_elements_tree_to_markdown` + can successfully write this tree to a string stream and produce non-empty output, + respectively. + + Args: + tmp_path: A temporary path (not directly used in the test but required by the fixture). + + Returns: + None + """ simulation_parameters = SimulationParameters( - axes={'W': torch.tensor([0]), 'H': torch.tensor([0]), 'wavelength': 1} + axes={"W": torch.tensor([0]), "H": torch.tensor([0]), "wavelength": 1} ) - repr1 = ReprRepr(1.) + repr1 = ReprRepr(1.0) subelement1 = SpecsTestElement(simulation_parameters, []) subelement2 = SpecsTestElement(simulation_parameters, []) - subelement3 = SpecsTestElement(simulation_parameters, [ - SubelementSpecs('subelement1', subelement1), - SubelementSpecs('subelement2', subelement2) - ]) + subelement3 = SpecsTestElement( + simulation_parameters, + [ + SubelementSpecs("subelement1", subelement1), + SubelementSpecs("subelement2", subelement2), + ], + ) element = SpecsTestElement( simulation_parameters=simulation_parameters, test_specs=[ ParameterSpecs( - 'test1', + "test1", [ repr1, - ] + ], ), - SubelementSpecs('subelement1_copy', subelement1), - SubelementSpecs('subelement3', subelement3), - ] + SubelementSpecs("subelement1_copy", subelement1), + SubelementSpecs("subelement3", subelement3), + ], ) - elements = _ElementsIterator(element, directory='') + elements = _ElementsIterator(element, directory="") # === test str === - stream = StringIO('') + stream = StringIO("") write_elements_tree_to_str(elements.tree, stream) assert stream.getvalue() # === test md === - stream = StringIO('') + stream = StringIO("") write_elements_tree_to_markdown(elements.tree, stream) assert stream.getvalue() def test_write_specs(tmp_path): + """ + Tests the write_specs function with different file formats.""" simulation_parameters = SimulationParameters( - axes={'W': torch.tensor([0]), 'H': torch.tensor([0]), 'wavelength': 1} + axes={"W": torch.tensor([0]), "H": torch.tensor([0]), "wavelength": 1} ) - repr1 = ReprRepr(1.) + repr1 = ReprRepr(1.0) subelement1 = SpecsTestElement(simulation_parameters, []) subelement2 = SpecsTestElement(simulation_parameters, []) - subelement3 = SpecsTestElement(simulation_parameters, [ - SubelementSpecs('subelement1', subelement1), - SubelementSpecs('subelement2', subelement2) - ]) + subelement3 = SpecsTestElement( + simulation_parameters, + [ + SubelementSpecs("subelement1", subelement1), + SubelementSpecs("subelement2", subelement2), + ], + ) element = SpecsTestElement( simulation_parameters=simulation_parameters, test_specs=[ ParameterSpecs( - 'test1', + "test1", [ repr1, - ] + ], ), - SubelementSpecs('subelement1_copy', subelement1), - SubelementSpecs('subelement3', subelement3), - ] + SubelementSpecs("subelement1_copy", subelement1), + SubelementSpecs("subelement3", subelement3), + ], ) # === test txt === - write_specs(element, filename='test_specs.txt', directory=tmp_path) - assert Path.exists(tmp_path / 'test_specs.txt') + write_specs(element, filename="test_specs.txt", directory=tmp_path) + assert Path.exists(tmp_path / "test_specs.txt") # === test md === - write_specs(element, filename='test_specs.md', directory=tmp_path) - assert Path.exists(tmp_path / 'test_specs.md') + write_specs(element, filename="test_specs.md", directory=tmp_path) + assert Path.exists(tmp_path / "test_specs.md") # === test unknown format === with pytest.raises(ValueError): - write_specs(element, filename='test_specs.test', directory=tmp_path) + write_specs(element, filename="test_specs.test", directory=tmp_path) diff --git a/tests/test_types.py b/tests/test_types.py index d04d544..0415ca1 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -8,10 +8,7 @@ parameters = "default_type" -@pytest.mark.parametrize(parameters, [ - torch.float64, - torch.float32 -]) +@pytest.mark.parametrize(parameters, [torch.float64, torch.float32]) def test_types(default_type: torch.dtype): """A test that checks that all elements belong to the same data type @@ -23,17 +20,17 @@ def test_types(default_type: torch.dtype): torch.set_default_dtype(default_type) - ox_size = 15. - oy_size = 8. + ox_size = 15.0 + oy_size = 8.0 ox_nodes = 1200 oy_nodes = 1100 - wavelength = torch.linspace(330*1e-6, 660*1e-6, 5) - waist_radius = 2. - distance = 100. - focal_length = 100. - radius = 10. - height = 4. - width = 3. + wavelength = torch.linspace(330 * 1e-6, 660 * 1e-6, 5) + waist_radius = 2.0 + distance = 100.0 + focal_length = 100.0 + radius = 10.0 + height = 4.0 + width = 3.0 if torch.get_default_dtype() == torch.float64: default_complex_dtype = torch.complex128 @@ -42,81 +39,60 @@ def test_types(default_type: torch.dtype): params = SimulationParameters( axes={ - 'W': torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes), - 'H': torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes), - 'wavelength': wavelength - } + "W": torch.linspace(-ox_size / 2, ox_size / 2, ox_nodes), + "H": torch.linspace(-oy_size / 2, oy_size / 2, oy_nodes), + "wavelength": wavelength, + } ) x_linear = params.axes.W y_linear = params.axes.H wavelength = params.axes.wavelength - x_grid, y_grid = params.meshgrid(x_axis='W', y_axis='H') + x_grid, y_grid = params.meshgrid(x_axis="W", y_axis="H") gaussian_beam = w.gaussian_beam( - simulation_parameters=params, - waist_radius=waist_radius, - distance=distance + simulation_parameters=params, waist_radius=waist_radius, distance=distance ) - plane_wave = w.plane_wave( - simulation_parameters=params, - distance=distance - ) + plane_wave = w.plane_wave(simulation_parameters=params, distance=distance) - spherical_wave = w.spherical_wave( - simulation_parameters=params, - distance=distance - ) + spherical_wave = w.spherical_wave(simulation_parameters=params, distance=distance) lens = elements.ThinLens( - simulation_parameters=params, - focal_length=focal_length, - radius=radius + simulation_parameters=params, focal_length=focal_length, radius=radius ).get_transmission_function() aperture = elements.Aperture( - simulation_parameters=params, - mask=torch.zeros(x_grid.shape) + simulation_parameters=params, mask=torch.zeros(x_grid.shape) ).get_transmission_function() rectangular_aperture = elements.RectangularAperture( - simulation_parameters=params, - height=height, - width=width + simulation_parameters=params, height=height, width=width ).get_transmission_function() round_aperture = elements.RoundAperture( - simulation_parameters=params, - radius=radius + simulation_parameters=params, radius=radius ).get_transmission_function() slm = elements.SpatialLightModulator( - simulation_parameters=params, - mask=torch.ones_like(x_grid), - height=8, - width=9 + simulation_parameters=params, mask=torch.ones_like(x_grid), height=8, width=9 ).transmission_function layer = elements.DiffractiveLayer( - simulation_parameters=params, - mask=torch.zeros(x_grid.shape) + simulation_parameters=params, mask=torch.zeros(x_grid.shape) ).transmission_function free_space_as = elements.FreeSpace( - simulation_parameters=params, - distance=distance, method='AS' + simulation_parameters=params, distance=distance, method="AS" )(gaussian_beam) free_space_fresnel = elements.FreeSpace( - simulation_parameters=params, - distance=distance, method='fresnel' + simulation_parameters=params, distance=distance, method="fresnel" )(gaussian_beam) free_space_reverse = elements.FreeSpace( - simulation_parameters=params, - distance=distance, method='fresnel' + simulation_parameters=params, distance=distance, method="fresnel" ).reverse(transmission_wavefront=gaussian_beam) default_type = torch.get_default_dtype() diff --git a/tests/test_units.py b/tests/test_units.py index 9c8eddc..fb25a4f 100644 --- a/tests/test_units.py +++ b/tests/test_units.py @@ -5,7 +5,7 @@ @pytest.mark.parametrize( - 'other', + "other", ( 123, 1.234, @@ -15,27 +15,36 @@ np.array(123), np.array(1.234), np.array([[1.23, 4.56]]), - ) + ), ) def test_arithmetics(other): - torch.testing.assert_close( - other * ureg.mm, other * ureg.mm.value - ) - torch.testing.assert_close( - ureg.mm * other, other * ureg.mm.value - ) - torch.testing.assert_close( - other / ureg.mm, other / ureg.mm.value - ) - torch.testing.assert_close( - ureg.mm / other, ureg.mm.value / other - ) - torch.testing.assert_close( - ureg.mm ** other, ureg.mm.value ** other - ) + """ + Tests arithmetic operations with the unit 'mm'. + + This function checks if basic arithmetic operations (multiplication, division, and exponentiation) + between a given value and the 'mm' unit from astropy.units produce the expected results when compared to + the underlying numerical value of the unit. It tests both left-hand and right-hand side operations. + + Args: + other: The value to perform arithmetic with. Can be an integer, float, torch tensor or numpy array. + + Returns: + None: This function only performs assertions and does not return a value. + """ + torch.testing.assert_close(other * ureg.mm, other * ureg.mm.value) + torch.testing.assert_close(ureg.mm * other, other * ureg.mm.value) + torch.testing.assert_close(other / ureg.mm, other / ureg.mm.value) + torch.testing.assert_close(ureg.mm / other, ureg.mm.value / other) + torch.testing.assert_close(ureg.mm**other, ureg.mm.value**other) def test_array_api(): + """ + Tests array API compatibility with pint and numpy. + + This tests that adding a pint Quantity to a NumPy array results in a NumPy array, + and that attempting to use __array__ with copy=False on a pint unit raises a ValueError. + """ assert isinstance(ureg.m + np.array([0.0]), np.ndarray) with pytest.raises(ValueError): diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 726b06f..693f694 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -12,38 +12,60 @@ def test_html_element(): + """ + Tests that the HTML representation of a FreeSpace element is generated.""" sim_params = svetlanna.SimulationParameters( - {'W': torch.tensor([0]), 'H': torch.tensor([0]), 'wavelength': 1} + {"W": torch.tensor([0]), "H": torch.tensor([0]), "wavelength": 1} ) - element = svetlanna.elements.FreeSpace(sim_params, distance=1, method='AS') + element = svetlanna.elements.FreeSpace(sim_params, distance=1, method="AS") assert element._repr_html_() def test_default_widget_html_method(): - assert default_widget_html_method(123, 'test', 'element_type', []) + """ + Tests the default widget HTML method. + + Args: + None + + Returns: + None + """ + assert default_widget_html_method(123, "test", "element_type", []) def test_generate_structure_html(): + """ + Tests the generation of HTML structure for a simple simulation element. + + This test creates a basic simulation setup with a FreeSpace element and + a nested NoWidgetHTMLElement, then asserts that generate_structure_html + returns without errors when given this structure. + + Args: + None + + Returns: + None + """ sim_params = svetlanna.SimulationParameters( - {'W': torch.tensor([0]), 'H': torch.tensor([0]), 'wavelength': 1} + {"W": torch.tensor([0]), "H": torch.tensor([0]), "wavelength": 1} ) - element = svetlanna.elements.FreeSpace(sim_params, distance=1, method='AS') + element = svetlanna.elements.FreeSpace(sim_params, distance=1, method="AS") class NoWidgetHTMLElement: def to_specs(self): return [] assert generate_structure_html( - [ - _ElementInTree(element, 0, [ - _ElementInTree(NoWidgetHTMLElement(), 0, []) - ]) - ] + [_ElementInTree(element, 0, [_ElementInTree(NoWidgetHTMLElement(), 0, [])])] ) def test_show_structure(monkeypatch): + """ + Tests the show_structure function's behavior with and without IPython.""" import IPython.display # monkeypatch IPython.display.display @@ -53,7 +75,7 @@ def set_displayed(): nonlocal displayed displayed = True - monkeypatch.setattr(IPython.display, 'display', lambda _: set_displayed()) + monkeypatch.setattr(IPython.display, "display", lambda _: set_displayed()) # Test if the HTML has been displayed displayed = False @@ -79,50 +101,83 @@ def import_with_no_ipython(name, *args, **kwargs): def test_show_specs(): + """ + Tests the show_specs function. + + This test creates a simple simulation setup with a FreeSpace element and + verifies that the show_specs function returns a SpecsWidget containing + information about the element. + + Args: + None + + Returns: + None + """ sim_params = svetlanna.SimulationParameters( - {'W': torch.tensor([0]), 'H': torch.tensor([0]), 'wavelength': 1} + {"W": torch.tensor([0]), "H": torch.tensor([0]), "wavelength": 1} ) - element = svetlanna.elements.FreeSpace(sim_params, distance=1, method='AS') + element = svetlanna.elements.FreeSpace(sim_params, distance=1, method="AS") widget = show_specs(element) assert isinstance(widget, SpecsWidget) assert len(widget.elements) == 1 - assert widget.elements[0]['name'] == 'FreeSpace' + assert widget.elements[0]["name"] == "FreeSpace" def test_draw_wavefront(): + """ + Tests the draw_wavefront function with different type combinations. + + This test creates a simple plane wavefront and then calls draw_wavefront + with single types and all available types to ensure it functions correctly + for various plotting configurations. + + Args: + None + + Returns: + bool: True if all assertions pass, indicating the function works as expected. + """ sim_params = svetlanna.SimulationParameters( { - 'W': torch.linspace(-1, 1, 10), - 'H': torch.linspace(-1, 1, 10), - 'wavelength': 1 + "W": torch.linspace(-1, 1, 10), + "H": torch.linspace(-1, 1, 10), + "wavelength": 1, } ) wavefront = svetlanna.Wavefront.plane_wave(sim_params) # Single type - types = ('A', 'I', 'phase', 'Re', 'Im') + types = ("A", "I", "phase", "Re", "Im") for t in types: - assert draw_wavefront( - wavefront, - sim_params, - types_to_plot=(t,) - ) + assert draw_wavefront(wavefront, sim_params, types_to_plot=(t,)) # All types - assert draw_wavefront( - wavefront, - sim_params, - types_to_plot=types - ) + assert draw_wavefront(wavefront, sim_params, types_to_plot=types) def test_show_stepwise_forward(): + """ + Tests the show_stepwise_forward function with various elements. + + This test creates a simulation setup with different optical elements, + including a valid FreeSpace element, an element that returns None, and + an element that returns a tensor instead of an image. It then asserts + that the resulting widget is a StepwiseForwardWidget, contains all three + elements, and correctly represents their outputs in JSON format. + + Args: + None + + Returns: + None + """ sim_params = svetlanna.SimulationParameters( { - 'W': torch.linspace(-1, 1, 10), - 'H': torch.linspace(-1, 1, 10), - 'wavelength': 1 + "W": torch.linspace(-1, 1, 10), + "H": torch.linspace(-1, 1, 10), + "wavelength": 1, } ) @@ -135,36 +190,32 @@ def to_specs(self): class WrongTensorForwardElement(torch.nn.Module): def forward(self, x): - return torch.tensor([1, 2, 3.]) + return torch.tensor([1, 2, 3.0]) def to_specs(self): return [] - element1 = svetlanna.elements.FreeSpace(sim_params, distance=1, method='AS') + element1 = svetlanna.elements.FreeSpace(sim_params, distance=1, method="AS") element2 = NoneForwardElement() element3 = WrongTensorForwardElement() wavefront = svetlanna.Wavefront.plane_wave(sim_params) widget = show_stepwise_forward( - element1, - element2, - element3, - input=wavefront, - simulation_parameters=sim_params + element1, element2, element3, input=wavefront, simulation_parameters=sim_params ) assert isinstance(widget, StepwiseForwardWidget) assert len(widget.elements) == 3 element1_json = widget.elements[0] - assert element1_json['name'] == 'FreeSpace' - assert element1_json['output_image'] + assert element1_json["name"] == "FreeSpace" + assert element1_json["output_image"] element2_json = widget.elements[1] - assert element2_json['name'] == 'NoneForwardElement' - assert element2_json['output_image'] is None + assert element2_json["name"] == "NoneForwardElement" + assert element2_json["output_image"] is None element3_json = widget.elements[2] - assert element3_json['name'] == 'WrongTensorForwardElement' - assert element3_json['output_image'][:1] == '\n' + assert element3_json["name"] == "WrongTensorForwardElement" + assert element3_json["output_image"][:1] == "\n" diff --git a/tests/test_wavefront.py b/tests/test_wavefront.py index d576159..e77b9a5 100644 --- a/tests/test_wavefront.py +++ b/tests/test_wavefront.py @@ -4,13 +4,23 @@ def test_creation(): - wf = Wavefront(1.) + """ + Tests the creation of Wavefront objects with various inputs. + + This method checks that a Wavefront object can be successfully initialized + with different types of input data (float, complex number, list of complex numbers, and torch tensor) + and verifies that the resulting object is a PyTorch tensor and an instance of the Wavefront class. + + Returns: + None + """ + wf = Wavefront(1.0) assert isinstance(wf, torch.Tensor) - wf = Wavefront(1. + 1.j) + wf = Wavefront(1.0 + 1.0j) assert isinstance(wf, torch.Tensor) - wf = Wavefront([1 + 2.j]) + wf = Wavefront([1 + 2.0j]) assert isinstance(wf, torch.Tensor) data = torch.tensor([1, 2, 3]) @@ -20,15 +30,19 @@ def test_creation(): @pytest.mark.parametrize( - ('a', 'b'), [ - (1., 2.), - (1., 1.,), - (-1., 1.3) - ] + ("a", "b"), + [ + (1.0, 2.0), + ( + 1.0, + 1.0, + ), + (-1.0, 1.3), + ], ) def test_intensity(a: float, b: float): """Test intensity calculations""" - wf = Wavefront([a + 1j*b]) + wf = Wavefront([a + 1j * b]) real_intensity = torch.tensor([a**2 + b**2]) torch.testing.assert_close(wf.intensity, real_intensity) @@ -43,25 +57,40 @@ def test_intensity(a: float, b: float): @pytest.mark.parametrize( - ('r', 'phi'), [ - (1., 0.), - (1., [1.]), - (10., [1., 2., 3.]) - ] + ("r", "phi"), [(1.0, 0.0), (1.0, [1.0]), (10.0, [1.0, 2.0, 3.0])] ) def test_phase(r, phi): + """ + Tests that the wavefront phase is correctly initialized. + + Args: + r: The radius of the wavefront. + phi: The initial phase values. + + Returns: + None: This function asserts a condition and does not return a value. + """ wf = Wavefront(r * torch.exp(1j * torch.tensor(phi))) torch.testing.assert_close(wf.phase, torch.tensor(phi)) -@pytest.mark.parametrize('waist_radius', (1, 0.5, 0.2)) +@pytest.mark.parametrize("waist_radius", (1, 0.5, 0.2)) def test_fwhm(waist_radius): + """ + Tests the full width at half maximum (FWHM) calculation for a Gaussian beam. + + Args: + waist_radius: The waist radius of the Gaussian beam. + + Returns: + None: This function asserts properties of the FWHM and does not return a value. + """ sim_params = SimulationParameters( { - 'W': torch.linspace(-1, 1, 1000), - 'H': torch.linspace(-1, 1, 1000), - 'wavelength': 1 + "W": torch.linspace(-1, 1, 1000), + "H": torch.linspace(-1, 1, 1000), + "wavelength": 1, } ) @@ -73,21 +102,32 @@ def test_fwhm(waist_radius): assert wf.fwhm(sim_params)[0] == wf.fwhm(sim_params)[1] torch.testing.assert_close( torch.tensor(wf.fwhm(sim_params)[0]), - torch.sqrt(2*torch.log(torch.tensor(2.))) * waist_radius, + torch.sqrt(2 * torch.log(torch.tensor(2.0))) * waist_radius, rtol=0.001, atol=0.01, ) -@pytest.mark.parametrize('distance', (1, 1.23, 1e-4, 1e4)) -@pytest.mark.parametrize('wavelength', (1.0, torch.tensor([1.23, 20]))) -@pytest.mark.parametrize('initial_phase', (1.0, 123, 2e-4)) +@pytest.mark.parametrize("distance", (1, 1.23, 1e-4, 1e4)) +@pytest.mark.parametrize("wavelength", (1.0, torch.tensor([1.23, 20]))) +@pytest.mark.parametrize("initial_phase", (1.0, 123, 2e-4)) def test_plane_wave(distance, wavelength, initial_phase): + """ + Tests the plane_wave method of the Wavefront class. + + Args: + distance: The distance to propagate the plane wave. + wavelength: The wavelength of the plane wave. + initial_phase: The initial phase of the plane wave. + + Returns: + None. This function asserts properties of the generated Wavefront object. + """ sim_params = SimulationParameters( { - 'W': torch.linspace(-0.1, 2, 10), - 'H': torch.linspace(-1, 5, 20), - 'wavelength': wavelength + "W": torch.linspace(-0.1, 2, 10), + "H": torch.linspace(-1, 5, 20), + "wavelength": wavelength, } ) k = 2 * torch.pi / sim_params.axes.wavelength @@ -99,11 +139,9 @@ def test_plane_wave(distance, wavelength, initial_phase): assert isinstance(wf, Wavefront) torch.allclose( wf.angle(), - torch.exp(1j * (k * distance + initial_phase)[..., None, None]).angle() - ) - torch.allclose( - wf.abs(), torch.tensor(1.) + torch.exp(1j * (k * distance + initial_phase)[..., None, None]).angle(), ) + torch.allclose(wf.abs(), torch.tensor(1.0)) # x,y propagation dir_x = 0.1312234 @@ -113,18 +151,18 @@ def test_plane_wave(distance, wavelength, initial_phase): x = sim_params.axes.W[None, :] y = sim_params.axes.H[:, None] wf = Wavefront.plane_wave( - sim_params, distance=distance, wave_direction=[dir_x, dir_y, 0], - initial_phase=initial_phase + sim_params, + distance=distance, + wave_direction=[dir_x, dir_y, 0], + initial_phase=initial_phase, ) torch.allclose( wf.angle(), - torch.exp(1j * ( - kx[..., None, None] * x + ky[..., None, None] * y + initial_phase - )).angle() - ) - torch.allclose( - wf.abs(), torch.tensor(1.) + torch.exp( + 1j * (kx[..., None, None] * x + ky[..., None, None] * y + initial_phase) + ).angle(), ) + torch.allclose(wf.abs(), torch.tensor(1.0)) # Test wrong wave direction with pytest.raises(ValueError): @@ -134,22 +172,31 @@ def test_plane_wave(distance, wavelength, initial_phase): # TODO: Test Gaussian beam against precomputed values -@pytest.mark.parametrize('distance', (1, 1.23, 1e-4, 1e4)) -@pytest.mark.parametrize('waist_radius', (1, 1.23, 1e-4, 1e4)) -@pytest.mark.parametrize('dx', (1.0, 123, 2e-4)) -@pytest.mark.parametrize('dy', (1.0, 123, 2e-4)) -@pytest.mark.parametrize( - 'wavelength', ( - 1.0, - torch.tensor([1.23, 20]) - ) -) +@pytest.mark.parametrize("distance", (1, 1.23, 1e-4, 1e4)) +@pytest.mark.parametrize("waist_radius", (1, 1.23, 1e-4, 1e4)) +@pytest.mark.parametrize("dx", (1.0, 123, 2e-4)) +@pytest.mark.parametrize("dy", (1.0, 123, 2e-4)) +@pytest.mark.parametrize("wavelength", (1.0, torch.tensor([1.23, 20]))) def test_gaussian_beam(distance, waist_radius, dx, dy, wavelength): + """ + Tests the gaussian_beam method with various parameters. + + Args: + distance: The distance to propagate the beam. + waist_radius: The radius of the Gaussian beam at its waist. + dx: Offset in x direction. + dy: Offset in y direction. + wavelength: The wavelength of the light. Can be a float or a torch tensor. + + Returns: + None: This test does not return any value; it asserts that the + gaussian_beam method runs without errors for given parameters. + """ sim_params = SimulationParameters( { - 'W': torch.linspace(-0.1, 2, 10), - 'H': torch.linspace(-1, 5, 20), - 'wavelength': wavelength + "W": torch.linspace(-0.1, 2, 10), + "H": torch.linspace(-1, 5, 20), + "wavelength": wavelength, } ) # Stupid test @@ -159,22 +206,31 @@ def test_gaussian_beam(distance, waist_radius, dx, dy, wavelength): # TODO: Test spherical wave against precomputed values -@pytest.mark.parametrize('distance', (1, 1.23, 1e-4, 1e4)) -@pytest.mark.parametrize('initial_phase', (1, 1.23, 1e-4, 1e4)) -@pytest.mark.parametrize('dx', (1.0, 123, 2e-4)) -@pytest.mark.parametrize('dy', (1.0, 123, 2e-4)) -@pytest.mark.parametrize( - 'wavelength', ( - 1.0, - torch.tensor([1.23, 20]) - ) -) +@pytest.mark.parametrize("distance", (1, 1.23, 1e-4, 1e4)) +@pytest.mark.parametrize("initial_phase", (1, 1.23, 1e-4, 1e4)) +@pytest.mark.parametrize("dx", (1.0, 123, 2e-4)) +@pytest.mark.parametrize("dy", (1.0, 123, 2e-4)) +@pytest.mark.parametrize("wavelength", (1.0, torch.tensor([1.23, 20]))) def test_spherical_wave(distance, initial_phase, dx, dy, wavelength): + """ + Tests the spherical wave function with various parameters. + + Args: + distance: The distance from the source of the spherical wave. + initial_phase: The initial phase of the wave. + dx: The x-coordinate offset. + dy: The y-coordinate offset. + wavelength: The wavelength of the wave. + + Returns: + None: This function does not return a value; it asserts that the + spherical_wave function runs without errors for given parameters. + """ sim_params = SimulationParameters( { - 'W': torch.linspace(-0.1, 2, 10), - 'H': torch.linspace(-1, 5, 20), - 'wavelength': wavelength + "W": torch.linspace(-0.1, 2, 10), + "H": torch.linspace(-1, 5, 20), + "wavelength": wavelength, } ) # Stupid test @@ -184,6 +240,20 @@ def test_spherical_wave(distance, initial_phase, dx, dy, wavelength): def test_wavefront_as_a_tensor(): + """ + Tests that Wavefront operations with tensors return a Wavefront object. + + This method creates a Wavefront object from a random tensor and then performs + various arithmetic operations (addition, multiplication, division) between the + Wavefront object and the original tensor. It asserts that the result of each + operation is also a Wavefront object. + + Args: + None + + Returns: + None + """ tensor = torch.rand((2, 10, 20)) wf = Wavefront(tensor) diff --git a/visualization.ipynb b/visualization.ipynb index 7dc7868..5118778 100644 --- a/visualization.ipynb +++ b/visualization.ipynb @@ -1,1071 +1,1071 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "%%html\n", - "" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import svetlanna\n", - "import torch" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "simulation_parameters = svetlanna.SimulationParameters(\n", - " {\n", - " 'W': torch.linspace(-1, 1, 100),\n", - " 'H': torch.linspace(-1, 1, 100),\n", - " 'wavelength': 2e-1\n", - " }\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "
\n", - " mask\n", - "
\n", - " \n", - "
\n", - "
Tensor of size (100x100)\n",
-       "
\n", - "
\n", - " \n", - "
\n", - "
\n",
-       "\n",
-       "
\n", - "
\n", - " \n", - "
\n", - " mask_norm\n", - "
\n", - " \n", - "
\n", - "
6.283185307179586\n",
-       "
\n", - "
\n", - "
" - ], - "text/plain": [ - "DiffractiveLayer()" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "svetlanna.elements.DiffractiveLayer(simulation_parameters=simulation_parameters, mask=torch.rand((100, 100)))" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/vigos/Documents/GitHub/SVETlANNa/svetlanna/elements/free_space.py:151: UserWarning: The paraxial (near-axis) optics condition required for the Fresnel method is not satisfied. Consider increasing the distance or decreasing the screen size.\n", - " warn(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "06fb527359f844baaf74b453f21a3425", - "version_major": 2, - "version_minor": 1 - }, - "text/plain": [ - "LinearOpticalSetupWidget(elements=[{'index': 0, 'type': 'ThinLens', 'specs_html': '
\n", - "
\n", - " feedback_gain\n", - "
\n", - " \n", - "
\n", - "
0.1\n",
-       "
\n", - "
\n", - " \n", - "
\n", - " input_gain\n", - "
\n", - " \n", - "
\n", - "
0.2\n",
-       "
\n", - "
\n", - " \n", - "
\n", - " delay\n", - "
\n", - " \n", - "
\n", - "
10\n",
-       "
\n", - "
\n", - "
[Nonlinear element] LinearOpticalSetup
[0] ThinLens
\n", - "
\n", - " focal_length\n", - "
\n", - " \n", - "
\n", - "
1\n",
-       "
\n", - "
\n", - " \n", - "
\n", - " radius\n", - "
\n", - " \n", - "
\n", - "
inf\n",
-       "
\n", - "
\n", - "
[1] FreeSpace
\n", - "
\n", - " distance\n", - "
\n", - " \n", - "
\n", - "
1\n",
-       "
\n", - "
\n", - "
[2] ThinLens
\n", - "
\n", - " focal_length\n", - "
\n", - " \n", - "
\n", - "
1\n",
-       "
\n", - "
\n", - " \n", - "
\n", - " radius\n", - "
\n", - " \n", - "
\n", - "
inf\n",
-       "
\n", - "
\n", - "
[Delay element] FreeSpace
\n", - "
\n", - " distance\n", - "
\n", - " \n", - "
\n", - "
1\n",
-       "
\n", - "
\n", - "
" - ], - "text/plain": [ - "SimpleReservoir(\n", - " (delay_element): FreeSpace()\n", - ")" - ] - }, - "execution_count": 39, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "system1 = svetlanna.LinearOpticalSetup([\n", - " svetlanna.elements.ThinLens(simulation_parameters=simulation_parameters, focal_length=1),\n", - " svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),\n", - " svetlanna.elements.ThinLens(simulation_parameters=simulation_parameters, focal_length=1),\n", - "])\n", - "\n", - "reservoir = svetlanna.elements.reservoir.SimpleReservoir(\n", - " simulation_parameters,\n", - " system1,\n", - " # system1,\n", - " svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),\n", - " 0.1,\n", - " 0.2,\n", - " 10\n", - ")\n", - "reservoir" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": {}, - "outputs": [], - "source": [ - "system2 = svetlanna.LinearOpticalSetup([\n", - " svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),\n", - " reservoir,\n", - " svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),\n", - "])" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "metadata": {}, - "outputs": [], - "source": [ - "from svetlanna.specs.specs_writer import _ElementsIterator, _ElementInTree\n", - "from svetlanna.specs import Specsable\n", - "from dataclasses import dataclass\n", - "\n", - "from IPython.core.display import display_html\n", - "from jinja2 import Environment, FileSystemLoader, select_autoescape\n", - "\n", - "jinja_env = Environment(\n", - " loader=FileSystemLoader(\"templates\"),\n", - " autoescape=select_autoescape()\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 46, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "
\n", - " (0) LinearOpticalSetup\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "\n", - "\n", - "
\n", - "
\n", - "
\n", - "
\n", - " (1) FreeSpace\n", - "
─→┃  ┃─→
\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "
\n", - " β†’\n", - "
\n", - "\n", - "
\n", - "
\n", - "
\n", - "
\n", - " (2) SimpleReservoir\n", - "
\n", - "
\n",
-       "  β”Œβ”€β”€β”€β†β”€β”€β”¨ Delay el. ┠──←───┐\n",
-       "  │↓                        │↑\n",
-       "β†’β”€βŠ•β”€β†’β”€β”€β”¨ Nonlinear el. ┠──→─┴─→\n",
-       "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "
\n", - " Nonlinear element\n", - "
\n", - "
\n", - "
\n", - " (3) LinearOpticalSetup\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "\n", - "\n", - "
\n", - "
\n", - "
\n", - "
\n", - " (4) ThinLens\n", - "
\n", - "
\n",
-       "β”€β†’β”ƒβ†˜\n",
-       "─→┃→ \n",
-       "─→┃↗\n",
-       "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "
\n", - " β†’\n", - "
\n", - "\n", - "
\n", - "
\n", - "
\n", - "
\n", - " (5) FreeSpace\n", - "
─→┃  ┃─→
\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "
\n", - " β†’\n", - "
\n", - "\n", - "
\n", - "
\n", - "
\n", - "
\n", - " (6) ThinLens\n", - "
\n", - "
\n",
-       "β”€β†’β”ƒβ†˜\n",
-       "─→┃→ \n",
-       "─→┃↗\n",
-       "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "\n", - "
\n", - " Delay element\n", - "
\n", - "
\n", - "
\n", - " (7) FreeSpace\n", - "
─→┃  ┃─→
\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "
\n", - " β†’\n", - "
\n", - "\n", - "
\n", - "
\n", - "
\n", - "
\n", - " (8) FreeSpace\n", - "
─→┃  ┃─→
\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "
\n", - "
\n", - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "\n", - "\n", - "@dataclass(frozen=True, slots=True)\n", - "class ElementHTML:\n", - " element_type: str | None\n", - " html: str\n", - "\n", - "\n", - "def _widget_html_(\n", - " index: int,\n", - " name: str,\n", - " element_type: str | None,\n", - " subelements: list[ElementHTML]\n", - ") -> str:\n", - " return jinja_env.get_template('default_widget.html.jinja').render(\n", - " index=index, name=name, subelements=subelements\n", - " )\n", - "\n", - "\n", - "def _ls_widget_html_(\n", - " index: int,\n", - " name: str,\n", - " element_type: str | None,\n", - " subelements: list[ElementHTML]\n", - ") -> str:\n", - " return jinja_env.get_template('linear_setup_widget.html.jinja').render(\n", - " index=index, name=name, subelements=subelements\n", - " )\n", - "\n", - "\n", - "def _fs_widget_html_(\n", - " index: int,\n", - " name: str,\n", - " element_type: str | None,\n", - " subelements: list[ElementHTML]\n", - ") -> str:\n", - " return jinja_env.get_template('free_space_widget.html.jinja').render(\n", - " index=index, name=name, subelements=subelements\n", - " )\n", - "\n", - "\n", - "def _rs_widget_html_(\n", - " index: int,\n", - " name: str,\n", - " element_type: str | None,\n", - " subelements: list[ElementHTML]\n", - ") -> str:\n", - " return jinja_env.get_template('reservoir_widget.html.jinja').render(\n", - " index=index, name=name, subelements=subelements\n", - " )\n", - "\n", - "\n", - "def _l_widget_html_(\n", - " index: int,\n", - " name: str,\n", - " element_type: str | None,\n", - " subelements: list[ElementHTML]\n", - ") -> str:\n", - " return jinja_env.get_template('lens_widget.html.jinja').render(\n", - " index=index, name=name, subelements=subelements\n", - " )\n", - "\n", - "\n", - "def _get_widget_html_method(element: Specsable):\n", - " if hasattr(element, '_widget_html_'):\n", - " widget_html_method = getattr(element, '_widget_html_')\n", - " else:\n", - " widget_html_method = _widget_html_\n", - "\n", - " if isinstance(element, svetlanna.LinearOpticalSetup):\n", - " widget_html_method = _ls_widget_html_\n", - "\n", - " if isinstance(element, svetlanna.elements.FreeSpace):\n", - " widget_html_method = _fs_widget_html_\n", - " \n", - " if isinstance(element, svetlanna.elements.SimpleReservoir):\n", - " widget_html_method = _rs_widget_html_\n", - " \n", - " if isinstance(element, svetlanna.elements.ThinLens):\n", - " widget_html_method = _l_widget_html_\n", - "\n", - " return widget_html_method\n", - "\n", - "\n", - "def _subelements_html(subelements: list[_ElementInTree]) -> list[ElementHTML]:\n", - " res = []\n", - "\n", - " for subelement in subelements:\n", - " widget_html_method = _get_widget_html_method(subelement.element)\n", - " try:\n", - " res.append(\n", - " ElementHTML(\n", - " subelement.subelement_type,\n", - " html=widget_html_method(\n", - " index=subelement.element_index,\n", - " name=subelement.element.__class__.__name__,\n", - " element_type=subelement.subelement_type,\n", - " subelements=_subelements_html(subelement.children)\n", - " )\n", - " )\n", - " )\n", - " except Exception as e:\n", - " pass\n", - "\n", - " return res\n", - "\n", - "\n", - "elements = _ElementsIterator(system2, directory='')\n", - "\n", - "for _, _, i in elements:\n", - " for _ in i:\n", - " pass\n", - "\n", - "res = _subelements_html(elements.tree)\n", - "\n", - "\n", - "containered_html = f'
{res[0].html}
'\n", - "display_html(containered_html, raw=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "e = svetlanna.specs.specs_writer.write_specs(system2, filename='specs.md')" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[_ElementInTree(element=, element_index=3, children=[_ElementInTree(element=ThinLens(), element_index=4, children=[], subelement_name='0'), _ElementInTree(element=FreeSpace(), element_index=5, children=[], subelement_name='1'), _ElementInTree(element=ThinLens(), element_index=6, children=[], subelement_name='2')], subelement_name='Nonlinear element'),\n", - " _ElementInTree(element=, element_index=3, children=[_ElementInTree(element=ThinLens(), element_index=4, children=[], subelement_name='0'), _ElementInTree(element=FreeSpace(), element_index=5, children=[], subelement_name='1'), _ElementInTree(element=ThinLens(), element_index=6, children=[], subelement_name='2')], subelement_name='Delay element')]" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "e.tree[0].children[1].children" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "ename": "Exception", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mException\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m\n", - "\u001b[0;31mException\u001b[0m: " - ] - } - ], - "source": [ - "raise Exception" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'Nonlinear element'" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "e._tree[0].children[1].children[0].element_name" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ElementInTree(element=, element_index=0, children=[ElementInTree(element=FreeSpace(), element_index=1, children=[]), ElementInTree(element=SimpleReservoir(), element_index=2, children=[ElementInTree(element=, element_index=3, children=[ElementInTree(element=ThinLens(), element_index=4, children=[]), ElementInTree(element=FreeSpace(), element_index=5, children=[]), ElementInTree(element=ThinLens(), element_index=6, children=[])]), ElementInTree(element=, element_index=7, children=[ElementInTree(element=ThinLens(), element_index=8, children=[]), ElementInTree(element=FreeSpace(), element_index=9, children=[]), ElementInTree(element=ThinLens(), element_index=10, children=[])])]), ElementInTree(element=FreeSpace(), element_index=11, children=[])])\n" - ] - } - ], - "source": [ - "print('\\n'.join([str(i) for i in e._tree]))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# torch.set_default_dtype(torch.float32)\n", - "# Image.fromarray(torch.tensor(a).to(torch.float64).numpy(), mode='L').show()\n", - "# Image.fromarray(np.uint8(255*torch.tensor(a).numpy()), mode='L').show() # <- works" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 97, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "torch.tensor([[1, 1,], [1, 2]]).size() < torch.tensor([1,]).size()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import svetlanna\n", - "import svetlanna.elements\n", - "\n", - "\n", - "class A(svetlanna.elements.Element):\n", - " def __init__(self, simulation_parameters: svetlanna.SimulationParameters) -> None:\n", - " super().__init__(simulation_parameters)\n", - "\n", - " self.a = self.make_buffer('a', self.simulation_parameters.axes.W)\n", - "\n", - " def forward(self, input_field: svetlanna.Wavefront) -> svetlanna.Wavefront:\n", - " pass" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "simulation_parameters = svetlanna.SimulationParameters(\n", - " {\n", - " 'W': torch.linspace(-1, 1, 10),\n", - " 'H': torch.linspace(-1, 1, 10),\n", - " 'wavelength': 10\n", - " }\n", - ")\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "a = A(simulation_parameters=simulation_parameters)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
" - ], - "text/plain": [ - "A()" - ] - }, - "execution_count": 108, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a.to('mps')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "device(type='cpu')" - ] - }, - "execution_count": 114, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "simulation_parameters.axes.W.device" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "device(type='mps', index=0)" - ] - }, - "execution_count": 113, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a.a.device" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%html\n", + "" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import svetlanna\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "simulation_parameters = svetlanna.SimulationParameters(\n", + " {\n", + " 'W': torch.linspace(-1, 1, 100),\n", + " 'H': torch.linspace(-1, 1, 100),\n", + " 'wavelength': 2e-1\n", + " }\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "
\n", + " mask\n", + "
\n", + " \n", + "
\n", + "
Tensor of size (100x100)\n",
+       "
\n", + "
\n", + " \n", + "
\n", + "
\n",
+       "\n",
+       "
\n", + "
\n", + " \n", + "
\n", + " mask_norm\n", + "
\n", + " \n", + "
\n", + "
6.283185307179586\n",
+       "
\n", + "
\n", + "
" + ], + "text/plain": [ + "DiffractiveLayer()" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "svetlanna.elements.DiffractiveLayer(simulation_parameters=simulation_parameters, mask=torch.rand((100, 100)))" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/vigos/Documents/GitHub/SVETlANNa/svetlanna/elements/free_space.py:151: UserWarning: The paraxial (near-axis) optics condition required for the Fresnel method is not satisfied. Consider increasing the distance or decreasing the screen size.\n", + " warn(\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "06fb527359f844baaf74b453f21a3425", + "version_major": 2, + "version_minor": 1 + }, + "text/plain": [ + "LinearOpticalSetupWidget(elements=[{'index': 0, 'type': 'ThinLens', 'specs_html': '
\n", + "
\n", + " feedback_gain\n", + "
\n", + " \n", + "
\n", + "
0.1\n",
+       "
\n", + "
\n", + " \n", + "
\n", + " input_gain\n", + "
\n", + " \n", + "
\n", + "
0.2\n",
+       "
\n", + "
\n", + " \n", + "
\n", + " delay\n", + "
\n", + " \n", + "
\n", + "
10\n",
+       "
\n", + "
\n", + "
[Nonlinear element] LinearOpticalSetup
[0] ThinLens
\n", + "
\n", + " focal_length\n", + "
\n", + " \n", + "
\n", + "
1\n",
+       "
\n", + "
\n", + " \n", + "
\n", + " radius\n", + "
\n", + " \n", + "
\n", + "
inf\n",
+       "
\n", + "
\n", + "
[1] FreeSpace
\n", + "
\n", + " distance\n", + "
\n", + " \n", + "
\n", + "
1\n",
+       "
\n", + "
\n", + "
[2] ThinLens
\n", + "
\n", + " focal_length\n", + "
\n", + " \n", + "
\n", + "
1\n",
+       "
\n", + "
\n", + " \n", + "
\n", + " radius\n", + "
\n", + " \n", + "
\n", + "
inf\n",
+       "
\n", + "
\n", + "
[Delay element] FreeSpace
\n", + "
\n", + " distance\n", + "
\n", + " \n", + "
\n", + "
1\n",
+       "
\n", + "
\n", + "
" + ], + "text/plain": [ + "SimpleReservoir(\n", + " (delay_element): FreeSpace()\n", + ")" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "system1 = svetlanna.LinearOpticalSetup([\n", + " svetlanna.elements.ThinLens(simulation_parameters=simulation_parameters, focal_length=1),\n", + " svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),\n", + " svetlanna.elements.ThinLens(simulation_parameters=simulation_parameters, focal_length=1),\n", + "])\n", + "\n", + "reservoir = svetlanna.elements.reservoir.SimpleReservoir(\n", + " simulation_parameters,\n", + " system1,\n", + " # system1,\n", + " svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),\n", + " 0.1,\n", + " 0.2,\n", + " 10\n", + ")\n", + "reservoir" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "system2 = svetlanna.LinearOpticalSetup([\n", + " svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),\n", + " reservoir,\n", + " svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "from svetlanna.specs.specs_writer import _ElementsIterator, _ElementInTree\n", + "from svetlanna.specs import Specsable\n", + "from dataclasses import dataclass\n", + "\n", + "from IPython.core.display import display_html\n", + "from jinja2 import Environment, FileSystemLoader, select_autoescape\n", + "\n", + "jinja_env = Environment(\n", + " loader=FileSystemLoader(\"templates\"),\n", + " autoescape=select_autoescape()\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "
\n", + " (0) LinearOpticalSetup\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + " (1) FreeSpace\n", + "
─→┃  ┃─→
\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "
\n", + " β†’\n", + "
\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + " (2) SimpleReservoir\n", + "
\n", + "
\n",
+       "  β”Œβ”€β”€β”€β†β”€β”€β”¨ Delay el. ┠──←───┐\n",
+       "  │↓                        │↑\n",
+       "β†’β”€βŠ•β”€β†’β”€β”€β”¨ Nonlinear el. ┠──→─┴─→\n",
+       "
\n", + "
\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "
\n", + " Nonlinear element\n", + "
\n", + "
\n", + "
\n", + " (3) LinearOpticalSetup\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + " (4) ThinLens\n", + "
\n", + "
\n",
+       "β”€β†’β”ƒβ†˜\n",
+       "─→┃→ \n",
+       "─→┃↗\n",
+       "
\n", + "
\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "
\n", + " β†’\n", + "
\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + " (5) FreeSpace\n", + "
─→┃  ┃─→
\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "
\n", + " β†’\n", + "
\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + " (6) ThinLens\n", + "
\n", + "
\n",
+       "β”€β†’β”ƒβ†˜\n",
+       "─→┃→ \n",
+       "─→┃↗\n",
+       "
\n", + "
\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + "
\n", + "\n", + "
\n", + " Delay element\n", + "
\n", + "
\n", + "
\n", + " (7) FreeSpace\n", + "
─→┃  ┃─→
\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "
\n", + " β†’\n", + "
\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + " (8) FreeSpace\n", + "
─→┃  ┃─→
\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "
\n", + "
\n", + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "\n", + "@dataclass(frozen=True, slots=True)\n", + "class ElementHTML:\n", + " element_type: str | None\n", + " html: str\n", + "\n", + "\n", + "def _widget_html_(\n", + " index: int,\n", + " name: str,\n", + " element_type: str | None,\n", + " subelements: list[ElementHTML]\n", + ") -> str:\n", + " return jinja_env.get_template('default_widget.html.jinja').render(\n", + " index=index, name=name, subelements=subelements\n", + " )\n", + "\n", + "\n", + "def _ls_widget_html_(\n", + " index: int,\n", + " name: str,\n", + " element_type: str | None,\n", + " subelements: list[ElementHTML]\n", + ") -> str:\n", + " return jinja_env.get_template('linear_setup_widget.html.jinja').render(\n", + " index=index, name=name, subelements=subelements\n", + " )\n", + "\n", + "\n", + "def _fs_widget_html_(\n", + " index: int,\n", + " name: str,\n", + " element_type: str | None,\n", + " subelements: list[ElementHTML]\n", + ") -> str:\n", + " return jinja_env.get_template('free_space_widget.html.jinja').render(\n", + " index=index, name=name, subelements=subelements\n", + " )\n", + "\n", + "\n", + "def _rs_widget_html_(\n", + " index: int,\n", + " name: str,\n", + " element_type: str | None,\n", + " subelements: list[ElementHTML]\n", + ") -> str:\n", + " return jinja_env.get_template('reservoir_widget.html.jinja').render(\n", + " index=index, name=name, subelements=subelements\n", + " )\n", + "\n", + "\n", + "def _l_widget_html_(\n", + " index: int,\n", + " name: str,\n", + " element_type: str | None,\n", + " subelements: list[ElementHTML]\n", + ") -> str:\n", + " return jinja_env.get_template('lens_widget.html.jinja').render(\n", + " index=index, name=name, subelements=subelements\n", + " )\n", + "\n", + "\n", + "def _get_widget_html_method(element: Specsable):\n", + " if hasattr(element, '_widget_html_'):\n", + " widget_html_method = getattr(element, '_widget_html_')\n", + " else:\n", + " widget_html_method = _widget_html_\n", + "\n", + " if isinstance(element, svetlanna.LinearOpticalSetup):\n", + " widget_html_method = _ls_widget_html_\n", + "\n", + " if isinstance(element, svetlanna.elements.FreeSpace):\n", + " widget_html_method = _fs_widget_html_\n", + " \n", + " if isinstance(element, svetlanna.elements.SimpleReservoir):\n", + " widget_html_method = _rs_widget_html_\n", + " \n", + " if isinstance(element, svetlanna.elements.ThinLens):\n", + " widget_html_method = _l_widget_html_\n", + "\n", + " return widget_html_method\n", + "\n", + "\n", + "def _subelements_html(subelements: list[_ElementInTree]) -> list[ElementHTML]:\n", + " res = []\n", + "\n", + " for subelement in subelements:\n", + " widget_html_method = _get_widget_html_method(subelement.element)\n", + " try:\n", + " res.append(\n", + " ElementHTML(\n", + " subelement.subelement_type,\n", + " html=widget_html_method(\n", + " index=subelement.element_index,\n", + " name=subelement.element.__class__.__name__,\n", + " element_type=subelement.subelement_type,\n", + " subelements=_subelements_html(subelement.children)\n", + " )\n", + " )\n", + " )\n", + " except Exception as e:\n", + " pass\n", + "\n", + " return res\n", + "\n", + "\n", + "elements = _ElementsIterator(system2, directory='')\n", + "\n", + "for _, _, i in elements:\n", + " for _ in i:\n", + " pass\n", + "\n", + "res = _subelements_html(elements.tree)\n", + "\n", + "\n", + "containered_html = f'
{res[0].html}
'\n", + "display_html(containered_html, raw=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "e = svetlanna.specs.specs_writer.write_specs(system2, filename='specs.md')" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[_ElementInTree(element=, element_index=3, children=[_ElementInTree(element=ThinLens(), element_index=4, children=[], subelement_name='0'), _ElementInTree(element=FreeSpace(), element_index=5, children=[], subelement_name='1'), _ElementInTree(element=ThinLens(), element_index=6, children=[], subelement_name='2')], subelement_name='Nonlinear element'),\n", + " _ElementInTree(element=, element_index=3, children=[_ElementInTree(element=ThinLens(), element_index=4, children=[], subelement_name='0'), _ElementInTree(element=FreeSpace(), element_index=5, children=[], subelement_name='1'), _ElementInTree(element=ThinLens(), element_index=6, children=[], subelement_name='2')], subelement_name='Delay element')]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "e.tree[0].children[1].children" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "ename": "Exception", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mException\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m\n", + "\u001b[0;31mException\u001b[0m: " + ] + } + ], + "source": [ + "raise Exception" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Nonlinear element'" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "e._tree[0].children[1].children[0].element_name" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ElementInTree(element=, element_index=0, children=[ElementInTree(element=FreeSpace(), element_index=1, children=[]), ElementInTree(element=SimpleReservoir(), element_index=2, children=[ElementInTree(element=, element_index=3, children=[ElementInTree(element=ThinLens(), element_index=4, children=[]), ElementInTree(element=FreeSpace(), element_index=5, children=[]), ElementInTree(element=ThinLens(), element_index=6, children=[])]), ElementInTree(element=, element_index=7, children=[ElementInTree(element=ThinLens(), element_index=8, children=[]), ElementInTree(element=FreeSpace(), element_index=9, children=[]), ElementInTree(element=ThinLens(), element_index=10, children=[])])]), ElementInTree(element=FreeSpace(), element_index=11, children=[])])\n" + ] + } + ], + "source": [ + "print('\\n'.join([str(i) for i in e._tree]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# torch.set_default_dtype(torch.float32)\n", + "# Image.fromarray(torch.tensor(a).to(torch.float64).numpy(), mode='L').show()\n", + "# Image.fromarray(np.uint8(255*torch.tensor(a).numpy()), mode='L').show() # <- works" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 97, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.tensor([[1, 1,], [1, 2]]).size() < torch.tensor([1,]).size()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import svetlanna\n", + "import svetlanna.elements\n", + "\n", + "\n", + "class A(svetlanna.elements.Element):\n", + " def __init__(self, simulation_parameters: svetlanna.SimulationParameters) -> None:\n", + " super().__init__(simulation_parameters)\n", + "\n", + " self.a = self.make_buffer('a', self.simulation_parameters.axes.W)\n", + "\n", + " def forward(self, input_field: svetlanna.Wavefront) -> svetlanna.Wavefront:\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "simulation_parameters = svetlanna.SimulationParameters(\n", + " {\n", + " 'W': torch.linspace(-1, 1, 10),\n", + " 'H': torch.linspace(-1, 1, 10),\n", + " 'wavelength': 10\n", + " }\n", + ")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "a = A(simulation_parameters=simulation_parameters)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "A()" + ] + }, + "execution_count": 108, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a.to('mps')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "device(type='cpu')" + ] + }, + "execution_count": 114, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "simulation_parameters.axes.W.device" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "device(type='mps', index=0)" + ] + }, + "execution_count": 113, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a.a.device" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}