From 7863a5a9e99e80e966baede9bc78f3ae7aea66b8 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Tue, 26 May 2026 11:45:45 -0500 Subject: [PATCH] Optimize gRASPA batch setup with CIF caching and parallel processing Cache cell size computation per CIF so it is read once instead of once per (T, P) combination. Process CIFs in parallel using ThreadPoolExecutor. Add --workers CLI flag and max_workers YAML config option. --- scripts/graspa/isotherm_config.yaml | 1 + scripts/graspa/setup_isotherms.py | 9 ++- src/matkit/cli.py | 9 ++- src/matkit/graspa/graspa.py | 102 +++++++++++++++++++++------- 4 files changed, 93 insertions(+), 28 deletions(-) diff --git a/scripts/graspa/isotherm_config.yaml b/scripts/graspa/isotherm_config.yaml index 86e86b4..3c0f364 100644 --- a/scripts/graspa/isotherm_config.yaml +++ b/scripts/graspa/isotherm_config.yaml @@ -40,3 +40,4 @@ pressure_unit: bar # bar, kPa, atm, or Pa cutoff: 12.8 cycles: 1000 +max_workers: 4 # parallel threads for CIF processing (null = auto) diff --git a/scripts/graspa/setup_isotherms.py b/scripts/graspa/setup_isotherms.py index 4167833..f499538 100644 --- a/scripts/graspa/setup_isotherms.py +++ b/scripts/graspa/setup_isotherms.py @@ -44,9 +44,11 @@ def main(): print(f"CIF dir: {cfg['cif_dir']}") print(f"Adsorbates: {cfg['adsorbates']}") print(f"Temperatures: {cfg['temperatures']}") - print(f"Pressures: {len(pressures_pa)} points " - f"({cfg['pressures'][0]}-{cfg['pressures'][-1]} " - f"{cfg.get('pressure_unit', 'Pa')})") + print( + f"Pressures: {len(pressures_pa)} points " + f"({cfg['pressures'][0]}-{cfg['pressures'][-1]} " + f"{cfg.get('pressure_unit', 'Pa')})" + ) manifest = setup_batch( cif_dir=cfg["cif_dir"], @@ -56,6 +58,7 @@ def main(): pressures=pressures_pa, cutoff=cfg.get("cutoff", 12.8), n_cycle=cfg.get("cycles", 1000), + max_workers=cfg.get("max_workers"), ) print(f"\nSet up {len(manifest)} simulations in {cfg['outdir']}") diff --git a/src/matkit/cli.py b/src/matkit/cli.py index 929df79..966b352 100644 --- a/src/matkit/cli.py +++ b/src/matkit/cli.py @@ -97,8 +97,14 @@ def graspa_setup(cif, outdir, adsorbate, temp, pressure, cutoff, cycles): ) @click.option("--cutoff", default=12.8, help="Cutoff radius in Angstrom.") @click.option("--cycles", default=1000, help="Number of MC cycles.") +@click.option( + "--workers", + default=None, + type=int, + help="Max parallel threads for CIF processing.", +) def graspa_batch_setup( - cif_dir, outdir, adsorbate, temp, pressure, cutoff, cycles + cif_dir, outdir, adsorbate, temp, pressure, cutoff, cycles, workers ): """Set up gRASPA simulations for all CIF x T x P.""" from matkit.graspa import setup_batch @@ -114,6 +120,7 @@ def graspa_batch_setup( pressures=list(pressure), cutoff=cutoff, n_cycle=cycles, + max_workers=workers, ) click.echo(f"Set up {len(manifest)} simulations in {outdir}") click.echo(f"Manifest written to {outdir}/simulations.jsonl") diff --git a/src/matkit/graspa/graspa.py b/src/matkit/graspa/graspa.py index b9be117..1d9d408 100644 --- a/src/matkit/graspa/graspa.py +++ b/src/matkit/graspa/graspa.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +from concurrent.futures import ThreadPoolExecutor from itertools import product from pathlib import Path import shutil @@ -163,6 +164,7 @@ def setup_simulation( cutoff: float = 12.8, n_cycle: int = 1000, template_dir: str = "template", + cell_size: list[int] | None = None, ) -> bool: """Set up a gRASPA GCMC simulation. @@ -179,6 +181,8 @@ def setup_simulation( cutoff: Van der Waals cutoff radius in Angstrom. n_cycle: Number of Monte Carlo cycles. template_dir: Template subdirectory name under files/. + cell_size: Pre-computed unit cell dimensions [uc_x, uc_y, uc_z]. + If provided, skips reading the CIF to calculate cell size. Returns: True on success. @@ -203,9 +207,12 @@ def setup_simulation( else: shutil.copy2(item, outdir) - # Read CIF to get cell size - atoms = ase_read(cifpath) - uc_x, uc_y, uc_z = calculate_cell_size(atoms) + # Use pre-computed cell size or read CIF + if cell_size is not None: + uc_x, uc_y, uc_z = cell_size + else: + atoms = ase_read(cifpath) + uc_x, uc_y, uc_z = calculate_cell_size(atoms) # Read template and replace placeholders input_path = outdir / "simulation.input" @@ -236,6 +243,50 @@ def setup_simulation( return True +def _setup_single_cif( + cif: Path, + out_path: Path, + adsorbates: list[dict], + temperatures: list[float], + pressures: list[float], + cutoff: float, + n_cycle: int, + template_dir: str, +) -> list[dict]: + """Set up all T x P simulations for a single CIF file. + + Reads the CIF once to compute cell size, then creates a simulation + directory for each (temperature, pressure) combination. + """ + atoms = ase_read(cif) + cell_size = calculate_cell_size(atoms) + + entries = [] + for temp, pres in product(temperatures, pressures): + sim_dir = out_path / cif.stem / f"T{temp}_P{pres:g}" + setup_simulation( + cif=str(cif), + outpath=str(sim_dir), + adsorbates=adsorbates, + temperature=temp, + pressure=pres, + cutoff=cutoff, + n_cycle=n_cycle, + template_dir=template_dir, + cell_size=cell_size, + ) + entries.append( + { + "sim_dir": str(sim_dir), + "cif": cif.name, + "temperature": temp, + "pressure": pres, + "adsorbates": [ad["MoleculeName"] for ad in adsorbates], + } + ) + return entries + + def setup_batch( cif_dir: str, outpath: str, @@ -245,11 +296,15 @@ def setup_batch( cutoff: float = 12.8, n_cycle: int = 1000, template_dir: str = "template", + max_workers: int | None = None, ) -> list[dict]: """Set up gRASPA simulations for all CIF x T x P. Discovers all .cif files in cif_dir and creates a simulation directory - for each (CIF, temperature, pressure) combination using setup_simulation(). + for each (CIF, temperature, pressure) combination. Each CIF is read + once to compute cell size, then all T x P combinations reuse the + cached result. CIFs are processed in parallel using threads. + Writes a simulations.jsonl manifest to outpath. Args: @@ -261,6 +316,8 @@ def setup_batch( cutoff: Van der Waals cutoff radius in Angstrom. n_cycle: Number of Monte Carlo cycles. template_dir: Template subdirectory name under files/. + max_workers: Max threads for parallel CIF processing. + Defaults to None (lets ThreadPoolExecutor choose). Returns: List of manifest dicts, each with keys: sim_dir, cif, @@ -280,26 +337,23 @@ def setup_batch( raise ValueError(f"No .cif files found in {cif_dir}") manifest = [] - for cif, temp, pres in product(cif_files, temperatures, pressures): - sim_dir = out_path / cif.stem / f"T{temp}_P{pres:g}" - setup_simulation( - cif=str(cif), - outpath=str(sim_dir), - adsorbates=adsorbates, - temperature=temp, - pressure=pres, - cutoff=cutoff, - n_cycle=n_cycle, - template_dir=template_dir, - ) - entry = { - "sim_dir": str(sim_dir), - "cif": cif.name, - "temperature": temp, - "pressure": pres, - "adsorbates": [ad["MoleculeName"] for ad in adsorbates], - } - manifest.append(entry) + with ThreadPoolExecutor(max_workers=max_workers) as pool: + futures = [ + pool.submit( + _setup_single_cif, + cif, + out_path, + adsorbates, + temperatures, + pressures, + cutoff, + n_cycle, + template_dir, + ) + for cif in cif_files + ] + for future in futures: + manifest.extend(future.result()) manifest_path = out_path / "simulations.jsonl" with manifest_path.open("w") as f: