diff --git a/ratapi/inputs.py b/ratapi/inputs.py index bdfadda..736a6cb 100644 --- a/ratapi/inputs.py +++ b/ratapi/inputs.py @@ -3,6 +3,7 @@ import importlib import os import pathlib +import warnings from collections.abc import Callable import numpy as np @@ -10,7 +11,7 @@ import ratapi import ratapi.wrappers from ratapi.rat_core import Checks, Control, NameStore, ProblemDefinition -from ratapi.utils.enums import Calculations, Languages, LayerModels, TypeOptions +from ratapi.utils.enums import Calculations, Languages, LayerModels, Procedures, TypeOptions parameter_field = { "parameters": "params", @@ -137,13 +138,13 @@ def make_input(project: ratapi.Project, controls: ratapi.Controls) -> tuple[Prob The controls object used in the compiled RAT code. """ - problem = make_problem(project) + problem = make_problem(project, controls.procedure != Procedures.Calculate) cpp_controls = make_controls(controls) return problem, cpp_controls -def make_problem(project: ratapi.Project) -> ProblemDefinition: +def make_problem(project: ratapi.Project, validate_range: bool = False) -> ProblemDefinition: """Construct the problem input required for the compiled RAT code. Parameters @@ -351,26 +352,35 @@ def make_problem(project: ratapi.Project) -> ProblemDefinition: problem.domainContrastLayers = [ domain_contrast_model if domain_contrast_model else [] for domain_contrast_model in domain_contrast_models ] - problem.fitParams = [ - param.value - for class_list in ratapi.project.parameter_class_lists - for param in getattr(project, class_list) - if param.fit - ] - problem.fitLimits = [ - [param.min, param.max] - for class_list in ratapi.project.parameter_class_lists - for param in getattr(project, class_list) - if param.fit - ] - problem.priorNames = [ - param.name for class_list in ratapi.project.parameter_class_lists for param in getattr(project, class_list) - ] - problem.priorValues = [ - [prior_id[param.prior_type], param.mu, param.sigma] - for class_list in ratapi.project.parameter_class_lists - for param in getattr(project, class_list) - ] + + fit_params = [] + fit_limits = [] + prior_names = [] + prior_values = [] + problem.checks = Checks() + for class_list in ratapi.project.parameter_class_lists: + field = parameter_field[class_list] + check_list = [] + for param in getattr(project, class_list): + prior_names.append(param.name) + prior_values.append([prior_id[param.prior_type], param.mu, param.sigma]) + check_list.append(int(param.fit)) + if param.fit: + if validate_range and (param.max - param.min) < 1e-10: + warnings.warn( + f'{class_list.replace("_", " ").title()} "{param.name}" was removed from the ' + "fit because its range is too small (< 1e-10).", + stacklevel=2, + ) + check_list[-1] = 0 + else: + fit_params.append(param.value) + fit_limits.append([param.min, param.max]) + setattr(problem.checks, field, check_list) + problem.fitParams = fit_params + problem.fitLimits = fit_limits + problem.priorNames = prior_names + problem.priorValues = prior_values # Names problem.names = NameStore() @@ -378,13 +388,6 @@ def make_problem(project: ratapi.Project) -> ProblemDefinition: setattr(problem.names, parameter_field[class_list], [param.name for param in getattr(project, class_list)]) problem.names.contrasts = [contrast.name for contrast in project.contrasts] - # Checks - problem.checks = Checks() - for class_list in ratapi.project.parameter_class_lists: - setattr( - problem.checks, parameter_field[class_list], [int(element.fit) for element in getattr(project, class_list)] - ) - check_indices(problem) return problem diff --git a/ratapi/run.py b/ratapi/run.py index d973d06..9a99224 100644 --- a/ratapi/run.py +++ b/ratapi/run.py @@ -130,7 +130,9 @@ def run(project, controls): # Update parameter values in project for class_list in ratapi.project.parameter_class_lists: for index, value in enumerate(getattr(problem_definition, parameter_field[class_list])): - getattr(project, class_list)[index].value = value + param = getattr(project, class_list)[index] + param.fit = bool(getattr(problem_definition.checks, parameter_field[class_list])[index]) + param.value = value controls.delete_IPC() diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 242d748..23dc2ad 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -483,6 +483,19 @@ def test_make_problem(test_project, test_problem, request) -> None: check_problem_equal(problem, test_problem) +def test_make_problem_validate_range(request) -> None: + """The problem should not contain fitted parameters with small range.""" + test_project = request.getfixturevalue("standard_layers_project") + + test_project.scalefactors.set_fields(0, min=10, value=10, max=10, fit=True) + problem = make_problem(test_project) + assert problem.checks.scalefactors[0] == 1 + + with pytest.warns(UserWarning, match="was removed from the fit because its range is too small \(< 1e-10\)"): + problem = make_problem(test_project, True) + assert problem.checks.scalefactors[0] == 0 + + @pytest.mark.parametrize("test_problem", ["standard_layers_problem", "custom_xy_problem", "domains_problem"]) class TestCheckIndices: """Tests for check_indices over a set of three test problems."""