diff --git a/CLAUDE.md b/CLAUDE.md index fa2f707..91a49d5 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -65,6 +65,8 @@ The SDK uses a src layout with the main package at `src/hubblenetwork/`. Public - **`crypto.py`** - Local packet decryption. Implements AES-CTR decryption with CMAC-based key derivation (SP800_108_Counter KDF). Supports both AES-256-CTR and AES-128-CTR. `decrypt()` accepts `counter_mode` as `"UNIX_TIME"` (default, UTC day-based) or `"DEVICE_UPTIME"` (counter-based, fixed pool size 128). Exports `UNIX_TIME` and `DEVICE_UPTIME` constants. `decrypt_eax()` decrypts AES-EAX packets by iterating counters 0-127, generating candidate EIDs via AES-ECB, and using `AES.MODE_EAX` for authenticated decryption. Uses key directly (no KDF). `decrypt_satellite()` decrypts a satellite packet's payload using the same AES-CTR/CMAC scheme as `decrypt()`; satellite packets deliver `seq_num`, `auth_tag`, and encrypted payload as separate fields (not packed into one advertisement). It accepts `counter_mode` (`"UNIX_TIME"` default, day-based; or `"DEVICE_UPTIME"`, sweeping the fixed 0-127 counter pool) just like `decrypt()`. +- **`detect.py`** - Decryption-strategy auto-detection, decoupled from the CLI (imports no `click`, never prints). `detect_eid_type()` classifies a key's EID rotation mode (UNIX_TIME / DEVICE_UPTIME / AMBIGUOUS) from sample packets. `CtrCounterModeDetector` and `EaxExponentDetector` are per-scan stateful objects that hold the detection cache and own the detect/sweep loop; their `decrypt()` takes a caller-supplied packet-bound `decrypt_fn` and returns a `Detection(result, label)` — `label` is set only on the first successful detection of a scan. The CLI builds the `decrypt_fn` (so test mocks of `cli.decrypt`/`cli.decrypt_eax` still apply) and prints the `[INFO] Detected:` line itself when `label` is set. + - **`packets.py`** - Data classes: `Location`, `EncryptedPacket`, `DecryptedPacket`, `AesEaxPacket`, `UnknownPacket`. - **`device.py`** - `Device` dataclass representing a registered device. diff --git a/src/hubblenetwork/cli.py b/src/hubblenetwork/cli.py index 556d8a2..9c88a83 100644 --- a/src/hubblenetwork/cli.py +++ b/src/hubblenetwork/cli.py @@ -14,7 +14,7 @@ from dataclasses import replace from datetime import datetime from functools import partial -from typing import Callable, Optional, List, TypeVar +from typing import Optional, List from tabulate import tabulate from hubblenetwork import Organization from hubblenetwork import Device, DecryptedPacket, EncryptedPacket, decrypt_eax @@ -24,6 +24,11 @@ from hubblenetwork import sat as sat_mod from hubblenetwork import decrypt, decrypt_satellite, UNIX_TIME, DEVICE_UPTIME from hubblenetwork.crypto import find_time_counter_delta +from hubblenetwork.detect import ( + CtrCounterModeDetector, + EaxExponentDetector, + detect_eid_type, +) from hubblenetwork import cloud from hubblenetwork import InvalidCredentialsError from hubblenetwork.errors import BackendError @@ -83,36 +88,6 @@ def _get_pkt_from_be_with_timestamp(org, device, timestamp): return None -def _detect_eid_type( - key: bytes, - pkts: List[EncryptedPacket], -) -> tuple[Optional[EncryptedPacket], Optional[DecryptedPacket], Optional[str], bool]: - epoch_pkt = None - epoch_dec = None - counter_pkt = None - counter_dec = None - for pkt in pkts: - if epoch_pkt is None: - result = decrypt(key, pkt) - if result: - epoch_pkt = pkt - epoch_dec = result - if counter_pkt is None: - result = decrypt(key, pkt, counter_mode=DEVICE_UPTIME) - if result: - counter_pkt = pkt - counter_dec = result - if epoch_pkt and counter_pkt: - break - if epoch_pkt and counter_pkt: - return (epoch_pkt, epoch_dec, "AMBIGUOUS", True) - if epoch_pkt: - return (epoch_pkt, epoch_dec, UNIX_TIME, False) - if counter_pkt: - return (counter_pkt, counter_dec, DEVICE_UPTIME, False) - return (None, None, None, False) - - def _announce_auto_detect(auto_ctr: bool, auto_eax: bool, *, suppress: bool) -> None: if suppress or not (auto_ctr or auto_eax): return @@ -129,161 +104,14 @@ def _announce_auto_detect(auto_ctr: bool, auto_eax: bool, *, suppress: bool) -> ) -def _decrypt_eax_with_detect( - key: bytes, - pkt: AesEaxPacket, - *, - auto_detect: bool, - fixed_exponent: int, - cache: dict, - announced: list[str], - suppress_info: bool, -) -> Optional[DecryptedPacket]: - if not auto_detect: - return decrypt_eax(key, pkt, period_exponent=fixed_exponent) - - cached = cache.get(pkt.eid) - if cached is not None: - result = decrypt_eax(key, pkt, period_exponent=cached) - if result: - return result - - for candidate in range(16): - result = decrypt_eax(key, pkt, period_exponent=candidate) - if result is None: - continue - cache[pkt.eid] = candidate - if not announced and not suppress_info: - announced.append("eax") - click.secho( - f"[INFO] Detected: AES-128-EAX, counter_source=DEVICE_UPTIME, " - f"period_exponent={candidate} (period={1 << candidate}s)", - fg="green", - err=True, - ) - return result - return None - - -_T = TypeVar("_T") - -# Satellite streams have no per-packet EID to key the counter-mode cache on, so -# the detected mode lives in a single shared slot for the whole scan. -_SAT_CTR_CACHE_KEY = "mode" - +def _announce_detection(label: Optional[str], *, suppress: bool) -> None: + """Print the one-shot ``[INFO] Detected:`` line for a fresh detection. -def _detect_ctr_counter_mode( - *, - decrypt_fn: Callable[..., Optional[_T]], - days: int, - auto_detect: bool, - fixed_counter_mode: str, - key_len: int, - cache: dict, - cache_key: object, - announced: list[str], - suppress_info: bool, -) -> Optional[_T]: - """Decrypt an AES-CTR packet, auto-detecting the counter source if asked. - - Shared by the BLE and satellite scan paths. ``decrypt_fn`` is a packet-bound - adapter accepting ``counter_mode`` (and, for UNIX_TIME, ``days``) and - returning the decrypted result or None. - - When ``auto_detect`` is False the ``fixed_counter_mode`` is used directly. - Otherwise the mode cached under ``cache_key`` (a BLE EID, or a per-stream - sentinel for satellite) is tried first, then UNIX_TIME and DEVICE_UPTIME are - swept; the first that succeeds is cached and announced once via ``announced``. - A ``cache_key`` of None disables caching (BLE packets without an EID). + ``label`` is set by the detector only on the first successful detection of a + scan, so a no-op when it is None or output is suppressed (JSON mode). """ - - def _try(mode: str) -> Optional[_T]: - kwargs = {"counter_mode": mode} - if mode == UNIX_TIME: - kwargs["days"] = days - return decrypt_fn(**kwargs) - - if not auto_detect: - return _try(fixed_counter_mode) - - if cache_key is not None: - cached = cache.get(cache_key) - if cached is not None: - result = _try(cached) - if result is not None: - return result - - for mode in (UNIX_TIME, DEVICE_UPTIME): - result = _try(mode) - if result is None: - continue - if cache_key is not None: - cache[cache_key] = mode - if not announced and not suppress_info: - announced.append("ctr") - variant = "AES-128-CTR" if key_len == 16 else "AES-256-CTR" - click.secho( - f"[INFO] Detected: {variant}, counter_source={mode}", - fg="green", - err=True, - ) - return result - return None - - -def _decrypt_ctr_with_detect( - key: bytes, - pkt: EncryptedPacket, - *, - auto_detect: bool, - fixed_counter_mode: str, - days: int, - cache: dict, - announced: list[str], - suppress_info: bool, -) -> Optional[DecryptedPacket]: - return _detect_ctr_counter_mode( - decrypt_fn=lambda **kw: decrypt(key, pkt, **kw), - days=days, - auto_detect=auto_detect, - fixed_counter_mode=fixed_counter_mode, - key_len=len(key), - cache=cache, - cache_key=pkt.eid, - announced=announced, - suppress_info=suppress_info, - ) - - -def _decrypt_satellite_with_detect( - key: bytes, - pkt, - *, - auto_detect: bool, - fixed_counter_mode: str, - days: int, - state: dict, - announced: list[str], - suppress_info: bool, -) -> Optional[bytes]: - return _detect_ctr_counter_mode( - decrypt_fn=lambda **kw: decrypt_satellite( - key, - seq_no=pkt.seq_num, - auth_tag=pkt.auth_tag, - encrypted_payload=pkt.payload, - timestamp=pkt.timestamp, - **kw, - ), - days=days, - auto_detect=auto_detect, - fixed_counter_mode=fixed_counter_mode, - key_len=len(key), - cache=state, - cache_key=_SAT_CTR_CACHE_KEY, - announced=announced, - suppress_info=suppress_info, - ) + if label and not suppress: + click.secho(f"[INFO] Detected: {label}", fg="green", err=True) def _format_payload(payload, fmt: str) -> str: @@ -1018,9 +846,15 @@ def _explicit(name: str) -> bool: _announce_auto_detect(auto_detect_ctr, auto_detect_eax, suppress=use_json) - detected_ctr_modes: dict = {} - detected_eax_exponents: dict = {} - announced: list[str] = [] + ctr_detector = CtrCounterModeDetector( + auto_detect=auto_detect_ctr, + fixed_counter_mode=counter_mode, + days=days, + key_len=len(decoded_key), + ) + eax_detector = EaxExponentDetector( + auto_detect=auto_detect_eax, fixed_exponent=period_exponent + ) # Set up timeout tracking start = time.monotonic() @@ -1054,26 +888,21 @@ def _explicit(name: str) -> bool: decrypted_pkt = None if isinstance(pkt, AesEaxPacket): - decrypted_pkt = _decrypt_eax_with_detect( - decoded_key, - pkt, - auto_detect=auto_detect_eax, - fixed_exponent=period_exponent, - cache=detected_eax_exponents, - announced=announced, - suppress_info=use_json, + d = eax_detector.decrypt( + decrypt_fn=lambda exp: decrypt_eax( + decoded_key, pkt, period_exponent=exp + ), + cache_key=pkt.eid, ) + _announce_detection(d.label, suppress=use_json) + decrypted_pkt = d.result elif isinstance(pkt, EncryptedPacket): - decrypted_pkt = _decrypt_ctr_with_detect( - decoded_key, - pkt, - auto_detect=auto_detect_ctr, - fixed_counter_mode=counter_mode, - days=days, - cache=detected_ctr_modes, - announced=announced, - suppress_info=use_json, + d = ctr_detector.decrypt( + decrypt_fn=lambda **kw: decrypt(decoded_key, pkt, **kw), + cache_key=pkt.eid, ) + _announce_detection(d.label, suppress=use_json) + decrypted_pkt = d.result # UnencryptedPacket and UnknownPacket fall through — keep scanning. if decrypted_pkt: @@ -1269,9 +1098,15 @@ def _explicit(name: str) -> bool: auto_detect_ctr, auto_detect_eax, suppress=printer.suppress_info_messages ) - detected_ctr_modes: dict = {} - detected_eax_exponents: dict = {} - announced: list[str] = [] + ctr_detector = CtrCounterModeDetector( + auto_detect=auto_detect_ctr, + fixed_counter_mode=counter_mode, + days=days, + key_len=len(decoded_key) if decoded_key is not None else 0, + ) + eax_detector = EaxExponentDetector( + auto_detect=auto_detect_eax, fixed_exponent=period_exponent + ) try: while deadline is None or time.monotonic() < deadline: @@ -1303,15 +1138,16 @@ def _explicit(name: str) -> bool: # AES-EAX packets: decrypt if key provided, else show raw fields elif isinstance(pkt, AesEaxPacket): if decoded_key: - decrypted_pkt = _decrypt_eax_with_detect( - decoded_key, - pkt, - auto_detect=auto_detect_eax, - fixed_exponent=period_exponent, - cache=detected_eax_exponents, - announced=announced, - suppress_info=printer.suppress_info_messages, + d = eax_detector.decrypt( + decrypt_fn=lambda exp: decrypt_eax( + decoded_key, pkt, period_exponent=exp + ), + cache_key=pkt.eid, ) + _announce_detection( + d.label, suppress=printer.suppress_info_messages + ) + decrypted_pkt = d.result if decrypted_pkt: printer.print_row(decrypted_pkt, decrypt_status="ok") elif show_failed_decryption: @@ -1320,16 +1156,14 @@ def _explicit(name: str) -> bool: printer.print_row(pkt) elif isinstance(pkt, EncryptedPacket): if decoded_key: - decrypted_pkt = _decrypt_ctr_with_detect( - decoded_key, - pkt, - auto_detect=auto_detect_ctr, - fixed_counter_mode=counter_mode, - days=days, - cache=detected_ctr_modes, - announced=announced, - suppress_info=printer.suppress_info_messages, + d = ctr_detector.decrypt( + decrypt_fn=lambda **kw: decrypt(decoded_key, pkt, **kw), + cache_key=pkt.eid, + ) + _announce_detection( + d.label, suppress=printer.suppress_info_messages ) + decrypted_pkt = d.result if decrypted_pkt: printer.print_row(decrypted_pkt, decrypt_status="ok") if ingest: @@ -1594,7 +1428,7 @@ def ble_validate(key: str, device_id: str, org_id: str, token: str, timeout: int # Step 6: Validate encryption and detect EID type _validate_info("Validating encryption of received packets") - pkt_to_ingest, dec_result, eid_label, _ = _detect_eid_type(decoded_key, pkts) + pkt_to_ingest, dec_result, eid_label, _ = detect_eid_type(decoded_key, pkts) if not pkt_to_ingest: _validate_error( 'Unable to decrypt packet with given device key.' @@ -3261,8 +3095,12 @@ def _run_sat_scan( auto_ctr=True, auto_eax=False, suppress=printer.suppress_info_messages ) - detected_ctr_state: dict = {} - announced: list[str] = [] + ctr_detector = CtrCounterModeDetector( + auto_detect=auto_detect_ctr, + fixed_counter_mode=counter_mode, + days=days, + key_len=len(decoded_key) if decoded_key is not None else 0, + ) if debug: sat_logger = logging.getLogger("hubblenetwork.sat") @@ -3310,16 +3148,22 @@ def _on_interrupt(sig, frame): # With a key, the user wants only packets the key can decrypt. decrypted = None if pkt.auth_tag is not None: - decrypted = _decrypt_satellite_with_detect( - decoded_key, - pkt, - auto_detect=auto_detect_ctr, - fixed_counter_mode=counter_mode, - days=days, - state=detected_ctr_state, - announced=announced, - suppress_info=printer.suppress_info_messages, + # Satellite streams have no per-packet EID; omitting cache_key + # lets the detector share one per-stream slot for the scan. + d = ctr_detector.decrypt( + decrypt_fn=lambda **kw: decrypt_satellite( + decoded_key, + seq_no=pkt.seq_num, + auth_tag=pkt.auth_tag, + encrypted_payload=pkt.payload, + timestamp=pkt.timestamp, + **kw, + ), + ) + _announce_detection( + d.label, suppress=printer.suppress_info_messages ) + decrypted = d.result if decrypted is not None: printer.print_row( replace(pkt, payload=decrypted), decrypt_status="ok" diff --git a/src/hubblenetwork/detect.py b/src/hubblenetwork/detect.py new file mode 100644 index 0000000..df734e0 --- /dev/null +++ b/src/hubblenetwork/detect.py @@ -0,0 +1,196 @@ +# hubblenetwork/detect.py +"""Auto-detect the decryption configuration of incoming Hubble packets. + +A device's packets are decryptable only once you know how its keys rotate: the +AES-CTR counter source (UNIX_TIME vs DEVICE_UPTIME) or, for AES-128-EAX, the +period exponent. That configuration is not carried in the packet, so when the +caller supplies only a key this module discovers it by trying each candidate +until one decrypts, then caches the winner so the rest of the scan skips the +sweep. Caching is keyed per EID for BLE (a scan can see many devices) and shared +across the whole stream for satellite (no per-packet EID). + +Detection returns its outcome as a :class:`Detection` rather than printing — +``result`` is the decrypted packet/payload and ``label`` describes the detected +configuration, set once per scan so the caller can announce it exactly once. +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Generic, List, Optional, Tuple, TypeVar + +from .crypto import DEVICE_UPTIME, UNIX_TIME, decrypt +from .packets import DecryptedPacket, EncryptedPacket + +T = TypeVar("T") + +# Default cache key for streams with no per-packet EID (satellite): all packets +# in the scan share one detected-mode slot. +_SINGLE_STREAM = object() + + +@dataclass +class Detection(Generic[T]): + """Outcome of one detect-and-decrypt attempt. + + ``result`` is the decrypted packet (BLE) or payload bytes (satellite), or + ``None`` if decryption failed. A zero-length payload (``b""``) is a *success*, + so callers must test ``result is not None`` rather than truthiness. + + ``label`` is the human description of the detected configuration (e.g. + ``"AES-256-CTR, counter_source=UNIX_TIME"``). It is set only on the *first* + successful detection of a scan and is ``None`` otherwise — i.e. when decrypt + failed, when the mode came from cache, or when a configuration was already + announced this scan. + """ + + result: Optional[T] + label: Optional[str] = None + + +def detect_eid_type( + key: bytes, + pkts: List[EncryptedPacket], +) -> Tuple[Optional[EncryptedPacket], Optional[DecryptedPacket], Optional[str], bool]: + """Classify a key's EID rotation mode from sample packets. + + Returns ``(packet, decrypted, label, ambiguous)`` where ``label`` is + ``UNIX_TIME``, ``DEVICE_UPTIME``, ``"AMBIGUOUS"`` (both modes decrypt + something), or ``None`` (neither decrypts any packet). + """ + epoch_pkt = None + epoch_dec = None + counter_pkt = None + counter_dec = None + for pkt in pkts: + if epoch_pkt is None: + result = decrypt(key, pkt) + if result: + epoch_pkt = pkt + epoch_dec = result + if counter_pkt is None: + result = decrypt(key, pkt, counter_mode=DEVICE_UPTIME) + if result: + counter_pkt = pkt + counter_dec = result + if epoch_pkt and counter_pkt: + break + if epoch_pkt and counter_pkt: + return (epoch_pkt, epoch_dec, "AMBIGUOUS", True) + if epoch_pkt: + return (epoch_pkt, epoch_dec, UNIX_TIME, False) + if counter_pkt: + return (counter_pkt, counter_dec, DEVICE_UPTIME, False) + return (None, None, None, False) + + +class CtrCounterModeDetector: + """Per-scan AES-CTR counter-source auto-detection. + + Shared by the BLE and satellite scan paths. The caller supplies a + packet-bound ``decrypt_fn`` (accepting ``counter_mode`` and, for UNIX_TIME, + ``days``) so this module never references a concrete decrypt primitive + directly. + + When ``auto_detect`` is False the ``fixed_counter_mode`` is used directly. + Otherwise the mode cached under ``cache_key`` is tried first, then UNIX_TIME + and DEVICE_UPTIME are swept; the first that succeeds is cached and labelled + once. BLE passes the packet's EID as ``cache_key`` (``None`` for EID-less + packets, which disables caching); satellite omits it, so all packets share + one per-stream slot. + """ + + def __init__( + self, + *, + auto_detect: bool, + fixed_counter_mode: str, + days: int, + key_len: int, + ) -> None: + self._auto_detect = auto_detect + self._fixed_counter_mode = fixed_counter_mode + self._days = days + self._key_len = key_len + self._cache: dict = {} + self._announced = False + + def decrypt( + self, + *, + decrypt_fn: Callable[..., Optional[T]], + cache_key: object = _SINGLE_STREAM, + ) -> Detection[T]: + def _try(mode: str) -> Optional[T]: + kwargs = {"counter_mode": mode} + if mode == UNIX_TIME: + kwargs["days"] = self._days + return decrypt_fn(**kwargs) + + if not self._auto_detect: + return Detection(_try(self._fixed_counter_mode)) + + if cache_key is not None: + cached = self._cache.get(cache_key) + if cached is not None: + result = _try(cached) + if result is not None: + return Detection(result) + + for mode in (UNIX_TIME, DEVICE_UPTIME): + result = _try(mode) + if result is None: + continue + if cache_key is not None: + self._cache[cache_key] = mode + label = None + if not self._announced: + self._announced = True + variant = "AES-128-CTR" if self._key_len == 16 else "AES-256-CTR" + label = f"{variant}, counter_source={mode}" + return Detection(result, label) + return Detection(None) + + +class EaxExponentDetector: + """Per-scan AES-128-EAX period-exponent auto-detection. + + Mirrors :class:`CtrCounterModeDetector` but sweeps period exponents 0-15 on a + caller-supplied ``decrypt_fn(period_exponent)``. EAX packets always carry an + EID, so caching is always keyed on ``cache_key``. + """ + + def __init__(self, *, auto_detect: bool, fixed_exponent: int) -> None: + self._auto_detect = auto_detect + self._fixed_exponent = fixed_exponent + self._cache: dict = {} + self._announced = False + + def decrypt( + self, + *, + decrypt_fn: Callable[[int], Optional[T]], + cache_key: object, + ) -> Detection[T]: + if not self._auto_detect: + return Detection(decrypt_fn(self._fixed_exponent)) + + cached = self._cache.get(cache_key) + if cached is not None: + result = decrypt_fn(cached) + if result is not None: + return Detection(result) + + for candidate in range(16): + result = decrypt_fn(candidate) + if result is None: + continue + self._cache[cache_key] = candidate + label = None + if not self._announced: + self._announced = True + label = ( + f"AES-128-EAX, counter_source=DEVICE_UPTIME, " + f"period_exponent={candidate} (period={1 << candidate}s)" + ) + return Detection(result, label) + return Detection(None) diff --git a/tests/test_ble_validate.py b/tests/test_ble_validate.py index de55e0e..25644c8 100644 --- a/tests/test_ble_validate.py +++ b/tests/test_ble_validate.py @@ -10,7 +10,6 @@ _validate_info, _validate_success, _validate_error, - _detect_eid_type, cli, ) class TestValidateHelpers: @@ -118,7 +117,7 @@ def test_decryption_failure_error(self): device_id = str(uuid.uuid4()) with patch("hubblenetwork.cli.Organization") as mock_org_cls, \ patch("hubblenetwork.cli.ble_mod") as mock_ble, \ - patch("hubblenetwork.cli.decrypt") as mock_decrypt: + patch("hubblenetwork.detect.decrypt") as mock_decrypt: mock_org = mock_org_cls.return_value mock_org.list_devices.return_value = [MagicMock(id=device_id)] mock_ble.scan.return_value = [object()] @@ -158,106 +157,6 @@ def test_returns_none_when_no_match(self): assert result is None -class TestDetectEidType: - """Unit tests for the _detect_eid_type helper.""" - - def test_epoch_only(self): - pkt = MagicMock() - mock_dec = MagicMock() - - def side_effect(*args, **kwargs): - return None if kwargs.get("counter_mode") == "DEVICE_UPTIME" else mock_dec - - with patch("hubblenetwork.cli.decrypt", side_effect=side_effect): - enc, dec, label, ambiguous = _detect_eid_type(b"k" * 16, [pkt]) - - assert enc is pkt - assert dec is mock_dec - assert label == "UNIX_TIME" - assert ambiguous is False - - def test_counter_only(self): - pkt = MagicMock() - mock_dec = MagicMock() - - def side_effect(*args, **kwargs): - return mock_dec if kwargs.get("counter_mode") == "DEVICE_UPTIME" else None - - with patch("hubblenetwork.cli.decrypt", side_effect=side_effect): - enc, dec, label, ambiguous = _detect_eid_type(b"k" * 16, [pkt]) - - assert enc is pkt - assert dec is mock_dec - assert label == "DEVICE_UPTIME" - assert ambiguous is False - - def test_ambiguous(self): - pkt = MagicMock() - epoch_dec = MagicMock() - counter_dec = MagicMock() - - def side_effect(*args, **kwargs): - return counter_dec if kwargs.get("counter_mode") == "DEVICE_UPTIME" else epoch_dec - - with patch("hubblenetwork.cli.decrypt", side_effect=side_effect): - enc, dec, label, ambiguous = _detect_eid_type(b"k" * 16, [pkt]) - - assert enc is pkt - assert dec is epoch_dec # epoch preferred - assert label == "AMBIGUOUS" - assert ambiguous is True - - def test_neither(self): - pkt = MagicMock() - - with patch("hubblenetwork.cli.decrypt", return_value=None): - enc, dec, label, ambiguous = _detect_eid_type(b"k" * 16, [pkt]) - - assert enc is None - assert dec is None - assert label is None - assert ambiguous is False - - def test_stops_early_when_both_found(self): - """Helper stops after pkts[0] resolves both modes; pkts[1] is never processed.""" - pkt0 = MagicMock() - pkt1 = MagicMock() - - with patch("hubblenetwork.cli.decrypt", return_value=MagicMock()) as mock_decrypt: - enc, dec, label, ambiguous = _detect_eid_type( - b"k" * 16, [pkt0, pkt1] - ) - - # Both modes resolved on pkt0: 1 epoch call + 1 counter call = 2 total - assert mock_decrypt.call_count == 2 - assert enc is pkt0 - assert label == "AMBIGUOUS" - assert ambiguous is True - - def test_advances_to_next_packet_when_first_fails(self): - """Loop continues past pkts[0] when it fails both modes.""" - pkt0 = MagicMock() - pkt1 = MagicMock() - mock_dec = MagicMock() - - call_count = {"n": 0} - - def side_effect(*args, **kwargs): - call_count["n"] += 1 - # pkt0 always fails; pkt1 succeeds epoch only - if args[1] is pkt0: - return None - return None if kwargs.get("counter_mode") == "DEVICE_UPTIME" else mock_dec - - with patch("hubblenetwork.cli.decrypt", side_effect=side_effect): - enc, dec, label, ambiguous = _detect_eid_type(b"k" * 16, [pkt0, pkt1]) - - assert enc is pkt1 - assert dec is mock_dec - assert label == "UNIX_TIME" - assert ambiguous is False - - class TestBleValidateEidOutput: """Integration tests verifying EID type is echoed in Step 6 output.""" @@ -271,7 +170,7 @@ def decrypt_side_effect(*args, **kwargs): with patch("hubblenetwork.cli.Organization") as mock_org_cls, \ patch("hubblenetwork.cli.ble_mod") as mock_ble, \ - patch("hubblenetwork.cli.decrypt", side_effect=decrypt_side_effect), \ + patch("hubblenetwork.detect.decrypt", side_effect=decrypt_side_effect), \ patch("hubblenetwork.cli.time.sleep"), \ patch("hubblenetwork.cli._get_pkt_from_be_with_timestamp", return_value=MagicMock(device_name="n", payload=b"p", sequence=1)): @@ -299,7 +198,7 @@ def decrypt_side_effect(*args, **kwargs): with patch("hubblenetwork.cli.Organization") as mock_org_cls, \ patch("hubblenetwork.cli.ble_mod") as mock_ble, \ - patch("hubblenetwork.cli.decrypt", side_effect=decrypt_side_effect), \ + patch("hubblenetwork.detect.decrypt", side_effect=decrypt_side_effect), \ patch("hubblenetwork.cli.time.sleep"), \ patch("hubblenetwork.cli._get_pkt_from_be_with_timestamp", return_value=MagicMock(device_name="n", payload=b"p", sequence=1)): diff --git a/tests/test_detect.py b/tests/test_detect.py new file mode 100644 index 0000000..9702969 --- /dev/null +++ b/tests/test_detect.py @@ -0,0 +1,349 @@ +"""Unit tests for the decryption auto-detection module (`hubblenetwork.detect`). + +These exercise the detect/cache/label logic directly — coverage that previously +only existed indirectly through full CLI scan invocations. +""" +from __future__ import annotations + +import ast +import inspect +from unittest.mock import MagicMock, patch + +from hubblenetwork import DEVICE_UPTIME, UNIX_TIME +from hubblenetwork import detect as detect_mod +from hubblenetwork.detect import ( + CtrCounterModeDetector, + Detection, + EaxExponentDetector, + detect_eid_type, +) + + +# --------------------------------------------------------------------------- +# CtrCounterModeDetector +# --------------------------------------------------------------------------- + + +class TestCtrCounterModeDetector: + def _detector(self, *, auto_detect=True, key_len=32, days=2): + return CtrCounterModeDetector( + auto_detect=auto_detect, + fixed_counter_mode=UNIX_TIME, + days=days, + key_len=key_len, + ) + + def test_detects_unix_time(self): + det = self._detector(key_len=32) + + def fn(**kw): + return "PKT" if kw.get("counter_mode") == UNIX_TIME else None + + d = det.decrypt(decrypt_fn=fn, cache_key=0xAB) + assert d.result == "PKT" + assert d.label == "AES-256-CTR, counter_source=UNIX_TIME" + + def test_detects_device_uptime(self): + det = self._detector(key_len=16) + + def fn(**kw): + return "PKT" if kw.get("counter_mode") == DEVICE_UPTIME else None + + d = det.decrypt(decrypt_fn=fn, cache_key=0xAB) + assert d.result == "PKT" + assert d.label == "AES-128-CTR, counter_source=DEVICE_UPTIME" + + def test_passes_days_only_for_unix_time(self): + det = self._detector(days=7) + seen = [] + + def fn(**kw): + seen.append(kw) + return "PKT" if kw.get("counter_mode") == DEVICE_UPTIME else None + + det.decrypt(decrypt_fn=fn, cache_key=1) + unix_kw = next(k for k in seen if k["counter_mode"] == UNIX_TIME) + uptime_kw = next(k for k in seen if k["counter_mode"] == DEVICE_UPTIME) + assert unix_kw["days"] == 7 + assert "days" not in uptime_kw + + def test_caches_mode_after_first_hit(self): + det = self._detector() + calls = {"n": 0} + + def fn(**kw): + calls["n"] += 1 + return "PKT" if kw.get("counter_mode") == DEVICE_UPTIME else None + + d1 = det.decrypt(decrypt_fn=fn, cache_key=0xCAFE) + # First packet sweeps UNIX_TIME (miss) then DEVICE_UPTIME (hit) = 2 calls. + assert d1.result == "PKT" + assert calls["n"] == 2 + + d2 = det.decrypt(decrypt_fn=fn, cache_key=0xCAFE) + # Second packet hits the cached DEVICE_UPTIME directly = 1 more call. + assert d2.result == "PKT" + assert calls["n"] == 3 + + def test_label_set_only_on_first_success(self): + det = self._detector() + + def fn(**kw): + return "PKT" if kw.get("counter_mode") == UNIX_TIME else None + + d1 = det.decrypt(decrypt_fn=fn, cache_key=1) + d2 = det.decrypt(decrypt_fn=fn, cache_key=2) + assert d1.label is not None + assert d2.label is None + + def test_wrong_key_returns_none(self): + det = self._detector() + d = det.decrypt(decrypt_fn=lambda **kw: None, cache_key=1) + assert d.result is None + assert d.label is None + + def test_cache_key_none_disables_caching(self): + det = self._detector() + calls = {"n": 0} + + def fn(**kw): + calls["n"] += 1 + return "PKT" if kw.get("counter_mode") == DEVICE_UPTIME else None + + det.decrypt(decrypt_fn=fn, cache_key=None) + det.decrypt(decrypt_fn=fn, cache_key=None) + # Both packets re-sweep (2 calls each) since nothing is cached. + assert calls["n"] == 4 + + def test_omitted_cache_key_shares_one_stream_slot(self): + # Satellite path omits cache_key; all packets share a single slot, so the + # second packet hits the cached mode instead of re-sweeping. + det = self._detector() + calls = {"n": 0} + + def fn(**kw): + calls["n"] += 1 + return "PKT" if kw.get("counter_mode") == DEVICE_UPTIME else None + + det.decrypt(decrypt_fn=fn) # sweep UNIX_TIME (miss) + DEVICE_UPTIME (hit) + det.decrypt(decrypt_fn=fn) # cached DEVICE_UPTIME hit only + assert calls["n"] == 3 + + def test_zero_length_payload_is_success(self): + det = self._detector() + # b"" is falsy but a valid decryption — must be treated as a hit. + def fn(**kw): + return b"" if kw.get("counter_mode") == UNIX_TIME else None + + d = det.decrypt(decrypt_fn=fn, cache_key=1) + assert d.result == b"" + assert d.label is not None + + def test_no_auto_detect_uses_fixed_mode(self): + det = CtrCounterModeDetector( + auto_detect=False, fixed_counter_mode=DEVICE_UPTIME, days=2, key_len=32 + ) + seen = [] + + def fn(**kw): + seen.append(kw["counter_mode"]) + return "PKT" + + d = det.decrypt(decrypt_fn=fn, cache_key=1) + assert d.result == "PKT" + assert d.label is None # no announcement in non-auto mode + assert seen == [DEVICE_UPTIME] + + +# --------------------------------------------------------------------------- +# EaxExponentDetector +# --------------------------------------------------------------------------- + + +class TestEaxExponentDetector: + def test_detects_correct_exponent(self): + det = EaxExponentDetector(auto_detect=True, fixed_exponent=15) + + def fn(exp): + return "PKT" if exp == 11 else None + + d = det.decrypt(decrypt_fn=fn, cache_key=0xAB) + assert d.result == "PKT" + assert d.label == ( + "AES-128-EAX, counter_source=DEVICE_UPTIME, " + "period_exponent=11 (period=2048s)" + ) + + def test_caches_exponent_per_eid(self): + det = EaxExponentDetector(auto_detect=True, fixed_exponent=15) + calls = {"n": 0} + + def fn(exp): + calls["n"] += 1 + return "PKT" if exp == 3 else None + + det.decrypt(decrypt_fn=fn, cache_key=0xAB) + # Sweep 0,1,2,3 = 4 calls. + assert calls["n"] == 4 + det.decrypt(decrypt_fn=fn, cache_key=0xAB) + # Cached exponent 3 hit directly = 1 more call. + assert calls["n"] == 5 + + def test_label_set_only_once(self): + det = EaxExponentDetector(auto_detect=True, fixed_exponent=15) + + def fn(exp): + return "PKT" if exp == 0 else None + + d1 = det.decrypt(decrypt_fn=fn, cache_key=1) + d2 = det.decrypt(decrypt_fn=fn, cache_key=2) + assert d1.label is not None + assert d2.label is None + + def test_wrong_key_returns_none(self): + det = EaxExponentDetector(auto_detect=True, fixed_exponent=15) + d = det.decrypt(decrypt_fn=lambda exp: None, cache_key=1) + assert d.result is None + assert d.label is None + + def test_no_auto_detect_uses_fixed_exponent(self): + det = EaxExponentDetector(auto_detect=False, fixed_exponent=12) + seen = [] + + def fn(exp): + seen.append(exp) + return "PKT" + + d = det.decrypt(decrypt_fn=fn, cache_key=1) + assert d.result == "PKT" + assert d.label is None + assert seen == [12] + + +# --------------------------------------------------------------------------- +# detect_eid_type +# --------------------------------------------------------------------------- + + +class TestDetectEidType: + def test_epoch_only(self): + pkt = MagicMock() + mock_dec = MagicMock() + + def side_effect(*args, **kwargs): + return None if kwargs.get("counter_mode") == "DEVICE_UPTIME" else mock_dec + + with patch("hubblenetwork.detect.decrypt", side_effect=side_effect): + enc, dec, label, ambiguous = detect_eid_type(b"k" * 16, [pkt]) + + assert enc is pkt + assert dec is mock_dec + assert label == "UNIX_TIME" + assert ambiguous is False + + def test_counter_only(self): + pkt = MagicMock() + mock_dec = MagicMock() + + def side_effect(*args, **kwargs): + return mock_dec if kwargs.get("counter_mode") == "DEVICE_UPTIME" else None + + with patch("hubblenetwork.detect.decrypt", side_effect=side_effect): + enc, dec, label, ambiguous = detect_eid_type(b"k" * 16, [pkt]) + + assert enc is pkt + assert dec is mock_dec + assert label == "DEVICE_UPTIME" + assert ambiguous is False + + def test_ambiguous(self): + pkt = MagicMock() + epoch_dec = MagicMock() + counter_dec = MagicMock() + + def side_effect(*args, **kwargs): + return counter_dec if kwargs.get("counter_mode") == "DEVICE_UPTIME" else epoch_dec + + with patch("hubblenetwork.detect.decrypt", side_effect=side_effect): + enc, dec, label, ambiguous = detect_eid_type(b"k" * 16, [pkt]) + + assert enc is pkt + assert dec is epoch_dec # epoch preferred + assert label == "AMBIGUOUS" + assert ambiguous is True + + def test_neither(self): + pkt = MagicMock() + + with patch("hubblenetwork.detect.decrypt", return_value=None): + enc, dec, label, ambiguous = detect_eid_type(b"k" * 16, [pkt]) + + assert enc is None + assert dec is None + assert label is None + assert ambiguous is False + + def test_stops_early_when_both_found(self): + """Stops after pkts[0] resolves both modes; pkts[1] is never processed.""" + pkt0 = MagicMock() + pkt1 = MagicMock() + + with patch("hubblenetwork.detect.decrypt", return_value=MagicMock()) as mock_decrypt: + enc, dec, label, ambiguous = detect_eid_type(b"k" * 16, [pkt0, pkt1]) + + # Both modes resolved on pkt0: 1 epoch call + 1 counter call = 2 total. + assert mock_decrypt.call_count == 2 + assert enc is pkt0 + assert label == "AMBIGUOUS" + assert ambiguous is True + + def test_advances_to_next_packet_when_first_fails(self): + """Loop continues past pkts[0] when it fails both modes.""" + pkt0 = MagicMock() + pkt1 = MagicMock() + mock_dec = MagicMock() + + def side_effect(*args, **kwargs): + if args[1] is pkt0: + return None + return None if kwargs.get("counter_mode") == "DEVICE_UPTIME" else mock_dec + + with patch("hubblenetwork.detect.decrypt", side_effect=side_effect): + enc, dec, label, ambiguous = detect_eid_type(b"k" * 16, [pkt0, pkt1]) + + assert enc is pkt1 + assert dec is mock_dec + assert label == "UNIX_TIME" + assert ambiguous is False + + +# --------------------------------------------------------------------------- +# Decoupling guard — the whole point of this module +# --------------------------------------------------------------------------- + + +class TestDecoupling: + def test_detect_module_does_not_import_click(self): + source = inspect.getsource(detect_mod) + tree = ast.parse(source) + imported = [] + for node in ast.walk(tree): + if isinstance(node, ast.Import): + imported += [a.name for a in node.names] + elif isinstance(node, ast.ImportFrom): + imported.append(node.module or "") + assert not any( + name == "click" or name.startswith("click.") for name in imported + ), f"detect.py must not import click; found imports: {imported}" + + +# --------------------------------------------------------------------------- +# Detection dataclass +# --------------------------------------------------------------------------- + + +class TestDetection: + def test_label_defaults_to_none(self): + d = Detection(result="x") + assert d.result == "x" + assert d.label is None