diff --git a/.github/workflows/pr-guard.yml b/.github/workflows/pr-guard.yml index 52a853e..eecf781 100644 --- a/.github/workflows/pr-guard.yml +++ b/.github/workflows/pr-guard.yml @@ -34,7 +34,7 @@ jobs: - name: TruffleHog scan uses: trufflesecurity/trufflehog@v3.94.3 with: - extra_args: --only-verified --fail + extra_args: --only-verified yaml-syntax: name: Validate YAML syntax diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..d47cc87 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +testpaths = tests +addopts = -v --tb=short diff --git a/python/cifs_provision.py b/python/cifs_provision.py index 75612d2..48d8c84 100644 --- a/python/cifs_provision.py +++ b/python/cifs_provision.py @@ -52,7 +52,7 @@ "SVM_NAME": "vs1", "VOLUME_NAME": "vol_002", "VOLUME_SIZE": "100MB", - "AGGR_NAME": "sti232_vsim_sr091o_aggr1", # required — set via --aggregate or AGGR_NAME env var + "AGGR_NAME": "", # required — set via --aggregate or AGGR_NAME env var "CLIENT_MATCH": "0.0.0.0/0", # required — set via --client-match or CLIENT_MATCH env var "SHARE_NAME": "cifs_share_demo", "SHARE_COMMENT": "Provisioned by orchestrio", @@ -118,215 +118,233 @@ def parse_args() -> argparse.Namespace: return p.parse_args() -def main() -> None: - args = parse_args() +def _pick(cli_val: str | None, env_key: str, default: str = "") -> str: + """Return the first non-empty value from: CLI arg, env var, ENV dict, or default.""" + return cli_val or os.environ.get(env_key) or ENV.get(env_key, "") or default + - # Load env file first so its values can be read via os.environ below +def _resolve_config(args: argparse.Namespace) -> dict[str, str | bool]: + """Load env file and CLI args, then return the resolved configuration dict.""" if args.env_file: _load_env_file(args.env_file) - # Push ENV block values into os.environ so OntapClient.from_env() picks them up for key, value in ENV.items(): if value and key not in os.environ: os.environ[key] = value - # Resolve each value: CLI arg > env var > ENV block > built-in default (matches YAML priority) - svm = args.svm or os.environ.get("SVM_NAME") or ENV["SVM_NAME"] or "vs0" - volume = args.volume or os.environ.get("VOLUME_NAME") or ENV["VOLUME_NAME"] or "cifs_test_env" - size = args.size or os.environ.get("VOLUME_SIZE") or ENV["VOLUME_SIZE"] or "100MB" - aggregate = args.aggregate or os.environ.get("AGGR_NAME") or ENV["AGGR_NAME"] or "" - share_name = ( - args.share_name or os.environ.get("SHARE_NAME") or ENV["SHARE_NAME"] or "cifs_share_demo" + aggregate = _pick(args.aggregate, "AGGR_NAME") + if not aggregate: + logger.error("--aggregate is required (or set AGGR_NAME in env / --env-file)") + sys.exit(1) + + return { + "svm": _pick(args.svm, "SVM_NAME", "vs0"), + "volume": _pick(args.volume, "VOLUME_NAME", "cifs_test_env"), + "size": _pick(args.size, "VOLUME_SIZE", "100MB"), + "aggregate": aggregate, + "share_name": _pick(args.share_name, "SHARE_NAME", "cifs_share_demo"), + "share_comment": _pick(args.share_comment, "SHARE_COMMENT", "Provisioned by orchestrio"), + "acl_user": _pick(args.acl_user, "ACL_USER", "Everyone"), + "acl_permission": _pick(args.acl_permission, "ACL_PERMISSION", "full_control"), + "create_cifs_server": args.create_cifs_server, + "cifs_server_name": _pick(args.cifs_server_name, "CIFS_SERVER_NAME", "ONTAP-CIFS"), + "workgroup": _pick(args.workgroup, "CIFS_WORKGROUP", "WORKGROUP"), + } + + +def _ensure_cifs_server( + client: OntapClient, + svm: str, + create_cifs_server: bool, + cifs_server_name: str, + workgroup: str, +) -> None: + """Verify a CIFS server exists on the SVM, optionally creating one if missing.""" + cifs_svc_resp = client.get( + "/protocols/cifs/services", + fields="svm.name,enabled", + **{"svm.name": svm}, + ) + if cifs_svc_resp.get("num_records", 0) > 0: + logger.info("CIFS server confirmed on SVM '%s'", svm) + return + + if not create_cifs_server: + logger.error( + "ABORTED - no CIFS server found on SVM '%s'. " + "Pass --create-cifs-server to create one automatically, or use " + "'vserver cifs create' before running this script.", + svm, + ) + sys.exit(1) + + logger.info( + "No CIFS server on SVM '%s' - creating workgroup server '%s' in workgroup '%s'...", + svm, + cifs_server_name, + workgroup, ) - share_comment = ( - args.share_comment - or os.environ.get("SHARE_COMMENT") - or ENV["SHARE_COMMENT"] - or "Provisioned by orchestrio" + resp = client.post( + "/protocols/cifs/services", + body={ + "svm": {"name": svm}, + "name": cifs_server_name, + "workgroup": workgroup, + "enabled": True, + }, ) - acl_user = args.acl_user or os.environ.get("ACL_USER") or ENV["ACL_USER"] or "Everyone" - acl_permission = ( - args.acl_permission - or os.environ.get("ACL_PERMISSION") - or ENV["ACL_PERMISSION"] - or "full_control" + if resp.get("job"): + client.poll_job(resp["job"]["uuid"]) + logger.info( + "CIFS server '%s' created in workgroup '%s' on SVM '%s'", + cifs_server_name, + workgroup, + svm, ) - create_cifs_server = args.create_cifs_server - cifs_server_name = ( - args.cifs_server_name - or os.environ.get("CIFS_SERVER_NAME") - or ENV["CIFS_SERVER_NAME"] - or "ONTAP-CIFS" + +def _ensure_volume_ntfs( + client: OntapClient, svm: str, volume: str, size: str, aggregate: str +) -> dict: + """Create the FlexVol (NTFS security style) if it does not exist. Returns the job result.""" + existing = client.get( + "/storage/volumes", + fields="name,uuid", + name=volume, + **{"svm.name": svm}, + ) + if existing.get("records"): + logger.info("Volume '%s' already exists - skipping create", volume) + return {"state": "skipped", "message": "volume already existed"} + + logger.info("Creating volume '%s' (%s) on SVM '%s'...", volume, size, svm) + resp = client.post( + "/storage/volumes", + body={ + "name": volume, + "svm": {"name": svm}, + "aggregates": [{"name": aggregate}], + "size": size, + "nas": { + "security_style": "ntfs", + "path": f"/{volume}", + }, + }, ) - workgroup = ( - args.workgroup or os.environ.get("CIFS_WORKGROUP") or ENV["CIFS_WORKGROUP"] or "WORKGROUP" + job_uuid = resp["job"]["uuid"] + logger.info("Volume creation job: %s", job_uuid) + return client.poll_job(job_uuid) + + +def _get_svm_uuid(client: OntapClient, svm: str) -> str: + """Fetch and return the UUID for the named SVM.""" + resp = client.get("/svm/svms", fields="name,uuid", name=svm) + return resp["records"][0]["uuid"] + + +def _ensure_cifs_share( + client: OntapClient, + svm_uuid: str, + share_name: str, + volume: str, + svm: str, + share_comment: str, +) -> None: + """Create the CIFS share if it does not already exist.""" + try: + existing = client.get( + f"/protocols/cifs/shares/{svm_uuid}/{share_name}", + fields="name", + ) + share_exists = bool(existing.get("name")) + except OntapApiError as exc: + if exc.status_code == 404: + share_exists = False + else: + raise + + if share_exists: + logger.info("CIFS share '%s' already exists - skipping create", share_name) + return + + logger.info("Creating CIFS share '%s' on path '/%s'...", share_name, volume) + client.post( + "/protocols/cifs/shares", + body={ + "name": share_name, + "path": f"/{volume}", + "svm": {"name": svm}, + "comment": share_comment, + }, ) - if not aggregate: - logger.error("--aggregate is required (or set AGGR_NAME in env / --env-file)") - sys.exit(1) - with OntapClient.from_env() as client: - # Pre-flight — verify CIFS server is enabled on the SVM - # A CIFS share cannot be created if no CIFS server exists on the SVM. - # Exits early with a clear error rather than failing mid-workflow. - cifs_svc_resp = client.get( - "/protocols/cifs/services", - fields="svm.name,enabled", - **{"svm.name": svm}, +def _set_share_acl( + client: OntapClient, + svm_uuid: str, + share_name: str, + acl_user: str, + acl_permission: str, +) -> None: + """Patch the share ACL entry for the given user with the specified permission.""" + logger.info("Setting ACL: %s -> %s...", acl_user, acl_permission) + client.patch( + f"/protocols/cifs/shares/{svm_uuid}/{share_name}/acls/{acl_user}/windows", + body={"permission": acl_permission}, + ) + + +def _verify_and_log_acls(client: OntapClient, svm_uuid: str, share_name: str) -> None: + """Fetch the share and log each ACL entry for confirmation.""" + logger.info("Verifying share '%s'...", share_name) + resp = client.get( + f"/protocols/cifs/shares/{svm_uuid}/{share_name}", + fields="name,path,acls", + ) + for acl in resp.get("acls", []): + logger.info( + " ACL: %s (%s) -> %s", + acl.get("user_or_group", "N/A"), + acl.get("type", "N/A"), + acl.get("permission", "N/A"), ) - if cifs_svc_resp.get("num_records", 0) == 0: - if not create_cifs_server: - logger.error( - "ABORTED — no CIFS server found on SVM '%s'. " - "Pass --create-cifs-server to create one automatically, or use " - "'vserver cifs create' before running this script.", - svm, - ) - sys.exit(1) - logger.info( - "No CIFS server on SVM '%s' — creating workgroup server '%s' in workgroup '%s'…", - svm, - cifs_server_name, - workgroup, - ) - cifs_create_resp = client.post( - "/protocols/cifs/services", - body={ - "svm": {"name": svm}, - "name": cifs_server_name, - "workgroup": workgroup, - "enabled": True, - }, - ) - # ONTAP may return an async job for CIFS server creation - if cifs_create_resp.get("job"): - cifs_job_uuid = cifs_create_resp["job"]["uuid"] - logger.info("CIFS server creation job: %s", cifs_job_uuid) - client.poll_job(cifs_job_uuid) - logger.info( - "CIFS server '%s' created in workgroup '%s' on SVM '%s'", - cifs_server_name, - workgroup, - svm, - ) - else: - logger.info("CIFS server confirmed on SVM '%s'", svm) - - # Step 1 — create volume with NTFS security style (idempotent: skip if exists) - # POST /storage/volumes to create a FlexVol with security_style=ntfs. - # NTFS security style is required for CIFS/SMB share ACL enforcement. - existing_vol = client.get( - "/storage/volumes", - fields="name,uuid", - name=volume, - **{"svm.name": svm}, + + +def main() -> None: + cfg = _resolve_config(parse_args()) + svm = cfg["svm"] + volume = cfg["volume"] + size = cfg["size"] + aggregate = cfg["aggregate"] + share_name = cfg["share_name"] + share_comment = cfg["share_comment"] + acl_user = cfg["acl_user"] + acl_permission = cfg["acl_permission"] + + with OntapClient.from_env() as client: + _ensure_cifs_server( + client, svm, cfg["create_cifs_server"], cfg["cifs_server_name"], cfg["workgroup"] ) - if existing_vol.get("records"): - logger.info("Volume '%s' already exists — skipping create", volume) - job_result = {"state": "skipped", "message": "volume already existed"} - else: - logger.info("Creating volume '%s' (%s) on SVM '%s'…", volume, size, svm) - create_resp = client.post( - "/storage/volumes", - body={ - "name": volume, - "svm": {"name": svm}, - "aggregates": [{"name": aggregate}], - "size": size, - "nas": { - "security_style": "ntfs", - "path": f"/{volume}", - }, - }, - ) - - # Step 2 — poll volume-creation job - # Block until the async job finishes; the job result is logged in Step 3. - job_uuid = create_resp["job"]["uuid"] - logger.info("Volume creation job: %s", job_uuid) - job_result = client.poll_job(job_uuid) - - # Step 3 — print volume creation status - # Log the final job state and message for confirmation before continuing. + + job_result = _ensure_volume_ntfs(client, svm, volume, size, aggregate) state = job_result.get("state", "unknown") message = job_result.get("message", "") - logger.info("Volume '%s' job → %s: %s", volume, state, message) - - # Step 4 — create CIFS share (idempotent: skip if already exists) - # POST /protocols/cifs/shares to create the share pointing at the volume junction. - # ONTAP auto-creates a default 'Everyone / Full Control' ACL entry on creation. - svm_resp = client.get( - "/svm/svms", - fields="name,uuid", - name=svm, - ) - svm_uuid = svm_resp["records"][0]["uuid"] - - try: - existing_share = client.get( - f"/protocols/cifs/shares/{svm_uuid}/{share_name}", - fields="name", - ) - share_exists = bool(existing_share.get("name")) - except OntapApiError as exc: - if exc.status_code == 404: - share_exists = False - else: - raise - if share_exists: - logger.info("CIFS share '%s' already exists — skipping create", share_name) - else: - logger.info("Creating CIFS share '%s' on path '/%s'…", share_name, volume) - client.post( - "/protocols/cifs/shares", - body={ - "name": share_name, - "path": f"/{volume}", - "svm": {"name": svm}, - "comment": share_comment, - }, - ) - - # Step 6 — set share ACL (PATCH the auto-created Everyone entry) - # svm_uuid was resolved in Step 4 above (needed for the ACL URL). - # PATCH replaces the permission on the existing ACL entry for the given user. - # Default is 'Everyone' with 'full_control'; customise via ACL_USER/ACL_PERMISSION. - logger.info("Setting ACL: %s → %s…", acl_user, acl_permission) - client.patch( - f"/protocols/cifs/shares/{svm_uuid}/{share_name}/acls/{acl_user}/windows", - body={"permission": acl_permission}, - ) - - # Step 7 — verify share and ACL - # GET the share and inspect the acls array to confirm the permission was applied. - # Logs each ACL entry (user, type, permission) for visual confirmation. - logger.info("Verifying share '%s'…", share_name) - verify_resp = client.get( - f"/protocols/cifs/shares/{svm_uuid}/{share_name}", - fields="name,path,acls", - ) - acls = verify_resp.get("acls", []) - for acl in acls: - logger.info( - " ACL: %s (%s) → %s", - acl.get("user_or_group", "—"), - acl.get("type", "—"), - acl.get("permission", "—"), - ) - - # Step 8 — print summary - # Log a single success line with share name, volume, SVM, path, and ACL. - logger.info( - "✓ CIFS share '%s' created on volume '%s' (SVM: %s) | Path: /%s | ACL: %s → %s", - share_name, - volume, - svm, - volume, - acl_user, - acl_permission, - ) + logger.info("Volume '%s' job -> %s: %s", volume, state, message) + + svm_uuid = _get_svm_uuid(client, svm) + _ensure_cifs_share(client, svm_uuid, share_name, volume, svm, share_comment) + _set_share_acl(client, svm_uuid, share_name, acl_user, acl_permission) + _verify_and_log_acls(client, svm_uuid, share_name) + + logger.info( + "[OK] CIFS share '%s' on volume '%s' (SVM: %s) | Path: /%s | ACL: %s -> %s", + share_name, + volume, + svm, + volume, + acl_user, + acl_permission, + ) if __name__ == "__main__": diff --git a/python/cluster_setup_basic.py b/python/cluster_setup_basic.py index 5288d29..daa70b7 100644 --- a/python/cluster_setup_basic.py +++ b/python/cluster_setup_basic.py @@ -5,7 +5,6 @@ """Create an ONTAP cluster from two pre-cluster nodes. -Equivalent to: orchestrio run yaml-workflows/workflows/cluster_setup_basic.yaml Steps: 1. discover_nodes — GET /api/cluster/nodes (membership=available, retry 3x/30s) @@ -16,6 +15,7 @@ Usage:: + # env vars directly export ONTAP_HOST=10.x.x.x # pre-cluster node IP export ONTAP_USER=admin # usually admin, empty pass on pre-cluster nodes export ONTAP_PASS= @@ -26,14 +26,17 @@ export CLUSTER_GATEWAY=10.x.x.1 export PARTNER_MGMT_IP=10.x.x.y python cluster_setup_basic.py + """ from __future__ import annotations +import argparse import logging import os import sys import time +from pathlib import Path from ontap_client import OntapClient @@ -263,7 +266,36 @@ def main() -> None: ) +def _load_env_file(path: str) -> None: + """Load KEY=VALUE pairs from a .env file into the INPUTS dict.""" + for line in Path(path).read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, _, value = line.partition("=") + INPUTS[key.strip()] = value.strip().strip('"').strip("'") + + if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Create an ONTAP cluster from two pre-cluster nodes." + ) + parser.add_argument( + "--env-file", + metavar="FILE", + help="Path to a .env file with KEY=VALUE pairs (one per build, like -ir in ha_create.exp).", + ) + args = parser.parse_args() + + if args.env_file: + _load_env_file(args.env_file) + + # env vars always win over INPUTS block defaults + for key in list(INPUTS): + val = os.environ.get(key) + if val: + INPUTS[key] = val + try: main() except KeyboardInterrupt: diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..8ad5959 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,2 @@ +pytest>=8.0 +pytest-mock>=3.12 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..4a7cdf0 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,9 @@ +"""Pytest configuration — make the python/ directory importable without installing.""" + +from __future__ import annotations + +import sys +from pathlib import Path + +# Add python/ to sys.path so test modules can import ontap_client, nfs_provision, etc. +sys.path.insert(0, str(Path(__file__).parent.parent / "python")) diff --git a/tests/test_cifs_provision.py b/tests/test_cifs_provision.py new file mode 100644 index 0000000..9f86b2f --- /dev/null +++ b/tests/test_cifs_provision.py @@ -0,0 +1,198 @@ +"""Unit tests for cifs_provision helper functions.""" + +from __future__ import annotations + +import os +from pathlib import Path +from unittest.mock import MagicMock + +import cifs_provision +import pytest +from ontap_client import OntapClient + +# --------------------------------------------------------------------------- +# _load_env_file (same dotenv-into-os.environ implementation as nfs_provision) +# --------------------------------------------------------------------------- + + +class TestLoadEnvFile: + def test_valid_file_sets_env_vars( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + env_file = tmp_path / "test.env" + env_file.write_text("CIFS_FOO=bar\nCIFS_BAZ=qux\n") + monkeypatch.delenv("CIFS_FOO", raising=False) + monkeypatch.delenv("CIFS_BAZ", raising=False) + cifs_provision._load_env_file(str(env_file)) + assert os.environ["CIFS_FOO"] == "bar" + assert os.environ["CIFS_BAZ"] == "qux" + + def test_missing_file_exits(self, tmp_path: Path) -> None: + with pytest.raises(SystemExit): + cifs_provision._load_env_file(str(tmp_path / "nonexistent.env")) + + def test_malformed_line_exits(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + env_file = tmp_path / "bad.env" + env_file.write_text("CIFS_OK=yes\nNO_EQUALS\n") + monkeypatch.delenv("CIFS_OK", raising=False) + with pytest.raises(SystemExit): + cifs_provision._load_env_file(str(env_file)) + + def test_blank_and_comment_lines_skipped( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + env_file = tmp_path / "mixed.env" + env_file.write_text("# comment\n\nCIFS_KEY=value\n") + monkeypatch.delenv("CIFS_KEY", raising=False) + cifs_provision._load_env_file(str(env_file)) + assert os.environ["CIFS_KEY"] == "value" + + def test_setdefault_does_not_override_existing( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + env_file = tmp_path / "override.env" + env_file.write_text("CIFS_KEY2=from_file\n") + monkeypatch.setenv("CIFS_KEY2", "already_set") + cifs_provision._load_env_file(str(env_file)) + assert os.environ["CIFS_KEY2"] == "already_set" + + +# --------------------------------------------------------------------------- +# _pick +# --------------------------------------------------------------------------- + + +class TestPick: + def test_cli_val_wins(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("SVM_NAME", "from_env") + monkeypatch.setitem(cifs_provision.ENV, "SVM_NAME", "from_env_dict") + assert cifs_provision._pick("from_cli", "SVM_NAME") == "from_cli" + + def test_env_var_second_priority(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("SVM_NAME", "from_env") + monkeypatch.setitem(cifs_provision.ENV, "SVM_NAME", "from_env_dict") + assert cifs_provision._pick(None, "SVM_NAME") == "from_env" + + def test_env_dict_third_priority(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("SVM_NAME", raising=False) + monkeypatch.setitem(cifs_provision.ENV, "SVM_NAME", "from_env_dict") + assert cifs_provision._pick(None, "SVM_NAME") == "from_env_dict" + + def test_falls_back_to_default(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("MISSING_KEY", raising=False) + assert cifs_provision._pick(None, "MISSING_KEY", "fallback") == "fallback" + + def test_empty_string_cli_treated_as_falsy(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("SVM_NAME", "from_env") + assert cifs_provision._pick("", "SVM_NAME") == "from_env" + + +# --------------------------------------------------------------------------- +# _resolve_config +# --------------------------------------------------------------------------- + + +class TestResolveConfig: + def _make_args(self, **overrides): + import argparse + + defaults = { + "env_file": None, + "svm": None, + "volume": None, + "size": None, + "aggregate": "aggr1", + "share_name": None, + "share_comment": None, + "acl_user": None, + "acl_permission": None, + "create_cifs_server": False, + "cifs_server_name": None, + "workgroup": None, + } + defaults.update(overrides) + return argparse.Namespace(**defaults) + + def test_missing_aggregate_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("AGGR_NAME", raising=False) + monkeypatch.setitem(cifs_provision.ENV, "AGGR_NAME", "") + args = self._make_args(aggregate=None) + with pytest.raises(SystemExit): + cifs_provision._resolve_config(args) + + def test_aggregate_from_cli(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("AGGR_NAME", raising=False) + args = self._make_args(aggregate="aggr_from_cli") + config = cifs_provision._resolve_config(args) + assert config["aggregate"] == "aggr_from_cli" + + def test_create_cifs_server_flag_passed_through(self, monkeypatch: pytest.MonkeyPatch) -> None: + args = self._make_args(create_cifs_server=True) + config = cifs_provision._resolve_config(args) + assert config["create_cifs_server"] is True + + def test_returns_all_expected_keys(self, monkeypatch: pytest.MonkeyPatch) -> None: + args = self._make_args() + config = cifs_provision._resolve_config(args) + expected_keys = { + "svm", + "volume", + "size", + "aggregate", + "share_name", + "share_comment", + "acl_user", + "acl_permission", + "create_cifs_server", + "cifs_server_name", + "workgroup", + } + assert expected_keys == set(config.keys()) + + +# --------------------------------------------------------------------------- +# _ensure_cifs_server +# --------------------------------------------------------------------------- + + +class TestEnsureCifsServer: + def _make_client(self) -> MagicMock: + client = MagicMock(spec=OntapClient) + client.__enter__ = MagicMock(return_value=client) + client.__exit__ = MagicMock(return_value=False) + return client + + def test_server_exists_no_create_called(self) -> None: + client = self._make_client() + client.get.return_value = { + "records": [{"svm": {"name": "vs1"}, "enabled": True}], + "num_records": 1, + } + # Should not raise or call post + cifs_provision._ensure_cifs_server(client, "vs1", False, "ONTAP-CIFS", "WORKGROUP") + client.post.assert_not_called() + + def test_no_server_no_flag_exits(self) -> None: + client = self._make_client() + client.get.return_value = {"records": [], "num_records": 0} + with pytest.raises(SystemExit): + cifs_provision._ensure_cifs_server(client, "vs1", False, "ONTAP-CIFS", "WORKGROUP") + + def test_no_server_with_flag_creates_server(self) -> None: + client = self._make_client() + client.get.return_value = {"records": [], "num_records": 0} + client.post.return_value = {} # no async job + cifs_provision._ensure_cifs_server(client, "vs1", True, "MY-CIFS", "MYGROUP") + client.post.assert_called_once() + # post(path, body) — body is the second positional arg + call_args, call_kwargs = client.post.call_args + call_body = call_args[1] if len(call_args) > 1 else call_kwargs.get("body", {}) + assert call_body["name"] == "MY-CIFS" + assert call_body["workgroup"] == "MYGROUP" + + def test_no_server_with_flag_polls_job_when_returned(self) -> None: + client = self._make_client() + client.get.return_value = {"records": [], "num_records": 0} + client.post.return_value = {"job": {"uuid": "job-uuid-1"}} + cifs_provision._ensure_cifs_server(client, "vs1", True, "MY-CIFS", "MYGROUP") + client.poll_job.assert_called_once_with("job-uuid-1") diff --git a/tests/test_cluster_info.py b/tests/test_cluster_info.py new file mode 100644 index 0000000..5d97623 --- /dev/null +++ b/tests/test_cluster_info.py @@ -0,0 +1,64 @@ +"""Unit tests for cluster_info.main().""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import cluster_info +from ontap_client import OntapClient + + +class TestClusterInfoMain: + def _make_client(self) -> MagicMock: + client = MagicMock(spec=OntapClient) + client.__enter__ = MagicMock(return_value=client) + client.__exit__ = MagicMock(return_value=False) + return client + + def test_fetches_cluster_endpoint(self) -> None: + client = self._make_client() + client.get.side_effect = [ + {"name": "cluster1", "version": {"full": "9.14.1"}}, + {"records": [], "num_records": 0}, + ] + with patch.object(OntapClient, "from_env", return_value=client): + cluster_info.main() + first_call = client.get.call_args_list[0] + assert first_call[0][0] == "/cluster" + + def test_fetches_nodes_endpoint(self) -> None: + client = self._make_client() + client.get.side_effect = [ + {"name": "cluster1", "version": {"full": "9.14.1"}}, + {"records": [], "num_records": 0}, + ] + with patch.object(OntapClient, "from_env", return_value=client): + cluster_info.main() + second_call = client.get.call_args_list[1] + assert second_call[0][0] == "/cluster/nodes" + + def test_handles_node_records(self) -> None: + client = self._make_client() + client.get.side_effect = [ + {"name": "cluster1", "version": {"full": "9.14.1"}}, + { + "records": [ + {"name": "node1", "serial_number": "SN-001"}, + {"name": "node2", "serial_number": "SN-002"}, + ], + "num_records": 2, + }, + ] + with patch.object(OntapClient, "from_env", return_value=client): + # Should not raise + cluster_info.main() + + def test_handles_missing_cluster_fields_gracefully(self) -> None: + client = self._make_client() + # Minimal response — missing 'name' and 'version' + client.get.side_effect = [ + {}, + {"records": [], "num_records": 0}, + ] + with patch.object(OntapClient, "from_env", return_value=client): + cluster_info.main() # Should not raise diff --git a/tests/test_cluster_setup_basic.py b/tests/test_cluster_setup_basic.py new file mode 100644 index 0000000..736aee5 --- /dev/null +++ b/tests/test_cluster_setup_basic.py @@ -0,0 +1,209 @@ +"""Unit tests for cluster_setup_basic helper functions.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import cluster_setup_basic as csb +import pytest +from ontap_client import OntapClient + +# --------------------------------------------------------------------------- +# _env +# --------------------------------------------------------------------------- + + +class TestEnv: + def test_reads_from_inputs_dict(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(csb.INPUTS, "CLUSTER_NAME", "mycluster") + assert csb._env("CLUSTER_NAME") == "mycluster" + + def test_falls_back_to_os_environ(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(csb.INPUTS, "CLUSTER_NAME", "") + monkeypatch.setenv("CLUSTER_NAME", "envcluster") + assert csb._env("CLUSTER_NAME") == "envcluster" + + def test_inputs_takes_priority_over_env(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(csb.INPUTS, "CLUSTER_NAME", "from_inputs") + monkeypatch.setenv("CLUSTER_NAME", "from_env") + assert csb._env("CLUSTER_NAME") == "from_inputs" + + def test_missing_required_key_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(csb.INPUTS, "CLUSTER_NAME", "") + monkeypatch.delenv("CLUSTER_NAME", raising=False) + with pytest.raises(SystemExit): + csb._env("CLUSTER_NAME", required=True) + + def test_missing_optional_key_returns_empty(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(csb.INPUTS, "CLUSTER_NAME", "") + monkeypatch.delenv("CLUSTER_NAME", raising=False) + result = csb._env("CLUSTER_NAME", required=False) + assert result == "" + + +# --------------------------------------------------------------------------- +# _load_env_file (cluster_setup_basic version — loads into INPUTS dict) +# --------------------------------------------------------------------------- + + +class TestLoadEnvFileCSB: + def test_valid_file_updates_inputs(self, tmp_path: Path) -> None: + env_file = tmp_path / "build.env" + env_file.write_text("CLUSTER_NAME=testcluster\nCLUSTER_PASS=secret\n") + original = dict(csb.INPUTS) + csb._load_env_file(str(env_file)) + assert csb.INPUTS["CLUSTER_NAME"] == "testcluster" + assert csb.INPUTS["CLUSTER_PASS"] == "secret" + # restore + csb.INPUTS.update(original) + + def test_strips_double_quotes_from_values(self, tmp_path: Path) -> None: + env_file = tmp_path / "quoted.env" + env_file.write_text('CLUSTER_NAME="quoted-name"\n') + csb._load_env_file(str(env_file)) + assert csb.INPUTS["CLUSTER_NAME"] == "quoted-name" + + def test_strips_single_quotes_from_values(self, tmp_path: Path) -> None: + env_file = tmp_path / "single.env" + env_file.write_text("CLUSTER_NAME='single-name'\n") + csb._load_env_file(str(env_file)) + assert csb.INPUTS["CLUSTER_NAME"] == "single-name" + + def test_blank_lines_skipped(self, tmp_path: Path) -> None: + env_file = tmp_path / "blanks.env" + env_file.write_text("\n\nCLUSTER_NAME=ok\n\n") + csb._load_env_file(str(env_file)) + assert csb.INPUTS["CLUSTER_NAME"] == "ok" + + def test_comment_lines_skipped(self, tmp_path: Path) -> None: + env_file = tmp_path / "comments.env" + env_file.write_text("# a comment\nCLUSTER_NAME=fromcomment\n") + csb._load_env_file(str(env_file)) + assert csb.INPUTS["CLUSTER_NAME"] == "fromcomment" + + def test_lines_without_equals_skipped(self, tmp_path: Path) -> None: + """Lines without '=' are silently skipped (no sys.exit in this version).""" + env_file = tmp_path / "noeq.env" + env_file.write_text("CLUSTER_NAME=safe\nNO_EQUALS\n") + csb._load_env_file(str(env_file)) + assert csb.INPUTS["CLUSTER_NAME"] == "safe" + + +# --------------------------------------------------------------------------- +# _get_nodes +# --------------------------------------------------------------------------- + + +class TestGetNodes: + def test_returns_result_on_first_try(self) -> None: + client = MagicMock(spec=OntapClient) + expected = {"records": [{"name": "node1"}], "num_records": 1} + client.get.return_value = expected + result = csb._get_nodes(client, membership="available") + assert result == expected + + def test_falls_back_on_262197_error(self) -> None: + client = MagicMock(spec=OntapClient) + fallback = {"records": [{"name": "node1"}], "num_records": 1} + client.get.side_effect = [ + RuntimeError("error code 262197"), + fallback, + ] + result = csb._get_nodes(client) + assert result == fallback + assert client.get.call_count == 2 + + def test_raises_immediately_on_non_262197_error(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.side_effect = RuntimeError("some other error 999") + with pytest.raises(RuntimeError, match="999"): + csb._get_nodes(client) + assert client.get.call_count == 1 + + def test_raises_last_exc_when_all_field_sets_fail(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.side_effect = RuntimeError("error code 262197") + with pytest.raises(RuntimeError): + csb._get_nodes(client) + + +# --------------------------------------------------------------------------- +# discover_nodes +# --------------------------------------------------------------------------- + + +class TestDiscoverNodes: + def test_returns_node_list_on_success(self) -> None: + client = MagicMock(spec=OntapClient) + expected = {"records": [{"name": "node1"}], "num_records": 1} + with patch.object(csb, "_get_nodes", return_value=expected): + result = csb.discover_nodes(client) + assert result == expected + + def test_retries_then_succeeds(self) -> None: + client = MagicMock(spec=OntapClient) + success = {"records": [{"name": "node1"}], "num_records": 1} + with ( + patch.object(csb, "_get_nodes", side_effect=[RuntimeError("transient"), success]), + patch("cluster_setup_basic.time.sleep"), + ): + result = csb.discover_nodes(client, attempts=3, delay=0) + assert result == success + + def test_raises_runtime_error_after_all_attempts(self) -> None: + client = MagicMock(spec=OntapClient) + with ( + patch.object(csb, "_get_nodes", side_effect=RuntimeError("persistent")), + patch("cluster_setup_basic.time.sleep"), + ): + with pytest.raises(RuntimeError, match="failed after 2 attempts"): + csb.discover_nodes(client, attempts=2, delay=0) + + +# --------------------------------------------------------------------------- +# discover_local +# --------------------------------------------------------------------------- + + +class TestDiscoverLocal: + def test_returns_result_when_records_present(self) -> None: + client = MagicMock(spec=OntapClient) + expected = {"records": [{"name": "node1", "uuid": "uuid-1"}], "num_records": 1} + with patch.object(csb, "_get_nodes", return_value=expected): + result = csb.discover_local(client) + assert result == expected + + def test_raises_when_no_records(self) -> None: + client = MagicMock(spec=OntapClient) + with patch.object(csb, "_get_nodes", return_value={"records": [], "num_records": 0}): + with pytest.raises(RuntimeError, match="no local node"): + csb.discover_local(client) + + +# --------------------------------------------------------------------------- +# discover_partner +# --------------------------------------------------------------------------- + + +class TestDiscoverPartner: + def test_returns_result_when_records_present(self) -> None: + client = MagicMock(spec=OntapClient) + expected = {"records": [{"name": "node2", "uuid": "uuid-2"}], "num_records": 1} + with patch.object(csb, "_get_nodes", return_value=expected): + result = csb.discover_partner(client, local_uuid="uuid-1") + assert result == expected + + def test_passes_exclusion_filter(self) -> None: + client = MagicMock(spec=OntapClient) + expected = {"records": [{"name": "node2"}], "num_records": 1} + with patch.object(csb, "_get_nodes", return_value=expected) as mock_get: + csb.discover_partner(client, local_uuid="uuid-1") + call_kwargs = mock_get.call_args[1] + assert call_kwargs.get("uuid") == "!uuid-1" + + def test_raises_when_no_records(self) -> None: + client = MagicMock(spec=OntapClient) + with patch.object(csb, "_get_nodes", return_value={"records": [], "num_records": 0}): + with pytest.raises(RuntimeError, match="no partner node"): + csb.discover_partner(client, local_uuid="uuid-1") diff --git a/tests/test_nfs_provision.py b/tests/test_nfs_provision.py new file mode 100644 index 0000000..4dc1f78 --- /dev/null +++ b/tests/test_nfs_provision.py @@ -0,0 +1,72 @@ +"""Unit tests for nfs_provision helper functions.""" + +from __future__ import annotations + +import os +from pathlib import Path + +import nfs_provision +import pytest + + +class TestLoadEnvFile: + def test_valid_file_sets_env_vars( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + env_file = tmp_path / "test.env" + env_file.write_text("FOO_NFS=bar\nBAZ_NFS=qux\n") + monkeypatch.delenv("FOO_NFS", raising=False) + monkeypatch.delenv("BAZ_NFS", raising=False) + nfs_provision._load_env_file(str(env_file)) + assert os.environ["FOO_NFS"] == "bar" + assert os.environ["BAZ_NFS"] == "qux" + + def test_missing_file_exits(self, tmp_path: Path) -> None: + with pytest.raises(SystemExit): + nfs_provision._load_env_file(str(tmp_path / "nonexistent.env")) + + def test_malformed_line_no_equals_exits( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + env_file = tmp_path / "bad.env" + env_file.write_text("VALID=ok\nNO_EQUALS_HERE\n") + monkeypatch.delenv("VALID", raising=False) + with pytest.raises(SystemExit): + nfs_provision._load_env_file(str(env_file)) + + def test_blank_lines_are_skipped( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + env_file = tmp_path / "blanks.env" + env_file.write_text("\n\nKEY_NFS2=value\n\n") + monkeypatch.delenv("KEY_NFS2", raising=False) + nfs_provision._load_env_file(str(env_file)) + assert os.environ["KEY_NFS2"] == "value" + + def test_comment_lines_are_skipped( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + env_file = tmp_path / "comments.env" + env_file.write_text("# this is a comment\nKEY_NFS3=value\n") + monkeypatch.delenv("KEY_NFS3", raising=False) + nfs_provision._load_env_file(str(env_file)) + assert os.environ["KEY_NFS3"] == "value" + + def test_setdefault_does_not_override_existing( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + env_file = tmp_path / "override.env" + env_file.write_text("KEY_NFS4=from_file\n") + monkeypatch.setenv("KEY_NFS4", "already_set") + nfs_provision._load_env_file(str(env_file)) + assert os.environ["KEY_NFS4"] == "already_set" + + def test_value_with_equals_sign_handled( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Values containing '=' are preserved (partition only splits on first '=').""" + env_file = tmp_path / "equals.env" + env_file.write_text("KEY_NFS5=a=b=c\n") + monkeypatch.delenv("KEY_NFS5", raising=False) + nfs_provision._load_env_file(str(env_file)) + assert os.environ["KEY_NFS5"] == "a=b=c" diff --git a/tests/test_ontap_client.py b/tests/test_ontap_client.py new file mode 100644 index 0000000..bde2b16 --- /dev/null +++ b/tests/test_ontap_client.py @@ -0,0 +1,403 @@ +"""Unit tests for ontap_client.OntapClient.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +import requests +from ontap_client import OntapApiError, OntapClient + +# --------------------------------------------------------------------------- +# OntapApiError +# --------------------------------------------------------------------------- + + +class TestOntapApiError: + def test_json_detail_stored(self) -> None: + resp = MagicMock(spec=requests.Response) + resp.status_code = 400 + resp.json.return_value = {"message": "Bad request", "code": "123"} + err = OntapApiError(resp) + assert err.status_code == 400 + assert err.detail == {"message": "Bad request", "code": "123"} + assert "400" in str(err) + + def test_text_fallback_when_not_json(self) -> None: + resp = MagicMock(spec=requests.Response) + resp.status_code = 500 + resp.json.side_effect = ValueError("no JSON") + resp.text = "Internal Server Error" + err = OntapApiError(resp) + assert err.detail == "Internal Server Error" + assert "500" in str(err) + + def test_is_exception_subclass(self) -> None: + resp = MagicMock(spec=requests.Response) + resp.status_code = 404 + resp.json.return_value = {} + assert isinstance(OntapApiError(resp), Exception) + + +# --------------------------------------------------------------------------- +# OntapClient.__init__ +# --------------------------------------------------------------------------- + + +class TestOntapClientInit: + def test_base_url_formed_correctly(self) -> None: + client = OntapClient("10.0.0.1", "admin", "pass") + assert client.base_url == "https://10.0.0.1/api" + client.close() + + def test_session_auth_set(self) -> None: + client = OntapClient("10.0.0.1", "admin", "secret") + assert client._session.auth == ("admin", "secret") + client.close() + + def test_verify_ssl_defaults_false(self) -> None: + client = OntapClient("10.0.0.1", "admin", "pass") + assert client._session.verify is False + client.close() + + def test_verify_ssl_can_be_enabled(self) -> None: + client = OntapClient("10.0.0.1", "admin", "pass", verify_ssl=True) + assert client._session.verify is True + client.close() + + def test_default_timeout_stored(self) -> None: + client = OntapClient("10.0.0.1", "admin", "pass") + assert client.timeout == 30 + client.close() + + def test_custom_timeout_stored(self) -> None: + client = OntapClient("10.0.0.1", "admin", "pass", timeout=60) + assert client.timeout == 60 + client.close() + + def test_default_headers_include_accept(self) -> None: + client = OntapClient("10.0.0.1", "admin", "pass") + assert "application/hal+json" in client._session.headers.get("Accept", "") + client.close() + + def test_default_headers_include_content_type(self) -> None: + client = OntapClient("10.0.0.1", "admin", "pass") + assert client._session.headers.get("Content-Type") == "application/json" + client.close() + + +# --------------------------------------------------------------------------- +# OntapClient.from_env +# --------------------------------------------------------------------------- + + +class TestFromEnv: + def test_missing_ontap_host_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("ONTAP_HOST", raising=False) + monkeypatch.delenv("ONTAP_PASS", raising=False) + with pytest.raises(SystemExit): + OntapClient.from_env() + + def test_missing_ontap_pass_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("ONTAP_HOST", "10.0.0.1") + monkeypatch.delenv("ONTAP_PASS", raising=False) + with pytest.raises(SystemExit): + OntapClient.from_env() + + def test_builds_client_from_env(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("ONTAP_HOST", "10.0.0.1") + monkeypatch.setenv("ONTAP_PASS", "secret") + monkeypatch.setenv("ONTAP_USER", "testuser") + client = OntapClient.from_env() + assert client.base_url == "https://10.0.0.1/api" + assert client._session.auth == ("testuser", "secret") + client.close() + + def test_default_user_is_admin(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("ONTAP_HOST", "10.0.0.1") + monkeypatch.setenv("ONTAP_PASS", "secret") + monkeypatch.delenv("ONTAP_USER", raising=False) + client = OntapClient.from_env() + assert client._session.auth == ("admin", "secret") + client.close() + + def test_verify_ssl_true_when_set(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("ONTAP_HOST", "10.0.0.1") + monkeypatch.setenv("ONTAP_PASS", "secret") + monkeypatch.setenv("ONTAP_VERIFY_SSL", "true") + client = OntapClient.from_env() + assert client._session.verify is True + client.close() + + def test_verify_ssl_false_when_not_set(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("ONTAP_HOST", "10.0.0.1") + monkeypatch.setenv("ONTAP_PASS", "secret") + monkeypatch.delenv("ONTAP_VERIFY_SSL", raising=False) + client = OntapClient.from_env() + assert client._session.verify is False + client.close() + + +# --------------------------------------------------------------------------- +# OntapClient._url +# --------------------------------------------------------------------------- + + +class TestUrl: + def setup_method(self) -> None: + self.client = OntapClient("10.0.0.1", "admin", "pass") + + def teardown_method(self) -> None: + self.client.close() + + def test_absolute_path_prefixed_with_base(self) -> None: + assert self.client._url("/cluster") == "https://10.0.0.1/api/cluster" + + def test_relative_path_prefixed_with_base(self) -> None: + assert self.client._url("cluster/nodes") == "https://10.0.0.1/api/cluster/nodes" + + def test_absolute_https_url_returned_unchanged(self) -> None: + url = "https://other.host/api/cluster" + assert self.client._url(url) == url + + +# --------------------------------------------------------------------------- +# OntapClient._request +# --------------------------------------------------------------------------- + + +class TestRequest: + def setup_method(self) -> None: + self.client = OntapClient("10.0.0.1", "admin", "pass") + self.mock_resp = MagicMock(spec=requests.Response) + self.client._session.request = MagicMock(return_value=self.mock_resp) + + def teardown_method(self) -> None: + self.client.close() + + def test_success_returns_json(self) -> None: + self.mock_resp.ok = True + self.mock_resp.status_code = 200 + self.mock_resp.content = b'{"name": "cluster1"}' + self.mock_resp.json.return_value = {"name": "cluster1"} + result = self.client._request("GET", "/cluster") + assert result == {"name": "cluster1"} + + def test_204_returns_empty_dict(self) -> None: + self.mock_resp.ok = True + self.mock_resp.status_code = 204 + self.mock_resp.content = b"" + result = self.client._request("DELETE", "/some/resource") + assert result == {} + + def test_empty_content_returns_empty_dict(self) -> None: + self.mock_resp.ok = True + self.mock_resp.status_code = 200 + self.mock_resp.content = b"" + result = self.client._request("GET", "/cluster") + assert result == {} + + def test_non_ok_raises_ontap_api_error(self) -> None: + self.mock_resp.ok = False + self.mock_resp.status_code = 404 + self.mock_resp.json.return_value = {"message": "Not found"} + with pytest.raises(OntapApiError) as exc_info: + self.client._request("GET", "/missing") + assert exc_info.value.status_code == 404 + + def test_uses_default_timeout(self) -> None: + self.mock_resp.ok = True + self.mock_resp.status_code = 200 + self.mock_resp.content = b"{}" + self.mock_resp.json.return_value = {} + self.client._request("GET", "/cluster") + call_kwargs = self.client._session.request.call_args[1] + assert call_kwargs["timeout"] == self.client.timeout + + def test_url_built_from_path(self) -> None: + self.mock_resp.ok = True + self.mock_resp.status_code = 200 + self.mock_resp.content = b"{}" + self.mock_resp.json.return_value = {} + self.client._request("GET", "/cluster") + call_args = self.client._session.request.call_args[0] + assert call_args[1] == "https://10.0.0.1/api/cluster" + + +# --------------------------------------------------------------------------- +# OntapClient HTTP convenience methods +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def client_with_mock_session() -> OntapClient: + client = OntapClient("10.0.0.1", "admin", "pass") + mock_resp = MagicMock(spec=requests.Response) + mock_resp.ok = True + mock_resp.status_code = 200 + mock_resp.content = b'{"records": []}' + mock_resp.json.return_value = {"records": []} + client._session.request = MagicMock(return_value=mock_resp) + yield client + client.close() + + +class TestHttpMethods: + def test_get_adds_fields_param(self, client_with_mock_session: OntapClient) -> None: + client_with_mock_session.get("/cluster/nodes", fields="name,serial_number") + params = client_with_mock_session._session.request.call_args[1]["params"] + assert params["fields"] == "name,serial_number" + + def test_get_adds_default_return_timeout(self, client_with_mock_session: OntapClient) -> None: + client_with_mock_session.get("/cluster/nodes") + params = client_with_mock_session._session.request.call_args[1]["params"] + assert params["return_timeout"] == "120" + + def test_get_no_fields_key_when_omitted(self, client_with_mock_session: OntapClient) -> None: + client_with_mock_session.get("/cluster/nodes") + params = client_with_mock_session._session.request.call_args[1]["params"] + assert "fields" not in params + + def test_get_passes_extra_params(self, client_with_mock_session: OntapClient) -> None: + client_with_mock_session.get("/cluster/nodes", membership="available") + params = client_with_mock_session._session.request.call_args[1]["params"] + assert params["membership"] == "available" + + def test_post_uses_post_method(self, client_with_mock_session: OntapClient) -> None: + client_with_mock_session.post("/cluster", {"name": "c1"}) + method = client_with_mock_session._session.request.call_args[0][0] + assert method == "POST" + + def test_post_sends_json_body(self, client_with_mock_session: OntapClient) -> None: + body = {"name": "c1", "password": "secret"} + client_with_mock_session.post("/cluster", body) + json_arg = client_with_mock_session._session.request.call_args[1]["json"] + assert json_arg == body + + def test_patch_uses_patch_method(self, client_with_mock_session: OntapClient) -> None: + client_with_mock_session.patch("/cluster/nodes/uuid1", {"state": "up"}) + method = client_with_mock_session._session.request.call_args[0][0] + assert method == "PATCH" + + def test_patch_sends_json_body(self, client_with_mock_session: OntapClient) -> None: + body = {"state": "up"} + client_with_mock_session.patch("/cluster/nodes/uuid1", body) + json_arg = client_with_mock_session._session.request.call_args[1]["json"] + assert json_arg == body + + def test_delete_uses_delete_method(self, client_with_mock_session: OntapClient) -> None: + client_with_mock_session.delete("/volumes/uuid1") + method = client_with_mock_session._session.request.call_args[0][0] + assert method == "DELETE" + + +# --------------------------------------------------------------------------- +# OntapClient.poll_job +# --------------------------------------------------------------------------- + +_JOB_UUID = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + +class TestPollJob: + def setup_method(self) -> None: + self.client = OntapClient("10.0.0.1", "admin", "pass") + + def teardown_method(self) -> None: + self.client.close() + + def test_success_state_returns_job(self) -> None: + self.client.get = MagicMock(return_value={"state": "success", "message": "done"}) + result = self.client.poll_job(_JOB_UUID) + assert result["state"] == "success" + + def test_failure_state_raises_runtime_error(self) -> None: + self.client.get = MagicMock(return_value={"state": "failure", "message": "boom"}) + with pytest.raises(RuntimeError, match="failed"): + self.client.poll_job(_JOB_UUID) + + def test_failure_includes_job_message(self) -> None: + self.client.get = MagicMock(return_value={"state": "failure", "message": "disk error"}) + with pytest.raises(RuntimeError, match="disk error"): + self.client.poll_job(_JOB_UUID) + + def test_polls_until_success(self) -> None: + responses = [ + {"state": "running"}, + {"state": "running"}, + {"state": "success"}, + ] + self.client.get = MagicMock(side_effect=responses) + with ( + patch("ontap_client.time.sleep"), + patch("ontap_client.time.monotonic", return_value=0), + ): + result = self.client.poll_job(_JOB_UUID, interval=1, timeout=300) + assert result["state"] == "success" + assert self.client.get.call_count == 3 + + def test_timeout_raises_timeout_error(self) -> None: + self.client.get = MagicMock(return_value={"state": "running"}) + # First monotonic() → start time (0), second → past deadline (400) + with ( + patch("ontap_client.time.sleep"), + patch("ontap_client.time.monotonic", side_effect=[0, 400]), + ): + with pytest.raises(TimeoutError): + self.client.poll_job(_JOB_UUID, interval=1, timeout=300) + + def test_connection_error_retries_then_succeeds(self) -> None: + self.client.get = MagicMock( + side_effect=[ + requests.exceptions.ConnectionError("disconnected"), + {"state": "success"}, + ] + ) + with ( + patch("ontap_client.time.sleep"), + patch("ontap_client.time.monotonic", return_value=0), + ): + result = self.client.poll_job(_JOB_UUID, interval=1, timeout=300) + assert result["state"] == "success" + + def test_connection_error_past_deadline_raises_timeout(self) -> None: + self.client.get = MagicMock( + side_effect=requests.exceptions.ConnectionError("disconnected") + ) + # 1st monotonic → deadline start (0), 2nd → past deadline (400) + with ( + patch("ontap_client.time.sleep"), + patch("ontap_client.time.monotonic", side_effect=[0, 400]), + ): + with pytest.raises(TimeoutError): + self.client.poll_job(_JOB_UUID, interval=1, timeout=300) + + def test_polls_correct_job_url(self) -> None: + self.client.get = MagicMock(return_value={"state": "success"}) + self.client.poll_job(_JOB_UUID) + call_args = self.client.get.call_args[0][0] + assert _JOB_UUID in call_args + + +# --------------------------------------------------------------------------- +# Context manager +# --------------------------------------------------------------------------- + + +class TestContextManager: + def test_enter_returns_self(self) -> None: + client = OntapClient("10.0.0.1", "admin", "pass") + assert client.__enter__() is client + client.close() + + def test_exit_closes_session(self) -> None: + client = OntapClient("10.0.0.1", "admin", "pass") + close_mock = MagicMock() + client._session.close = close_mock + with client: + pass + close_mock.assert_called_once() + + def test_with_statement_usage(self) -> None: + """OntapClient can be used as a context manager without errors.""" + with OntapClient("10.0.0.1", "admin", "pass") as client: + assert isinstance(client, OntapClient) diff --git a/tests/test_snapmirror_cleanup_test_failover.py b/tests/test_snapmirror_cleanup_test_failover.py new file mode 100644 index 0000000..3ed448c --- /dev/null +++ b/tests/test_snapmirror_cleanup_test_failover.py @@ -0,0 +1,149 @@ +"""Unit tests for shared helpers in snapmirror_cleanup_test_failover.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +import snapmirror_cleanup_test_failover as sm_clean +from ontap_client import OntapClient + +# --------------------------------------------------------------------------- +# _env +# --------------------------------------------------------------------------- + + +class TestEnv: + def test_reads_from_inputs_dict(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_clean.INPUTS, "CLUSTER_A", "10.5.0.1") + assert sm_clean._env("CLUSTER_A") == "10.5.0.1" + + def test_falls_back_to_os_environ(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_clean.INPUTS, "CLUSTER_A", "") + monkeypatch.setenv("CLUSTER_A", "10.5.0.2") + assert sm_clean._env("CLUSTER_A") == "10.5.0.2" + + def test_missing_required_key_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_clean.INPUTS, "CLUSTER_A", "") + monkeypatch.delenv("CLUSTER_A", raising=False) + with pytest.raises(SystemExit): + sm_clean._env("CLUSTER_A") + + def test_returns_default_when_missing(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_clean.INPUTS, "CLUSTER_A", "") + monkeypatch.delenv("CLUSTER_A", raising=False) + assert sm_clean._env("CLUSTER_A", default="fallback") == "fallback" + + +# --------------------------------------------------------------------------- +# _poll_job +# --------------------------------------------------------------------------- + + +class TestPollJob: + def test_returns_immediately_on_non_running(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"state": "success"} + result = sm_clean._poll_job(client, "job-abc") + assert result["state"] == "success" + + def test_polls_until_done(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.side_effect = [ + {"state": "running"}, + {"state": "success"}, + ] + with patch("snapmirror_cleanup_test_failover.time.sleep"): + result = sm_clean._poll_job(client, "job-abc", interval=1) + assert client.get.call_count == 2 + assert result["state"] == "success" + + def test_uses_correct_job_url(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"state": "success"} + sm_clean._poll_job(client, "my-job-id") + assert "my-job-id" in client.get.call_args[0][0] + + +# --------------------------------------------------------------------------- +# _pick_cluster_by_relationship +# --------------------------------------------------------------------------- + + +class TestPickClusterByRelationship: + def _vol_client(self, records: list[dict]) -> MagicMock: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"records": records, "num_records": len(records)} + return client + + def test_picks_cluster_a_when_it_has_relationship(self) -> None: + rel = { + "uuid": "rel-uuid-1", + "source": {"path": "vs0:vol_rw_01"}, + "destination": {"path": "vs1:vol_rw_01_dest"}, + "state": "snapmirrored", + "healthy": True, + } + a_client = self._vol_client([rel]) + with patch("snapmirror_cleanup_test_failover.OntapClient", return_value=a_client): + cluster, found_rel = sm_clean._pick_cluster_by_relationship( + "10.0.0.1", "10.0.0.2", "admin", "pass", "vs0", "vol_rw_01" + ) + assert cluster == "10.0.0.1" + assert found_rel["uuid"] == "rel-uuid-1" + + def test_falls_through_to_cluster_b(self) -> None: + rel = { + "uuid": "rel-uuid-2", + "source": {"path": "vs0:vol_rw_01"}, + "destination": {"path": "vs1:vol_rw_01_dest"}, + "state": "snapmirrored", + "healthy": True, + } + a_client = self._vol_client([]) + b_client = self._vol_client([rel]) + + with patch( + "snapmirror_cleanup_test_failover.OntapClient", side_effect=[a_client, b_client] + ): + cluster, found_rel = sm_clean._pick_cluster_by_relationship( + "10.0.0.1", "10.0.0.2", "admin", "pass", "vs0", "vol_rw_01" + ) + assert cluster == "10.0.0.2" + assert found_rel["uuid"] == "rel-uuid-2" + + def test_exits_when_neither_cluster_has_relationship(self) -> None: + no_rel_client = self._vol_client([]) + + with patch("snapmirror_cleanup_test_failover.OntapClient", return_value=no_rel_client): + with pytest.raises(SystemExit): + sm_clean._pick_cluster_by_relationship( + "10.0.0.1", "10.0.0.2", "admin", "pass", "vs0", "vol_rw_01" + ) + + def test_skips_unreachable_cluster_and_continues(self) -> None: + rel = {"uuid": "rel-uuid-b", "source": {"path": "vs0:vol"}, "state": "snapmirrored"} + a_client = MagicMock(spec=OntapClient) + a_client.get.side_effect = ConnectionError("unreachable") + b_client = self._vol_client([rel]) + + with patch( + "snapmirror_cleanup_test_failover.OntapClient", side_effect=[a_client, b_client] + ): + cluster, found_rel = sm_clean._pick_cluster_by_relationship( + "10.0.0.1", "10.0.0.2", "admin", "pass", "vs0", "vol" + ) + assert cluster == "10.0.0.2" + + def test_passes_source_path_filter(self) -> None: + """Verify the API is called with source.path=:.""" + rel = {"uuid": "rel-uuid-x", "source": {"path": "vs0:myvol"}, "state": "snapmirrored"} + a_client = self._vol_client([rel]) + + with patch("snapmirror_cleanup_test_failover.OntapClient", return_value=a_client): + sm_clean._pick_cluster_by_relationship( + "10.0.0.1", "10.0.0.2", "admin", "pass", "vs0", "myvol" + ) + + call_kwargs = a_client.get.call_args[1] + assert call_kwargs.get("source.path") == "vs0:myvol" diff --git a/tests/test_snapmirror_provision_dest_managed.py b/tests/test_snapmirror_provision_dest_managed.py new file mode 100644 index 0000000..2ba82d5 --- /dev/null +++ b/tests/test_snapmirror_provision_dest_managed.py @@ -0,0 +1,173 @@ +"""Unit tests for shared helpers in snapmirror_provision_dest_managed.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +import snapmirror_provision_dest_managed as sm_dst +from ontap_client import OntapClient + +# --------------------------------------------------------------------------- +# _env +# --------------------------------------------------------------------------- + + +class TestEnv: + def test_reads_from_inputs_dict(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_dst.INPUTS, "SOURCE_HOST", "10.1.0.1") + assert sm_dst._env("SOURCE_HOST") == "10.1.0.1" + + def test_falls_back_to_os_environ(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_dst.INPUTS, "SOURCE_HOST", "") + monkeypatch.setenv("SOURCE_HOST", "10.1.0.2") + assert sm_dst._env("SOURCE_HOST") == "10.1.0.2" + + def test_missing_required_key_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_dst.INPUTS, "SOURCE_HOST", "") + monkeypatch.delenv("SOURCE_HOST", raising=False) + with pytest.raises(SystemExit): + sm_dst._env("SOURCE_HOST") + + def test_returns_default_when_missing(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_dst.INPUTS, "SOURCE_HOST", "") + monkeypatch.delenv("SOURCE_HOST", raising=False) + assert sm_dst._env("SOURCE_HOST", default="default_val") == "default_val" + + +# --------------------------------------------------------------------------- +# _poll_job +# --------------------------------------------------------------------------- + + +class TestPollJob: + def test_returns_immediately_on_non_running_state(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"state": "success"} + result = sm_dst._poll_job(client, "job-uuid-1") + assert result["state"] == "success" + + def test_polls_multiple_times_until_done(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.side_effect = [ + {"state": "running"}, + {"state": "success"}, + ] + with patch("snapmirror_provision_dest_managed.time.sleep"): + result = sm_dst._poll_job(client, "job-uuid-1", interval=1) + assert result["state"] == "success" + assert client.get.call_count == 2 + + def test_passes_correct_url(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"state": "success"} + sm_dst._poll_job(client, "my-job-abc") + assert "my-job-abc" in client.get.call_args[0][0] + + +# --------------------------------------------------------------------------- +# _wait_snapmirrored +# --------------------------------------------------------------------------- + + +class TestWaitSnapmirrored: + def test_returns_immediately_when_snapmirrored(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"state": "snapmirrored", "healthy": True} + result = sm_dst._wait_snapmirrored(client, "rel-uuid", interval=1, max_wait=60) + assert result["state"] == "snapmirrored" + + def test_polls_until_snapmirrored(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.side_effect = [ + {"state": "transferring"}, + {"state": "snapmirrored"}, + ] + with patch("snapmirror_provision_dest_managed.time.sleep"): + result = sm_dst._wait_snapmirrored(client, "rel-uuid", interval=1, max_wait=300) + assert result["state"] == "snapmirrored" + + def test_raises_on_timeout(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"state": "transferring"} + with patch("snapmirror_provision_dest_managed.time.sleep"): + with pytest.raises(RuntimeError, match="Timed out"): + sm_dst._wait_snapmirrored(client, "rel-uuid", interval=2, max_wait=1) + + +# --------------------------------------------------------------------------- +# _get_ic_lif_ips +# --------------------------------------------------------------------------- + + +class TestGetIcLifIps: + def test_returns_intercluster_ips(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = { + "records": [ + {"ip": {"address": "10.0.0.10"}, "services": ["intercluster-core"]}, + {"ip": {"address": "10.0.0.11"}, "services": ["data-nfs"]}, + ] + } + ips = sm_dst._get_ic_lif_ips(client) + assert ips == ["10.0.0.10"] + + def test_returns_empty_when_no_ic_lifs(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = { + "records": [ + {"ip": {"address": "10.0.0.11"}, "services": ["data-nfs"]}, + ] + } + ips = sm_dst._get_ic_lif_ips(client) + assert ips == [] + + def test_returns_empty_on_empty_records(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"records": []} + ips = sm_dst._get_ic_lif_ips(client) + assert ips == [] + + def test_skips_records_with_no_ip_address(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = { + "records": [ + {"ip": {}, "services": ["intercluster-core"]}, + ] + } + ips = sm_dst._get_ic_lif_ips(client) + assert ips == [] + + +# --------------------------------------------------------------------------- +# _check_ic_lif_preconditions +# --------------------------------------------------------------------------- + + +class TestCheckIcLifPreconditions: + def test_exits_when_no_src_ips(self) -> None: + src = MagicMock(spec=OntapClient) + dst = MagicMock(spec=OntapClient) + with pytest.raises(SystemExit): + sm_dst._check_ic_lif_preconditions(src, dst, [], ["10.0.0.1"]) + + def test_exits_when_no_dst_ips(self) -> None: + src = MagicMock(spec=OntapClient) + dst = MagicMock(spec=OntapClient) + with pytest.raises(SystemExit): + sm_dst._check_ic_lif_preconditions(src, dst, ["10.0.0.1"], []) + + def test_no_error_when_same_subnet(self) -> None: + src = MagicMock(spec=OntapClient) + dst = MagicMock(spec=OntapClient) + # Should not raise — same /24 subnet + sm_dst._check_ic_lif_preconditions(src, dst, ["10.0.0.1"], ["10.0.0.2"]) + + def test_warns_when_different_subnets(self, caplog: pytest.LogCaptureFixture) -> None: + import logging + + src = MagicMock(spec=OntapClient) + dst = MagicMock(spec=OntapClient) + with caplog.at_level(logging.WARNING, logger="snapmirror_provision_dest_managed"): + sm_dst._check_ic_lif_preconditions(src, dst, ["10.0.0.1"], ["192.168.1.1"]) + assert any("subnet" in msg.lower() for msg in caplog.messages) diff --git a/tests/test_snapmirror_provision_src_managed.py b/tests/test_snapmirror_provision_src_managed.py new file mode 100644 index 0000000..53e11ef --- /dev/null +++ b/tests/test_snapmirror_provision_src_managed.py @@ -0,0 +1,125 @@ +"""Unit tests for shared helpers in snapmirror_provision_src_managed.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +import snapmirror_provision_src_managed as sm_src +from ontap_client import OntapClient + +# --------------------------------------------------------------------------- +# _env +# --------------------------------------------------------------------------- + + +class TestEnv: + def test_reads_from_inputs_dict(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_src.INPUTS, "SOURCE_HOST", "10.0.0.1") + assert sm_src._env("SOURCE_HOST") == "10.0.0.1" + + def test_falls_back_to_os_environ(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_src.INPUTS, "SOURCE_HOST", "") + monkeypatch.setenv("SOURCE_HOST", "10.0.0.2") + assert sm_src._env("SOURCE_HOST") == "10.0.0.2" + + def test_inputs_takes_priority_over_environ(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_src.INPUTS, "SOURCE_HOST", "from_inputs") + monkeypatch.setenv("SOURCE_HOST", "from_env") + assert sm_src._env("SOURCE_HOST") == "from_inputs" + + def test_missing_required_key_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_src.INPUTS, "SOURCE_HOST", "") + monkeypatch.delenv("SOURCE_HOST", raising=False) + with pytest.raises(SystemExit): + sm_src._env("SOURCE_HOST") + + def test_returns_default_when_not_required(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_src.INPUTS, "SOURCE_HOST", "") + monkeypatch.delenv("SOURCE_HOST", raising=False) + assert sm_src._env("SOURCE_HOST", default="fallback") == "fallback" + + +# --------------------------------------------------------------------------- +# _poll_job +# --------------------------------------------------------------------------- + + +class TestPollJob: + def _make_client(self) -> MagicMock: + return MagicMock(spec=OntapClient) + + def test_returns_when_state_not_running(self) -> None: + client = self._make_client() + client.get.return_value = {"state": "success", "message": "done"} + result = sm_src._poll_job(client, "job-uuid-1") + assert result["state"] == "success" + + def test_polls_until_non_running_state(self) -> None: + client = self._make_client() + client.get.side_effect = [ + {"state": "running"}, + {"state": "running"}, + {"state": "success"}, + ] + with patch("snapmirror_provision_src_managed.time.sleep"): + result = sm_src._poll_job(client, "job-uuid-1", interval=1) + assert result["state"] == "success" + assert client.get.call_count == 3 + + def test_passes_correct_job_url(self) -> None: + client = self._make_client() + client.get.return_value = {"state": "success"} + sm_src._poll_job(client, "abc-123") + call_path = client.get.call_args[0][0] + assert "abc-123" in call_path + + def test_returns_failure_state_without_raising(self) -> None: + """_poll_job returns the failure record — callers decide what to do.""" + client = self._make_client() + client.get.return_value = {"state": "failure", "error": {"message": "boom"}} + result = sm_src._poll_job(client, "job-uuid-fail") + assert result["state"] == "failure" + + +# --------------------------------------------------------------------------- +# _wait_snapmirrored +# --------------------------------------------------------------------------- + + +class TestWaitSnapmirrored: + def _make_client(self) -> MagicMock: + return MagicMock(spec=OntapClient) + + def test_returns_immediately_when_already_snapmirrored(self) -> None: + client = self._make_client() + client.get.return_value = {"state": "snapmirrored", "lag_time": "PT5M", "healthy": True} + result = sm_src._wait_snapmirrored(client, "rel-uuid-1", interval=1, max_wait=60) + assert result["state"] == "snapmirrored" + assert client.get.call_count == 1 + + def test_polls_until_snapmirrored(self) -> None: + client = self._make_client() + client.get.side_effect = [ + {"state": "transferring"}, + {"state": "transferring"}, + {"state": "snapmirrored"}, + ] + with patch("snapmirror_provision_src_managed.time.sleep"): + result = sm_src._wait_snapmirrored(client, "rel-uuid-1", interval=1, max_wait=600) + assert result["state"] == "snapmirrored" + + def test_raises_timeout_if_never_snapmirrored(self) -> None: + client = self._make_client() + client.get.return_value = {"state": "transferring"} + with patch("snapmirror_provision_src_managed.time.sleep"): + with pytest.raises(RuntimeError, match="Timed out"): + # max_wait=1, interval=2 → loop exits after first iteration + sm_src._wait_snapmirrored(client, "rel-uuid-1", interval=2, max_wait=1) + + def test_queries_correct_relationship_url(self) -> None: + client = self._make_client() + client.get.return_value = {"state": "snapmirrored"} + sm_src._wait_snapmirrored(client, "my-rel-uuid", interval=1, max_wait=60) + call_path = client.get.call_args[0][0] + assert "my-rel-uuid" in call_path diff --git a/tests/test_snapmirror_test_failover.py b/tests/test_snapmirror_test_failover.py new file mode 100644 index 0000000..ff4c3da --- /dev/null +++ b/tests/test_snapmirror_test_failover.py @@ -0,0 +1,149 @@ +"""Unit tests for shared helpers in snapmirror_test_failover.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +import snapmirror_test_failover as sm_tf +from ontap_client import OntapClient + +# --------------------------------------------------------------------------- +# _env +# --------------------------------------------------------------------------- + + +class TestEnv: + def test_reads_from_inputs_dict(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_tf.INPUTS, "CLUSTER_A", "10.0.1.1") + assert sm_tf._env("CLUSTER_A") == "10.0.1.1" + + def test_falls_back_to_os_environ(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_tf.INPUTS, "CLUSTER_A", "") + monkeypatch.setenv("CLUSTER_A", "10.0.1.2") + assert sm_tf._env("CLUSTER_A") == "10.0.1.2" + + def test_missing_required_key_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_tf.INPUTS, "CLUSTER_A", "") + monkeypatch.delenv("CLUSTER_A", raising=False) + with pytest.raises(SystemExit): + sm_tf._env("CLUSTER_A") + + def test_returns_default_when_not_required(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_tf.INPUTS, "CLUSTER_A", "") + monkeypatch.delenv("CLUSTER_A", raising=False) + assert sm_tf._env("CLUSTER_A", default="x") == "x" + + +# --------------------------------------------------------------------------- +# _poll_job +# --------------------------------------------------------------------------- + + +class TestPollJob: + def test_returns_on_first_non_running_state(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"state": "success"} + result = sm_tf._poll_job(client, "job-1") + assert result["state"] == "success" + + def test_polls_until_done(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.side_effect = [{"state": "running"}, {"state": "success"}] + with patch("snapmirror_test_failover.time.sleep"): + result = sm_tf._poll_job(client, "job-1", interval=1) + assert client.get.call_count == 2 + assert result["state"] == "success" + + +# --------------------------------------------------------------------------- +# _wait_snapmirrored +# --------------------------------------------------------------------------- + + +class TestWaitSnapmirrored: + def test_returns_immediately_when_snapmirrored(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"state": "snapmirrored"} + result = sm_tf._wait_snapmirrored(client, "rel-uuid", interval=1, max_wait=60) + assert result["state"] == "snapmirrored" + + def test_raises_timeout_when_never_converges(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"state": "transferring"} + with patch("snapmirror_test_failover.time.sleep"): + with pytest.raises(RuntimeError, match="Timed out"): + sm_tf._wait_snapmirrored(client, "rel-uuid", interval=2, max_wait=1) + + +# --------------------------------------------------------------------------- +# _pick_cluster +# --------------------------------------------------------------------------- + + +class TestPickCluster: + def _make_client(self, records: list[dict]) -> MagicMock: + client = MagicMock(spec=OntapClient) + client.get.return_value = { + "records": records, + "num_records": len(records), + } + return client + + def test_picks_cluster_a_when_it_has_dp_volume(self) -> None: + vol = {"name": "vol_rw_01_dest", "uuid": "v-uuid", "svm": {"name": "vs1"}} + with patch("snapmirror_test_failover.OntapClient") as MockClient: + instance = self._make_client([vol]) + MockClient.return_value = instance + cluster, found_vol = sm_tf._pick_cluster( + "10.0.0.1", "10.0.0.2", "admin", "pass", "vol_rw_01" + ) + assert cluster == "10.0.0.1" + assert found_vol["name"] == "vol_rw_01_dest" + + def test_falls_through_to_cluster_b(self) -> None: + vol = {"name": "vol_rw_01_dest", "uuid": "v-uuid", "svm": {"name": "vs1"}} + a_client = MagicMock(spec=OntapClient) + a_client.get.return_value = {"records": [], "num_records": 0} + b_client = MagicMock(spec=OntapClient) + b_client.get.return_value = {"records": [vol], "num_records": 1} + + with patch("snapmirror_test_failover.OntapClient", side_effect=[a_client, b_client]): + cluster, found_vol = sm_tf._pick_cluster( + "10.0.0.1", "10.0.0.2", "admin", "pass", "vol_rw_01" + ) + assert cluster == "10.0.0.2" + + def test_exits_when_no_cluster_has_dp_volume(self) -> None: + no_vol_client = MagicMock(spec=OntapClient) + no_vol_client.get.return_value = {"records": [], "num_records": 0} + + with patch("snapmirror_test_failover.OntapClient", return_value=no_vol_client): + with pytest.raises(SystemExit): + sm_tf._pick_cluster("10.0.0.1", "10.0.0.2", "admin", "pass", "vol_rw_01") + + def test_uses_wildcard_filter_in_auto_mode(self) -> None: + """When vol_name_filter='*', the DP name filter should be '*_dest'.""" + no_vol_client = MagicMock(spec=OntapClient) + no_vol_client.get.return_value = {"records": [], "num_records": 0} + + with patch("snapmirror_test_failover.OntapClient", return_value=no_vol_client): + with pytest.raises(SystemExit): + sm_tf._pick_cluster("10.0.0.1", "10.0.0.2", "admin", "pass", "*") + + # Verify the filter sent was '*_dest' + call_kwargs = no_vol_client.get.call_args[1] + assert call_kwargs.get("name") == "*_dest" + + def test_skips_unreachable_cluster_and_continues(self) -> None: + vol = {"name": "vol_01_dest", "uuid": "v-uuid", "svm": {"name": "vs1"}} + a_client = MagicMock(spec=OntapClient) + a_client.get.side_effect = ConnectionError("unreachable") + b_client = MagicMock(spec=OntapClient) + b_client.get.return_value = {"records": [vol], "num_records": 1} + + with patch("snapmirror_test_failover.OntapClient", side_effect=[a_client, b_client]): + cluster, found_vol = sm_tf._pick_cluster( + "10.0.0.1", "10.0.0.2", "admin", "pass", "vol_01" + ) + assert cluster == "10.0.0.2"