From 4924cf173ac5b0ef3e0664a3da832d10da1d4b01 Mon Sep 17 00:00:00 2001 From: archis Date: Tue, 16 Jun 2026 18:11:38 -0700 Subject: [PATCH] Add a diagnostic data-provider interface to isolate OMEGA specifics Realizes the primary goal from #105 (per @almilder): isolate the OMEGA-specific diagnostic specifics (raw file loading, calibration, background, lineouts) behind a minimal interface so other users can plug in their own diagnostic (e.g. NIF OTS) or supply preprocessed data -- without touching the jax/AD model code. The whole pipeline downstream of data loading already consumes only a plain `(all_data, sa, all_axes)` tuple. This turns that existing seam into an explicit, registerable contract: - `tsadar/data/__init__.py`: `register_data_provider(name)` + `get_data_provider(config)` registry. A provider is any `config -> (all_data, sa, all_axes)` callable. Selected by `config["data"]["provider"]`, defaulting to "omega". - `tsadar/data/omega.py`: the OMEGA loader (moved out of `fitter`) as the reference provider, self-registered as "omega". - `fitter.fit` now resolves data via `get_data_provider(config)(config)`. `load_data_for_fitting` is kept as a thin back-compat shim. Behavior-preserving: configs without a `provider` key resolve to "omega" and follow the exact previous code path. External diagnostics register with `tsadar.data.register_data_provider` and require no jax/AD knowledge. Toward #105. Stacked on the coupling-break PR (#106). Co-Authored-By: Claude Opus 4.8 (1M context) --- tsadar/data/__init__.py | 62 ++++++++++++++++++++++++++++++++++++++++ tsadar/data/omega.py | 38 ++++++++++++++++++++++++ tsadar/inverse/fitter.py | 32 +++++++-------------- 3 files changed, 110 insertions(+), 22 deletions(-) create mode 100644 tsadar/data/omega.py diff --git a/tsadar/data/__init__.py b/tsadar/data/__init__.py index e69de29b..280ce3b0 100644 --- a/tsadar/data/__init__.py +++ b/tsadar/data/__init__.py @@ -0,0 +1,62 @@ +"""Diagnostic data providers. + +A *data provider* maps a config dict to the ``(all_data, sa, all_axes)`` tuple +that the forward/inverse pipeline consumes: + +- ``all_data``: plain dict of arrays (``e_data``, ``i_data``, ``e_amps``, + ``i_amps``, ``noiseE``, ``noiseI``, ...) +- ``sa``: scattering angles / weights +- ``all_axes``: spectral axes + +Everything downstream (``LossFunction``, ``ThomsonScatteringDiagnostic``, +postprocessing) depends only on this contract -- not on how the data was +produced. All OMEGA-specific behaviour (raw file loading, lineouts, background, +throughput, calibration) lives *upstream* of this seam, inside the ``omega`` +provider. + +This is the extension point for adding a new diagnostic (e.g. NIF OTS) or for +supplying preprocessed data instead of raw OMEGA files -- without touching the +jax/AD model code. Implement a ``provider(config) -> (all_data, sa, all_axes)`` +function and register it:: + + from tsadar.data import register_data_provider + + @register_data_provider("nif_ots") + def load_nif(config): + ... + return all_data, sa, all_axes + +then select it with ``config["data"]["provider"] = "nif_ots"`` (default: +``"omega"``). +""" +from typing import Callable, Dict, Tuple + +DataProvider = Callable[[Dict], Tuple[Dict, Dict, Dict]] + +_DATA_PROVIDERS: Dict[str, DataProvider] = {} + + +def register_data_provider(name: str) -> Callable[[DataProvider], DataProvider]: + """Register a ``config -> (all_data, sa, all_axes)`` provider under ``name``.""" + + def _register(fn: DataProvider) -> DataProvider: + _DATA_PROVIDERS[name.casefold()] = fn + return fn + + return _register + + +def get_data_provider(config: Dict) -> DataProvider: + """Return the provider selected by ``config["data"]["provider"]`` (default ``"omega"``).""" + name = config.get("data", {}).get("provider", "omega").casefold() + try: + return _DATA_PROVIDERS[name] + except KeyError: + raise KeyError( + f"Unknown data provider {name!r}. Registered providers: {sorted(_DATA_PROVIDERS)}. " + f"Register one with tsadar.data.register_data_provider." + ) + + +# Import built-in providers for their registration side effects. +from . import omega # noqa: E402,F401 diff --git a/tsadar/data/omega.py b/tsadar/data/omega.py new file mode 100644 index 00000000..b133a172 --- /dev/null +++ b/tsadar/data/omega.py @@ -0,0 +1,38 @@ +"""OMEGA Thomson scattering data provider. + +Loads and preprocesses raw OMEGA data (file loading, background, throughput, +lineouts, calibration) into the ``(all_data, sa, all_axes)`` contract. This is +the reference implementation of the data-provider interface; see +``tsadar.data`` for how to register a provider for another diagnostic. +""" +from typing import Dict, Tuple + +from . import prepare, register_data_provider + + +@register_data_provider("omega") +def load_omega_data(config: Dict) -> Tuple[Dict, Dict, Dict]: + """Load and preprocess OMEGA TS data into ``(all_data, sa, all_axes)``. + + Handles the single-shot case as well as the multiplexed (rotated) angular + case where ``config["data"]["shotnum"]`` is a two-element list. + """ + if isinstance(config["data"]["shotnum"], list): + startCCDsize = config["other"]["CCDsize"] + all_data, sa, all_axes = prepare.prepare_data(config, config["data"]["shotnum"][0]) + config["other"]["CCDsize"] = startCCDsize + all_data2, _, _ = prepare.prepare_data(config, config["data"]["shotnum"][1]) + all_data.update( + { + "e_data_rot": all_data2["e_data"], + "e_amps_rot": all_data2["e_amps"], + "rot_angle": config["data"]["shot_rot"], + "noiseE_rot": all_data2["noiseE"], + } + ) + + if config["other"]["extraoptions"]["spectype"] != "angular_full": + raise NotImplementedError("Muliplexed data fitting is only availible for angular data") + else: + all_data, sa, all_axes = prepare.prepare_data(config, config["data"]["shotnum"]) + return all_data, sa, all_axes diff --git a/tsadar/inverse/fitter.py b/tsadar/inverse/fitter.py index 210063e0..b388eaa4 100644 --- a/tsadar/inverse/fitter.py +++ b/tsadar/inverse/fitter.py @@ -7,7 +7,7 @@ from tsadar.inverse.loops import multirun_angular_optax, one_d_loop -from ..data import prepare +from ..data import get_data_provider from . import postprocess @@ -119,8 +119,8 @@ def fit(config) -> Tuple[pd.DataFrame, float]: mlflow.set_tag("status", "preprocessing") config = _validate_inputs_(config) - # prepare data - all_data, sa, all_axes = load_data_for_fitting(config) + # prepare data via the configured diagnostic data provider (default: OMEGA) + all_data, sa, all_axes = get_data_provider(config)(config) sample_indices = np.arange(max(len(all_data["e_data"]), len(all_data["i_data"]))) num_batches = len(sample_indices) // config["optimizer"]["batch_size"] or 1 mlflow.log_metrics({"setup_time": round(time.time() - t1, 2)}) @@ -146,22 +146,10 @@ def fit(config) -> Tuple[pd.DataFrame, float]: def load_data_for_fitting(config): - if isinstance(config["data"]["shotnum"], list): - startCCDsize = config["other"]["CCDsize"] - all_data, sa, all_axes = prepare.prepare_data(config, config["data"]["shotnum"][0]) - config["other"]["CCDsize"] = startCCDsize - all_data2, _, _ = prepare.prepare_data(config, config["data"]["shotnum"][1]) - all_data.update( - { - "e_data_rot": all_data2["e_data"], - "e_amps_rot": all_data2["e_amps"], - "rot_angle": config["data"]["shot_rot"], - "noiseE_rot": all_data2["noiseE"], - } - ) - - if config["other"]["extraoptions"]["spectype"] != "angular_full": - raise NotImplementedError("Muliplexed data fitting is only availible for angular data") - else: - all_data, sa, all_axes = prepare.prepare_data(config, config["data"]["shotnum"]) - return all_data, sa, all_axes + """Backwards-compatible shim; dispatches to the configured data provider. + + Prefer ``tsadar.data.get_data_provider(config)(config)`` directly. Selects + the OMEGA loader by default; see ``tsadar.data`` to plug in another + diagnostic or supply preprocessed data. + """ + return get_data_provider(config)(config)