Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions tsadar/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Diagnostic data providers.

A *data provider* maps a config dict to the ``(all_data, sa, all_axes)`` tuple
that the forward/inverse pipeline consumes:

- ``all_data``: plain dict of arrays (``e_data``, ``i_data``, ``e_amps``,
``i_amps``, ``noiseE``, ``noiseI``, ...)
- ``sa``: scattering angles / weights
- ``all_axes``: spectral axes

Everything downstream (``LossFunction``, ``ThomsonScatteringDiagnostic``,
postprocessing) depends only on this contract -- not on how the data was
produced. All OMEGA-specific behaviour (raw file loading, lineouts, background,
throughput, calibration) lives *upstream* of this seam, inside the ``omega``
provider.

This is the extension point for adding a new diagnostic (e.g. NIF OTS) or for
supplying preprocessed data instead of raw OMEGA files -- without touching the
jax/AD model code. Implement a ``provider(config) -> (all_data, sa, all_axes)``
function and register it::

from tsadar.data import register_data_provider

@register_data_provider("nif_ots")
def load_nif(config):
...
return all_data, sa, all_axes

then select it with ``config["data"]["provider"] = "nif_ots"`` (default:
``"omega"``).
"""
from typing import Callable, Dict, Tuple

DataProvider = Callable[[Dict], Tuple[Dict, Dict, Dict]]

_DATA_PROVIDERS: Dict[str, DataProvider] = {}


def register_data_provider(name: str) -> Callable[[DataProvider], DataProvider]:
"""Register a ``config -> (all_data, sa, all_axes)`` provider under ``name``."""

def _register(fn: DataProvider) -> DataProvider:
_DATA_PROVIDERS[name.casefold()] = fn
return fn

return _register


def get_data_provider(config: Dict) -> DataProvider:
"""Return the provider selected by ``config["data"]["provider"]`` (default ``"omega"``)."""
name = config.get("data", {}).get("provider", "omega").casefold()
try:
return _DATA_PROVIDERS[name]
except KeyError:
raise KeyError(
f"Unknown data provider {name!r}. Registered providers: {sorted(_DATA_PROVIDERS)}. "
f"Register one with tsadar.data.register_data_provider."
)


# Import built-in providers for their registration side effects.
from . import omega # noqa: E402,F401
38 changes: 38 additions & 0 deletions tsadar/data/omega.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""OMEGA Thomson scattering data provider.

Loads and preprocesses raw OMEGA data (file loading, background, throughput,
lineouts, calibration) into the ``(all_data, sa, all_axes)`` contract. This is
the reference implementation of the data-provider interface; see
``tsadar.data`` for how to register a provider for another diagnostic.
"""
from typing import Dict, Tuple

from . import prepare, register_data_provider


@register_data_provider("omega")
def load_omega_data(config: Dict) -> Tuple[Dict, Dict, Dict]:
"""Load and preprocess OMEGA TS data into ``(all_data, sa, all_axes)``.

Handles the single-shot case as well as the multiplexed (rotated) angular
case where ``config["data"]["shotnum"]`` is a two-element list.
"""
if isinstance(config["data"]["shotnum"], list):
startCCDsize = config["other"]["CCDsize"]
all_data, sa, all_axes = prepare.prepare_data(config, config["data"]["shotnum"][0])
config["other"]["CCDsize"] = startCCDsize
all_data2, _, _ = prepare.prepare_data(config, config["data"]["shotnum"][1])
all_data.update(
{
"e_data_rot": all_data2["e_data"],
"e_amps_rot": all_data2["e_amps"],
"rot_angle": config["data"]["shot_rot"],
"noiseE_rot": all_data2["noiseE"],
}
)

if config["other"]["extraoptions"]["spectype"] != "angular_full":
raise NotImplementedError("Muliplexed data fitting is only availible for angular data")
else:
all_data, sa, all_axes = prepare.prepare_data(config, config["data"]["shotnum"])
return all_data, sa, all_axes
32 changes: 10 additions & 22 deletions tsadar/inverse/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from tsadar.inverse.loops import multirun_angular_optax, one_d_loop

from ..data import prepare
from ..data import get_data_provider
from . import postprocess


Expand Down Expand Up @@ -119,8 +119,8 @@ def fit(config) -> Tuple[pd.DataFrame, float]:
mlflow.set_tag("status", "preprocessing")
config = _validate_inputs_(config)

# prepare data
all_data, sa, all_axes = load_data_for_fitting(config)
# prepare data via the configured diagnostic data provider (default: OMEGA)
all_data, sa, all_axes = get_data_provider(config)(config)
sample_indices = np.arange(max(len(all_data["e_data"]), len(all_data["i_data"])))
num_batches = len(sample_indices) // config["optimizer"]["batch_size"] or 1
mlflow.log_metrics({"setup_time": round(time.time() - t1, 2)})
Expand All @@ -146,22 +146,10 @@ def fit(config) -> Tuple[pd.DataFrame, float]:


def load_data_for_fitting(config):
if isinstance(config["data"]["shotnum"], list):
startCCDsize = config["other"]["CCDsize"]
all_data, sa, all_axes = prepare.prepare_data(config, config["data"]["shotnum"][0])
config["other"]["CCDsize"] = startCCDsize
all_data2, _, _ = prepare.prepare_data(config, config["data"]["shotnum"][1])
all_data.update(
{
"e_data_rot": all_data2["e_data"],
"e_amps_rot": all_data2["e_amps"],
"rot_angle": config["data"]["shot_rot"],
"noiseE_rot": all_data2["noiseE"],
}
)

if config["other"]["extraoptions"]["spectype"] != "angular_full":
raise NotImplementedError("Muliplexed data fitting is only availible for angular data")
else:
all_data, sa, all_axes = prepare.prepare_data(config, config["data"]["shotnum"])
return all_data, sa, all_axes
"""Backwards-compatible shim; dispatches to the configured data provider.

Prefer ``tsadar.data.get_data_provider(config)(config)`` directly. Selects
the OMEGA loader by default; see ``tsadar.data`` to plug in another
diagnostic or supply preprocessed data.
"""
return get_data_provider(config)(config)