From f3090307e1c506e2f025fbd128efe4c4277a15a1 Mon Sep 17 00:00:00 2001 From: Somanath Date: Thu, 7 May 2026 19:43:45 +0530 Subject: [PATCH 1/4] fix: replace garbled mojibake characters with clean ASCII in comments --- python/cluster_info.py | 14 +++++----- python/cluster_setup_basic.py | 48 +++++++++++++++++++++++++++++----- python/nfs_provision.py | 49 +++++++++++++++++------------------ 3 files changed, 73 insertions(+), 38 deletions(-) diff --git a/python/cluster_info.py b/python/cluster_info.py index ce2a48f..703517d 100644 --- a/python/cluster_info.py +++ b/python/cluster_info.py @@ -2,8 +2,8 @@ """Retrieve ONTAP cluster version and list all nodes with serial numbers. Steps: - 1. GET /cluster ΓÇö retrieve cluster name and ONTAP version - 2. GET /cluster/nodes ΓÇö list all nodes with serial numbers + 1. GET /cluster - retrieve cluster name and ONTAP version + 2. GET /cluster/nodes - list all nodes with serial numbers Prerequisites:: @@ -31,15 +31,15 @@ def main() -> None: with OntapClient.from_env() as client: - # Step 1 ΓÇö cluster version + # Step 1: cluster version cluster = client.get("/cluster", fields="version") logger.info( - "Cluster: %s ΓÇö ONTAP %s", + "Cluster: %s - ONTAP %s", cluster.get("name", "unknown"), cluster.get("version", {}).get("full", "unknown"), ) - # Step 2 ΓÇö node list with serial numbers + # Step 2: node list with serial numbers nodes_resp = client.get("/cluster/nodes", fields="name,serial_number") records = nodes_resp.get("records", []) logger.info("Nodes in cluster: %d", nodes_resp.get("num_records", len(records))) @@ -47,8 +47,8 @@ def main() -> None: for node in records: logger.info( " %-30s serial: %s", - node.get("name", "ΓÇö"), - node.get("serial_number", "ΓÇö"), + node.get("name", "N/A"), + node.get("serial_number", "N/A"), ) diff --git a/python/cluster_setup_basic.py b/python/cluster_setup_basic.py index 4f6682d..db17c22 100644 --- a/python/cluster_setup_basic.py +++ b/python/cluster_setup_basic.py @@ -12,6 +12,7 @@ Usage:: + # env vars directly export ONTAP_HOST=10.x.x.x # pre-cluster node IP export ONTAP_USER=admin # usually admin, empty pass on pre-cluster nodes export ONTAP_PASS= @@ -22,14 +23,20 @@ export CLUSTER_GATEWAY=10.x.x.1 export PARTNER_MGMT_IP=10.x.x.y python cluster_setup_basic.py + + # or use a per-build .env file (analogous to -ir ) + python cluster_setup_basic.py --env-file r9141_build.env + python cluster_setup_basic.py --env-file r919_build.env """ from __future__ import annotations +import argparse import logging import os import sys import time +from pathlib import Path from ontap_client import OntapClient @@ -43,15 +50,15 @@ # USER INPUTS — fill in your values here before running # --------------------------------------------------------------------------- INPUTS = { - "ONTAP_HOST": "", # Node 1 management IP — set via ONTAP_HOST env var + "ONTAP_HOST": "10.140.108.120", # Node 1 management IP — set via ONTAP_HOST env var "ONTAP_USER": "admin", "ONTAP_PASS": "", # set via ONTAP_PASS env var — leave empty for pre-cluster nodes - "CLUSTER_NAME": "", # choose your cluster name — set via CLUSTER_NAME env var + "CLUSTER_NAME": "sp57388-cluster", # choose your cluster name — set via CLUSTER_NAME env var "CLUSTER_PASS": "", # set via CLUSTER_PASS env var — choose your cluster admin password - "CLUSTER_MGMT_IP": "", # cluster management IP — set via CLUSTER_MGMT_IP env var - "CLUSTER_NETMASK": "", # e.g. 255.255.255.0 — set via CLUSTER_NETMASK env var - "CLUSTER_GATEWAY": "", # default gateway — set via CLUSTER_GATEWAY env var - "PARTNER_MGMT_IP": "", # Node 2 management IP — set via PARTNER_MGMT_IP env var + "CLUSTER_MGMT_IP": "10.140.108.120", # cluster management IP — set via CLUSTER_MGMT_IP env var + "CLUSTER_NETMASK": "255.255.192.0", # e.g. 255.255.255.0 — set via CLUSTER_NETMASK env var + "CLUSTER_GATEWAY": "10.140.64.1", # default gateway — set via CLUSTER_GATEWAY env var + "PARTNER_MGMT_IP": "10.140.108.124", # Node 2 management IP — set via PARTNER_MGMT_IP env var } # --------------------------------------------------------------------------- @@ -259,7 +266,36 @@ def main() -> None: ) +def _load_env_file(path: str) -> None: + """Load KEY=VALUE pairs from a .env file into the INPUTS dict.""" + for line in Path(path).read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, _, value = line.partition("=") + INPUTS[key.strip()] = value.strip().strip('"').strip("'") + + if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Create an ONTAP cluster from two pre-cluster nodes." + ) + parser.add_argument( + "--env-file", + metavar="FILE", + help="Path to a .env file with KEY=VALUE pairs (one per build, like -ir in ha_create.exp).", + ) + args = parser.parse_args() + + if args.env_file: + _load_env_file(args.env_file) + + # env vars always win over INPUTS block defaults + for key in list(INPUTS): + val = os.environ.get(key) + if val: + INPUTS[key] = val + try: main() except KeyboardInterrupt: diff --git a/python/nfs_provision.py b/python/nfs_provision.py index 59f8d92..44cb235 100644 --- a/python/nfs_provision.py +++ b/python/nfs_provision.py @@ -26,11 +26,11 @@ python nfs_provision.py --env-file nfs-provision.env Default values (vs0, vol_nfs_test_01, 0.0.0.0/0, etc.) are for illustration -only. Replace them with values appropriate for your environment ΓÇö +only. Replace them with values appropriate for your environment - in particular, restrict ``--client-match`` to your actual client subnet. This script is *not* idempotent: running it twice with the same volume name -will fail. See ``python/README.md`` ΓåÆ "Adapting for Your Environment" for +will fail. See ``python/README.md`` -> "Adapting for Your Environment" for guidance on adding existence checks. """ @@ -50,19 +50,18 @@ ) logger = logging.getLogger(__name__) -# ΓöÇΓöÇ Inputs (edit these directly, same as the YAML env: block) ΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇ +# Inputs (edit these directly, same as the YAML env: block) # These are the defaults. CLI args and env vars override them. ENV = { - "ONTAP_HOST": "", # cluster management IP ΓÇö set here or via ONTAP_HOST env var + "ONTAP_HOST": "", # cluster management IP - set here or via ONTAP_HOST env var "ONTAP_USER": "admin", - "ONTAP_PASS": "", # never hardcode ΓÇö set via ONTAP_PASS env var + "ONTAP_PASS": "", # never hardcode - set via ONTAP_PASS env var "SVM_NAME": "vs1", "VOLUME_NAME": "vol_001", "VOLUME_SIZE": "100MB", - "AGGR_NAME": "sti232_vsim_sr091o_aggr1", # required ΓÇö set via --aggregate or AGGR_NAME env var + "AGGR_NAME": "sti232_vsim_sr091o_aggr1", # required - set via --aggregate or AGGR_NAME env var "CLIENT_MATCH": "0.0.0.0/0", } -# ΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇΓöÇ def _load_env_file(path: str) -> None: @@ -128,9 +127,9 @@ def main() -> None: policy_name = f"{volume}_export_policy" with OntapClient.from_env() as client: - # Step 1 ΓÇö create volume (idempotent: skip if already exists) + # Step 1 - create volume (idempotent: skip if already exists) # POST /storage/volumes to create a new FlexVol with a NAS junction path. - # Volume creation is asynchronous ΓÇö the response contains a job UUID. + # Volume creation is asynchronous - the response contains a job UUID. existing_vol = client.get( "/storage/volumes", fields="name,uuid", @@ -138,9 +137,9 @@ def main() -> None: **{"svm.name": svm}, ) if existing_vol.get("records"): - logger.info("Volume '%s' already exists ΓÇö skipping create", volume) + logger.info("Volume '%s' already exists - skipping create", volume) else: - logger.info("Creating volume '%s' (%s) on SVM '%s'ΓǪ", volume, size, svm) + logger.info("Creating volume '%s' (%s) on SVM '%s'...", volume, size, svm) create_resp = client.post( "/storage/volumes", body={ @@ -152,7 +151,7 @@ def main() -> None: }, ) - # Step 2 ΓÇö poll volume-creation job + # Step 2 - poll volume-creation job # Block until the async job finishes before proceeding. # poll_job raises RuntimeError if the job ends in a failure state. job_uuid = create_resp["job"]["uuid"] @@ -160,7 +159,7 @@ def main() -> None: client.poll_job(job_uuid) logger.info("Volume '%s' created successfully", volume) - # Step 3 ΓÇö fetch volume UUID + # Step 3 - fetch volume UUID # The UUID is required to PATCH the volume later when assigning the export policy. # Filter by name + svm.name to pinpoint exactly the volume just created. vol_resp = client.get( @@ -173,7 +172,7 @@ def main() -> None: raise RuntimeError(f"Volume '{volume}' not found on SVM '{svm}' after creation") volume_uuid = vol_resp["records"][0]["uuid"] - # Step 4 ΓÇö create export policy (idempotent: skip if already exists) + # Step 4 - create export policy (idempotent: skip if already exists) # Creates a dedicated policy named _export_policy scoped to the SVM. # A per-volume policy makes it easy to manage access rules independently. existing_policy = client.get( @@ -183,15 +182,15 @@ def main() -> None: **{"svm.name": svm}, ) if existing_policy.get("records"): - logger.info("Export policy '%s' already exists ΓÇö skipping create", policy_name) + logger.info("Export policy '%s' already exists - skipping create", policy_name) else: - logger.info("Creating export policy '%s'ΓǪ", policy_name) + logger.info("Creating export policy '%s'...", policy_name) client.post( "/protocols/nfs/export-policies", body={"name": policy_name, "svm": {"name": svm}}, ) - # Step 5 ΓÇö fetch export policy ID + # Step 5 - fetch export policy ID # The numeric ID is required when POSTing rules to the policy. # Filter by name + svm.name to retrieve only this policy's record. policy_resp = client.get( @@ -206,7 +205,7 @@ def main() -> None: ) policy_id = policy_resp["records"][0]["id"] - # Step 6 ΓÇö add client rule (idempotent: skip if a matching rule already exists) + # Step 6 - add client rule (idempotent: skip if a matching rule already exists) # POST a rule to the export policy allowing the given client IP or CIDR range. # ro_rule, rw_rule, superuser = 'any' is suitable for lab; tighten for production. existing_rules = client.get( @@ -218,9 +217,9 @@ def main() -> None: for r in existing_rules.get("records", []) ) if rule_exists: - logger.info("Client rule '%s' already exists in policy ΓÇö skipping", client_match) + logger.info("Client rule '%s' already exists in policy - skipping", client_match) else: - logger.info("Adding client rule '%s' to policyΓǪ", client_match) + logger.info("Adding client rule '%s' to policy...", client_match) client.post( f"/protocols/nfs/export-policies/{policy_id}/rules", body={ @@ -231,26 +230,26 @@ def main() -> None: }, ) - # Step 7 ΓÇö assign export policy to volume + # Step 7 - assign export policy to volume # PATCH the volume's nas.export_policy field to link the policy. # This makes the volume accessible to NFS clients that match the rule. - logger.info("Assigning export policy to volumeΓǪ") + logger.info("Assigning export policy to volume...") patch_resp = client.patch( f"/storage/volumes/{volume_uuid}", body={"nas": {"export_policy": {"name": policy_name}}}, ) - # Step 8 ΓÇö poll assign-policy job + # Step 8 - poll assign-policy job # The PATCH may return a job UUID if the operation is async. # Only poll if a UUID was returned; sync responses skip this block. if "job" in patch_resp: client.poll_job(patch_resp["job"]["uuid"]) - # Step 9 ΓÇö print summary + # Step 9 - print summary # Log a single success line with volume, size, SVM, mount path, # export policy name, and client rule for quick confirmation. logger.info( - "Γ£ô Volume '%s' (%s) created on SVM '%s' | Mount path: /%s | " + "[OK] Volume '%s' (%s) created on SVM '%s' | Mount path: /%s | " "Export policy '%s' created with client rule '%s' and assigned to volume", volume, size, From d9fc979d3316795807cc0c81ddc2b6c788f5eda0 Mon Sep 17 00:00:00 2001 From: Somanath Date: Thu, 7 May 2026 20:39:15 +0530 Subject: [PATCH 2/4] fix: remove internal cluster IPs and lab aggregate name from defaults --- python/cifs_provision.py | 420 +++++++++++++++++++--------------- python/cluster_setup_basic.py | 13 +- python/nfs_provision.py | 286 ++++++++++++----------- 3 files changed, 378 insertions(+), 341 deletions(-) diff --git a/python/cifs_provision.py b/python/cifs_provision.py index 8adddc3..720213e 100644 --- a/python/cifs_provision.py +++ b/python/cifs_provision.py @@ -48,7 +48,7 @@ "SVM_NAME": "vs1", "VOLUME_NAME": "vol_002", "VOLUME_SIZE": "100MB", - "AGGR_NAME": "sti232_vsim_sr091o_aggr1", # required — set via --aggregate or AGGR_NAME env var + "AGGR_NAME": "", # required — set via --aggregate or AGGR_NAME env var "CLIENT_MATCH": "0.0.0.0/0", # required — set via --client-match or CLIENT_MATCH env var "SHARE_NAME": "cifs_share_demo", "SHARE_COMMENT": "Provisioned by orchestrio", @@ -114,215 +114,255 @@ def parse_args() -> argparse.Namespace: return p.parse_args() -def main() -> None: - args = parse_args() - - # Load env file first so its values can be read via os.environ below +def _resolve_config(args: argparse.Namespace) -> dict[str, str | bool]: + """Load env file and CLI args, then return the resolved configuration dict.""" if args.env_file: _load_env_file(args.env_file) - # Push ENV block values into os.environ so OntapClient.from_env() picks them up for key, value in ENV.items(): if value and key not in os.environ: os.environ[key] = value - # Resolve each value: CLI arg > env var > ENV block > built-in default (matches YAML priority) - svm = args.svm or os.environ.get("SVM_NAME") or ENV["SVM_NAME"] or "vs0" - volume = args.volume or os.environ.get("VOLUME_NAME") or ENV["VOLUME_NAME"] or "cifs_test_env" - size = args.size or os.environ.get("VOLUME_SIZE") or ENV["VOLUME_SIZE"] or "100MB" aggregate = args.aggregate or os.environ.get("AGGR_NAME") or ENV["AGGR_NAME"] or "" - share_name = ( - args.share_name or os.environ.get("SHARE_NAME") or ENV["SHARE_NAME"] or "cifs_share_demo" + if not aggregate: + logger.error("--aggregate is required (or set AGGR_NAME in env / --env-file)") + sys.exit(1) + + return { + "svm": args.svm or os.environ.get("SVM_NAME") or ENV["SVM_NAME"] or "vs0", + "volume": ( + args.volume or os.environ.get("VOLUME_NAME") or ENV["VOLUME_NAME"] or "cifs_test_env" + ), + "size": args.size or os.environ.get("VOLUME_SIZE") or ENV["VOLUME_SIZE"] or "100MB", + "aggregate": aggregate, + "share_name": ( + args.share_name + or os.environ.get("SHARE_NAME") + or ENV["SHARE_NAME"] + or "cifs_share_demo" + ), + "share_comment": ( + args.share_comment + or os.environ.get("SHARE_COMMENT") + or ENV["SHARE_COMMENT"] + or "Provisioned by orchestrio" + ), + "acl_user": (args.acl_user or os.environ.get("ACL_USER") or ENV["ACL_USER"] or "Everyone"), + "acl_permission": ( + args.acl_permission + or os.environ.get("ACL_PERMISSION") + or ENV["ACL_PERMISSION"] + or "full_control" + ), + "create_cifs_server": args.create_cifs_server, + "cifs_server_name": ( + args.cifs_server_name + or os.environ.get("CIFS_SERVER_NAME") + or ENV["CIFS_SERVER_NAME"] + or "ONTAP-CIFS" + ), + "workgroup": ( + args.workgroup + or os.environ.get("CIFS_WORKGROUP") + or ENV["CIFS_WORKGROUP"] + or "WORKGROUP" + ), + } + + +def _ensure_cifs_server( + client: OntapClient, + svm: str, + create_cifs_server: bool, + cifs_server_name: str, + workgroup: str, +) -> None: + """Verify a CIFS server exists on the SVM, optionally creating one if missing.""" + cifs_svc_resp = client.get( + "/protocols/cifs/services", + fields="svm.name,enabled", + **{"svm.name": svm}, + ) + if cifs_svc_resp.get("num_records", 0) > 0: + logger.info("CIFS server confirmed on SVM '%s'", svm) + return + + if not create_cifs_server: + logger.error( + "ABORTED - no CIFS server found on SVM '%s'. " + "Pass --create-cifs-server to create one automatically, or use " + "'vserver cifs create' before running this script.", + svm, + ) + sys.exit(1) + + logger.info( + "No CIFS server on SVM '%s' - creating workgroup server '%s' in workgroup '%s'...", + svm, + cifs_server_name, + workgroup, ) - share_comment = ( - args.share_comment - or os.environ.get("SHARE_COMMENT") - or ENV["SHARE_COMMENT"] - or "Provisioned by orchestrio" + resp = client.post( + "/protocols/cifs/services", + body={ + "svm": {"name": svm}, + "name": cifs_server_name, + "workgroup": workgroup, + "enabled": True, + }, ) - acl_user = args.acl_user or os.environ.get("ACL_USER") or ENV["ACL_USER"] or "Everyone" - acl_permission = ( - args.acl_permission - or os.environ.get("ACL_PERMISSION") - or ENV["ACL_PERMISSION"] - or "full_control" + if resp.get("job"): + client.poll_job(resp["job"]["uuid"]) + logger.info( + "CIFS server '%s' created in workgroup '%s' on SVM '%s'", + cifs_server_name, + workgroup, + svm, ) - create_cifs_server = args.create_cifs_server - cifs_server_name = ( - args.cifs_server_name - or os.environ.get("CIFS_SERVER_NAME") - or ENV["CIFS_SERVER_NAME"] - or "ONTAP-CIFS" + +def _ensure_volume_ntfs( + client: OntapClient, svm: str, volume: str, size: str, aggregate: str +) -> dict: + """Create the FlexVol (NTFS security style) if it does not exist. Returns the job result.""" + existing = client.get( + "/storage/volumes", + fields="name,uuid", + name=volume, + **{"svm.name": svm}, + ) + if existing.get("records"): + logger.info("Volume '%s' already exists - skipping create", volume) + return {"state": "skipped", "message": "volume already existed"} + + logger.info("Creating volume '%s' (%s) on SVM '%s'...", volume, size, svm) + resp = client.post( + "/storage/volumes", + body={ + "name": volume, + "svm": {"name": svm}, + "aggregates": [{"name": aggregate}], + "size": size, + "nas": { + "security_style": "ntfs", + "path": f"/{volume}", + }, + }, ) - workgroup = ( - args.workgroup or os.environ.get("CIFS_WORKGROUP") or ENV["CIFS_WORKGROUP"] or "WORKGROUP" + job_uuid = resp["job"]["uuid"] + logger.info("Volume creation job: %s", job_uuid) + return client.poll_job(job_uuid) + + +def _get_svm_uuid(client: OntapClient, svm: str) -> str: + """Fetch and return the UUID for the named SVM.""" + resp = client.get("/svm/svms", fields="name,uuid", name=svm) + return resp["records"][0]["uuid"] + + +def _ensure_cifs_share( + client: OntapClient, + svm_uuid: str, + share_name: str, + volume: str, + svm: str, + share_comment: str, +) -> None: + """Create the CIFS share if it does not already exist.""" + try: + existing = client.get( + f"/protocols/cifs/shares/{svm_uuid}/{share_name}", + fields="name", + ) + share_exists = bool(existing.get("name")) + except OntapApiError as exc: + if exc.status_code == 404: + share_exists = False + else: + raise + + if share_exists: + logger.info("CIFS share '%s' already exists - skipping create", share_name) + return + + logger.info("Creating CIFS share '%s' on path '/%s'...", share_name, volume) + client.post( + "/protocols/cifs/shares", + body={ + "name": share_name, + "path": f"/{volume}", + "svm": {"name": svm}, + "comment": share_comment, + }, ) - if not aggregate: - logger.error("--aggregate is required (or set AGGR_NAME in env / --env-file)") - sys.exit(1) - with OntapClient.from_env() as client: - # Pre-flight — verify CIFS server is enabled on the SVM - # A CIFS share cannot be created if no CIFS server exists on the SVM. - # Exits early with a clear error rather than failing mid-workflow. - cifs_svc_resp = client.get( - "/protocols/cifs/services", - fields="svm.name,enabled", - **{"svm.name": svm}, +def _set_share_acl( + client: OntapClient, + svm_uuid: str, + share_name: str, + acl_user: str, + acl_permission: str, +) -> None: + """Patch the share ACL entry for the given user with the specified permission.""" + logger.info("Setting ACL: %s -> %s...", acl_user, acl_permission) + client.patch( + f"/protocols/cifs/shares/{svm_uuid}/{share_name}/acls/{acl_user}/windows", + body={"permission": acl_permission}, + ) + + +def _verify_and_log_acls(client: OntapClient, svm_uuid: str, share_name: str) -> None: + """Fetch the share and log each ACL entry for confirmation.""" + logger.info("Verifying share '%s'...", share_name) + resp = client.get( + f"/protocols/cifs/shares/{svm_uuid}/{share_name}", + fields="name,path,acls", + ) + for acl in resp.get("acls", []): + logger.info( + " ACL: %s (%s) -> %s", + acl.get("user_or_group", "N/A"), + acl.get("type", "N/A"), + acl.get("permission", "N/A"), ) - if cifs_svc_resp.get("num_records", 0) == 0: - if not create_cifs_server: - logger.error( - "ABORTED — no CIFS server found on SVM '%s'. " - "Pass --create-cifs-server to create one automatically, or use " - "'vserver cifs create' before running this script.", - svm, - ) - sys.exit(1) - logger.info( - "No CIFS server on SVM '%s' — creating workgroup server '%s' in workgroup '%s'…", - svm, - cifs_server_name, - workgroup, - ) - cifs_create_resp = client.post( - "/protocols/cifs/services", - body={ - "svm": {"name": svm}, - "name": cifs_server_name, - "workgroup": workgroup, - "enabled": True, - }, - ) - # ONTAP may return an async job for CIFS server creation - if cifs_create_resp.get("job"): - cifs_job_uuid = cifs_create_resp["job"]["uuid"] - logger.info("CIFS server creation job: %s", cifs_job_uuid) - client.poll_job(cifs_job_uuid) - logger.info( - "CIFS server '%s' created in workgroup '%s' on SVM '%s'", - cifs_server_name, - workgroup, - svm, - ) - else: - logger.info("CIFS server confirmed on SVM '%s'", svm) - - # Step 1 — create volume with NTFS security style (idempotent: skip if exists) - # POST /storage/volumes to create a FlexVol with security_style=ntfs. - # NTFS security style is required for CIFS/SMB share ACL enforcement. - existing_vol = client.get( - "/storage/volumes", - fields="name,uuid", - name=volume, - **{"svm.name": svm}, + + +def main() -> None: + cfg = _resolve_config(parse_args()) + svm = cfg["svm"] + volume = cfg["volume"] + size = cfg["size"] + aggregate = cfg["aggregate"] + share_name = cfg["share_name"] + share_comment = cfg["share_comment"] + acl_user = cfg["acl_user"] + acl_permission = cfg["acl_permission"] + + with OntapClient.from_env() as client: + _ensure_cifs_server( + client, svm, cfg["create_cifs_server"], cfg["cifs_server_name"], cfg["workgroup"] ) - if existing_vol.get("records"): - logger.info("Volume '%s' already exists — skipping create", volume) - job_result = {"state": "skipped", "message": "volume already existed"} - else: - logger.info("Creating volume '%s' (%s) on SVM '%s'…", volume, size, svm) - create_resp = client.post( - "/storage/volumes", - body={ - "name": volume, - "svm": {"name": svm}, - "aggregates": [{"name": aggregate}], - "size": size, - "nas": { - "security_style": "ntfs", - "path": f"/{volume}", - }, - }, - ) - - # Step 2 — poll volume-creation job - # Block until the async job finishes; the job result is logged in Step 3. - job_uuid = create_resp["job"]["uuid"] - logger.info("Volume creation job: %s", job_uuid) - job_result = client.poll_job(job_uuid) - - # Step 3 — print volume creation status - # Log the final job state and message for confirmation before continuing. + + job_result = _ensure_volume_ntfs(client, svm, volume, size, aggregate) state = job_result.get("state", "unknown") message = job_result.get("message", "") - logger.info("Volume '%s' job → %s: %s", volume, state, message) - - # Step 4 — create CIFS share (idempotent: skip if already exists) - # POST /protocols/cifs/shares to create the share pointing at the volume junction. - # ONTAP auto-creates a default 'Everyone / Full Control' ACL entry on creation. - svm_resp = client.get( - "/svm/svms", - fields="name,uuid", - name=svm, - ) - svm_uuid = svm_resp["records"][0]["uuid"] - - try: - existing_share = client.get( - f"/protocols/cifs/shares/{svm_uuid}/{share_name}", - fields="name", - ) - share_exists = bool(existing_share.get("name")) - except OntapApiError as exc: - if exc.status_code == 404: - share_exists = False - else: - raise - if share_exists: - logger.info("CIFS share '%s' already exists — skipping create", share_name) - else: - logger.info("Creating CIFS share '%s' on path '/%s'…", share_name, volume) - client.post( - "/protocols/cifs/shares", - body={ - "name": share_name, - "path": f"/{volume}", - "svm": {"name": svm}, - "comment": share_comment, - }, - ) - - # Step 6 — set share ACL (PATCH the auto-created Everyone entry) - # svm_uuid was resolved in Step 4 above (needed for the ACL URL). - # PATCH replaces the permission on the existing ACL entry for the given user. - # Default is 'Everyone' with 'full_control'; customise via ACL_USER/ACL_PERMISSION. - logger.info("Setting ACL: %s → %s…", acl_user, acl_permission) - client.patch( - f"/protocols/cifs/shares/{svm_uuid}/{share_name}/acls/{acl_user}/windows", - body={"permission": acl_permission}, - ) - - # Step 7 — verify share and ACL - # GET the share and inspect the acls array to confirm the permission was applied. - # Logs each ACL entry (user, type, permission) for visual confirmation. - logger.info("Verifying share '%s'…", share_name) - verify_resp = client.get( - f"/protocols/cifs/shares/{svm_uuid}/{share_name}", - fields="name,path,acls", - ) - acls = verify_resp.get("acls", []) - for acl in acls: - logger.info( - " ACL: %s (%s) → %s", - acl.get("user_or_group", "—"), - acl.get("type", "—"), - acl.get("permission", "—"), - ) - - # Step 8 — print summary - # Log a single success line with share name, volume, SVM, path, and ACL. - logger.info( - "✓ CIFS share '%s' created on volume '%s' (SVM: %s) | Path: /%s | ACL: %s → %s", - share_name, - volume, - svm, - volume, - acl_user, - acl_permission, - ) + logger.info("Volume '%s' job -> %s: %s", volume, state, message) + + svm_uuid = _get_svm_uuid(client, svm) + _ensure_cifs_share(client, svm_uuid, share_name, volume, svm, share_comment) + _set_share_acl(client, svm_uuid, share_name, acl_user, acl_permission) + _verify_and_log_acls(client, svm_uuid, share_name) + + logger.info( + "[OK] CIFS share '%s' on volume '%s' (SVM: %s) | Path: /%s | ACL: %s -> %s", + share_name, + volume, + svm, + volume, + acl_user, + acl_permission, + ) if __name__ == "__main__": diff --git a/python/cluster_setup_basic.py b/python/cluster_setup_basic.py index db17c22..1d636b6 100644 --- a/python/cluster_setup_basic.py +++ b/python/cluster_setup_basic.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 """Create an ONTAP cluster from two pre-cluster nodes. -Equivalent to: orchestrio run yaml-workflows/workflows/cluster_setup_basic.yaml Steps: 1. discover_nodes — GET /api/cluster/nodes (membership=available, retry 3x/30s) @@ -50,15 +49,15 @@ # USER INPUTS — fill in your values here before running # --------------------------------------------------------------------------- INPUTS = { - "ONTAP_HOST": "10.140.108.120", # Node 1 management IP — set via ONTAP_HOST env var + "ONTAP_HOST": "", # Node 1 management IP — set via ONTAP_HOST env var "ONTAP_USER": "admin", "ONTAP_PASS": "", # set via ONTAP_PASS env var — leave empty for pre-cluster nodes - "CLUSTER_NAME": "sp57388-cluster", # choose your cluster name — set via CLUSTER_NAME env var + "CLUSTER_NAME": "", # choose your cluster name — set via CLUSTER_NAME env var "CLUSTER_PASS": "", # set via CLUSTER_PASS env var — choose your cluster admin password - "CLUSTER_MGMT_IP": "10.140.108.120", # cluster management IP — set via CLUSTER_MGMT_IP env var - "CLUSTER_NETMASK": "255.255.192.0", # e.g. 255.255.255.0 — set via CLUSTER_NETMASK env var - "CLUSTER_GATEWAY": "10.140.64.1", # default gateway — set via CLUSTER_GATEWAY env var - "PARTNER_MGMT_IP": "10.140.108.124", # Node 2 management IP — set via PARTNER_MGMT_IP env var + "CLUSTER_MGMT_IP": "", # cluster management IP — set via CLUSTER_MGMT_IP env var + "CLUSTER_NETMASK": "", # e.g. 255.255.255.0 — set via CLUSTER_NETMASK env var + "CLUSTER_GATEWAY": "", # default gateway — set via CLUSTER_GATEWAY env var + "PARTNER_MGMT_IP": "", # Node 2 management IP — set via PARTNER_MGMT_IP env var } # --------------------------------------------------------------------------- diff --git a/python/nfs_provision.py b/python/nfs_provision.py index 44cb235..3c2201d 100644 --- a/python/nfs_provision.py +++ b/python/nfs_provision.py @@ -59,7 +59,7 @@ "SVM_NAME": "vs1", "VOLUME_NAME": "vol_001", "VOLUME_SIZE": "100MB", - "AGGR_NAME": "sti232_vsim_sr091o_aggr1", # required - set via --aggregate or AGGR_NAME env var + "AGGR_NAME": "", # required - set via --aggregate or AGGR_NAME env var "CLIENT_MATCH": "0.0.0.0/0", } @@ -97,167 +97,165 @@ def parse_args() -> argparse.Namespace: return p.parse_args() -def main() -> None: - args = parse_args() - - # Load env file first so its values can be read via os.environ below +def _resolve_config(args: argparse.Namespace) -> dict[str, str]: + """Load env file and CLI args, then return the resolved configuration dict.""" if args.env_file: _load_env_file(args.env_file) - # Push ENV block values into os.environ so OntapClient.from_env() picks them up for key, value in ENV.items(): if value and key not in os.environ: os.environ[key] = value - # Resolve each value: CLI arg > env var > ENV block > built-in default (matches YAML priority) - svm = args.svm or os.environ.get("SVM_NAME") or ENV["SVM_NAME"] or "vs0" - volume = ( - args.volume or os.environ.get("VOLUME_NAME") or ENV["VOLUME_NAME"] or "vol_nfs_test_01" - ) - size = args.size or os.environ.get("VOLUME_SIZE") or ENV["VOLUME_SIZE"] or "100MB" aggregate = args.aggregate or os.environ.get("AGGR_NAME") or ENV["AGGR_NAME"] or "" - client_match = ( - args.client_match or os.environ.get("CLIENT_MATCH") or ENV["CLIENT_MATCH"] or "0.0.0.0/0" - ) - if not aggregate: logger.error("--aggregate is required (or set AGGR_NAME in env / --env-file)") sys.exit(1) - policy_name = f"{volume}_export_policy" - - with OntapClient.from_env() as client: - # Step 1 - create volume (idempotent: skip if already exists) - # POST /storage/volumes to create a new FlexVol with a NAS junction path. - # Volume creation is asynchronous - the response contains a job UUID. - existing_vol = client.get( - "/storage/volumes", - fields="name,uuid", - name=volume, - **{"svm.name": svm}, - ) - if existing_vol.get("records"): - logger.info("Volume '%s' already exists - skipping create", volume) - else: - logger.info("Creating volume '%s' (%s) on SVM '%s'...", volume, size, svm) - create_resp = client.post( - "/storage/volumes", - body={ - "name": volume, - "svm": {"name": svm}, - "aggregates": [{"name": aggregate}], - "size": size, - "nas": {"path": f"/{volume}"}, - }, - ) - - # Step 2 - poll volume-creation job - # Block until the async job finishes before proceeding. - # poll_job raises RuntimeError if the job ends in a failure state. - job_uuid = create_resp["job"]["uuid"] - logger.info("Volume creation job: %s", job_uuid) - client.poll_job(job_uuid) - logger.info("Volume '%s' created successfully", volume) - - # Step 3 - fetch volume UUID - # The UUID is required to PATCH the volume later when assigning the export policy. - # Filter by name + svm.name to pinpoint exactly the volume just created. - vol_resp = client.get( + return { + "svm": args.svm or os.environ.get("SVM_NAME") or ENV["SVM_NAME"] or "vs0", + "volume": ( + args.volume or os.environ.get("VOLUME_NAME") or ENV["VOLUME_NAME"] or "vol_nfs_test_01" + ), + "size": args.size or os.environ.get("VOLUME_SIZE") or ENV["VOLUME_SIZE"] or "100MB", + "aggregate": aggregate, + "client_match": ( + args.client_match + or os.environ.get("CLIENT_MATCH") + or ENV["CLIENT_MATCH"] + or "0.0.0.0/0" + ), + } + + +def _ensure_volume(client: OntapClient, svm: str, volume: str, size: str, aggregate: str) -> str: + """Create the FlexVol if it does not exist. Returns the volume UUID.""" + existing = client.get( + "/storage/volumes", + fields="name,uuid", + name=volume, + **{"svm.name": svm}, + ) + if existing.get("records"): + logger.info("Volume '%s' already exists - skipping create", volume) + else: + logger.info("Creating volume '%s' (%s) on SVM '%s'...", volume, size, svm) + resp = client.post( "/storage/volumes", - fields="name,uuid", - name=volume, - **{"svm.name": svm}, - ) - if not vol_resp.get("records"): - raise RuntimeError(f"Volume '{volume}' not found on SVM '{svm}' after creation") - volume_uuid = vol_resp["records"][0]["uuid"] - - # Step 4 - create export policy (idempotent: skip if already exists) - # Creates a dedicated policy named _export_policy scoped to the SVM. - # A per-volume policy makes it easy to manage access rules independently. - existing_policy = client.get( - "/protocols/nfs/export-policies", - fields="name,id", - name=policy_name, - **{"svm.name": svm}, + body={ + "name": volume, + "svm": {"name": svm}, + "aggregates": [{"name": aggregate}], + "size": size, + "nas": {"path": f"/{volume}"}, + }, ) - if existing_policy.get("records"): - logger.info("Export policy '%s' already exists - skipping create", policy_name) - else: - logger.info("Creating export policy '%s'...", policy_name) - client.post( - "/protocols/nfs/export-policies", - body={"name": policy_name, "svm": {"name": svm}}, - ) - - # Step 5 - fetch export policy ID - # The numeric ID is required when POSTing rules to the policy. - # Filter by name + svm.name to retrieve only this policy's record. - policy_resp = client.get( + job_uuid = resp["job"]["uuid"] + logger.info("Volume creation job: %s", job_uuid) + client.poll_job(job_uuid) + logger.info("Volume '%s' created successfully", volume) + + vol_resp = client.get( + "/storage/volumes", + fields="name,uuid", + name=volume, + **{"svm.name": svm}, + ) + if not vol_resp.get("records"): + raise RuntimeError(f"Volume '{volume}' not found on SVM '{svm}' after creation") + return vol_resp["records"][0]["uuid"] + + +def _ensure_export_policy(client: OntapClient, svm: str, policy_name: str) -> int: + """Create the NFS export policy if it does not exist. Returns the policy ID.""" + existing = client.get( + "/protocols/nfs/export-policies", + fields="name,id", + name=policy_name, + **{"svm.name": svm}, + ) + if existing.get("records"): + logger.info("Export policy '%s' already exists - skipping create", policy_name) + else: + logger.info("Creating export policy '%s'...", policy_name) + client.post( "/protocols/nfs/export-policies", - fields="name,id", - name=policy_name, - **{"svm.name": svm}, - ) - if not policy_resp.get("records"): - raise RuntimeError( - f"Export policy '{policy_name}' not found on SVM '{svm}' after creation" - ) - policy_id = policy_resp["records"][0]["id"] - - # Step 6 - add client rule (idempotent: skip if a matching rule already exists) - # POST a rule to the export policy allowing the given client IP or CIDR range. - # ro_rule, rw_rule, superuser = 'any' is suitable for lab; tighten for production. - existing_rules = client.get( - f"/protocols/nfs/export-policies/{policy_id}/rules", - fields="index,clients", - ) - rule_exists = any( - any(c.get("match") == client_match for c in r.get("clients", [])) - for r in existing_rules.get("records", []) - ) - if rule_exists: - logger.info("Client rule '%s' already exists in policy - skipping", client_match) - else: - logger.info("Adding client rule '%s' to policy...", client_match) - client.post( - f"/protocols/nfs/export-policies/{policy_id}/rules", - body={ - "clients": [{"match": client_match}], - "ro_rule": ["any"], - "rw_rule": ["any"], - "superuser": ["any"], - }, - ) - - # Step 7 - assign export policy to volume - # PATCH the volume's nas.export_policy field to link the policy. - # This makes the volume accessible to NFS clients that match the rule. - logger.info("Assigning export policy to volume...") - patch_resp = client.patch( - f"/storage/volumes/{volume_uuid}", - body={"nas": {"export_policy": {"name": policy_name}}}, + body={"name": policy_name, "svm": {"name": svm}}, ) - # Step 8 - poll assign-policy job - # The PATCH may return a job UUID if the operation is async. - # Only poll if a UUID was returned; sync responses skip this block. - if "job" in patch_resp: - client.poll_job(patch_resp["job"]["uuid"]) - - # Step 9 - print summary - # Log a single success line with volume, size, SVM, mount path, - # export policy name, and client rule for quick confirmation. - logger.info( - "[OK] Volume '%s' (%s) created on SVM '%s' | Mount path: /%s | " - "Export policy '%s' created with client rule '%s' and assigned to volume", - volume, - size, - svm, - volume, - policy_name, - client_match, + policy_resp = client.get( + "/protocols/nfs/export-policies", + fields="name,id", + name=policy_name, + **{"svm.name": svm}, + ) + if not policy_resp.get("records"): + raise RuntimeError( + f"Export policy '{policy_name}' not found on SVM '{svm}' after creation" ) + return policy_resp["records"][0]["id"] + + +def _ensure_client_rule(client: OntapClient, policy_id: int, client_match: str) -> None: + """Add a client-match rule to the export policy if one does not already exist.""" + existing_rules = client.get( + f"/protocols/nfs/export-policies/{policy_id}/rules", + fields="index,clients", + ) + rule_exists = any( + any(c.get("match") == client_match for c in r.get("clients", [])) + for r in existing_rules.get("records", []) + ) + if rule_exists: + logger.info("Client rule '%s' already exists in policy - skipping", client_match) + return + logger.info("Adding client rule '%s' to policy...", client_match) + client.post( + f"/protocols/nfs/export-policies/{policy_id}/rules", + body={ + "clients": [{"match": client_match}], + "ro_rule": ["any"], + "rw_rule": ["any"], + "superuser": ["any"], + }, + ) + + +def _assign_export_policy(client: OntapClient, volume_uuid: str, policy_name: str) -> None: + """Assign the export policy to the volume and wait for any async job to complete.""" + logger.info("Assigning export policy to volume...") + patch_resp = client.patch( + f"/storage/volumes/{volume_uuid}", + body={"nas": {"export_policy": {"name": policy_name}}}, + ) + if "job" in patch_resp: + client.poll_job(patch_resp["job"]["uuid"]) + + +def main() -> None: + cfg = _resolve_config(parse_args()) + svm = cfg["svm"] + volume = cfg["volume"] + size = cfg["size"] + aggregate = cfg["aggregate"] + client_match = cfg["client_match"] + policy_name = f"{volume}_export_policy" + + with OntapClient.from_env() as client: + volume_uuid = _ensure_volume(client, svm, volume, size, aggregate) + policy_id = _ensure_export_policy(client, svm, policy_name) + _ensure_client_rule(client, policy_id, client_match) + _assign_export_policy(client, volume_uuid, policy_name) + + logger.info( + "[OK] Volume '%s' (%s) on SVM '%s' | Mount: /%s | " + "Export policy '%s' with client rule '%s' assigned", + volume, + size, + svm, + volume, + policy_name, + client_match, + ) if __name__ == "__main__": From 7e6e11f25b92404679367b6d6aee327d635afcf2 Mon Sep 17 00:00:00 2001 From: Somanath Date: Fri, 8 May 2026 14:34:52 +0530 Subject: [PATCH 3/4] fix: reduce cyclomatic complexity in _resolve_config functions - Extract _pick() helper in cifs_provision.py and nfs_provision.py to eliminate repeated cli-arg / env-var / ENV-dict / default or-chains - cifs_provision::_resolve_config: CC E(36) -> B(6) - nfs_provision::_resolve_config: CC D(21) -> B(6) - Strip UTF-8 BOM from cluster_info.py and nfs_provision.py so radon and other tooling can parse them cleanly - Overall average CC across all scripts: A (3.2) --- python/cifs_provision.py | 52 ++++++++++++---------------------------- python/cluster_info.py | 2 +- python/nfs_provision.py | 24 +++++++++---------- 3 files changed, 27 insertions(+), 51 deletions(-) diff --git a/python/cifs_provision.py b/python/cifs_provision.py index 720213e..07737f2 100644 --- a/python/cifs_provision.py +++ b/python/cifs_provision.py @@ -114,6 +114,11 @@ def parse_args() -> argparse.Namespace: return p.parse_args() +def _pick(cli_val: str | None, env_key: str, default: str = "") -> str: + """Return the first non-empty value from: CLI arg, env var, ENV dict, or default.""" + return cli_val or os.environ.get(env_key) or ENV.get(env_key, "") or default + + def _resolve_config(args: argparse.Namespace) -> dict[str, str | bool]: """Load env file and CLI args, then return the resolved configuration dict.""" if args.env_file: @@ -123,50 +128,23 @@ def _resolve_config(args: argparse.Namespace) -> dict[str, str | bool]: if value and key not in os.environ: os.environ[key] = value - aggregate = args.aggregate or os.environ.get("AGGR_NAME") or ENV["AGGR_NAME"] or "" + aggregate = _pick(args.aggregate, "AGGR_NAME") if not aggregate: logger.error("--aggregate is required (or set AGGR_NAME in env / --env-file)") sys.exit(1) return { - "svm": args.svm or os.environ.get("SVM_NAME") or ENV["SVM_NAME"] or "vs0", - "volume": ( - args.volume or os.environ.get("VOLUME_NAME") or ENV["VOLUME_NAME"] or "cifs_test_env" - ), - "size": args.size or os.environ.get("VOLUME_SIZE") or ENV["VOLUME_SIZE"] or "100MB", + "svm": _pick(args.svm, "SVM_NAME", "vs0"), + "volume": _pick(args.volume, "VOLUME_NAME", "cifs_test_env"), + "size": _pick(args.size, "VOLUME_SIZE", "100MB"), "aggregate": aggregate, - "share_name": ( - args.share_name - or os.environ.get("SHARE_NAME") - or ENV["SHARE_NAME"] - or "cifs_share_demo" - ), - "share_comment": ( - args.share_comment - or os.environ.get("SHARE_COMMENT") - or ENV["SHARE_COMMENT"] - or "Provisioned by orchestrio" - ), - "acl_user": (args.acl_user or os.environ.get("ACL_USER") or ENV["ACL_USER"] or "Everyone"), - "acl_permission": ( - args.acl_permission - or os.environ.get("ACL_PERMISSION") - or ENV["ACL_PERMISSION"] - or "full_control" - ), + "share_name": _pick(args.share_name, "SHARE_NAME", "cifs_share_demo"), + "share_comment": _pick(args.share_comment, "SHARE_COMMENT", "Provisioned by orchestrio"), + "acl_user": _pick(args.acl_user, "ACL_USER", "Everyone"), + "acl_permission": _pick(args.acl_permission, "ACL_PERMISSION", "full_control"), "create_cifs_server": args.create_cifs_server, - "cifs_server_name": ( - args.cifs_server_name - or os.environ.get("CIFS_SERVER_NAME") - or ENV["CIFS_SERVER_NAME"] - or "ONTAP-CIFS" - ), - "workgroup": ( - args.workgroup - or os.environ.get("CIFS_WORKGROUP") - or ENV["CIFS_WORKGROUP"] - or "WORKGROUP" - ), + "cifs_server_name": _pick(args.cifs_server_name, "CIFS_SERVER_NAME", "ONTAP-CIFS"), + "workgroup": _pick(args.workgroup, "CIFS_WORKGROUP", "WORKGROUP"), } diff --git a/python/cluster_info.py b/python/cluster_info.py index 703517d..aa705de 100644 --- a/python/cluster_info.py +++ b/python/cluster_info.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3 +#!/usr/bin/env python3 """Retrieve ONTAP cluster version and list all nodes with serial numbers. Steps: diff --git a/python/nfs_provision.py b/python/nfs_provision.py index 3c2201d..3f0930b 100644 --- a/python/nfs_provision.py +++ b/python/nfs_provision.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3 +#!/usr/bin/env python3 """Create an ONTAP NFS volume with a dedicated export policy. Steps: @@ -97,6 +97,11 @@ def parse_args() -> argparse.Namespace: return p.parse_args() +def _pick(cli_val: str | None, env_key: str, default: str = "") -> str: + """Return the first non-empty value from: CLI arg, env var, ENV dict, or default.""" + return cli_val or os.environ.get(env_key) or ENV.get(env_key, "") or default + + def _resolve_config(args: argparse.Namespace) -> dict[str, str]: """Load env file and CLI args, then return the resolved configuration dict.""" if args.env_file: @@ -106,24 +111,17 @@ def _resolve_config(args: argparse.Namespace) -> dict[str, str]: if value and key not in os.environ: os.environ[key] = value - aggregate = args.aggregate or os.environ.get("AGGR_NAME") or ENV["AGGR_NAME"] or "" + aggregate = _pick(args.aggregate, "AGGR_NAME") if not aggregate: logger.error("--aggregate is required (or set AGGR_NAME in env / --env-file)") sys.exit(1) return { - "svm": args.svm or os.environ.get("SVM_NAME") or ENV["SVM_NAME"] or "vs0", - "volume": ( - args.volume or os.environ.get("VOLUME_NAME") or ENV["VOLUME_NAME"] or "vol_nfs_test_01" - ), - "size": args.size or os.environ.get("VOLUME_SIZE") or ENV["VOLUME_SIZE"] or "100MB", + "svm": _pick(args.svm, "SVM_NAME", "vs0"), + "volume": _pick(args.volume, "VOLUME_NAME", "vol_nfs_test_01"), + "size": _pick(args.size, "VOLUME_SIZE", "100MB"), "aggregate": aggregate, - "client_match": ( - args.client_match - or os.environ.get("CLIENT_MATCH") - or ENV["CLIENT_MATCH"] - or "0.0.0.0/0" - ), + "client_match": _pick(args.client_match, "CLIENT_MATCH", "0.0.0.0/0"), } From ce0e8196f7233c41b50a57ca3487b9a4f9443b50 Mon Sep 17 00:00:00 2001 From: Somanath Date: Mon, 11 May 2026 22:12:23 +0530 Subject: [PATCH 4/4] test: add unit tests for helper functions --- pytest.ini | 3 + requirements-dev.txt | 2 + tests/__init__.py | 0 tests/conftest.py | 9 + tests/test_cifs_provision.py | 198 +++++++++ tests/test_cluster_info.py | 64 +++ tests/test_cluster_setup_basic.py | 209 +++++++++ tests/test_nfs_provision.py | 72 ++++ tests/test_ontap_client.py | 403 ++++++++++++++++++ .../test_snapmirror_cleanup_test_failover.py | 149 +++++++ .../test_snapmirror_provision_dest_managed.py | 173 ++++++++ .../test_snapmirror_provision_src_managed.py | 125 ++++++ tests/test_snapmirror_test_failover.py | 149 +++++++ 13 files changed, 1556 insertions(+) create mode 100644 pytest.ini create mode 100644 requirements-dev.txt create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_cifs_provision.py create mode 100644 tests/test_cluster_info.py create mode 100644 tests/test_cluster_setup_basic.py create mode 100644 tests/test_nfs_provision.py create mode 100644 tests/test_ontap_client.py create mode 100644 tests/test_snapmirror_cleanup_test_failover.py create mode 100644 tests/test_snapmirror_provision_dest_managed.py create mode 100644 tests/test_snapmirror_provision_src_managed.py create mode 100644 tests/test_snapmirror_test_failover.py diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..d47cc87 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +testpaths = tests +addopts = -v --tb=short diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..8ad5959 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,2 @@ +pytest>=8.0 +pytest-mock>=3.12 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..4a7cdf0 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,9 @@ +"""Pytest configuration — make the python/ directory importable without installing.""" + +from __future__ import annotations + +import sys +from pathlib import Path + +# Add python/ to sys.path so test modules can import ontap_client, nfs_provision, etc. +sys.path.insert(0, str(Path(__file__).parent.parent / "python")) diff --git a/tests/test_cifs_provision.py b/tests/test_cifs_provision.py new file mode 100644 index 0000000..9f86b2f --- /dev/null +++ b/tests/test_cifs_provision.py @@ -0,0 +1,198 @@ +"""Unit tests for cifs_provision helper functions.""" + +from __future__ import annotations + +import os +from pathlib import Path +from unittest.mock import MagicMock + +import cifs_provision +import pytest +from ontap_client import OntapClient + +# --------------------------------------------------------------------------- +# _load_env_file (same dotenv-into-os.environ implementation as nfs_provision) +# --------------------------------------------------------------------------- + + +class TestLoadEnvFile: + def test_valid_file_sets_env_vars( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + env_file = tmp_path / "test.env" + env_file.write_text("CIFS_FOO=bar\nCIFS_BAZ=qux\n") + monkeypatch.delenv("CIFS_FOO", raising=False) + monkeypatch.delenv("CIFS_BAZ", raising=False) + cifs_provision._load_env_file(str(env_file)) + assert os.environ["CIFS_FOO"] == "bar" + assert os.environ["CIFS_BAZ"] == "qux" + + def test_missing_file_exits(self, tmp_path: Path) -> None: + with pytest.raises(SystemExit): + cifs_provision._load_env_file(str(tmp_path / "nonexistent.env")) + + def test_malformed_line_exits(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + env_file = tmp_path / "bad.env" + env_file.write_text("CIFS_OK=yes\nNO_EQUALS\n") + monkeypatch.delenv("CIFS_OK", raising=False) + with pytest.raises(SystemExit): + cifs_provision._load_env_file(str(env_file)) + + def test_blank_and_comment_lines_skipped( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + env_file = tmp_path / "mixed.env" + env_file.write_text("# comment\n\nCIFS_KEY=value\n") + monkeypatch.delenv("CIFS_KEY", raising=False) + cifs_provision._load_env_file(str(env_file)) + assert os.environ["CIFS_KEY"] == "value" + + def test_setdefault_does_not_override_existing( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + env_file = tmp_path / "override.env" + env_file.write_text("CIFS_KEY2=from_file\n") + monkeypatch.setenv("CIFS_KEY2", "already_set") + cifs_provision._load_env_file(str(env_file)) + assert os.environ["CIFS_KEY2"] == "already_set" + + +# --------------------------------------------------------------------------- +# _pick +# --------------------------------------------------------------------------- + + +class TestPick: + def test_cli_val_wins(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("SVM_NAME", "from_env") + monkeypatch.setitem(cifs_provision.ENV, "SVM_NAME", "from_env_dict") + assert cifs_provision._pick("from_cli", "SVM_NAME") == "from_cli" + + def test_env_var_second_priority(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("SVM_NAME", "from_env") + monkeypatch.setitem(cifs_provision.ENV, "SVM_NAME", "from_env_dict") + assert cifs_provision._pick(None, "SVM_NAME") == "from_env" + + def test_env_dict_third_priority(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("SVM_NAME", raising=False) + monkeypatch.setitem(cifs_provision.ENV, "SVM_NAME", "from_env_dict") + assert cifs_provision._pick(None, "SVM_NAME") == "from_env_dict" + + def test_falls_back_to_default(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("MISSING_KEY", raising=False) + assert cifs_provision._pick(None, "MISSING_KEY", "fallback") == "fallback" + + def test_empty_string_cli_treated_as_falsy(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("SVM_NAME", "from_env") + assert cifs_provision._pick("", "SVM_NAME") == "from_env" + + +# --------------------------------------------------------------------------- +# _resolve_config +# --------------------------------------------------------------------------- + + +class TestResolveConfig: + def _make_args(self, **overrides): + import argparse + + defaults = { + "env_file": None, + "svm": None, + "volume": None, + "size": None, + "aggregate": "aggr1", + "share_name": None, + "share_comment": None, + "acl_user": None, + "acl_permission": None, + "create_cifs_server": False, + "cifs_server_name": None, + "workgroup": None, + } + defaults.update(overrides) + return argparse.Namespace(**defaults) + + def test_missing_aggregate_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("AGGR_NAME", raising=False) + monkeypatch.setitem(cifs_provision.ENV, "AGGR_NAME", "") + args = self._make_args(aggregate=None) + with pytest.raises(SystemExit): + cifs_provision._resolve_config(args) + + def test_aggregate_from_cli(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("AGGR_NAME", raising=False) + args = self._make_args(aggregate="aggr_from_cli") + config = cifs_provision._resolve_config(args) + assert config["aggregate"] == "aggr_from_cli" + + def test_create_cifs_server_flag_passed_through(self, monkeypatch: pytest.MonkeyPatch) -> None: + args = self._make_args(create_cifs_server=True) + config = cifs_provision._resolve_config(args) + assert config["create_cifs_server"] is True + + def test_returns_all_expected_keys(self, monkeypatch: pytest.MonkeyPatch) -> None: + args = self._make_args() + config = cifs_provision._resolve_config(args) + expected_keys = { + "svm", + "volume", + "size", + "aggregate", + "share_name", + "share_comment", + "acl_user", + "acl_permission", + "create_cifs_server", + "cifs_server_name", + "workgroup", + } + assert expected_keys == set(config.keys()) + + +# --------------------------------------------------------------------------- +# _ensure_cifs_server +# --------------------------------------------------------------------------- + + +class TestEnsureCifsServer: + def _make_client(self) -> MagicMock: + client = MagicMock(spec=OntapClient) + client.__enter__ = MagicMock(return_value=client) + client.__exit__ = MagicMock(return_value=False) + return client + + def test_server_exists_no_create_called(self) -> None: + client = self._make_client() + client.get.return_value = { + "records": [{"svm": {"name": "vs1"}, "enabled": True}], + "num_records": 1, + } + # Should not raise or call post + cifs_provision._ensure_cifs_server(client, "vs1", False, "ONTAP-CIFS", "WORKGROUP") + client.post.assert_not_called() + + def test_no_server_no_flag_exits(self) -> None: + client = self._make_client() + client.get.return_value = {"records": [], "num_records": 0} + with pytest.raises(SystemExit): + cifs_provision._ensure_cifs_server(client, "vs1", False, "ONTAP-CIFS", "WORKGROUP") + + def test_no_server_with_flag_creates_server(self) -> None: + client = self._make_client() + client.get.return_value = {"records": [], "num_records": 0} + client.post.return_value = {} # no async job + cifs_provision._ensure_cifs_server(client, "vs1", True, "MY-CIFS", "MYGROUP") + client.post.assert_called_once() + # post(path, body) — body is the second positional arg + call_args, call_kwargs = client.post.call_args + call_body = call_args[1] if len(call_args) > 1 else call_kwargs.get("body", {}) + assert call_body["name"] == "MY-CIFS" + assert call_body["workgroup"] == "MYGROUP" + + def test_no_server_with_flag_polls_job_when_returned(self) -> None: + client = self._make_client() + client.get.return_value = {"records": [], "num_records": 0} + client.post.return_value = {"job": {"uuid": "job-uuid-1"}} + cifs_provision._ensure_cifs_server(client, "vs1", True, "MY-CIFS", "MYGROUP") + client.poll_job.assert_called_once_with("job-uuid-1") diff --git a/tests/test_cluster_info.py b/tests/test_cluster_info.py new file mode 100644 index 0000000..5d97623 --- /dev/null +++ b/tests/test_cluster_info.py @@ -0,0 +1,64 @@ +"""Unit tests for cluster_info.main().""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import cluster_info +from ontap_client import OntapClient + + +class TestClusterInfoMain: + def _make_client(self) -> MagicMock: + client = MagicMock(spec=OntapClient) + client.__enter__ = MagicMock(return_value=client) + client.__exit__ = MagicMock(return_value=False) + return client + + def test_fetches_cluster_endpoint(self) -> None: + client = self._make_client() + client.get.side_effect = [ + {"name": "cluster1", "version": {"full": "9.14.1"}}, + {"records": [], "num_records": 0}, + ] + with patch.object(OntapClient, "from_env", return_value=client): + cluster_info.main() + first_call = client.get.call_args_list[0] + assert first_call[0][0] == "/cluster" + + def test_fetches_nodes_endpoint(self) -> None: + client = self._make_client() + client.get.side_effect = [ + {"name": "cluster1", "version": {"full": "9.14.1"}}, + {"records": [], "num_records": 0}, + ] + with patch.object(OntapClient, "from_env", return_value=client): + cluster_info.main() + second_call = client.get.call_args_list[1] + assert second_call[0][0] == "/cluster/nodes" + + def test_handles_node_records(self) -> None: + client = self._make_client() + client.get.side_effect = [ + {"name": "cluster1", "version": {"full": "9.14.1"}}, + { + "records": [ + {"name": "node1", "serial_number": "SN-001"}, + {"name": "node2", "serial_number": "SN-002"}, + ], + "num_records": 2, + }, + ] + with patch.object(OntapClient, "from_env", return_value=client): + # Should not raise + cluster_info.main() + + def test_handles_missing_cluster_fields_gracefully(self) -> None: + client = self._make_client() + # Minimal response — missing 'name' and 'version' + client.get.side_effect = [ + {}, + {"records": [], "num_records": 0}, + ] + with patch.object(OntapClient, "from_env", return_value=client): + cluster_info.main() # Should not raise diff --git a/tests/test_cluster_setup_basic.py b/tests/test_cluster_setup_basic.py new file mode 100644 index 0000000..736aee5 --- /dev/null +++ b/tests/test_cluster_setup_basic.py @@ -0,0 +1,209 @@ +"""Unit tests for cluster_setup_basic helper functions.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import cluster_setup_basic as csb +import pytest +from ontap_client import OntapClient + +# --------------------------------------------------------------------------- +# _env +# --------------------------------------------------------------------------- + + +class TestEnv: + def test_reads_from_inputs_dict(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(csb.INPUTS, "CLUSTER_NAME", "mycluster") + assert csb._env("CLUSTER_NAME") == "mycluster" + + def test_falls_back_to_os_environ(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(csb.INPUTS, "CLUSTER_NAME", "") + monkeypatch.setenv("CLUSTER_NAME", "envcluster") + assert csb._env("CLUSTER_NAME") == "envcluster" + + def test_inputs_takes_priority_over_env(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(csb.INPUTS, "CLUSTER_NAME", "from_inputs") + monkeypatch.setenv("CLUSTER_NAME", "from_env") + assert csb._env("CLUSTER_NAME") == "from_inputs" + + def test_missing_required_key_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(csb.INPUTS, "CLUSTER_NAME", "") + monkeypatch.delenv("CLUSTER_NAME", raising=False) + with pytest.raises(SystemExit): + csb._env("CLUSTER_NAME", required=True) + + def test_missing_optional_key_returns_empty(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(csb.INPUTS, "CLUSTER_NAME", "") + monkeypatch.delenv("CLUSTER_NAME", raising=False) + result = csb._env("CLUSTER_NAME", required=False) + assert result == "" + + +# --------------------------------------------------------------------------- +# _load_env_file (cluster_setup_basic version — loads into INPUTS dict) +# --------------------------------------------------------------------------- + + +class TestLoadEnvFileCSB: + def test_valid_file_updates_inputs(self, tmp_path: Path) -> None: + env_file = tmp_path / "build.env" + env_file.write_text("CLUSTER_NAME=testcluster\nCLUSTER_PASS=secret\n") + original = dict(csb.INPUTS) + csb._load_env_file(str(env_file)) + assert csb.INPUTS["CLUSTER_NAME"] == "testcluster" + assert csb.INPUTS["CLUSTER_PASS"] == "secret" + # restore + csb.INPUTS.update(original) + + def test_strips_double_quotes_from_values(self, tmp_path: Path) -> None: + env_file = tmp_path / "quoted.env" + env_file.write_text('CLUSTER_NAME="quoted-name"\n') + csb._load_env_file(str(env_file)) + assert csb.INPUTS["CLUSTER_NAME"] == "quoted-name" + + def test_strips_single_quotes_from_values(self, tmp_path: Path) -> None: + env_file = tmp_path / "single.env" + env_file.write_text("CLUSTER_NAME='single-name'\n") + csb._load_env_file(str(env_file)) + assert csb.INPUTS["CLUSTER_NAME"] == "single-name" + + def test_blank_lines_skipped(self, tmp_path: Path) -> None: + env_file = tmp_path / "blanks.env" + env_file.write_text("\n\nCLUSTER_NAME=ok\n\n") + csb._load_env_file(str(env_file)) + assert csb.INPUTS["CLUSTER_NAME"] == "ok" + + def test_comment_lines_skipped(self, tmp_path: Path) -> None: + env_file = tmp_path / "comments.env" + env_file.write_text("# a comment\nCLUSTER_NAME=fromcomment\n") + csb._load_env_file(str(env_file)) + assert csb.INPUTS["CLUSTER_NAME"] == "fromcomment" + + def test_lines_without_equals_skipped(self, tmp_path: Path) -> None: + """Lines without '=' are silently skipped (no sys.exit in this version).""" + env_file = tmp_path / "noeq.env" + env_file.write_text("CLUSTER_NAME=safe\nNO_EQUALS\n") + csb._load_env_file(str(env_file)) + assert csb.INPUTS["CLUSTER_NAME"] == "safe" + + +# --------------------------------------------------------------------------- +# _get_nodes +# --------------------------------------------------------------------------- + + +class TestGetNodes: + def test_returns_result_on_first_try(self) -> None: + client = MagicMock(spec=OntapClient) + expected = {"records": [{"name": "node1"}], "num_records": 1} + client.get.return_value = expected + result = csb._get_nodes(client, membership="available") + assert result == expected + + def test_falls_back_on_262197_error(self) -> None: + client = MagicMock(spec=OntapClient) + fallback = {"records": [{"name": "node1"}], "num_records": 1} + client.get.side_effect = [ + RuntimeError("error code 262197"), + fallback, + ] + result = csb._get_nodes(client) + assert result == fallback + assert client.get.call_count == 2 + + def test_raises_immediately_on_non_262197_error(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.side_effect = RuntimeError("some other error 999") + with pytest.raises(RuntimeError, match="999"): + csb._get_nodes(client) + assert client.get.call_count == 1 + + def test_raises_last_exc_when_all_field_sets_fail(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.side_effect = RuntimeError("error code 262197") + with pytest.raises(RuntimeError): + csb._get_nodes(client) + + +# --------------------------------------------------------------------------- +# discover_nodes +# --------------------------------------------------------------------------- + + +class TestDiscoverNodes: + def test_returns_node_list_on_success(self) -> None: + client = MagicMock(spec=OntapClient) + expected = {"records": [{"name": "node1"}], "num_records": 1} + with patch.object(csb, "_get_nodes", return_value=expected): + result = csb.discover_nodes(client) + assert result == expected + + def test_retries_then_succeeds(self) -> None: + client = MagicMock(spec=OntapClient) + success = {"records": [{"name": "node1"}], "num_records": 1} + with ( + patch.object(csb, "_get_nodes", side_effect=[RuntimeError("transient"), success]), + patch("cluster_setup_basic.time.sleep"), + ): + result = csb.discover_nodes(client, attempts=3, delay=0) + assert result == success + + def test_raises_runtime_error_after_all_attempts(self) -> None: + client = MagicMock(spec=OntapClient) + with ( + patch.object(csb, "_get_nodes", side_effect=RuntimeError("persistent")), + patch("cluster_setup_basic.time.sleep"), + ): + with pytest.raises(RuntimeError, match="failed after 2 attempts"): + csb.discover_nodes(client, attempts=2, delay=0) + + +# --------------------------------------------------------------------------- +# discover_local +# --------------------------------------------------------------------------- + + +class TestDiscoverLocal: + def test_returns_result_when_records_present(self) -> None: + client = MagicMock(spec=OntapClient) + expected = {"records": [{"name": "node1", "uuid": "uuid-1"}], "num_records": 1} + with patch.object(csb, "_get_nodes", return_value=expected): + result = csb.discover_local(client) + assert result == expected + + def test_raises_when_no_records(self) -> None: + client = MagicMock(spec=OntapClient) + with patch.object(csb, "_get_nodes", return_value={"records": [], "num_records": 0}): + with pytest.raises(RuntimeError, match="no local node"): + csb.discover_local(client) + + +# --------------------------------------------------------------------------- +# discover_partner +# --------------------------------------------------------------------------- + + +class TestDiscoverPartner: + def test_returns_result_when_records_present(self) -> None: + client = MagicMock(spec=OntapClient) + expected = {"records": [{"name": "node2", "uuid": "uuid-2"}], "num_records": 1} + with patch.object(csb, "_get_nodes", return_value=expected): + result = csb.discover_partner(client, local_uuid="uuid-1") + assert result == expected + + def test_passes_exclusion_filter(self) -> None: + client = MagicMock(spec=OntapClient) + expected = {"records": [{"name": "node2"}], "num_records": 1} + with patch.object(csb, "_get_nodes", return_value=expected) as mock_get: + csb.discover_partner(client, local_uuid="uuid-1") + call_kwargs = mock_get.call_args[1] + assert call_kwargs.get("uuid") == "!uuid-1" + + def test_raises_when_no_records(self) -> None: + client = MagicMock(spec=OntapClient) + with patch.object(csb, "_get_nodes", return_value={"records": [], "num_records": 0}): + with pytest.raises(RuntimeError, match="no partner node"): + csb.discover_partner(client, local_uuid="uuid-1") diff --git a/tests/test_nfs_provision.py b/tests/test_nfs_provision.py new file mode 100644 index 0000000..4dc1f78 --- /dev/null +++ b/tests/test_nfs_provision.py @@ -0,0 +1,72 @@ +"""Unit tests for nfs_provision helper functions.""" + +from __future__ import annotations + +import os +from pathlib import Path + +import nfs_provision +import pytest + + +class TestLoadEnvFile: + def test_valid_file_sets_env_vars( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + env_file = tmp_path / "test.env" + env_file.write_text("FOO_NFS=bar\nBAZ_NFS=qux\n") + monkeypatch.delenv("FOO_NFS", raising=False) + monkeypatch.delenv("BAZ_NFS", raising=False) + nfs_provision._load_env_file(str(env_file)) + assert os.environ["FOO_NFS"] == "bar" + assert os.environ["BAZ_NFS"] == "qux" + + def test_missing_file_exits(self, tmp_path: Path) -> None: + with pytest.raises(SystemExit): + nfs_provision._load_env_file(str(tmp_path / "nonexistent.env")) + + def test_malformed_line_no_equals_exits( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + env_file = tmp_path / "bad.env" + env_file.write_text("VALID=ok\nNO_EQUALS_HERE\n") + monkeypatch.delenv("VALID", raising=False) + with pytest.raises(SystemExit): + nfs_provision._load_env_file(str(env_file)) + + def test_blank_lines_are_skipped( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + env_file = tmp_path / "blanks.env" + env_file.write_text("\n\nKEY_NFS2=value\n\n") + monkeypatch.delenv("KEY_NFS2", raising=False) + nfs_provision._load_env_file(str(env_file)) + assert os.environ["KEY_NFS2"] == "value" + + def test_comment_lines_are_skipped( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + env_file = tmp_path / "comments.env" + env_file.write_text("# this is a comment\nKEY_NFS3=value\n") + monkeypatch.delenv("KEY_NFS3", raising=False) + nfs_provision._load_env_file(str(env_file)) + assert os.environ["KEY_NFS3"] == "value" + + def test_setdefault_does_not_override_existing( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + env_file = tmp_path / "override.env" + env_file.write_text("KEY_NFS4=from_file\n") + monkeypatch.setenv("KEY_NFS4", "already_set") + nfs_provision._load_env_file(str(env_file)) + assert os.environ["KEY_NFS4"] == "already_set" + + def test_value_with_equals_sign_handled( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Values containing '=' are preserved (partition only splits on first '=').""" + env_file = tmp_path / "equals.env" + env_file.write_text("KEY_NFS5=a=b=c\n") + monkeypatch.delenv("KEY_NFS5", raising=False) + nfs_provision._load_env_file(str(env_file)) + assert os.environ["KEY_NFS5"] == "a=b=c" diff --git a/tests/test_ontap_client.py b/tests/test_ontap_client.py new file mode 100644 index 0000000..bde2b16 --- /dev/null +++ b/tests/test_ontap_client.py @@ -0,0 +1,403 @@ +"""Unit tests for ontap_client.OntapClient.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +import requests +from ontap_client import OntapApiError, OntapClient + +# --------------------------------------------------------------------------- +# OntapApiError +# --------------------------------------------------------------------------- + + +class TestOntapApiError: + def test_json_detail_stored(self) -> None: + resp = MagicMock(spec=requests.Response) + resp.status_code = 400 + resp.json.return_value = {"message": "Bad request", "code": "123"} + err = OntapApiError(resp) + assert err.status_code == 400 + assert err.detail == {"message": "Bad request", "code": "123"} + assert "400" in str(err) + + def test_text_fallback_when_not_json(self) -> None: + resp = MagicMock(spec=requests.Response) + resp.status_code = 500 + resp.json.side_effect = ValueError("no JSON") + resp.text = "Internal Server Error" + err = OntapApiError(resp) + assert err.detail == "Internal Server Error" + assert "500" in str(err) + + def test_is_exception_subclass(self) -> None: + resp = MagicMock(spec=requests.Response) + resp.status_code = 404 + resp.json.return_value = {} + assert isinstance(OntapApiError(resp), Exception) + + +# --------------------------------------------------------------------------- +# OntapClient.__init__ +# --------------------------------------------------------------------------- + + +class TestOntapClientInit: + def test_base_url_formed_correctly(self) -> None: + client = OntapClient("10.0.0.1", "admin", "pass") + assert client.base_url == "https://10.0.0.1/api" + client.close() + + def test_session_auth_set(self) -> None: + client = OntapClient("10.0.0.1", "admin", "secret") + assert client._session.auth == ("admin", "secret") + client.close() + + def test_verify_ssl_defaults_false(self) -> None: + client = OntapClient("10.0.0.1", "admin", "pass") + assert client._session.verify is False + client.close() + + def test_verify_ssl_can_be_enabled(self) -> None: + client = OntapClient("10.0.0.1", "admin", "pass", verify_ssl=True) + assert client._session.verify is True + client.close() + + def test_default_timeout_stored(self) -> None: + client = OntapClient("10.0.0.1", "admin", "pass") + assert client.timeout == 30 + client.close() + + def test_custom_timeout_stored(self) -> None: + client = OntapClient("10.0.0.1", "admin", "pass", timeout=60) + assert client.timeout == 60 + client.close() + + def test_default_headers_include_accept(self) -> None: + client = OntapClient("10.0.0.1", "admin", "pass") + assert "application/hal+json" in client._session.headers.get("Accept", "") + client.close() + + def test_default_headers_include_content_type(self) -> None: + client = OntapClient("10.0.0.1", "admin", "pass") + assert client._session.headers.get("Content-Type") == "application/json" + client.close() + + +# --------------------------------------------------------------------------- +# OntapClient.from_env +# --------------------------------------------------------------------------- + + +class TestFromEnv: + def test_missing_ontap_host_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("ONTAP_HOST", raising=False) + monkeypatch.delenv("ONTAP_PASS", raising=False) + with pytest.raises(SystemExit): + OntapClient.from_env() + + def test_missing_ontap_pass_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("ONTAP_HOST", "10.0.0.1") + monkeypatch.delenv("ONTAP_PASS", raising=False) + with pytest.raises(SystemExit): + OntapClient.from_env() + + def test_builds_client_from_env(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("ONTAP_HOST", "10.0.0.1") + monkeypatch.setenv("ONTAP_PASS", "secret") + monkeypatch.setenv("ONTAP_USER", "testuser") + client = OntapClient.from_env() + assert client.base_url == "https://10.0.0.1/api" + assert client._session.auth == ("testuser", "secret") + client.close() + + def test_default_user_is_admin(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("ONTAP_HOST", "10.0.0.1") + monkeypatch.setenv("ONTAP_PASS", "secret") + monkeypatch.delenv("ONTAP_USER", raising=False) + client = OntapClient.from_env() + assert client._session.auth == ("admin", "secret") + client.close() + + def test_verify_ssl_true_when_set(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("ONTAP_HOST", "10.0.0.1") + monkeypatch.setenv("ONTAP_PASS", "secret") + monkeypatch.setenv("ONTAP_VERIFY_SSL", "true") + client = OntapClient.from_env() + assert client._session.verify is True + client.close() + + def test_verify_ssl_false_when_not_set(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("ONTAP_HOST", "10.0.0.1") + monkeypatch.setenv("ONTAP_PASS", "secret") + monkeypatch.delenv("ONTAP_VERIFY_SSL", raising=False) + client = OntapClient.from_env() + assert client._session.verify is False + client.close() + + +# --------------------------------------------------------------------------- +# OntapClient._url +# --------------------------------------------------------------------------- + + +class TestUrl: + def setup_method(self) -> None: + self.client = OntapClient("10.0.0.1", "admin", "pass") + + def teardown_method(self) -> None: + self.client.close() + + def test_absolute_path_prefixed_with_base(self) -> None: + assert self.client._url("/cluster") == "https://10.0.0.1/api/cluster" + + def test_relative_path_prefixed_with_base(self) -> None: + assert self.client._url("cluster/nodes") == "https://10.0.0.1/api/cluster/nodes" + + def test_absolute_https_url_returned_unchanged(self) -> None: + url = "https://other.host/api/cluster" + assert self.client._url(url) == url + + +# --------------------------------------------------------------------------- +# OntapClient._request +# --------------------------------------------------------------------------- + + +class TestRequest: + def setup_method(self) -> None: + self.client = OntapClient("10.0.0.1", "admin", "pass") + self.mock_resp = MagicMock(spec=requests.Response) + self.client._session.request = MagicMock(return_value=self.mock_resp) + + def teardown_method(self) -> None: + self.client.close() + + def test_success_returns_json(self) -> None: + self.mock_resp.ok = True + self.mock_resp.status_code = 200 + self.mock_resp.content = b'{"name": "cluster1"}' + self.mock_resp.json.return_value = {"name": "cluster1"} + result = self.client._request("GET", "/cluster") + assert result == {"name": "cluster1"} + + def test_204_returns_empty_dict(self) -> None: + self.mock_resp.ok = True + self.mock_resp.status_code = 204 + self.mock_resp.content = b"" + result = self.client._request("DELETE", "/some/resource") + assert result == {} + + def test_empty_content_returns_empty_dict(self) -> None: + self.mock_resp.ok = True + self.mock_resp.status_code = 200 + self.mock_resp.content = b"" + result = self.client._request("GET", "/cluster") + assert result == {} + + def test_non_ok_raises_ontap_api_error(self) -> None: + self.mock_resp.ok = False + self.mock_resp.status_code = 404 + self.mock_resp.json.return_value = {"message": "Not found"} + with pytest.raises(OntapApiError) as exc_info: + self.client._request("GET", "/missing") + assert exc_info.value.status_code == 404 + + def test_uses_default_timeout(self) -> None: + self.mock_resp.ok = True + self.mock_resp.status_code = 200 + self.mock_resp.content = b"{}" + self.mock_resp.json.return_value = {} + self.client._request("GET", "/cluster") + call_kwargs = self.client._session.request.call_args[1] + assert call_kwargs["timeout"] == self.client.timeout + + def test_url_built_from_path(self) -> None: + self.mock_resp.ok = True + self.mock_resp.status_code = 200 + self.mock_resp.content = b"{}" + self.mock_resp.json.return_value = {} + self.client._request("GET", "/cluster") + call_args = self.client._session.request.call_args[0] + assert call_args[1] == "https://10.0.0.1/api/cluster" + + +# --------------------------------------------------------------------------- +# OntapClient HTTP convenience methods +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def client_with_mock_session() -> OntapClient: + client = OntapClient("10.0.0.1", "admin", "pass") + mock_resp = MagicMock(spec=requests.Response) + mock_resp.ok = True + mock_resp.status_code = 200 + mock_resp.content = b'{"records": []}' + mock_resp.json.return_value = {"records": []} + client._session.request = MagicMock(return_value=mock_resp) + yield client + client.close() + + +class TestHttpMethods: + def test_get_adds_fields_param(self, client_with_mock_session: OntapClient) -> None: + client_with_mock_session.get("/cluster/nodes", fields="name,serial_number") + params = client_with_mock_session._session.request.call_args[1]["params"] + assert params["fields"] == "name,serial_number" + + def test_get_adds_default_return_timeout(self, client_with_mock_session: OntapClient) -> None: + client_with_mock_session.get("/cluster/nodes") + params = client_with_mock_session._session.request.call_args[1]["params"] + assert params["return_timeout"] == "120" + + def test_get_no_fields_key_when_omitted(self, client_with_mock_session: OntapClient) -> None: + client_with_mock_session.get("/cluster/nodes") + params = client_with_mock_session._session.request.call_args[1]["params"] + assert "fields" not in params + + def test_get_passes_extra_params(self, client_with_mock_session: OntapClient) -> None: + client_with_mock_session.get("/cluster/nodes", membership="available") + params = client_with_mock_session._session.request.call_args[1]["params"] + assert params["membership"] == "available" + + def test_post_uses_post_method(self, client_with_mock_session: OntapClient) -> None: + client_with_mock_session.post("/cluster", {"name": "c1"}) + method = client_with_mock_session._session.request.call_args[0][0] + assert method == "POST" + + def test_post_sends_json_body(self, client_with_mock_session: OntapClient) -> None: + body = {"name": "c1", "password": "secret"} + client_with_mock_session.post("/cluster", body) + json_arg = client_with_mock_session._session.request.call_args[1]["json"] + assert json_arg == body + + def test_patch_uses_patch_method(self, client_with_mock_session: OntapClient) -> None: + client_with_mock_session.patch("/cluster/nodes/uuid1", {"state": "up"}) + method = client_with_mock_session._session.request.call_args[0][0] + assert method == "PATCH" + + def test_patch_sends_json_body(self, client_with_mock_session: OntapClient) -> None: + body = {"state": "up"} + client_with_mock_session.patch("/cluster/nodes/uuid1", body) + json_arg = client_with_mock_session._session.request.call_args[1]["json"] + assert json_arg == body + + def test_delete_uses_delete_method(self, client_with_mock_session: OntapClient) -> None: + client_with_mock_session.delete("/volumes/uuid1") + method = client_with_mock_session._session.request.call_args[0][0] + assert method == "DELETE" + + +# --------------------------------------------------------------------------- +# OntapClient.poll_job +# --------------------------------------------------------------------------- + +_JOB_UUID = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + +class TestPollJob: + def setup_method(self) -> None: + self.client = OntapClient("10.0.0.1", "admin", "pass") + + def teardown_method(self) -> None: + self.client.close() + + def test_success_state_returns_job(self) -> None: + self.client.get = MagicMock(return_value={"state": "success", "message": "done"}) + result = self.client.poll_job(_JOB_UUID) + assert result["state"] == "success" + + def test_failure_state_raises_runtime_error(self) -> None: + self.client.get = MagicMock(return_value={"state": "failure", "message": "boom"}) + with pytest.raises(RuntimeError, match="failed"): + self.client.poll_job(_JOB_UUID) + + def test_failure_includes_job_message(self) -> None: + self.client.get = MagicMock(return_value={"state": "failure", "message": "disk error"}) + with pytest.raises(RuntimeError, match="disk error"): + self.client.poll_job(_JOB_UUID) + + def test_polls_until_success(self) -> None: + responses = [ + {"state": "running"}, + {"state": "running"}, + {"state": "success"}, + ] + self.client.get = MagicMock(side_effect=responses) + with ( + patch("ontap_client.time.sleep"), + patch("ontap_client.time.monotonic", return_value=0), + ): + result = self.client.poll_job(_JOB_UUID, interval=1, timeout=300) + assert result["state"] == "success" + assert self.client.get.call_count == 3 + + def test_timeout_raises_timeout_error(self) -> None: + self.client.get = MagicMock(return_value={"state": "running"}) + # First monotonic() → start time (0), second → past deadline (400) + with ( + patch("ontap_client.time.sleep"), + patch("ontap_client.time.monotonic", side_effect=[0, 400]), + ): + with pytest.raises(TimeoutError): + self.client.poll_job(_JOB_UUID, interval=1, timeout=300) + + def test_connection_error_retries_then_succeeds(self) -> None: + self.client.get = MagicMock( + side_effect=[ + requests.exceptions.ConnectionError("disconnected"), + {"state": "success"}, + ] + ) + with ( + patch("ontap_client.time.sleep"), + patch("ontap_client.time.monotonic", return_value=0), + ): + result = self.client.poll_job(_JOB_UUID, interval=1, timeout=300) + assert result["state"] == "success" + + def test_connection_error_past_deadline_raises_timeout(self) -> None: + self.client.get = MagicMock( + side_effect=requests.exceptions.ConnectionError("disconnected") + ) + # 1st monotonic → deadline start (0), 2nd → past deadline (400) + with ( + patch("ontap_client.time.sleep"), + patch("ontap_client.time.monotonic", side_effect=[0, 400]), + ): + with pytest.raises(TimeoutError): + self.client.poll_job(_JOB_UUID, interval=1, timeout=300) + + def test_polls_correct_job_url(self) -> None: + self.client.get = MagicMock(return_value={"state": "success"}) + self.client.poll_job(_JOB_UUID) + call_args = self.client.get.call_args[0][0] + assert _JOB_UUID in call_args + + +# --------------------------------------------------------------------------- +# Context manager +# --------------------------------------------------------------------------- + + +class TestContextManager: + def test_enter_returns_self(self) -> None: + client = OntapClient("10.0.0.1", "admin", "pass") + assert client.__enter__() is client + client.close() + + def test_exit_closes_session(self) -> None: + client = OntapClient("10.0.0.1", "admin", "pass") + close_mock = MagicMock() + client._session.close = close_mock + with client: + pass + close_mock.assert_called_once() + + def test_with_statement_usage(self) -> None: + """OntapClient can be used as a context manager without errors.""" + with OntapClient("10.0.0.1", "admin", "pass") as client: + assert isinstance(client, OntapClient) diff --git a/tests/test_snapmirror_cleanup_test_failover.py b/tests/test_snapmirror_cleanup_test_failover.py new file mode 100644 index 0000000..3ed448c --- /dev/null +++ b/tests/test_snapmirror_cleanup_test_failover.py @@ -0,0 +1,149 @@ +"""Unit tests for shared helpers in snapmirror_cleanup_test_failover.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +import snapmirror_cleanup_test_failover as sm_clean +from ontap_client import OntapClient + +# --------------------------------------------------------------------------- +# _env +# --------------------------------------------------------------------------- + + +class TestEnv: + def test_reads_from_inputs_dict(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_clean.INPUTS, "CLUSTER_A", "10.5.0.1") + assert sm_clean._env("CLUSTER_A") == "10.5.0.1" + + def test_falls_back_to_os_environ(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_clean.INPUTS, "CLUSTER_A", "") + monkeypatch.setenv("CLUSTER_A", "10.5.0.2") + assert sm_clean._env("CLUSTER_A") == "10.5.0.2" + + def test_missing_required_key_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_clean.INPUTS, "CLUSTER_A", "") + monkeypatch.delenv("CLUSTER_A", raising=False) + with pytest.raises(SystemExit): + sm_clean._env("CLUSTER_A") + + def test_returns_default_when_missing(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_clean.INPUTS, "CLUSTER_A", "") + monkeypatch.delenv("CLUSTER_A", raising=False) + assert sm_clean._env("CLUSTER_A", default="fallback") == "fallback" + + +# --------------------------------------------------------------------------- +# _poll_job +# --------------------------------------------------------------------------- + + +class TestPollJob: + def test_returns_immediately_on_non_running(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"state": "success"} + result = sm_clean._poll_job(client, "job-abc") + assert result["state"] == "success" + + def test_polls_until_done(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.side_effect = [ + {"state": "running"}, + {"state": "success"}, + ] + with patch("snapmirror_cleanup_test_failover.time.sleep"): + result = sm_clean._poll_job(client, "job-abc", interval=1) + assert client.get.call_count == 2 + assert result["state"] == "success" + + def test_uses_correct_job_url(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"state": "success"} + sm_clean._poll_job(client, "my-job-id") + assert "my-job-id" in client.get.call_args[0][0] + + +# --------------------------------------------------------------------------- +# _pick_cluster_by_relationship +# --------------------------------------------------------------------------- + + +class TestPickClusterByRelationship: + def _vol_client(self, records: list[dict]) -> MagicMock: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"records": records, "num_records": len(records)} + return client + + def test_picks_cluster_a_when_it_has_relationship(self) -> None: + rel = { + "uuid": "rel-uuid-1", + "source": {"path": "vs0:vol_rw_01"}, + "destination": {"path": "vs1:vol_rw_01_dest"}, + "state": "snapmirrored", + "healthy": True, + } + a_client = self._vol_client([rel]) + with patch("snapmirror_cleanup_test_failover.OntapClient", return_value=a_client): + cluster, found_rel = sm_clean._pick_cluster_by_relationship( + "10.0.0.1", "10.0.0.2", "admin", "pass", "vs0", "vol_rw_01" + ) + assert cluster == "10.0.0.1" + assert found_rel["uuid"] == "rel-uuid-1" + + def test_falls_through_to_cluster_b(self) -> None: + rel = { + "uuid": "rel-uuid-2", + "source": {"path": "vs0:vol_rw_01"}, + "destination": {"path": "vs1:vol_rw_01_dest"}, + "state": "snapmirrored", + "healthy": True, + } + a_client = self._vol_client([]) + b_client = self._vol_client([rel]) + + with patch( + "snapmirror_cleanup_test_failover.OntapClient", side_effect=[a_client, b_client] + ): + cluster, found_rel = sm_clean._pick_cluster_by_relationship( + "10.0.0.1", "10.0.0.2", "admin", "pass", "vs0", "vol_rw_01" + ) + assert cluster == "10.0.0.2" + assert found_rel["uuid"] == "rel-uuid-2" + + def test_exits_when_neither_cluster_has_relationship(self) -> None: + no_rel_client = self._vol_client([]) + + with patch("snapmirror_cleanup_test_failover.OntapClient", return_value=no_rel_client): + with pytest.raises(SystemExit): + sm_clean._pick_cluster_by_relationship( + "10.0.0.1", "10.0.0.2", "admin", "pass", "vs0", "vol_rw_01" + ) + + def test_skips_unreachable_cluster_and_continues(self) -> None: + rel = {"uuid": "rel-uuid-b", "source": {"path": "vs0:vol"}, "state": "snapmirrored"} + a_client = MagicMock(spec=OntapClient) + a_client.get.side_effect = ConnectionError("unreachable") + b_client = self._vol_client([rel]) + + with patch( + "snapmirror_cleanup_test_failover.OntapClient", side_effect=[a_client, b_client] + ): + cluster, found_rel = sm_clean._pick_cluster_by_relationship( + "10.0.0.1", "10.0.0.2", "admin", "pass", "vs0", "vol" + ) + assert cluster == "10.0.0.2" + + def test_passes_source_path_filter(self) -> None: + """Verify the API is called with source.path=:.""" + rel = {"uuid": "rel-uuid-x", "source": {"path": "vs0:myvol"}, "state": "snapmirrored"} + a_client = self._vol_client([rel]) + + with patch("snapmirror_cleanup_test_failover.OntapClient", return_value=a_client): + sm_clean._pick_cluster_by_relationship( + "10.0.0.1", "10.0.0.2", "admin", "pass", "vs0", "myvol" + ) + + call_kwargs = a_client.get.call_args[1] + assert call_kwargs.get("source.path") == "vs0:myvol" diff --git a/tests/test_snapmirror_provision_dest_managed.py b/tests/test_snapmirror_provision_dest_managed.py new file mode 100644 index 0000000..2ba82d5 --- /dev/null +++ b/tests/test_snapmirror_provision_dest_managed.py @@ -0,0 +1,173 @@ +"""Unit tests for shared helpers in snapmirror_provision_dest_managed.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +import snapmirror_provision_dest_managed as sm_dst +from ontap_client import OntapClient + +# --------------------------------------------------------------------------- +# _env +# --------------------------------------------------------------------------- + + +class TestEnv: + def test_reads_from_inputs_dict(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_dst.INPUTS, "SOURCE_HOST", "10.1.0.1") + assert sm_dst._env("SOURCE_HOST") == "10.1.0.1" + + def test_falls_back_to_os_environ(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_dst.INPUTS, "SOURCE_HOST", "") + monkeypatch.setenv("SOURCE_HOST", "10.1.0.2") + assert sm_dst._env("SOURCE_HOST") == "10.1.0.2" + + def test_missing_required_key_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_dst.INPUTS, "SOURCE_HOST", "") + monkeypatch.delenv("SOURCE_HOST", raising=False) + with pytest.raises(SystemExit): + sm_dst._env("SOURCE_HOST") + + def test_returns_default_when_missing(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_dst.INPUTS, "SOURCE_HOST", "") + monkeypatch.delenv("SOURCE_HOST", raising=False) + assert sm_dst._env("SOURCE_HOST", default="default_val") == "default_val" + + +# --------------------------------------------------------------------------- +# _poll_job +# --------------------------------------------------------------------------- + + +class TestPollJob: + def test_returns_immediately_on_non_running_state(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"state": "success"} + result = sm_dst._poll_job(client, "job-uuid-1") + assert result["state"] == "success" + + def test_polls_multiple_times_until_done(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.side_effect = [ + {"state": "running"}, + {"state": "success"}, + ] + with patch("snapmirror_provision_dest_managed.time.sleep"): + result = sm_dst._poll_job(client, "job-uuid-1", interval=1) + assert result["state"] == "success" + assert client.get.call_count == 2 + + def test_passes_correct_url(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"state": "success"} + sm_dst._poll_job(client, "my-job-abc") + assert "my-job-abc" in client.get.call_args[0][0] + + +# --------------------------------------------------------------------------- +# _wait_snapmirrored +# --------------------------------------------------------------------------- + + +class TestWaitSnapmirrored: + def test_returns_immediately_when_snapmirrored(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"state": "snapmirrored", "healthy": True} + result = sm_dst._wait_snapmirrored(client, "rel-uuid", interval=1, max_wait=60) + assert result["state"] == "snapmirrored" + + def test_polls_until_snapmirrored(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.side_effect = [ + {"state": "transferring"}, + {"state": "snapmirrored"}, + ] + with patch("snapmirror_provision_dest_managed.time.sleep"): + result = sm_dst._wait_snapmirrored(client, "rel-uuid", interval=1, max_wait=300) + assert result["state"] == "snapmirrored" + + def test_raises_on_timeout(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"state": "transferring"} + with patch("snapmirror_provision_dest_managed.time.sleep"): + with pytest.raises(RuntimeError, match="Timed out"): + sm_dst._wait_snapmirrored(client, "rel-uuid", interval=2, max_wait=1) + + +# --------------------------------------------------------------------------- +# _get_ic_lif_ips +# --------------------------------------------------------------------------- + + +class TestGetIcLifIps: + def test_returns_intercluster_ips(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = { + "records": [ + {"ip": {"address": "10.0.0.10"}, "services": ["intercluster-core"]}, + {"ip": {"address": "10.0.0.11"}, "services": ["data-nfs"]}, + ] + } + ips = sm_dst._get_ic_lif_ips(client) + assert ips == ["10.0.0.10"] + + def test_returns_empty_when_no_ic_lifs(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = { + "records": [ + {"ip": {"address": "10.0.0.11"}, "services": ["data-nfs"]}, + ] + } + ips = sm_dst._get_ic_lif_ips(client) + assert ips == [] + + def test_returns_empty_on_empty_records(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"records": []} + ips = sm_dst._get_ic_lif_ips(client) + assert ips == [] + + def test_skips_records_with_no_ip_address(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = { + "records": [ + {"ip": {}, "services": ["intercluster-core"]}, + ] + } + ips = sm_dst._get_ic_lif_ips(client) + assert ips == [] + + +# --------------------------------------------------------------------------- +# _check_ic_lif_preconditions +# --------------------------------------------------------------------------- + + +class TestCheckIcLifPreconditions: + def test_exits_when_no_src_ips(self) -> None: + src = MagicMock(spec=OntapClient) + dst = MagicMock(spec=OntapClient) + with pytest.raises(SystemExit): + sm_dst._check_ic_lif_preconditions(src, dst, [], ["10.0.0.1"]) + + def test_exits_when_no_dst_ips(self) -> None: + src = MagicMock(spec=OntapClient) + dst = MagicMock(spec=OntapClient) + with pytest.raises(SystemExit): + sm_dst._check_ic_lif_preconditions(src, dst, ["10.0.0.1"], []) + + def test_no_error_when_same_subnet(self) -> None: + src = MagicMock(spec=OntapClient) + dst = MagicMock(spec=OntapClient) + # Should not raise — same /24 subnet + sm_dst._check_ic_lif_preconditions(src, dst, ["10.0.0.1"], ["10.0.0.2"]) + + def test_warns_when_different_subnets(self, caplog: pytest.LogCaptureFixture) -> None: + import logging + + src = MagicMock(spec=OntapClient) + dst = MagicMock(spec=OntapClient) + with caplog.at_level(logging.WARNING, logger="snapmirror_provision_dest_managed"): + sm_dst._check_ic_lif_preconditions(src, dst, ["10.0.0.1"], ["192.168.1.1"]) + assert any("subnet" in msg.lower() for msg in caplog.messages) diff --git a/tests/test_snapmirror_provision_src_managed.py b/tests/test_snapmirror_provision_src_managed.py new file mode 100644 index 0000000..53e11ef --- /dev/null +++ b/tests/test_snapmirror_provision_src_managed.py @@ -0,0 +1,125 @@ +"""Unit tests for shared helpers in snapmirror_provision_src_managed.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +import snapmirror_provision_src_managed as sm_src +from ontap_client import OntapClient + +# --------------------------------------------------------------------------- +# _env +# --------------------------------------------------------------------------- + + +class TestEnv: + def test_reads_from_inputs_dict(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_src.INPUTS, "SOURCE_HOST", "10.0.0.1") + assert sm_src._env("SOURCE_HOST") == "10.0.0.1" + + def test_falls_back_to_os_environ(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_src.INPUTS, "SOURCE_HOST", "") + monkeypatch.setenv("SOURCE_HOST", "10.0.0.2") + assert sm_src._env("SOURCE_HOST") == "10.0.0.2" + + def test_inputs_takes_priority_over_environ(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_src.INPUTS, "SOURCE_HOST", "from_inputs") + monkeypatch.setenv("SOURCE_HOST", "from_env") + assert sm_src._env("SOURCE_HOST") == "from_inputs" + + def test_missing_required_key_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_src.INPUTS, "SOURCE_HOST", "") + monkeypatch.delenv("SOURCE_HOST", raising=False) + with pytest.raises(SystemExit): + sm_src._env("SOURCE_HOST") + + def test_returns_default_when_not_required(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_src.INPUTS, "SOURCE_HOST", "") + monkeypatch.delenv("SOURCE_HOST", raising=False) + assert sm_src._env("SOURCE_HOST", default="fallback") == "fallback" + + +# --------------------------------------------------------------------------- +# _poll_job +# --------------------------------------------------------------------------- + + +class TestPollJob: + def _make_client(self) -> MagicMock: + return MagicMock(spec=OntapClient) + + def test_returns_when_state_not_running(self) -> None: + client = self._make_client() + client.get.return_value = {"state": "success", "message": "done"} + result = sm_src._poll_job(client, "job-uuid-1") + assert result["state"] == "success" + + def test_polls_until_non_running_state(self) -> None: + client = self._make_client() + client.get.side_effect = [ + {"state": "running"}, + {"state": "running"}, + {"state": "success"}, + ] + with patch("snapmirror_provision_src_managed.time.sleep"): + result = sm_src._poll_job(client, "job-uuid-1", interval=1) + assert result["state"] == "success" + assert client.get.call_count == 3 + + def test_passes_correct_job_url(self) -> None: + client = self._make_client() + client.get.return_value = {"state": "success"} + sm_src._poll_job(client, "abc-123") + call_path = client.get.call_args[0][0] + assert "abc-123" in call_path + + def test_returns_failure_state_without_raising(self) -> None: + """_poll_job returns the failure record — callers decide what to do.""" + client = self._make_client() + client.get.return_value = {"state": "failure", "error": {"message": "boom"}} + result = sm_src._poll_job(client, "job-uuid-fail") + assert result["state"] == "failure" + + +# --------------------------------------------------------------------------- +# _wait_snapmirrored +# --------------------------------------------------------------------------- + + +class TestWaitSnapmirrored: + def _make_client(self) -> MagicMock: + return MagicMock(spec=OntapClient) + + def test_returns_immediately_when_already_snapmirrored(self) -> None: + client = self._make_client() + client.get.return_value = {"state": "snapmirrored", "lag_time": "PT5M", "healthy": True} + result = sm_src._wait_snapmirrored(client, "rel-uuid-1", interval=1, max_wait=60) + assert result["state"] == "snapmirrored" + assert client.get.call_count == 1 + + def test_polls_until_snapmirrored(self) -> None: + client = self._make_client() + client.get.side_effect = [ + {"state": "transferring"}, + {"state": "transferring"}, + {"state": "snapmirrored"}, + ] + with patch("snapmirror_provision_src_managed.time.sleep"): + result = sm_src._wait_snapmirrored(client, "rel-uuid-1", interval=1, max_wait=600) + assert result["state"] == "snapmirrored" + + def test_raises_timeout_if_never_snapmirrored(self) -> None: + client = self._make_client() + client.get.return_value = {"state": "transferring"} + with patch("snapmirror_provision_src_managed.time.sleep"): + with pytest.raises(RuntimeError, match="Timed out"): + # max_wait=1, interval=2 → loop exits after first iteration + sm_src._wait_snapmirrored(client, "rel-uuid-1", interval=2, max_wait=1) + + def test_queries_correct_relationship_url(self) -> None: + client = self._make_client() + client.get.return_value = {"state": "snapmirrored"} + sm_src._wait_snapmirrored(client, "my-rel-uuid", interval=1, max_wait=60) + call_path = client.get.call_args[0][0] + assert "my-rel-uuid" in call_path diff --git a/tests/test_snapmirror_test_failover.py b/tests/test_snapmirror_test_failover.py new file mode 100644 index 0000000..ff4c3da --- /dev/null +++ b/tests/test_snapmirror_test_failover.py @@ -0,0 +1,149 @@ +"""Unit tests for shared helpers in snapmirror_test_failover.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +import snapmirror_test_failover as sm_tf +from ontap_client import OntapClient + +# --------------------------------------------------------------------------- +# _env +# --------------------------------------------------------------------------- + + +class TestEnv: + def test_reads_from_inputs_dict(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_tf.INPUTS, "CLUSTER_A", "10.0.1.1") + assert sm_tf._env("CLUSTER_A") == "10.0.1.1" + + def test_falls_back_to_os_environ(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_tf.INPUTS, "CLUSTER_A", "") + monkeypatch.setenv("CLUSTER_A", "10.0.1.2") + assert sm_tf._env("CLUSTER_A") == "10.0.1.2" + + def test_missing_required_key_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_tf.INPUTS, "CLUSTER_A", "") + monkeypatch.delenv("CLUSTER_A", raising=False) + with pytest.raises(SystemExit): + sm_tf._env("CLUSTER_A") + + def test_returns_default_when_not_required(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sm_tf.INPUTS, "CLUSTER_A", "") + monkeypatch.delenv("CLUSTER_A", raising=False) + assert sm_tf._env("CLUSTER_A", default="x") == "x" + + +# --------------------------------------------------------------------------- +# _poll_job +# --------------------------------------------------------------------------- + + +class TestPollJob: + def test_returns_on_first_non_running_state(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"state": "success"} + result = sm_tf._poll_job(client, "job-1") + assert result["state"] == "success" + + def test_polls_until_done(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.side_effect = [{"state": "running"}, {"state": "success"}] + with patch("snapmirror_test_failover.time.sleep"): + result = sm_tf._poll_job(client, "job-1", interval=1) + assert client.get.call_count == 2 + assert result["state"] == "success" + + +# --------------------------------------------------------------------------- +# _wait_snapmirrored +# --------------------------------------------------------------------------- + + +class TestWaitSnapmirrored: + def test_returns_immediately_when_snapmirrored(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"state": "snapmirrored"} + result = sm_tf._wait_snapmirrored(client, "rel-uuid", interval=1, max_wait=60) + assert result["state"] == "snapmirrored" + + def test_raises_timeout_when_never_converges(self) -> None: + client = MagicMock(spec=OntapClient) + client.get.return_value = {"state": "transferring"} + with patch("snapmirror_test_failover.time.sleep"): + with pytest.raises(RuntimeError, match="Timed out"): + sm_tf._wait_snapmirrored(client, "rel-uuid", interval=2, max_wait=1) + + +# --------------------------------------------------------------------------- +# _pick_cluster +# --------------------------------------------------------------------------- + + +class TestPickCluster: + def _make_client(self, records: list[dict]) -> MagicMock: + client = MagicMock(spec=OntapClient) + client.get.return_value = { + "records": records, + "num_records": len(records), + } + return client + + def test_picks_cluster_a_when_it_has_dp_volume(self) -> None: + vol = {"name": "vol_rw_01_dest", "uuid": "v-uuid", "svm": {"name": "vs1"}} + with patch("snapmirror_test_failover.OntapClient") as MockClient: + instance = self._make_client([vol]) + MockClient.return_value = instance + cluster, found_vol = sm_tf._pick_cluster( + "10.0.0.1", "10.0.0.2", "admin", "pass", "vol_rw_01" + ) + assert cluster == "10.0.0.1" + assert found_vol["name"] == "vol_rw_01_dest" + + def test_falls_through_to_cluster_b(self) -> None: + vol = {"name": "vol_rw_01_dest", "uuid": "v-uuid", "svm": {"name": "vs1"}} + a_client = MagicMock(spec=OntapClient) + a_client.get.return_value = {"records": [], "num_records": 0} + b_client = MagicMock(spec=OntapClient) + b_client.get.return_value = {"records": [vol], "num_records": 1} + + with patch("snapmirror_test_failover.OntapClient", side_effect=[a_client, b_client]): + cluster, found_vol = sm_tf._pick_cluster( + "10.0.0.1", "10.0.0.2", "admin", "pass", "vol_rw_01" + ) + assert cluster == "10.0.0.2" + + def test_exits_when_no_cluster_has_dp_volume(self) -> None: + no_vol_client = MagicMock(spec=OntapClient) + no_vol_client.get.return_value = {"records": [], "num_records": 0} + + with patch("snapmirror_test_failover.OntapClient", return_value=no_vol_client): + with pytest.raises(SystemExit): + sm_tf._pick_cluster("10.0.0.1", "10.0.0.2", "admin", "pass", "vol_rw_01") + + def test_uses_wildcard_filter_in_auto_mode(self) -> None: + """When vol_name_filter='*', the DP name filter should be '*_dest'.""" + no_vol_client = MagicMock(spec=OntapClient) + no_vol_client.get.return_value = {"records": [], "num_records": 0} + + with patch("snapmirror_test_failover.OntapClient", return_value=no_vol_client): + with pytest.raises(SystemExit): + sm_tf._pick_cluster("10.0.0.1", "10.0.0.2", "admin", "pass", "*") + + # Verify the filter sent was '*_dest' + call_kwargs = no_vol_client.get.call_args[1] + assert call_kwargs.get("name") == "*_dest" + + def test_skips_unreachable_cluster_and_continues(self) -> None: + vol = {"name": "vol_01_dest", "uuid": "v-uuid", "svm": {"name": "vs1"}} + a_client = MagicMock(spec=OntapClient) + a_client.get.side_effect = ConnectionError("unreachable") + b_client = MagicMock(spec=OntapClient) + b_client.get.return_value = {"records": [vol], "num_records": 1} + + with patch("snapmirror_test_failover.OntapClient", side_effect=[a_client, b_client]): + cluster, found_vol = sm_tf._pick_cluster( + "10.0.0.1", "10.0.0.2", "admin", "pass", "vol_01" + ) + assert cluster == "10.0.0.2"