Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 33 additions & 30 deletions ratapi/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import importlib
import os
import pathlib
import warnings
from collections.abc import Callable

import numpy as np

import ratapi
import ratapi.wrappers
from ratapi.rat_core import Checks, Control, NameStore, ProblemDefinition
from ratapi.utils.enums import Calculations, Languages, LayerModels, TypeOptions
from ratapi.utils.enums import Calculations, Languages, LayerModels, Procedures, TypeOptions

parameter_field = {
"parameters": "params",
Expand Down Expand Up @@ -137,13 +138,13 @@ def make_input(project: ratapi.Project, controls: ratapi.Controls) -> tuple[Prob
The controls object used in the compiled RAT code.

"""
problem = make_problem(project)
problem = make_problem(project, controls.procedure != Procedures.Calculate)
cpp_controls = make_controls(controls)

return problem, cpp_controls


def make_problem(project: ratapi.Project) -> ProblemDefinition:
def make_problem(project: ratapi.Project, validate_range: bool = False) -> ProblemDefinition:
"""Construct the problem input required for the compiled RAT code.

Parameters
Expand Down Expand Up @@ -351,40 +352,42 @@ def make_problem(project: ratapi.Project) -> ProblemDefinition:
problem.domainContrastLayers = [
domain_contrast_model if domain_contrast_model else [] for domain_contrast_model in domain_contrast_models
]
problem.fitParams = [
param.value
for class_list in ratapi.project.parameter_class_lists
for param in getattr(project, class_list)
if param.fit
]
problem.fitLimits = [
[param.min, param.max]
for class_list in ratapi.project.parameter_class_lists
for param in getattr(project, class_list)
if param.fit
]
problem.priorNames = [
param.name for class_list in ratapi.project.parameter_class_lists for param in getattr(project, class_list)
]
problem.priorValues = [
[prior_id[param.prior_type], param.mu, param.sigma]
for class_list in ratapi.project.parameter_class_lists
for param in getattr(project, class_list)
]

fit_params = []
fit_limits = []
prior_names = []
prior_values = []
problem.checks = Checks()
for class_list in ratapi.project.parameter_class_lists:
field = parameter_field[class_list]
check_list = []
for param in getattr(project, class_list):
prior_names.append(param.name)
prior_values.append([prior_id[param.prior_type], param.mu, param.sigma])
check_list.append(int(param.fit))
if param.fit:
if validate_range and (param.max - param.min) < 1e-10:
warnings.warn(
f'{class_list.replace("_", " ").title()} "{param.name}" was removed from the '
"fit because its range is too small (< 1e-10).",
stacklevel=2,
)
check_list[-1] = 0
else:
fit_params.append(param.value)
fit_limits.append([param.min, param.max])
setattr(problem.checks, field, check_list)
problem.fitParams = fit_params
problem.fitLimits = fit_limits
problem.priorNames = prior_names
problem.priorValues = prior_values

# Names
problem.names = NameStore()
for class_list in ratapi.project.parameter_class_lists:
setattr(problem.names, parameter_field[class_list], [param.name for param in getattr(project, class_list)])
problem.names.contrasts = [contrast.name for contrast in project.contrasts]

# Checks
problem.checks = Checks()
for class_list in ratapi.project.parameter_class_lists:
setattr(
problem.checks, parameter_field[class_list], [int(element.fit) for element in getattr(project, class_list)]
)

check_indices(problem)

return problem
Expand Down
4 changes: 3 additions & 1 deletion ratapi/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def run(project, controls):
# Update parameter values in project
for class_list in ratapi.project.parameter_class_lists:
for index, value in enumerate(getattr(problem_definition, parameter_field[class_list])):
getattr(project, class_list)[index].value = value
param = getattr(project, class_list)[index]
param.fit = bool(getattr(problem_definition.checks, parameter_field[class_list])[index])
param.value = value

controls.delete_IPC()

Expand Down
13 changes: 13 additions & 0 deletions tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,19 @@ def test_make_problem(test_project, test_problem, request) -> None:
check_problem_equal(problem, test_problem)


def test_make_problem_validate_range(request) -> None:
"""The problem should not contain fitted parameters with small range."""
test_project = request.getfixturevalue("standard_layers_project")

test_project.scalefactors.set_fields(0, min=10, value=10, max=10, fit=True)
problem = make_problem(test_project)
assert problem.checks.scalefactors[0] == 1

with pytest.warns(UserWarning, match="was removed from the fit because its range is too small \(< 1e-10\)"):
problem = make_problem(test_project, True)
assert problem.checks.scalefactors[0] == 0


@pytest.mark.parametrize("test_problem", ["standard_layers_problem", "custom_xy_problem", "domains_problem"])
class TestCheckIndices:
"""Tests for check_indices over a set of three test problems."""
Expand Down
Loading