diff --git a/.github/scripts/release-policy.sh b/.github/scripts/release-policy.sh new file mode 100644 index 0000000..0f634ee --- /dev/null +++ b/.github/scripts/release-policy.sh @@ -0,0 +1,134 @@ +#!/usr/bin/env bash + +# Shared helpers for the release pipeline (release-intake / release-guard / +# post-merge-release). This project derives its version purely from the git tag +# at build time via GoReleaser ldflags (internal/version.Version), so there is NO +# in-repo version file to bump and NO manifest assertions here. + +# Keep this regex identical to the SemVer check in release.yml so the two never +# disagree (allows vX.Y.Z and prereleases like vX.Y.Z-rc1 / vX.Y.Z-beta1). +RELEASE_TAG_REGEX='^v[0-9]+\.[0-9]+\.[0-9]+(-[0-9A-Za-z.-]+)?$' +# Unprotected trigger tag that starts a release. It deliberately does NOT match +# the v* glob (release.yml / a tag-immutability ruleset), so it can be created and +# deleted freely; the real vX.Y.Z tag is CREATED ONCE on the squash commit by +# post-merge-release (a tag creation, which a tag-immutability ruleset allows). +PR_TAG_REGEX='^pr-v[0-9]+\.[0-9]+\.[0-9]+(-[0-9A-Za-z.-]+)?$' + +die() { + echo "::error::$*" >&2 + exit 1 +} + +notice() { + echo "::notice::$*" +} + +is_release_tag() { + [[ "${1:-}" =~ ${RELEASE_TAG_REGEX} ]] +} + +validate_release_tag() { + local tag="${1:-}" + + if ! is_release_tag "${tag}"; then + die "Invalid release tag '${tag}'. Allowed formats are vX.Y.Z and vX.Y.Z-rc1/-beta1." + fi +} + +is_pr_tag() { + [[ "${1:-}" =~ ${PR_TAG_REGEX} ]] +} + +# pr-vX.Y.Z -> vX.Y.Z (the release tag that will be CREATED at merge). +release_tag_from_pr_tag() { + local pr_tag="${1:-}" + + if ! is_pr_tag "${pr_tag}"; then + die "Invalid trigger tag '${pr_tag}'. Expected pr-vX.Y.Z (or pr-vX.Y.Z-rc1/-beta1)." + fi + printf '%s\n' "${pr_tag#pr-}" +} + +# Informational only: GoReleaser's `release.prerelease: auto` already marks +# prereleases from the -suffix, so nothing branches on this. +is_prerelease_tag() { + [[ "${1:-}" == *-* ]] +} + +# Success if the given commit is contained in (ancestor of, or equal to) +# origin/main. The caller MUST `git fetch origin main` first. +tag_commit_on_main() { + git merge-base --is-ancestor "${1:-}" origin/main +} + +extract_pr_marker() { + local marker="${1:-}" + + python3 - "${marker}" <<'PY' +import os +import re +import sys + +marker = sys.argv[1] +body = os.environ.get("PR_BODY", "") +pattern = rf"^$" +match = re.search(pattern, body, re.MULTILINE) +if not match: + sys.exit(1) +print(match.group(1).strip()) +PY +} + +delete_remote_tag() { + local tag="${1:-}" + + git push origin ":refs/tags/${tag}" || true +} + +# Existence probes that DISTINGUISH absent from error. A transient auth/network +# failure must never be silently read as "does not exist" (which would defeat the +# preflight / immutability gates). Echo: present|absent|error. + +remote_tag_state() { + local tag="${1:-}" + local rc=0 + # git ls-remote --exit-code: 0 = ref found, 2 = no matching ref, other = failure. + git ls-remote --exit-code --tags origin "refs/tags/${tag}" >/dev/null 2>&1 || rc=$? + case "${rc}" in + 0) echo "present" ;; + 2) echo "absent" ;; + *) echo "error" ;; + esac +} + +remote_release_state() { + local tag="${1:-}" + local out + local rc=0 + out="$(gh api "repos/${GITHUB_REPOSITORY}/releases/tags/${tag}" 2>&1)" || rc=$? + if [[ "${rc}" -eq 0 ]]; then + echo "present" + elif printf '%s' "${out}" | grep -qi 'HTTP 404\|Not Found'; then + echo "absent" + else + echo "error" + fi +} + +# Hard gates: die on "present" AND on "error" (fail closed). Used where the only +# acceptable state to proceed is a confirmed "absent". +assert_release_tag_absent() { + local tag="${1:-}" + case "$(remote_tag_state "${tag}")" in + present) die "Tag ${tag} already exists and is immutable." ;; + error) die "Could not determine whether tag ${tag} exists (git ls-remote failed); aborting." ;; + esac +} + +assert_release_absent() { + local tag="${1:-}" + case "$(remote_release_state "${tag}")" in + present) die "Release ${tag} already exists." ;; + error) die "Could not determine whether release ${tag} exists (gh api failed); aborting." ;; + esac +} diff --git a/.github/workflows/autotag.yml b/.github/workflows/autotag.yml deleted file mode 100644 index d34b90b..0000000 --- a/.github/workflows/autotag.yml +++ /dev/null @@ -1,116 +0,0 @@ -name: Auto Tag (DISABLED) - -on: - push: - branches: - - main # 🔥 SOLO MAIN — mai dev, mai PR - -permissions: - contents: write - -jobs: - autotag: - if: false # 🔴 DISATTIVATO — togli questa riga per attivarlo - runs-on: ubuntu-latest - - steps: - - name: Checkout repository - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 - with: - fetch-depth: 0 # necessario per leggere commit + tag - - - name: Determine next version from commits - id: next_version - run: | - echo "➡ Detecting last tag..." - LAST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0") - echo "Last tag: $LAST_TAG" - - echo "➡ Collecting commits after last tag..." - COMMITS=$(git log "${LAST_TAG}..HEAD" --pretty=format:%s || true) - echo "Commits:" - echo "$COMMITS" - - echo "➡ Filtering version-relevant commits..." - # Rimuove commit irrilevanti - VALID_COMMITS=$(echo "$COMMITS" | grep -viE '^(docs:|ci:|chore:|style:|refactor:)' || true) - - if [ -z "$VALID_COMMITS" ]; then - echo "ℹ️ No version-relevant commits found." - echo "ℹ️ Skipping tag creation." - echo "skip=true" >> "$GITHUB_OUTPUT" - exit 0 - fi - - echo "➡ Relevant commits:" - echo "$VALID_COMMITS" - - # Determina bump type - BUMP="patch" - - if echo "$VALID_COMMITS" | grep -qiE 'BREAKING CHANGE|feat!:'; then - BUMP="major" - elif echo "$VALID_COMMITS" | grep -qiE '^feat:'; then - BUMP="minor" - fi - - echo "➡ Version bump type: $BUMP" - - # Calcolo nuova versione - VERSION_NO_V=${LAST_TAG#v} - IFS='.' read -r MAJOR MINOR PATCH <<< "$VERSION_NO_V" - - case "$BUMP" in - major) - MAJOR=$((MAJOR+1)) - MINOR=0 - PATCH=0 - ;; - minor) - MINOR=$((MINOR+1)) - PATCH=0 - ;; - patch) - PATCH=$((PATCH+1)) - ;; - esac - - NEW_TAG="v${MAJOR}.${MINOR}.${PATCH}" - echo "➡ New tag: $NEW_TAG" - - echo "new_tag=$NEW_TAG" >> "$GITHUB_OUTPUT" - echo "skip=false" >> "$GITHUB_OUTPUT" - - - name: Show next tag (dry run) - if: steps.next_version.outputs.skip == 'false' - run: | - echo "Auto-tagging is currently DISABLED." - echo "If enabled, next tag would be: ${{ steps.next_version.outputs.new_tag }}" - - # Quando vuoi attivarlo: - # basta togliere if:false sopra, e sbloccare questo step: - # - # - name: Create and push tag - # if: steps.next_version.outputs.skip == 'false' - # run: | - # NEW_TAG="${{ steps.next_version.outputs.new_tag }}" - # git config user.name "github-actions[bot]" - # git config user.email "41898282+github-actions[bot]@users.noreply.github.com" - # git tag "$NEW_TAG" - # git push origin "$NEW_TAG" - -# Auto Tag Workflow -# ------------------ -# This workflow automatically determines the next semantic version -# based on commits pushed to the main branch. It analyzes all commits -# made after the latest existing tag and applies the following rules: -# -# - "feat:" → minor version bump -# - "feat!:" or "BREAKING CHANGE" → major version bump -# - "fix:" or other relevant commits → patch bump -# -# Commits that do NOT affect the program version (docs:, ci:, chore:, -# refactor:, style:) are ignored. If all commits are irrelevant, no tag -# is created. -# -# The job is currently disabled (if: false). Remove that condition to enable it. diff --git a/.github/workflows/post-merge-release.yml b/.github/workflows/post-merge-release.yml new file mode 100644 index 0000000..e53bf05 --- /dev/null +++ b/.github/workflows/post-merge-release.yml @@ -0,0 +1,110 @@ +name: Post Merge Release + +on: + pull_request: + branches: + - main + types: + - closed + +permissions: + contents: write + pull-requests: read + +concurrency: + group: post-merge-release + cancel-in-progress: false + +env: + GH_TOKEN: ${{ secrets.RELEASE_BOT_TOKEN }} + GITHUB_TOKEN: ${{ secrets.RELEASE_BOT_TOKEN }} + +jobs: + finalize-release: + name: finalize-release + if: > + github.event.pull_request.merged == true && + github.event.pull_request.base.ref == 'main' && + github.event.pull_request.head.ref == 'dev' && + github.event.pull_request.head.repo.full_name == github.repository + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 + with: + fetch-depth: 0 + token: ${{ secrets.RELEASE_BOT_TOKEN || github.token }} + + - name: Finalize release + shell: bash + env: + MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }} + PR_HEAD_SHA: ${{ github.event.pull_request.head.sha }} + PR_BODY: ${{ github.event.pull_request.body }} + run: | + set -euo pipefail + source .github/scripts/release-policy.sh + + HAS_RELEASE_MARKER="false" + TAG="" + + if TAG="$(extract_pr_marker release-tag)"; then + HAS_RELEASE_MARKER="true" + validate_release_tag "${TAG}" + else + notice "No release marker found; syncing dev to the squash commit without publishing a release." + fi + + git fetch --force origin dev main + + MAIN_SHA="$(git rev-parse origin/main)" + DEV_SHA="$(git rev-parse origin/dev)" + + [[ "${MERGE_SHA}" == "${MAIN_SHA}" ]] || die "main (${MAIN_SHA}) does not point to the PR merge commit (${MERGE_SHA})." + + PARENT_COUNT="$(git show -s --format=%P "${MERGE_SHA}" | wc -w | tr -d ' ')" + [[ "${PARENT_COUNT}" == "1" ]] || die "The PR was not squash-merged; refusing to create a release." + + # Content gate that tolerates dev advancing during review: the squash + # commit on main must capture EXACTLY the current dev tree. We compare + # trees, not commit shas, so review fix commits (which rewrite history + # and change the head sha) are fine as long as the released content + # equals dev. If dev carries extra content the merge does not include, + # we refuse to clobber it and ask for a manual reconcile. + MERGE_TREE="$(git rev-parse "${MERGE_SHA}^{tree}")" + DEV_TREE="$(git rev-parse "origin/dev^{tree}")" + [[ "${MERGE_TREE}" == "${DEV_TREE}" ]] || die "origin/dev content differs from the merged release; refusing to overwrite dev. Reconcile dev with main manually." + + git config user.name "github-actions[bot]" + git config user.email "41898282+github-actions[bot]@users.noreply.github.com" + + # Fast-forward dev onto the squash commit. This REPLACES the old + # sync-dev.yml (which did a destructive `reset --hard` + bare `--force`). + # Lease on the dev we just observed so a genuine concurrent push aborts + # instead of being silently lost. Runs for both release and maintenance + # (no-marker) PRs. + git push origin "${MERGE_SHA}:refs/heads/dev" --force-with-lease="refs/heads/dev:${DEV_SHA}" + + if [[ "${HAS_RELEASE_MARKER}" != "true" ]]; then + notice "dev synchronized to ${MERGE_SHA}; no release marker was present." + exit 0 + fi + + # The vX.Y.Z tag is immutable (a tag-immutability ruleset blocks update + + # deletion), so it must NOT exist yet: we only ever CREATE it, once, here. + # Fail closed: a transient auth/network error aborts rather than being + # misread as "absent" (which would let us try to release a stale state). + assert_release_absent "${TAG}" + assert_release_tag_absent "${TAG}" + + # Create the authoritative ANNOTATED tag ONCE on the squash commit and push + # it with a PLAIN (non-force) push = a tag CREATION, which the ruleset + # allows. We never run `git tag -f`, force-push, or delete a v* tag. Pushing + # the tag (via RELEASE_BOT_TOKEN, a PAT) re-triggers release.yml, whose gate + # now passes (tag is on main) and GoReleaser builds + signs + publishes the + # GitHub release. + git tag -a "${TAG}" -m "Release ${TAG}" "${MERGE_SHA}" + git push origin "refs/tags/${TAG}" + + notice "Tag ${TAG} created on squash commit ${MERGE_SHA}; release.yml will now cut the release." diff --git a/.github/workflows/release-guard.yml b/.github/workflows/release-guard.yml new file mode 100644 index 0000000..32b2f2c --- /dev/null +++ b/.github/workflows/release-guard.yml @@ -0,0 +1,65 @@ +name: Release Guard + +on: + pull_request: + branches: + - main + types: + - opened + - reopened + - synchronize + - edited + - ready_for_review + +permissions: + contents: read + pull-requests: read + +jobs: + release-guard: + name: release-guard + runs-on: ubuntu-latest + + steps: + - name: Checkout trusted base policy + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 + with: + fetch-depth: 0 + ref: ${{ github.event.pull_request.base.sha }} + persist-credentials: false + + - name: Validate release PR + shell: bash + env: + PR_BASE_REF: ${{ github.event.pull_request.base.ref }} + PR_HEAD_REF: ${{ github.event.pull_request.head.ref }} + PR_HEAD_REPO: ${{ github.event.pull_request.head.repo.full_name }} + PR_BODY: ${{ github.event.pull_request.body }} + REPOSITORY: ${{ github.repository }} + run: | + set -euo pipefail + + if [[ ! -f .github/scripts/release-policy.sh ]]; then + echo "::notice::Release policy is not present on base yet; allowing bootstrap PR." + exit 0 + fi + + source .github/scripts/release-policy.sh + + [[ "${PR_BASE_REF}" == "main" ]] || die "Release PRs must target main." + [[ "${PR_HEAD_REF}" == "dev" ]] || die "main only accepts release PRs from dev." + [[ "${PR_HEAD_REPO}" == "${REPOSITORY}" ]] || die "Release PRs must come from the same repository." + + # Maintenance (non-release) PRs dev -> main carry no release marker. + # Allow them through: post-merge syncs dev without publishing a release. + if ! TAG="$(extract_pr_marker release-tag)"; then + notice "No release marker in PR body; treating as a maintenance dev -> main PR." + exit 0 + fi + + # This project has no in-repo version file, so there is no manifest + # invariant to assert against the dev tree. The guard's job is to enforce + # the PR shape/provenance and that the release marker carries a + # well-formed tag. + validate_release_tag "${TAG}" + notice "Release PR for ${TAG} validated." diff --git a/.github/workflows/release-intake.yml b/.github/workflows/release-intake.yml new file mode 100644 index 0000000..208135d --- /dev/null +++ b/.github/workflows/release-intake.yml @@ -0,0 +1,127 @@ +name: Release Intake + +# Started by pushing an UNPROTECTED pr-vX.Y.Z tag on dev. That tag does NOT match +# the v* glob (so it never triggers release.yml and is not covered by a +# tag-immutability ruleset), and it can be created/deleted freely. This workflow +# never creates or moves a v* tag: the authoritative vX.Y.Z tag is CREATED ONCE on +# the squash commit by post-merge-release (a tag creation, which the ruleset allows), +# and that push re-triggers release.yml -> GoReleaser. +on: + push: + tags: + - "pr-v*" + +permissions: + contents: write + pull-requests: write + +concurrency: + group: release-intake-${{ github.ref_name }} + cancel-in-progress: false + +env: + GH_TOKEN: ${{ secrets.RELEASE_BOT_TOKEN || github.token }} + GITHUB_TOKEN: ${{ secrets.RELEASE_BOT_TOKEN || github.token }} + +jobs: + open-release-pr: + name: open-release-pr + # Skip the tag-DELETION push event (this workflow deletes the pr-v* trigger + # tag itself; without this guard that deletion would re-trigger the job). + if: ${{ github.event.deleted != true }} + runs-on: ubuntu-latest + + steps: + - name: Checkout tag + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 + with: + fetch-depth: 0 + # Persist the token (default): this job runs `git push` to delete the + # unprotected pr-v* trigger tag, which needs the credential in git config. + token: ${{ secrets.RELEASE_BOT_TOKEN || github.token }} + + - name: Validate trigger tag and open release PR + shell: bash + run: | + set -euo pipefail + + # SECURITY: source the release policy from the TRUSTED base (origin/main), + # NEVER from the checked-out tag tree. On a `push: tags` event the default + # checkout is the tag = dev HEAD (the to-be-released, untrusted code), so + # sourcing its release-policy.sh would let a tampered policy govern this + # write-credentialed job. Mirrors release-guard, which sources from base. + git fetch --force origin main + POLICY="$(mktemp)" + if git cat-file -e "origin/main:.github/scripts/release-policy.sh" 2>/dev/null; then + git show "origin/main:.github/scripts/release-policy.sh" > "${POLICY}" + else + echo "::notice::Release policy absent on origin/main; bootstrap release, using the tag's policy." + cp .github/scripts/release-policy.sh "${POLICY}" + fi + # shellcheck source=/dev/null + source "${POLICY}" + + TRIGGER="${GITHUB_REF_NAME}" + + if ! is_pr_tag "${TRIGGER}"; then + echo "::error::Invalid trigger '${TRIGGER}'. Push pr-vX.Y.Z (or pr-vX.Y.Z-rc1/-beta1) on dev to start a release." + delete_remote_tag "${TRIGGER}" + exit 1 + fi + + TAG="$(release_tag_from_pr_tag "${TRIGGER}")" + + git fetch --force origin dev "refs/tags/${TRIGGER}:refs/tags/${TRIGGER}" + + TRIGGER_SHA="$(git rev-list -n 1 "${TRIGGER}")" + DEV_SHA="$(git rev-parse origin/dev)" + + if [[ "${TRIGGER_SHA}" != "${DEV_SHA}" ]]; then + echo "::error::${TRIGGER} must point to HEAD of dev (${DEV_SHA}), but points to ${TRIGGER_SHA}." + delete_remote_tag "${TRIGGER}" + exit 1 + fi + + # v* tags are immutable: refuse to start a release whose version already + # exists (release or tag). The probes distinguish absent from error, so a + # transient failure is NOT misread as "absent" (fail closed). + REL_STATE="$(remote_release_state "${TAG}")" + TAG_STATE="$(remote_tag_state "${TAG}")" + if [[ "${REL_STATE}" == "present" || "${TAG_STATE}" == "present" ]]; then + delete_remote_tag "${TRIGGER}" + die "Version ${TAG} already exists (release=${REL_STATE}, tag=${TAG_STATE}); pick a new version." + fi + if [[ "${REL_STATE}" == "error" || "${TAG_STATE}" == "error" ]]; then + # Transient query failure: do NOT open a PR and do NOT delete the trigger, + # so re-pushing ${TRIGGER} retries cleanly. + die "Could not verify whether ${TAG} already exists (release=${REL_STATE}, tag=${TAG_STATE}); aborting." + fi + + OPEN_PR="$(gh pr list --base main --head dev --state open --json number --jq '.[0].number // ""')" + if [[ -n "${OPEN_PR}" ]]; then + PR_BODY="$(gh pr view "${OPEN_PR}" --json body --jq '.body')" + export PR_BODY + if PR_TAG="$(extract_pr_marker release-tag)" && [[ "${PR_TAG}" == "${TAG}" ]]; then + notice "Release PR #${OPEN_PR} for ${TAG} already open." + delete_remote_tag "${TRIGGER}" + exit 0 + fi + delete_remote_tag "${TRIGGER}" + die "An open dev -> main PR already exists (#${OPEN_PR}). Close or merge it before starting a new release." + fi + + # The trigger tag has done its job. Remove it (unprotected -> allowed). + delete_remote_tag "${TRIGGER}" + + BODY="$(cat < + EOF + )" + + gh pr create \ + --base main \ + --head dev \ + --title "Release ${TAG}" \ + --body "${BODY}" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 02cf9c0..71e406b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -5,11 +5,49 @@ on: tags: - "v*" +# Serialize runs for the same tag. The intake-time run (gate skips) and the +# post-merge re-tag run are far apart in time, so this never queues in practice; +# it makes a manual "re-run all jobs" deterministic instead of racing GoReleaser. +concurrency: + group: release-${{ github.ref_name }} + cancel-in-progress: false + jobs: + # Gate: only cut a release when the tag is on main (i.e. it was created on the + # squash commit by post-merge-release after the dev->main PR merged). A v* tag + # that is not reachable from main is NOT released, which neutralizes a stray or + # legacy direct tag push on a dev commit. A tag pushed directly onto a commit + # already on main (legacy escape hatch) passes the gate immediately. + gate: + name: gate + runs-on: ubuntu-latest + permissions: + contents: read + outputs: + on_main: ${{ steps.check.outputs.on_main }} + steps: + - name: Checkout repository + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 + with: + fetch-depth: 0 + persist-credentials: false + - name: Check tag is on main + id: check + run: | + set -euo pipefail + git fetch --force origin main + TAG_SHA="$(git rev-list -n 1 "${GITHUB_REF#refs/tags/}")" + if git merge-base --is-ancestor "$TAG_SHA" origin/main; then + echo "on_main=true" >> "$GITHUB_OUTPUT" + echo "Tag is on main; proceeding with release." + else + echo "on_main=false" >> "$GITHUB_OUTPUT" + echo "::notice::Tag is not on main; skipping release. A release fires only after the squash-merge creates the tag on main." + fi + release: - # Run the job ONLY when the tag is pushed directly (not from a PR) - # and the ref is a tag. - if: startsWith(github.ref, 'refs/tags/') && github.event.base_ref == '' + needs: gate + if: needs.gate.outputs.on_main == 'true' runs-on: ubuntu-latest diff --git a/.github/workflows/sync-dev.yml b/.github/workflows/sync-dev.yml deleted file mode 100644 index cf1bae0..0000000 --- a/.github/workflows/sync-dev.yml +++ /dev/null @@ -1,34 +0,0 @@ -name: Sync dev after squash merge - -on: - pull_request: - types: - - closed - branches: - - main - -jobs: - sync: - if: github.event.pull_request.merged == true - - runs-on: ubuntu-latest - - steps: - - name: Checkout repository - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 - with: - fetch-depth: 0 - - - name: Configure Git - run: | - git config user.name "github-actions[bot]" - git config user.email "github-actions[bot]@users.noreply.github.com" - - - name: Sync dev with main (HARD RESET) - run: | - git fetch origin - git checkout dev - git reset --hard origin/main - - - name: Push new cleaned dev - run: git push origin dev --force diff --git a/cmd/proxsave/backup_execution.go b/cmd/proxsave/backup_execution.go index 8ed9644..44e8941 100644 --- a/cmd/proxsave/backup_execution.go +++ b/cmd/proxsave/backup_execution.go @@ -20,9 +20,16 @@ func runConfiguredBackup(opts backupModeOptions, orch *orchestrator.Orchestrator return nil, nil, types.ExitSuccess.Int() } - if earlyErrorState, exitCode := runPreBackupChecks(opts, orch); earlyErrorState != nil { + skip, earlyErrorState, exitCode := runPreBackupChecks(opts, orch) + if earlyErrorState != nil { return nil, earlyErrorState, exitCode } + if skip { + // Benign concurrency skip (another backup is already running): no failure + // notification, exit 0. The deferred ReleaseBackupLock is a no-op because + // this process never acquired the lock. + return nil, nil, exitCode + } logging.Step("Start Go backup orchestration") hostname := resolveHostname() @@ -47,12 +54,19 @@ func runConfiguredBackup(opts backupModeOptions, orch *orchestrator.Orchestrator return stats, nil, stats.ExitCode } -func runPreBackupChecks(opts backupModeOptions, orch *orchestrator.Orchestrator) (*orchestrator.EarlyErrorState, int) { +// runPreBackupChecks returns (skip, earlyError, exitCode). skip=true means a +// benign concurrency skip (another backup is already running): no early error, +// no notification, exit 0. +func runPreBackupChecks(opts backupModeOptions, orch *orchestrator.Orchestrator) (bool, *orchestrator.EarlyErrorState, int) { preCheckDone := logging.DebugStart(opts.logger, "pre-backup checks", "") if err := orch.RunPreBackupChecks(opts.ctx); err != nil { preCheckDone(err) + if errors.Is(err, orchestrator.ErrBackupInProgress) { + logging.Warning("Skipping backup: %v", err) + return true, nil, types.ExitSuccess.Int() + } logging.Error("Pre-backup validation failed: %v", err) - return &orchestrator.EarlyErrorState{ + return false, &orchestrator.EarlyErrorState{ Phase: "pre_backup_checks", Error: err, ExitCode: types.ExitBackupError, @@ -61,7 +75,7 @@ func runPreBackupChecks(opts backupModeOptions, orch *orchestrator.Orchestrator) } preCheckDone(nil) fmt.Println() - return nil, types.ExitSuccess.Int() + return false, nil, types.ExitSuccess.Int() } func handleBackupRunError(ctx context.Context, orch *orchestrator.Orchestrator, stats *orchestrator.BackupStats, err error) (*orchestrator.BackupStats, *orchestrator.EarlyErrorState, int) { diff --git a/cmd/proxsave/backup_mode.go b/cmd/proxsave/backup_mode.go index e542487..551ec64 100644 --- a/cmd/proxsave/backup_mode.go +++ b/cmd/proxsave/backup_mode.go @@ -63,7 +63,6 @@ func runBackupMode(opts backupModeOptions) backupModeResult { return finishBackupMode(orch, earlyErrorState, nil, exitCode) } - initializeBackupNotifications(opts, orch) logBackupRuntimeSummary(opts.cfg, storageState) stats, earlyErrorState, exitCode := runConfiguredBackup(opts, orch) @@ -87,6 +86,12 @@ func initializeBackupOrchestrator(opts backupModeOptions) (*orchestrator.Orchest orch := orchestrator.New(logger, opts.dryRun) configureBackupOrchestrator(opts, orch) + // Register external notification channels NOW (before any fallible init step), + // so an early-init failure (encryption_setup / checker_config / storage_init) + // still reaches the configured channels via the deferred early-error dispatch. + // Notifiers are built purely from config and have no storage/encryption deps. + initializeBackupNotifications(opts, orch) + if earlyErrorState, exitCode := ensureBackupAgeRecipientsReady(opts, orch, orchInitDone); earlyErrorState != nil { return orch, earlyErrorState, exitCode } diff --git a/cmd/proxsave/backup_notifications.go b/cmd/proxsave/backup_notifications.go index 85cd6b4..714171f 100644 --- a/cmd/proxsave/backup_notifications.go +++ b/cmd/proxsave/backup_notifications.go @@ -14,6 +14,10 @@ import ( func initializeBackupNotifications(opts backupModeOptions, orch *orchestrator.Orchestrator) { logger := opts.logger + // Register notifier secrets so they are scrubbed from every log line + // (defense-in-depth on top of the per-notifier source redaction). + registerNotificationSecrets(logger, opts.cfg) + logging.Step("Initializing notification channels") notifyDone := logging.DebugStart(logger, "notifications init", "") initializeEmailNotification(opts, orch) @@ -25,6 +29,26 @@ func initializeBackupNotifications(opts backupModeOptions, orch *orchestrator.Or fmt.Println() } +// registerNotificationSecrets registers the notifier credentials with the logger +// so they are masked out of any log line. The public Cloudflare relay +// worker token / HMAC secret are intentionally NOT registered (documented +// shared anti-abuse credentials, not confidential). +func registerNotificationSecrets(logger *logging.Logger, cfg *config.Config) { + if logger == nil || cfg == nil { + return + } + logger.RegisterSecret(cfg.TelegramBotToken) + logger.RegisterSecret(cfg.GotifyToken) + if cfg.WebhookEnabled { + for _, ep := range cfg.BuildWebhookConfig().Endpoints { + logger.RegisterSecret(ep.URL) + logger.RegisterSecret(ep.Auth.Token) + logger.RegisterSecret(ep.Auth.Secret) + logger.RegisterSecret(ep.Auth.Pass) + } + } +} + func initializeEmailNotification(opts backupModeOptions, orch *orchestrator.Orchestrator) { cfg := opts.cfg logger := opts.logger diff --git a/cmd/proxsave/backup_storage.go b/cmd/proxsave/backup_storage.go index 5cfb621..12ed997 100644 --- a/cmd/proxsave/backup_storage.go +++ b/cmd/proxsave/backup_storage.go @@ -33,9 +33,9 @@ func initializeBackupStorage(opts backupModeOptions, orch *orchestrator.Orchestr return state, &orchestrator.EarlyErrorState{ Phase: "storage_init", Error: err, - ExitCode: types.ExitConfigError, + ExitCode: types.ExitStorageError, Timestamp: time.Now(), - }, types.ExitConfigError.Int() + }, types.ExitStorageError.Int() } state.localFS = localFS registerPrimaryStorage(opts, orch, localBackend, localFS) diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index 84b1546..b57cd40 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -140,18 +140,28 @@ PORT_WHITELIST= # e.g., "sshd:22,nginx:443" # Built-in defaults: ncat, cryptominer, xmrig, kdevtmpfsi, kinsing, minerd, mr.sh SUSPICIOUS_PROCESSES="ncat,cryptominer,xmrig,kdevtmpfsi,kinsing,minerd,mr.sh" -# Safe process names (won't trigger alerts) -# NOTE: Your values are ADDED to the built-in defaults (not replaced) -# Supports exact match, prefix with *, or regex: patterns (case-insensitive) +# Safe process names for the bracketed "kernel-style" process warning. +# That warning ("Suspicious kernel-style process: ...") fires for any process the +# host's `ps` reports inside square brackets, e.g. a real kernel thread +# `[kworker/0:1]` or a container worker an unprivileged LXC exposes to the host as +# `[celeryd: celery@paperless:ForkPoolWorker-3057]`. Both lists below are checked. +# NOTE: Your values are ADDED to the built-in defaults (not replaced). +# Matching is against the text BETWEEN the brackets, case-insensitive, and a plain +# entry is an EXACT whole-name match (NOT a prefix). Use "name*" (prefix) or +# "regex:pattern" (unanchored) to match part of the name. So the celery worker +# above is matched by `celeryd*` or `regex:^celeryd`, but NOT by a plain `celeryd`. # Built-in defaults for SAFE_BRACKET_PROCESSES: sshd:, systemd, cron, rsyslogd, dbus-daemon, zvol_tq*, arc_*, dbu_*, dbuf_*, l2arc_feed, lockd, nfsd*, nfsv4 callback* SAFE_BRACKET_PROCESSES="sshd:,systemd,cron,rsyslogd,dbus-daemon" # Built-in defaults for SAFE_KERNEL_PROCESSES: ksgxd, hwrng, usb-storage, vdev_autotrim, card1-crtc0, card1-crtc1, card1-crtc2, kvm-pit*, and various regex patterns SAFE_KERNEL_PROCESSES="ksgxd,hwrng,usb-storage,vdev_autotrim,card1-crtc0,card1-crtc1,card1-crtc2,kvm-pit,regex:^card[0-9]+-crtc[0-9]+$,regex:^drbd_[wrs]_.+,regex:^kvm-pit/[0-9]+$,regex:^kmmpd-drbd[0-9]+$" -# Allowlist for the suspicious-process scan (no built-in defaults; purely user-driven) -# A process is never flagged if any token of its command line (or that token's basename) -# matches an entry, even if it also matches SUSPICIOUS_PROCESSES. +# Allowlist for the suspicious-process scan ONLY (no built-in defaults; purely user-driven). +# This list does NOT silence the bracketed "kernel-style" warning above; use +# SAFE_BRACKET_PROCESSES / SAFE_KERNEL_PROCESSES for those. +# A process is never flagged by the suspicious-process scan if any token of its +# command line (or that token's basename) matches an entry, even if it also matches +# SUSPICIOUS_PROCESSES. # Matching is anchored to the start of each token: a plain entry matches any token that # STARTS WITH it (e.g. "ssh" also matches "sshd"), so use "regex:^name$" for an exact match. # "name*" wildcard and "regex:pattern" are also supported (case-insensitive). @@ -194,12 +204,18 @@ This means you don't need to repeat the default values - just add your custom en #### Process Matching -`SUSPICIOUS_PROCESSES` and `SAFE_PROCESSES` are matched per command-line token (and each token's path basename), anchored to the **start** of the token: +proxsave runs two independent process detectors, and their allowlists are **not** interchangeable. Pick the list that matches the warning you see. + +**Suspicious-process scan** (`SUSPICIOUS_PROCESSES`, allowlisted by `SAFE_PROCESSES`) is matched per command-line token (and each token's path basename), anchored to the **start** of the token: - A **plain entry** matches any token that **starts with** it. For example `ncat` matches `ncat` and `/usr/bin/ncat`, but no longer matches the substring inside `concat`; note that `ssh` would also match `sshd`. - Use **`regex:^name$`** for an exact match, or **`name*`** / **`regex:pattern`** for explicit wildcard/regex control (case-insensitive). -`SAFE_BRACKET_PROCESSES` and `SAFE_KERNEL_PROCESSES` match a single process name and behave differently: a plain entry there is an **exact** match (use `name*` / `regex:` for broader matching). +**Bracketed "kernel-style" detector** (`SAFE_BRACKET_PROCESSES` and `SAFE_KERNEL_PROCESSES`) handles the `Suspicious kernel-style process: ...` warning. It fires for any process the host's `ps` reports inside square brackets, such as a real kernel thread (`[kworker/0:1]`) or a container worker an unprivileged LXC exposes to the host (`[celeryd: celery@paperless:ForkPoolWorker-3057]`). It matches a **single name** (the text *between* the brackets), case-insensitively, and behaves differently from the scan above: + +- A **plain entry is an exact, whole-name match** (not a prefix). `celeryd` does **not** match `celeryd: celery@paperless:ForkPoolWorker-3057`. +- Use **`name*`** for a prefix match or **`regex:pattern`** for an unanchored regex. The celery worker above is matched by `celeryd*` or `regex:^celeryd`, but not by a plain `celeryd` or by an anchored `regex:^celeryd$`. +- `SAFE_PROCESSES` has **no effect** on this detector. Allowlist bracketed processes via `SAFE_BRACKET_PROCESSES` (or `SAFE_KERNEL_PROCESSES`). ### Permission Management @@ -1005,7 +1021,7 @@ BACKUP_PVE_FIREWALL=true # PVE firewall configuration BACKUP_VZDUMP_CONFIG=true # /etc/vzdump.conf # Access control lists -BACKUP_PVE_ACL=true # Access control (users/roles/groups/ACL; realms when configured) +BACKUP_PVE_ACL=true # Access control + priv credentials (shadow/token/tfa); realms when configured # Scheduled jobs BACKUP_PVE_JOBS=true # Backup jobs configuration @@ -1030,7 +1046,9 @@ CEPH_CONFIG_PATH=/etc/ceph # Ceph config directory BACKUP_VM_CONFIGS=true # VM/CT config files ``` -**Note (PVE snapshot behavior)**: ProxSave snapshots `PVE_CONFIG_PATH` for completeness. When a PVE feature is disabled, proxsave also excludes its well-known files from that snapshot to avoid “still included via full directory copy” surprises (e.g. `qemu-server/` + `lxc/` for `BACKUP_VM_CONFIGS=false`, `firewall/` + `host.fw` for `BACKUP_PVE_FIREWALL=false`, `user.cfg`/`domains.cfg` for `BACKUP_PVE_ACL=false` (ACLs are stored in `user.cfg` on PVE), `jobs.cfg` + `vzdump.cron` for `BACKUP_PVE_JOBS=false`, `corosync.conf` (and `config.db` capture) for `BACKUP_CLUSTER_CONFIG=false`). +**Note (PVE snapshot behavior)**: ProxSave snapshots `PVE_CONFIG_PATH` for completeness. When a PVE feature is disabled, proxsave also excludes its well-known files from that snapshot to avoid “still included via full directory copy” surprises (e.g. `qemu-server/` + `lxc/` for `BACKUP_VM_CONFIGS=false`, `firewall/` + `host.fw` for `BACKUP_PVE_FIREWALL=false`, `user.cfg`/`domains.cfg` plus the credential files `priv/shadow.cfg`/`priv/token.cfg`/`priv/tfa.cfg` for `BACKUP_PVE_ACL=false` (ACLs are stored in `user.cfg` on PVE), `jobs.cfg` + `vzdump.cron` for `BACKUP_PVE_JOBS=false`, `corosync.conf` (and `config.db` capture) for `BACKUP_CLUSTER_CONFIG=false`). + +> **Security note**: `/etc/pve` is a pmxcfs mount backed by the cluster database `config.db`. Setting `BACKUP_PVE_ACL=false` removes the flat `priv/*` credential files from the snapshot, but the same secrets remain inside `config.db` (captured when `BACKUP_CLUSTER_CONFIG=true`). To exclude PVE access-control secrets from the backup entirely, set both `BACKUP_PVE_ACL=false` and `BACKUP_CLUSTER_CONFIG=false`. ProxSave logs a WARNING during backup when this combination leaves secrets in `config.db`. ### PBS-Specific @@ -1057,7 +1075,7 @@ BACKUP_PBS_NOTIFICATIONS=true # notifications.cfg (targets/matchers/endpoin BACKUP_PBS_NOTIFICATIONS_PRIV=true # notifications-priv.cfg (secrets/credentials for endpoints) # User and permissions -BACKUP_USER_CONFIGS=true # PBS users and tokens +BACKUP_USER_CONFIGS=true # PBS users/ACLs/realms + credentials (token.cfg, shadow.json, token.shadow, tfa.json) # Remote configurations BACKUP_REMOTE_CONFIGS=true # Remote PBS servers @@ -1161,7 +1179,7 @@ BACKUP_ZFS_CONFIG=true # /etc/zfs, /etc/hostid, zpool cache & proper BACKUP_ROOT_HOME=true # /root (excluding .cache, .local/share/Trash) # Backup script repository -BACKUP_SCRIPT_REPOSITORY=false # Include .git directory +BACKUP_SCRIPT_REPOSITORY=false # Snapshot the ProxSave install dir (excludes .git and backup/log output) # Backup configuration file BACKUP_CONFIG_FILE=true # Include this backup.env configuration file in the backup diff --git a/docs/ENCRYPTION.md b/docs/ENCRYPTION.md index 99fffc3..1df901f 100644 --- a/docs/ENCRYPTION.md +++ b/docs/ENCRYPTION.md @@ -244,12 +244,20 @@ The `--decrypt` workflow converts an encrypted backup into a decrypted bundle fo **Output**: - A decrypted bundle saved as: `*.decrypted.bundle.tar` -If you need fully scripted/non-interactive decryption, use the official `age` CLI tool: +If you need fully scripted/non-interactive decryption with a **private key**, use the official `age` CLI tool: ```bash age --decrypt -i /path/to/age-keys.txt host-backup-YYYYMMDD-HHMMSS.tar.xz.age > host-backup-YYYYMMDD-HHMMSS.tar.xz ``` +> **Passphrase recipients are not native age passphrases.** A passphrase recipient +> is an X25519 key *derived* from the passphrase, so the raw `age --decrypt` (which +> only understands age's own scrypt passphrase stanza) cannot decrypt it from the +> passphrase alone — use `proxsave --decrypt`. proxsave re-derives the identity from +> the passphrase plus the **per-installation random salt** generated at setup, which +> is stored next to the recipient (`identity/age/passphrase.salt`) and embedded in +> every backup manifest (`passphrase_salt`) so recovery works on any host. + --- ## Restoring Encrypted Backups @@ -370,7 +378,7 @@ age --decrypt -i /path/to/age-keys.txt host-backup-YYYYMMDD-HHMMSS.tar.xz.age | ### Encryption Implementation - **Algorithm**: ChaCha20-Poly1305 (AEAD) with X25519 ECDH -- **Key derivation**: scrypt (N=2^15, r=8, p=1) for passphrases +- **Key derivation**: scrypt (N=2^15, r=8, p=1) for passphrases, with a per-installation random salt (stored in `identity/age/passphrase.salt` and embedded in each manifest as `passphrase_salt`; legacy archives used a fixed salt and remain decryptable) - **Random nonces**: Unique per encryption operation - **Authentication**: Poly1305 MAC prevents tampering diff --git a/docs/TROUBLESHOOTING.md b/docs/TROUBLESHOOTING.md index 0d0376a..6565aba 100644 --- a/docs/TROUBLESHOOTING.md +++ b/docs/TROUBLESHOOTING.md @@ -467,12 +467,15 @@ AGE_RECIPIENT="age1abc123def456..." ```bash # With private key file age --decrypt -i configs/age-keys.txt backup.tar.xz.age > test.tar.xz - -# With passphrase -age --decrypt backup.tar.xz.age > test.tar.xz -# (prompts for passphrase) ``` +> **Passphrase-based backups**: the archive is encrypted to an X25519 recipient +> *derived* from your passphrase, not with age's native passphrase (scrypt) mode, +> so `age --decrypt` will not prompt for a passphrase and cannot decrypt it on its +> own. Use `proxsave --decrypt` and enter the passphrase when prompted: proxsave +> re-derives the matching identity using the per-installation salt recorded in the +> backup manifest (`passphrase_salt`). + **Verify backup integrity**: ```bash # Check SHA256 checksum diff --git a/go.mod b/go.mod index 947bd99..049ef5f 100644 --- a/go.mod +++ b/go.mod @@ -6,9 +6,9 @@ require ( filippo.io/age v1.3.1 github.com/gdamore/tcell/v2 v2.13.10 github.com/rivo/tview v0.42.0 - golang.org/x/crypto v0.52.0 - golang.org/x/term v0.43.0 - golang.org/x/text v0.37.0 + golang.org/x/crypto v0.53.0 + golang.org/x/term v0.44.0 + golang.org/x/text v0.38.0 ) require ( @@ -17,5 +17,5 @@ require ( github.com/gdamore/encoding v1.0.1 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect - golang.org/x/sys v0.45.0 // indirect + golang.org/x/sys v0.46.0 // indirect ) diff --git a/go.sum b/go.sum index 8bf2957..f9ced8b 100644 --- a/go.sum +++ b/go.sum @@ -19,8 +19,8 @@ github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUc github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988= -golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc= +golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto= +golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -36,20 +36,20 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY= -golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw= +golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4= -golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk= +golang.org/x/term v0.44.0 h1:0rLvDRCtNj0gZkyIXhCyOb2OAzEhLVqc4B+hrsBhrmc= +golang.org/x/term v0.44.0/go.mod h1:7ze4MdzUzLXpSAoFP1H0bOI9aXDqveSvatT5vKcFh2Y= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= -golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= +golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE= +golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= diff --git a/internal/backup/archiver.go b/internal/backup/archiver.go index 3d1cd4b..40f509e 100644 --- a/internal/backup/archiver.go +++ b/internal/backup/archiver.go @@ -49,6 +49,27 @@ func WithLookPathOverride(fn func(string) (string, error)) func() { } } +// ErrArchiveIncomplete signals that the archive was written and closed cleanly +// but one or more source entries could not be added (omitted before their header) +// or were stored with truncated content. The archive is structurally valid (it +// passes `tar -t`) yet does not faithfully represent the source, so CreateArchive +// returns this rather than a success - the run then fails with ExitArchiveError +// instead of silently shipping a valid-looking but incomplete backup (H04). +var ErrArchiveIncomplete = errors.New("archive incomplete: one or more source entries could not be added") + +// ErrArchiveEntryCountMismatch signals that the finished archive lists a different +// number of entries than the archiver wrote, i.e. entries were lost after being +// written (on-disk corruption/truncation that `tar -t` integrity alone does not +// catch). Returned by VerifyArchive so the run fails with ExitVerificationError. +var ErrArchiveEntryCountMismatch = errors.New("archive entry count mismatch") + +// skippedEntry records a source path the archiver could not add to the tar during +// the current walk (omitted entirely before its header, or stored truncated). +type skippedEntry struct { + path string + reason string +} + // Archiver handles tar archive creation with compression type Archiver struct { logger *logging.Logger @@ -62,6 +83,14 @@ type Archiver struct { ageRecipients []age.Recipient excludePatterns []string deps ArchiverDeps + + // State for the current CreateArchive run, reset at its start. Written only by + // the (single) walk goroutine and read by CreateArchive/VerifyArchive after the + // walk has finished (happens-before via the compressor errChan), so no locking + // is needed. + skipped []skippedEntry // source entries that could not be archived + entriesWritten int // tar headers successfully written, for verify reconciliation + contentVerify bool // true once this instance produced the archive (enables entry reconciliation) } // ArchiverConfig holds configuration for archive creation @@ -417,25 +446,94 @@ func (a *Archiver) CreateArchive(ctx context.Context, sourceDir, outputPath stri return fmt.Errorf("failed to create output directory: %w", err) } + a.resetArchiveState() + // Choose compression method + var archiveErr error switch actualCompression { case types.CompressionGzip: - return a.createGzipArchive(ctx, sourceDir, outputPath) + archiveErr = a.createGzipArchive(ctx, sourceDir, outputPath) case types.CompressionPigz: - return a.createPigzArchive(ctx, sourceDir, outputPath) + archiveErr = a.createPigzArchive(ctx, sourceDir, outputPath) case types.CompressionXZ: - return a.createXZArchive(ctx, sourceDir, outputPath) + archiveErr = a.createXZArchive(ctx, sourceDir, outputPath) case types.CompressionBzip2: - return a.createBzip2Archive(ctx, sourceDir, outputPath) + archiveErr = a.createBzip2Archive(ctx, sourceDir, outputPath) case types.CompressionLZMA: - return a.createLzmaArchive(ctx, sourceDir, outputPath) + archiveErr = a.createLzmaArchive(ctx, sourceDir, outputPath) case types.CompressionZstd: - return a.createZstdArchive(ctx, sourceDir, outputPath) + archiveErr = a.createZstdArchive(ctx, sourceDir, outputPath) case types.CompressionNone: - return a.createTarArchive(ctx, sourceDir, outputPath) + archiveErr = a.createTarArchive(ctx, sourceDir, outputPath) default: return fmt.Errorf("unsupported compression type: %s", actualCompression) } + if archiveErr != nil { + return archiveErr + } + + // The compressor finished cleanly; the walk goroutine has joined, so reading + // the walk's accumulated state here is race-free. This instance produced the + // archive, so VerifyArchive may reconcile its entry count. + a.contentVerify = true + return a.incompleteArchiveError() +} + +// resetArchiveState clears the per-run accounting before a new CreateArchive walk. +func (a *Archiver) resetArchiveState() { + a.skipped = nil + a.entriesWritten = 0 + a.contentVerify = false +} + +// recordSkipped notes a source path that could not be added to the archive. +func (a *Archiver) recordSkipped(path, reason string) { + a.skipped = append(a.skipped, skippedEntry{path: path, reason: reason}) +} + +// incompleteArchiveError returns an ErrArchiveIncomplete error naming a sample of +// the skipped entries, or nil if the archive captured every source entry. +func (a *Archiver) incompleteArchiveError() error { + if len(a.skipped) == 0 { + return nil + } + const sample = 5 + names := make([]string, 0, sample) + for i, e := range a.skipped { + if i >= sample { + break + } + names = append(names, fmt.Sprintf("%s (%s)", e.path, e.reason)) + } + more := "" + if len(a.skipped) > sample { + more = fmt.Sprintf(" (and %d more)", len(a.skipped)-sample) + } + return fmt.Errorf("%w: %d source entries could not be archived: %s%s", + ErrArchiveIncomplete, len(a.skipped), strings.Join(names, "; "), more) +} + +// reconcileEntryCount compares the entries listed in the finished archive against +// the number addToTar wrote, catching entries lost after they were written +// (corruption/truncation that `tar -t` integrity alone does not surface). Skipped +// for archives this instance did not create (preserving legacy verify-only +// behaviour) and for encrypted archives (no plaintext listing is available). +// +// The check is "listed < written", not "!=": a genuine entry loss shortens the +// listing, whereas a listing LONGER than expected is never data loss - it only +// arises when the external `tar` splits a member name on an embedded newline +// (e.g. busybox tar, which does not escape control characters like GNU tar does). +// Tolerating the longer case avoids failing a healthy backup over a tar-flavour +// quirk while still catching every dropped entry. +func (a *Archiver) reconcileEntryCount(listed int) error { + if !a.contentVerify || a.encryptArchive { + return nil + } + if listed < a.entriesWritten { + return fmt.Errorf("%w: wrote %d entries but archive lists only %d", ErrArchiveEntryCountMismatch, a.entriesWritten, listed) + } + a.logger.Debug("Archive entry-count reconciliation passed (wrote %d, listed %d)", a.entriesWritten, listed) + return nil } // createGzipArchive creates a gzip-compressed tar archive using Go's stdlib @@ -802,6 +900,10 @@ func (a *Archiver) addToTar(ctx context.Context, tarWriter *tar.Writer, sourceDi if err != nil { a.logger.Warning("Error accessing path %s: %v", path, err) + // A walk error on a directory means filepath.Walk does not descend into + // it, so a whole subtree may be missing from the archive: record it so + // the run fails instead of shipping a silently incomplete archive. + a.recordSkipped(path, fmt.Sprintf("access error: %v", err)) return nil // Continue with other files } @@ -830,6 +932,7 @@ func (a *Archiver) addToTar(ctx context.Context, tarWriter *tar.Writer, sourceDi linkInfo, err := os.Lstat(path) if err != nil { a.logger.Warning("Failed to stat path %s: %v", path, err) + a.recordSkipped(path, fmt.Sprintf("stat failed: %v", err)) return nil } @@ -842,6 +945,7 @@ func (a *Archiver) addToTar(ctx context.Context, tarWriter *tar.Writer, sourceDi linkTarget, err = os.Readlink(path) if err != nil { a.logger.Warning("Failed to read symlink %s: %v", path, err) + a.recordSkipped(path, fmt.Sprintf("readlink failed: %v", err)) return nil } } @@ -850,6 +954,7 @@ func (a *Archiver) addToTar(ctx context.Context, tarWriter *tar.Writer, sourceDi header, err := tar.FileInfoHeader(linkInfo, linkTarget) if err != nil { a.logger.Warning("Failed to create header for %s: %v", path, err) + a.recordSkipped(path, fmt.Sprintf("header build failed: %v", err)) return nil } @@ -882,18 +987,29 @@ func (a *Archiver) addToTar(ctx context.Context, tarWriter *tar.Writer, sourceDi if err := tarWriter.WriteHeader(header); err != nil { return fmt.Errorf("failed to write tar header: %w", err) } + a.entriesWritten++ // If it's a regular file (not symlink, dir, etc), write its content if linkInfo.Mode().IsRegular() { file, err := root.Open(relPath) if err != nil { a.logger.Warning("Failed to open file %s: %v", path, err) + // For a file with content the header (Size>0) is already written with + // no body, so the next WriteHeader/Close fails hard with "missed + // writing N bytes" - the run already fails, no silent loss. A 0-byte + // file has no body to lose, so its empty entry is faithful. Either way + // there is nothing to record here. return nil } if _, err := io.Copy(tarWriter, file); err != nil { _ = file.Close() a.logger.Warning("Failed to write file %s to archive: %v", path, err) + // A short read leaves the entry under-filled (next op hard-fails); a + // file that grew yields io.ErrShortWrite/ErrWriteTooLong with the full + // Size written but truncated content, which `tar -t` never flags. Record + // it so the archive is treated as incomplete in both cases. + a.recordSkipped(path, fmt.Sprintf("content copy failed: %v", err)) return nil } if err := file.Close(); err != nil { @@ -1030,9 +1146,13 @@ func (a *Archiver) verifyXZArchive(ctx context.Context, archivePath string) erro if err != nil { return err } - if err := runTarListVerification(cmd); err != nil { + listed, err := runTarListVerification(cmd) + if err != nil { return fmt.Errorf("tar listing failed: %w", err) } + if err := a.reconcileEntryCount(listed); err != nil { + return err + } a.logger.Debug("Archive verification passed: XZ compression and tar structure are valid") return nil @@ -1058,9 +1178,13 @@ func (a *Archiver) verifyZstdArchive(ctx context.Context, archivePath string) er if err != nil { return err } - if err := runTarListVerification(cmd); err != nil { + listed, err := runTarListVerification(cmd) + if err != nil { return fmt.Errorf("tar listing failed: %w", err) } + if err := a.reconcileEntryCount(listed); err != nil { + return err + } a.logger.Debug("Archive verification passed: Zstd compression and tar structure are valid") return nil @@ -1075,9 +1199,13 @@ func (a *Archiver) verifyGzipArchive(ctx context.Context, archivePath string) er if err != nil { return err } - if err := runTarListVerification(cmd); err != nil { + listed, err := runTarListVerification(cmd) + if err != nil { return fmt.Errorf("tar/gzip verification failed: %w", err) } + if err := a.reconcileEntryCount(listed); err != nil { + return err + } a.logger.Debug("Archive verification passed: Gzip compression and tar structure are valid") return nil @@ -1092,9 +1220,13 @@ func (a *Archiver) verifyBzip2Archive(ctx context.Context, archivePath string) e if err != nil { return err } - if err := runTarListVerification(cmd); err != nil { + listed, err := runTarListVerification(cmd) + if err != nil { return fmt.Errorf("tar/bzip2 verification failed: %w", err) } + if err := a.reconcileEntryCount(listed); err != nil { + return err + } a.logger.Debug("Archive verification passed: Bzip2 compression and tar structure are valid") return nil @@ -1109,9 +1241,13 @@ func (a *Archiver) verifyLzmaArchive(ctx context.Context, archivePath string) er if err != nil { return err } - if err := runTarListVerification(cmd); err != nil { + listed, err := runTarListVerification(cmd) + if err != nil { return fmt.Errorf("tar/lzma verification failed: %w", err) } + if err := a.reconcileEntryCount(listed); err != nil { + return err + } a.logger.Debug("Archive verification passed: LZMA compression and tar structure are valid") return nil @@ -1126,9 +1262,13 @@ func (a *Archiver) verifyTarArchive(ctx context.Context, archivePath string) err if err != nil { return err } - if err := runTarListVerification(cmd); err != nil { + listed, err := runTarListVerification(cmd) + if err != nil { return fmt.Errorf("tar verification failed: %w", err) } + if err := a.reconcileEntryCount(listed); err != nil { + return err + } a.logger.Debug("Archive verification passed: Tar structure is valid") return nil @@ -1159,22 +1299,37 @@ func (c *cappedBuffer) Write(p []byte) (int, error) { func (c *cappedBuffer) String() string { return string(c.buf) } +// lineCountWriter counts newline-terminated lines without retaining them, so a +// `tar -t` listing (one line per entry, potentially enormous) can be counted for +// entry reconciliation without buffering it in memory. +type lineCountWriter struct{ lines int } + +func (c *lineCountWriter) Write(p []byte) (int, error) { + for _, b := range p { + if b == '\n' { + c.lines++ + } + } + return len(p), nil +} + // runTarListVerification runs a `tar -t...` listing command used only to verify // archive integrity. The listing prints one line per entry and can be enormous, -// so its stdout is discarded instead of buffered in memory (the previous -// CombinedOutput kept the whole listing despite the "discard" intent). Only a +// so its stdout is counted (one line per archive member) instead of buffered in +// memory; the count is returned for source-vs-archive entry reconciliation. Only a // bounded amount of stderr is captured so a failure stays actionable. -func runTarListVerification(cmd *exec.Cmd) error { - cmd.Stdout = io.Discard +func runTarListVerification(cmd *exec.Cmd) (int, error) { + counter := &lineCountWriter{} + cmd.Stdout = counter stderr := &cappedBuffer{cap: verifyStderrCap} cmd.Stderr = stderr if err := cmd.Run(); err != nil { if msg := strings.TrimSpace(stderr.String()); msg != "" { - return fmt.Errorf("%w (stderr: %s)", err, msg) + return 0, fmt.Errorf("%w (stderr: %s)", err, msg) } - return err + return 0, err } - return nil + return counter.lines, nil } // GetArchiveSize returns the size of the archive in bytes diff --git a/internal/backup/archiver_completeness_audited_test.go b/internal/backup/archiver_completeness_audited_test.go new file mode 100644 index 0000000..5648ca2 --- /dev/null +++ b/internal/backup/archiver_completeness_audited_test.go @@ -0,0 +1,223 @@ +package backup + +import ( + "archive/tar" + "context" + "errors" + "io" + "net" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/types" +) + +// tarEntryNames reads an uncompressed tar and returns the set of member names +// (with the leading "./" stripped) for assertions. +func tarEntryNames(t *testing.T, archivePath string) map[string]bool { + t.Helper() + f, err := os.Open(archivePath) + if err != nil { + t.Fatalf("open archive: %v", err) + } + defer func() { _ = f.Close() }() + + names := map[string]bool{} + tr := tar.NewReader(f) + for { + hdr, err := tr.Next() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + t.Fatalf("read tar: %v", err) + } + names[strings.TrimPrefix(hdr.Name, "./")] = true + } + return names +} + +// newTestSocket creates a unix socket inside dir. A socket is a file type +// tar.FileInfoHeader cannot represent ("sockets not supported"), which is how the +// tests deterministically force a per-file archiving failure even when running as +// root (permission-based failures do not block root). Returns the socket path. +func newTestSocket(t *testing.T, dir string) string { + t.Helper() + sockPath := filepath.Join(dir, "x.sock") + lc := net.ListenConfig{} + l, err := lc.Listen(context.Background(), "unix", sockPath) + if err != nil { + t.Skipf("cannot create unix socket for test (path length / platform?): %v", err) + } + t.Cleanup(func() { _ = l.Close() }) + return sockPath +} + +// TestCreateArchive_UnarchivableEntryFailsClosed covers H04 fix (a): a source +// entry that cannot be added to the tar makes CreateArchive fail with +// ErrArchiveIncomplete instead of returning a valid-looking but incomplete +// archive as success. The walk still continues so every other file is captured. +func TestCreateArchive_UnarchivableEntryFailsClosed(t *testing.T) { + dir, err := os.MkdirTemp("", "arx") + if err != nil { + t.Fatalf("mkdir temp: %v", err) + } + t.Cleanup(func() { _ = os.RemoveAll(dir) }) + + src := filepath.Join(dir, "s") + if err := os.MkdirAll(src, 0o755); err != nil { + t.Fatalf("mkdir src: %v", err) + } + if err := os.WriteFile(filepath.Join(src, "keep.txt"), []byte("data"), 0o644); err != nil { + t.Fatalf("write keep: %v", err) + } + newTestSocket(t, src) + + archiver := NewArchiver(logging.New(types.LogLevelError, false), &ArchiverConfig{Compression: types.CompressionNone}) + out := filepath.Join(dir, "out.tar") + + err = archiver.CreateArchive(context.Background(), src, out) + if err == nil { + t.Fatal("expected CreateArchive to fail for an incomplete archive, got nil") + } + if !errors.Is(err, ErrArchiveIncomplete) { + t.Fatalf("expected ErrArchiveIncomplete, got %v", err) + } + + // The failure is recorded but the walk is not aborted: the good file is still + // in the archive, only the unrepresentable socket is missing. + names := tarEntryNames(t, out) + if !names["keep.txt"] { + t.Errorf("expected keep.txt to be archived despite the skipped socket, got %v", names) + } + if names["x.sock"] { + t.Errorf("the socket must not appear in the archive") + } +} + +// TestCreateArchive_CompleteArchiveSucceeds is the negative control for fix (a): +// a fully archivable tree yields no error, marks the instance as the producer, +// and counts every written entry. +func TestCreateArchive_CompleteArchiveSucceeds(t *testing.T) { + tempDir := t.TempDir() + src := filepath.Join(tempDir, "src") + if err := os.MkdirAll(filepath.Join(src, "sub"), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filepath.Join(src, "a.txt"), []byte("a"), 0o644); err != nil { + t.Fatalf("write a: %v", err) + } + if err := os.WriteFile(filepath.Join(src, "sub", "b.txt"), []byte("bb"), 0o644); err != nil { + t.Fatalf("write b: %v", err) + } + + archiver := NewArchiver(logging.New(types.LogLevelError, false), &ArchiverConfig{Compression: types.CompressionNone}) + out := filepath.Join(tempDir, "out.tar") + if err := archiver.CreateArchive(context.Background(), src, out); err != nil { + t.Fatalf("CreateArchive: %v", err) + } + if len(archiver.skipped) != 0 { + t.Errorf("expected no skipped entries, got %v", archiver.skipped) + } + if !archiver.contentVerify { + t.Error("expected contentVerify to be set after a successful create") + } + // sub dir + a.txt + sub/b.txt = 3 tar entries. + if archiver.entriesWritten != 3 { + t.Errorf("expected 3 entries written, got %d", archiver.entriesWritten) + } +} + +// TestVerifyArchive_EntryCountMismatchFailsClosed covers H04 fix (b): if the +// finished archive lists fewer entries than were written (entries lost to on-disk +// corruption/truncation that `tar -t` integrity alone does not catch), +// VerifyArchive fails with ErrArchiveEntryCountMismatch. +func TestVerifyArchive_EntryCountMismatchFailsClosed(t *testing.T) { + tempDir := t.TempDir() + src := filepath.Join(tempDir, "src") + if err := os.MkdirAll(src, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filepath.Join(src, "a.txt"), []byte("a"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + archiver := NewArchiver(logging.New(types.LogLevelError, false), &ArchiverConfig{Compression: types.CompressionNone}) + out := filepath.Join(tempDir, "out.tar") + ctx := context.Background() + if err := archiver.CreateArchive(ctx, src, out); err != nil { + t.Fatalf("CreateArchive: %v", err) + } + + // Simulate an entry that vanished after being written: the archive on disk now + // lists one fewer entry than the archiver believes it wrote. + archiver.entriesWritten++ + + err := archiver.VerifyArchive(ctx, out) + if !errors.Is(err, ErrArchiveEntryCountMismatch) { + t.Fatalf("expected ErrArchiveEntryCountMismatch, got %v", err) + } +} + +// TestVerifyArchive_ToleratesListingLongerThanWritten guards the "<" (not "!=") +// reconciliation rule: a listing with MORE lines than entries written is never +// data loss (it happens when a tar flavour like busybox splits a member name on +// an embedded newline), so VerifyArchive must accept it. Only a SHORTER listing +// (a lost entry) is a failure. +func TestVerifyArchive_ToleratesListingLongerThanWritten(t *testing.T) { + tempDir := t.TempDir() + src := filepath.Join(tempDir, "src") + if err := os.MkdirAll(src, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filepath.Join(src, "a.txt"), []byte("a"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + archiver := NewArchiver(logging.New(types.LogLevelError, false), &ArchiverConfig{Compression: types.CompressionNone}) + out := filepath.Join(tempDir, "out.tar") + ctx := context.Background() + if err := archiver.CreateArchive(ctx, src, out); err != nil { + t.Fatalf("CreateArchive: %v", err) + } + + // Pretend we wrote one fewer entry than the listing reports (listed > written), + // mimicking a tar flavour that splits a name across lines. + archiver.entriesWritten-- + + if err := archiver.VerifyArchive(ctx, out); err != nil { + t.Fatalf("a listing longer than written must be tolerated, got %v", err) + } +} + +// TestVerifyArchive_SkipsReconciliationForForeignArchive guards the contentVerify +// gate: an archiver that did not create the archive must NOT reconcile entry +// counts (its entriesWritten is 0), preserving legacy verify-only behaviour and +// avoiding a false mismatch. +func TestVerifyArchive_SkipsReconciliationForForeignArchive(t *testing.T) { + tempDir := t.TempDir() + src := filepath.Join(tempDir, "src") + if err := os.MkdirAll(src, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filepath.Join(src, "a.txt"), []byte("a"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + ctx := context.Background() + creator := NewArchiver(logging.New(types.LogLevelError, false), &ArchiverConfig{Compression: types.CompressionNone}) + out := filepath.Join(tempDir, "out.tar") + if err := creator.CreateArchive(ctx, src, out); err != nil { + t.Fatalf("CreateArchive: %v", err) + } + + // A different instance verifies the archive; it never ran CreateArchive, so it + // must skip reconciliation rather than compare against its zero count. + verifier := NewArchiver(logging.New(types.LogLevelError, false), &ArchiverConfig{Compression: types.CompressionNone}) + if err := verifier.VerifyArchive(ctx, out); err != nil { + t.Fatalf("foreign-archive verification must skip entry reconciliation, got %v", err) + } +} diff --git a/internal/backup/archiver_verification_test.go b/internal/backup/archiver_verification_test.go index b99a46d..b053d8d 100644 --- a/internal/backup/archiver_verification_test.go +++ b/internal/backup/archiver_verification_test.go @@ -395,13 +395,18 @@ func TestVerifyGzipArchive_ValidTarContent(t *testing.T) { // successful command's (potentially huge) stdout is discarded rather than // buffered, and a failure surfaces the command's stderr (not its stdout). func TestRunTarListVerification(t *testing.T) { - // Success: stdout is produced (and discarded), exit 0 -> no error. - if err := runTarListVerification(exec.Command("sh", "-c", "echo a; echo b")); err != nil { + // Success: stdout is produced (and counted, not buffered), exit 0 -> no error. + // The returned count is the number of newline-terminated lines. + lines, err := runTarListVerification(exec.Command("sh", "-c", "echo a; echo b")) + if err != nil { t.Fatalf("expected success, got %v", err) } + if lines != 2 { + t.Fatalf("expected 2 counted lines, got %d", lines) + } // Failure: the error carries stderr, not the discarded stdout. - err := runTarListVerification(exec.Command("sh", "-c", "echo discarded-stdout; echo boom 1>&2; exit 3")) + _, err = runTarListVerification(exec.Command("sh", "-c", "echo discarded-stdout; echo boom 1>&2; exit 3")) if err == nil { t.Fatal("expected an error") } diff --git a/internal/backup/checksum.go b/internal/backup/checksum.go index 75835d8..f3fbe2c 100644 --- a/internal/backup/checksum.go +++ b/internal/backup/checksum.go @@ -35,6 +35,11 @@ type Manifest struct { ScriptVersion string `json:"script_version,omitempty"` EncryptionMode string `json:"encryption_mode,omitempty"` ClusterMode string `json:"cluster_mode,omitempty"` + // PassphraseSalt is the per-installation random salt used to derive a + // passphrase-based AGE recipient. It is a public value embedded so the + // archive stays decryptable from the passphrase alone on any host. Empty + // for X25519/SSH recipients and for legacy archives (which used a fixed salt). + PassphraseSalt string `json:"passphrase_salt,omitempty"` } // NormalizeChecksum validates and normalizes a SHA256 checksum string. diff --git a/internal/backup/collector.go b/internal/backup/collector.go index b92a38c..556eb00 100644 --- a/internal/backup/collector.go +++ b/internal/backup/collector.go @@ -61,6 +61,15 @@ type Collector struct { pbsManifest map[string]ManifestEntry pveManifest map[string]ManifestEntry systemManifest map[string]ManifestEntry + // recordSystemManifest gates population of systemManifest to the system + // collection phase; systemManifestDepth>0 means a directory walk is in + // progress, so only the top-level target is recorded, not every nested file + // (issue #59). + recordSystemManifest bool + systemManifestDepth int + // collectingCustomPaths is set while copying operator-supplied CUSTOM_BACKUP_PATHS, + // during which the source walk prunes the staging workspace to avoid self-copy (#56). + collectingCustomPaths bool } var osSymlink = os.Symlink @@ -417,7 +426,7 @@ func GetDefaultCollectorConfig() *CollectorConfig { BackupSSHKeys: true, BackupZFSConfig: true, BackupRootHome: true, - BackupScriptRepository: true, + BackupScriptRepository: false, BackupUserHomes: true, BackupConfigFile: true, SystemRootPrefix: "", @@ -753,6 +762,24 @@ func (c *Collector) applySymlinkOwnership(dest string, info os.FileInfo) { } } +// recordSystemManifestEntry records a system collection target into the manifest +// (issue #59). It is a no-op outside the system collection phase and inside +// directory walks, so the manifest lists collection targets rather than every +// nested file (no bloat). pveManifestKey computes a tempDir-relative key. +func (c *Collector) recordSystemManifestEntry(dest string, entry ManifestEntry) { + if !c.recordSystemManifest || c.systemManifestDepth != 0 || c.systemManifest == nil { + return + } + c.systemManifest[pveManifestKey(c.tempDir, dest)] = entry +} + +func manifestEntryFromResult(err error, size int64) ManifestEntry { + if err != nil { + return ManifestEntry{Status: StatusFailed, Error: err.Error()} + } + return ManifestEntry{Status: StatusCollected, Size: size} +} + func (c *Collector) safeCopyFile(ctx context.Context, src, dest, description string) error { if err := ctx.Err(); err != nil { return err @@ -760,31 +787,48 @@ func (c *Collector) safeCopyFile(ctx context.Context, src, dest, description str c.logger.Debug("Collecting %s: %s -> %s", description, src, dest) + if c.isWithinStagingDir(src) { + c.logger.Debug("Skipping file %s: inside the staging workspace (#56)", src) + return nil + } + info, found, err := c.statCopySource(src, description) - if err != nil || !found { + if err != nil { + c.recordSystemManifestEntry(dest, ManifestEntry{Status: StatusFailed, Error: err.Error()}) return err } + if !found { + c.recordSystemManifestEntry(dest, ManifestEntry{Status: StatusNotFound}) + return nil + } if c.shouldSkipCopy(src, dest) { + c.recordSystemManifestEntry(dest, ManifestEntry{Status: StatusSkipped}) return nil } if c.dryRun { c.logger.Debug("[DRY RUN] Would copy file: %s -> %s", src, dest) c.incFilesProcessed() + c.recordSystemManifestEntry(dest, ManifestEntry{Status: StatusCollected}) return nil } if info.Mode()&os.ModeSymlink != 0 { - return c.copySymlinkFile(src, dest, info) + err := c.copySymlinkFile(src, dest, info) + c.recordSystemManifestEntry(dest, manifestEntryFromResult(err, 0)) + return err } if !info.Mode().IsRegular() { c.logger.Debug("Skipping non-regular file: %s", src) + c.recordSystemManifestEntry(dest, ManifestEntry{Status: StatusSkipped}) return nil } - return c.copyRegularFile(src, dest, description, info) + err = c.copyRegularFile(src, dest, description, info) + c.recordSystemManifestEntry(dest, manifestEntryFromResult(err, info.Size())) + return err } func (c *Collector) statCopySource(src, description string) (os.FileInfo, bool, error) { @@ -899,6 +943,19 @@ func copyRegularFileContents(srcFile io.Reader, src, dest string) (int64, error) return written, nil } +// isWithinStagingDir reports whether path is the staging tempDir or lives under +// it. The source walk must never descend into the destination staging tree, or a +// broad CUSTOM_BACKUP_PATHS entry (e.g. "/", "/tmp", "/tmp/proxsave") would copy +// the growing archive into itself, recursing and ballooning the backup (#56). +func (c *Collector) isWithinStagingDir(path string) bool { + if !c.collectingCustomPaths || c.tempDir == "" { + return false + } + clean := filepath.Clean(path) + root := filepath.Clean(c.tempDir) + return clean == root || strings.HasPrefix(clean, root+string(os.PathSeparator)) +} + func (c *Collector) safeCopyDir(ctx context.Context, src, dest, description string) error { if err := ctx.Err(); err != nil { return err @@ -906,68 +963,96 @@ func (c *Collector) safeCopyDir(ctx context.Context, src, dest, description stri c.logger.Debug("Collecting directory %s: %s -> %s", description, src, dest) + if c.isWithinStagingDir(src) { + c.logger.Debug("Skipping directory %s: inside the staging workspace (would copy the archive into itself)", src) + return nil + } + if c.shouldExclude(src) || c.shouldExclude(dest) { c.logger.Debug("Skipping directory %s due to exclusion pattern", src) c.incFilesSkipped() + c.recordSystemManifestEntry(dest, ManifestEntry{Status: StatusSkipped}) return nil } if _, err := os.Stat(src); os.IsNotExist(err) { c.logger.Debug("%s not found: %s (skipping)", description, src) + c.recordSystemManifestEntry(dest, ManifestEntry{Status: StatusNotFound}) return nil } if c.dryRun { c.logger.Debug("[DRY RUN] Would copy directory: %s -> %s", src, dest) + c.recordSystemManifestEntry(dest, ManifestEntry{Status: StatusCollected}) return nil } - // Ensure destination exists - if err := c.ensureDir(dest); err != nil { - return err - } - - // Walk source directory - err := filepath.Walk(src, func(path string, info os.FileInfo, err error) error { - if errCtx := ctx.Err(); errCtx != nil { - return errCtx - } - - if err != nil { + // Suppress per-file recording during the walk so the manifest stays at target + // granularity (#59), then record the directory's FINAL status from the actual + // outcome below. Recording StatusCollected up front would misreport a directory + // whose ensureDir/walk later fails as successfully collected. + c.systemManifestDepth++ + walkErr := func() error { + // Ensure destination exists + if err := c.ensureDir(dest); err != nil { return err } - // Calculate relative path and destination path for archive matching. - relPath, err := filepath.Rel(src, path) - if err != nil { - return err - } - destPath := filepath.Join(dest, relPath) + // Walk source directory + return filepath.Walk(src, func(path string, info os.FileInfo, err error) error { + if errCtx := ctx.Err(); errCtx != nil { + return errCtx + } - // Check if this path should be excluded - if c.shouldExclude(path) || c.shouldExclude(destPath) { - // If it's a directory, skip it entirely - if info.IsDir() { - return filepath.SkipDir + if err != nil { + return err + } + + // Never descend into the staging workspace: a broad source (e.g. a custom + // path of "/" or "/tmp") would otherwise copy the in-progress archive into + // itself (#56). + if c.isWithinStagingDir(path) { + if info.IsDir() { + return filepath.SkipDir + } + return nil } - return nil - } - if info.IsDir() { - if err := c.ensureDir(destPath); err != nil { + // Calculate relative path and destination path for archive matching. + relPath, err := filepath.Rel(src, path) + if err != nil { return err } - c.applyMetadata(destPath, info) - return nil - } + destPath := filepath.Join(dest, relPath) - return c.safeCopyFile(ctx, path, destPath, filepath.Base(path)) - }) + // Check if this path should be excluded + if c.shouldExclude(path) || c.shouldExclude(destPath) { + // If it's a directory, skip it entirely + if info.IsDir() { + return filepath.SkipDir + } + return nil + } - if err != nil { - c.logger.Warning("Failed to copy directory %s: %v", description, err) - return err + if info.IsDir() { + if err := c.ensureDir(destPath); err != nil { + return err + } + c.applyMetadata(destPath, info) + return nil + } + + return c.safeCopyFile(ctx, path, destPath, filepath.Base(path)) + }) + }() + c.systemManifestDepth-- + + if walkErr != nil { + c.logger.Warning("Failed to copy directory %s: %v", description, walkErr) + c.recordSystemManifestEntry(dest, ManifestEntry{Status: StatusFailed, Error: walkErr.Error()}) + return walkErr } + c.recordSystemManifestEntry(dest, ManifestEntry{Status: StatusCollected}) c.logger.Debug("Successfully collected %s: %s", description, src) return nil diff --git a/internal/backup/collector_manifest.go b/internal/backup/collector_manifest.go index 6a373e2..88c87c0 100644 --- a/internal/backup/collector_manifest.go +++ b/internal/backup/collector_manifest.go @@ -24,7 +24,12 @@ type ManifestEntry struct { Error string `json:"error,omitempty"` } -// BackupManifest contains metadata about all files in the backup +// BackupManifest is the PRE-OPTIMIZATION collection inventory: its per-file Size +// and Stats.BytesCollected describe files as collected, BEFORE the dedup/prefilter +// stages mutate the staging tree (issue #73). It is an ExportOnly diagnostic +// (category proxsave_info) and is never read back by restore. The authoritative +// record of the shipped payload is the archive sidecar (.sha256 and +// .manifest.json), computed after the archive is built. type BackupManifest struct { CreatedAt time.Time `json:"created_at"` Hostname string `json:"hostname"` diff --git a/internal/backup/collector_pbs.go b/internal/backup/collector_pbs.go index 1b2d635..8113f8a 100644 --- a/internal/backup/collector_pbs.go +++ b/internal/backup/collector_pbs.go @@ -89,6 +89,20 @@ func (c *Collector) CollectPBSConfigs(ctx context.Context) error { return nil } +// pbsUserConfigSecretExcludes are the PBS access-control credential files dropped +// from the /etc/proxmox-backup snapshot when BACKUP_USER_CONFIGS is disabled, in +// addition to user.cfg/acl.cfg/domains.cfg. They mirror the secret files listed by +// the pbs_access_control restore category (internal/orchestrator/categories.go) and +// the pbs*Path constants in internal/orchestrator/restore_access_control.go (kept in +// sync manually: internal/backup cannot import internal/orchestrator). These are +// top-level files in /etc/proxmox-backup, so plain basename patterns match them. +var pbsUserConfigSecretExcludes = []string{ + "token.cfg", + "shadow.json", + "token.shadow", + "tfa.json", +} + func (c *Collector) collectPBSConfigSnapshot(ctx context.Context, root string) error { c.logger.Debug("Collecting PBS directories (source=%s, dest=%s)", root, filepath.Join(c.tempDir, "etc/proxmox-backup")) @@ -122,6 +136,10 @@ func (c *Collector) collectPBSConfigSnapshot(ctx context.Context, root string) e } if !c.config.BackupUserConfigs { extraExclude = append(extraExclude, "user.cfg", "acl.cfg", "domains.cfg") + // Users/ACLs are not the whole story: token secrets, password hashes and TFA + // secrets live alongside in /etc/proxmox-backup. Exclude them too so the toggle + // removes the whole access-control domain. + extraExclude = append(extraExclude, pbsUserConfigSecretExcludes...) } if !c.config.BackupRemoteConfigs { extraExclude = append(extraExclude, "remote.cfg") diff --git a/internal/backup/collector_pbs_datastore.go b/internal/backup/collector_pbs_datastore.go index 878d0f1..9a05618 100644 --- a/internal/backup/collector_pbs_datastore.go +++ b/internal/backup/collector_pbs_datastore.go @@ -798,11 +798,11 @@ func (c *Collector) getDatastoreList(ctx context.Context) ([]pbsDatastore, error if ctxErr := ctx.Err(); ctxErr != nil { return nil, ctxErr } - c.logger.Debug("PBS datastore CLI enumeration failed: %v", err) + c.logger.Warning("PBS datastore enumeration via proxmox-backup-manager failed; per-datastore status may be incomplete (raw datastore.cfg is still collected): %v", err) } else { var entries []datastoreEntry if err := json.Unmarshal(output, &entries); err != nil { - c.logger.Debug("Failed to parse PBS datastore list JSON: %v", err) + c.logger.Warning("Could not parse 'proxmox-backup-manager datastore list' output; per-datastore status may be incomplete (raw datastore.cfg is still collected): %v", err) } else { datastores = make([]pbsDatastore, 0, len(entries)+len(c.config.PBSDatastorePaths)) for _, entry := range entries { diff --git a/internal/backup/collector_pbs_test.go b/internal/backup/collector_pbs_test.go index b690cb3..07919a7 100644 --- a/internal/backup/collector_pbs_test.go +++ b/internal/backup/collector_pbs_test.go @@ -227,6 +227,9 @@ func TestGetDatastoreListCommandError(t *testing.T) { if datastores[0].Name != "from-error" || datastores[0].Path != "/override/from-error" || datastores[0].Source != pbsDatastoreSourceOverride { t.Fatalf("unexpected override datastore after command failure: %+v", datastores[0]) } + if got := collector.logger.WarningCount(); got != 1 { + t.Fatalf("expected the enumeration failure surfaced as 1 warning (issue #62), got %d", got) + } } func TestGetDatastoreListBadJSON(t *testing.T) { @@ -250,6 +253,9 @@ func TestGetDatastoreListBadJSON(t *testing.T) { if datastores[0].Name != "from-parse" || datastores[0].Path != "/override/from-parse" || datastores[0].Source != pbsDatastoreSourceOverride { t.Fatalf("unexpected override datastore after parse failure: %+v", datastores[0]) } + if got := collector.logger.WarningCount(); got != 1 { + t.Fatalf("expected the JSON parse failure surfaced as 1 warning (issue #62), got %d", got) + } } func TestHasTapeSupportContextCanceled(t *testing.T) { @@ -926,3 +932,72 @@ func (fakeFileInfo) Mode() os.FileMode { return 0 } func (fakeFileInfo) ModTime() time.Time { return time.Time{} } func (fakeFileInfo) IsDir() bool { return false } func (fakeFileInfo) Sys() interface{} { return nil } + +func TestCollectPBSConfigSnapshotExcludesUserConfigSecretsWhenDisabled(t *testing.T) { + pbsRoot := t.TempDir() + for _, f := range []string{ + "user.cfg", "acl.cfg", "domains.cfg", + "token.cfg", "shadow.json", "token.shadow", "tfa.json", + "datastore.cfg", "notifications-priv.cfg", + } { + if err := os.WriteFile(filepath.Join(pbsRoot, f), []byte("x"), 0o600); err != nil { + t.Fatalf("write %s: %v", f, err) + } + } + + cfg := GetDefaultCollectorConfig() + cfg.PBSConfigPath = pbsRoot + cfg.BackupUserConfigs = false + + collector := NewCollectorWithDeps(newTestLogger(), cfg, t.TempDir(), types.ProxmoxBS, false, CollectorDeps{}) + + if err := collector.collectPBSConfigSnapshot(context.Background(), pbsRoot); err != nil { + t.Fatalf("collectPBSConfigSnapshot failed: %v", err) + } + + dest := filepath.Join(collector.tempDir, "etc/proxmox-backup") + for _, excluded := range []string{ + "user.cfg", "acl.cfg", "domains.cfg", + "token.cfg", "shadow.json", "token.shadow", "tfa.json", + } { + if _, err := os.Stat(filepath.Join(dest, excluded)); err == nil { + t.Fatalf("expected %s excluded when BACKUP_USER_CONFIGS=false", excluded) + } else if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("stat %s: %v", excluded, err) + } + } + + // Files governed by other toggles must remain: proves the exclusion is scoped to + // the access-control domain, not a blanket wipe. + for _, kept := range []string{"datastore.cfg", "notifications-priv.cfg"} { + if _, err := os.Stat(filepath.Join(dest, kept)); err != nil { + t.Fatalf("expected %s retained (not governed by BACKUP_USER_CONFIGS): %v", kept, err) + } + } +} + +func TestCollectPBSConfigSnapshotKeepsUserConfigSecretsWhenEnabled(t *testing.T) { + pbsRoot := t.TempDir() + for _, f := range []string{"user.cfg", "token.cfg", "shadow.json", "token.shadow", "tfa.json"} { + if err := os.WriteFile(filepath.Join(pbsRoot, f), []byte("x"), 0o600); err != nil { + t.Fatalf("write %s: %v", f, err) + } + } + + cfg := GetDefaultCollectorConfig() + cfg.PBSConfigPath = pbsRoot + cfg.BackupUserConfigs = true + + collector := NewCollectorWithDeps(newTestLogger(), cfg, t.TempDir(), types.ProxmoxBS, false, CollectorDeps{}) + + if err := collector.collectPBSConfigSnapshot(context.Background(), pbsRoot); err != nil { + t.Fatalf("collectPBSConfigSnapshot failed: %v", err) + } + + dest := filepath.Join(collector.tempDir, "etc/proxmox-backup") + for _, kept := range []string{"user.cfg", "token.cfg", "shadow.json", "token.shadow", "tfa.json"} { + if _, err := os.Stat(filepath.Join(dest, kept)); err != nil { + t.Fatalf("expected %s collected when BACKUP_USER_CONFIGS=true, got %v", kept, err) + } + } +} diff --git a/internal/backup/collector_pve.go b/internal/backup/collector_pve.go index 8ff6e13..d04e0ef 100644 --- a/internal/backup/collector_pve.go +++ b/internal/backup/collector_pve.go @@ -53,19 +53,42 @@ type pveStorageScanResult struct { SkipRemaining bool } +// defaultPVEBackupPatterns are the glob patterns used to recognise PVE backup +// files for sampling/analysis and for the opt-in small/selected copy. It includes +// the legacy vzdump compression variants (.lzo from lzop, .xz) so backups produced +// by older or non-default vzdump settings are not silently missed (issue #65). var defaultPVEBackupPatterns = []string{ "*.vma", "*.vma.gz", "*.vma.lz4", + "*.vma.lzo", + "*.vma.xz", "*.vma.zst", "*.tar", "*.tar.gz", "*.tar.lz4", + "*.tar.lzo", + "*.tar.xz", "*.tar.zst", "*.log", "*.notes", } +// pveACLPrivExcludePatterns are the access-control credential files under +// /etc/pve/priv/ that must be dropped from the flat /etc/pve snapshot when +// BACKUP_PVE_ACL is disabled. They mirror the access-control material listed by +// the pve_access_control restore category (internal/orchestrator/categories.go) +// and the pve*CfgPath constants in internal/orchestrator/restore_access_control.go +// (kept in sync manually: internal/backup cannot import internal/orchestrator). +// The "**/priv/.cfg" form anchors on the priv parent so it matches the +// path candidate "etc/pve/priv/shadow.cfg" (see matchesGlob/globToRegex) without +// touching priv/notifications.cfg (pve_notifications domain), authkey.key or acme/. +var pveACLPrivExcludePatterns = []string{ + "**/priv/shadow.cfg", + "**/priv/token.cfg", + "**/priv/tfa.cfg", +} + var errStopWalk = errors.New("stop walk") // CollectPVEConfigs collects Proxmox VE specific configurations @@ -209,6 +232,22 @@ func (c *Collector) populatePVEManifest() { countNotFound: false, suppressNotFoundLog: true, }) + // Access-control credential material under priv/ (gated by the same toggle). + // These are absent on a fresh install with no custom users/tokens/2FA, so a + // missing file is not an error. + for _, privFile := range []struct{ name, description string }{ + {"shadow.cfg", "User password hashes"}, + {"token.cfg", "API token secrets"}, + {"tfa.cfg", "TFA secrets"}, + } { + record(filepath.Join(pveConfigPath, "priv", privFile.name), c.config.BackupPVEACL, manifestLogOpts{ + description: privFile.description, + disableHint: "BACKUP_PVE_ACL", + log: true, + countNotFound: false, + suppressNotFoundLog: true, + }) + } // Scheduled jobs. record(filepath.Join(pveConfigPath, "jobs.cfg"), c.config.BackupPVEJobs, manifestLogOpts{ @@ -245,13 +284,7 @@ func (c *Collector) populatePVEManifest() { }) // VZDump configuration. - vzdumpPath := c.config.VzdumpConfigPath - if vzdumpPath == "" { - vzdumpPath = "/etc/vzdump.conf" - } else if !filepath.IsAbs(vzdumpPath) { - vzdumpPath = filepath.Join(pveConfigPath, vzdumpPath) - } - record(vzdumpPath, c.config.BackupVZDumpConfig, manifestLogOpts{ + record(c.effectiveVzdumpConfigPath(), c.config.BackupVZDumpConfig, manifestLogOpts{ description: "VZDump configuration", disableHint: "BACKUP_VZDUMP_CONFIG", log: true, @@ -321,6 +354,10 @@ func (c *Collector) collectPVEConfigSnapshot(ctx context.Context) error { } if !c.config.BackupPVEACL { extraExclude = append(extraExclude, "user.cfg", "domains.cfg") + // ACLs/users are not just user.cfg/domains.cfg: the credential material + // lives under priv/ (password hashes, API token secrets, TFA secrets). + // Exclude it too so the toggle removes the whole access-control domain. + extraExclude = append(extraExclude, pveACLPrivExcludePatterns...) } if !c.config.BackupPVEJobs { extraExclude = append(extraExclude, "jobs.cfg", "vzdump.cron") @@ -341,16 +378,19 @@ func (c *Collector) collectPVEConfigSnapshot(ctx context.Context) error { } func (c *Collector) collectPVEClusterSnapshot(ctx context.Context, clustered bool) error { - pveConfigPath := c.effectivePVEConfigPath() clusterPath := c.effectivePVEClusterPath() + // /etc/pve is a pmxcfs mount backed by config.db: the cluster database still + // contains the PVE access-control secrets even though BACKUP_PVE_ACL=false + // excludes the flat priv files. Warn so the operator is not left with a false + // sense of exclusion; the only way to drop them entirely is to also disable + // cluster backup. + if c.config.BackupClusterConfig && !c.config.BackupPVEACL { + c.logger.Warning("PVE access control: BACKUP_PVE_ACL=false excludes /etc/pve/priv/{shadow,token,tfa}.cfg, but the same secrets remain inside the cluster database config.db; set BACKUP_CLUSTER_CONFIG=false to exclude them entirely") + } + if c.config.BackupClusterConfig { - corosyncPath := c.config.CorosyncConfigPath - if corosyncPath == "" { - corosyncPath = filepath.Join(pveConfigPath, "corosync.conf") - } else if !filepath.IsAbs(corosyncPath) { - corosyncPath = filepath.Join(pveConfigPath, corosyncPath) - } + corosyncPath := c.effectiveCorosyncConfigPath() if err := c.safeCopyFile(ctx, corosyncPath, c.targetPathFor(corosyncPath), @@ -432,15 +472,9 @@ func (c *Collector) collectPVEFirewallSnapshot(ctx context.Context) error { } func (c *Collector) collectPVEVZDumpSnapshot(ctx context.Context) error { - pveConfigPath := c.effectivePVEConfigPath() if c.config.BackupVZDumpConfig { c.logger.Info("Collecting VZDump backup configuration") - vzdumpPath := c.config.VzdumpConfigPath - if vzdumpPath == "" { - vzdumpPath = "/etc/vzdump.conf" - } else if !filepath.IsAbs(vzdumpPath) { - vzdumpPath = filepath.Join(pveConfigPath, vzdumpPath) - } + vzdumpPath := c.effectiveVzdumpConfigPath() if err := c.safeCopyFile(ctx, vzdumpPath, c.targetPathFor(vzdumpPath), @@ -1236,11 +1270,9 @@ func (c *Collector) collectPVEStorageMetadataJSONStep(ctx context.Context, resul includePatterns := c.config.PxarFileIncludePatterns if len(includePatterns) == 0 { - includePatterns = []string{ - "*.vma", "*.vma.gz", "*.vma.lz4", "*.vma.zst", - "*.tar", "*.tar.gz", "*.tar.lz4", "*.tar.zst", - "*.log", "*.notes", - } + // Same default set as the analysis scan; keep them unified so legacy + // variants stay in sync (issue #65). + includePatterns = defaultPVEBackupPatterns } excludePatterns := c.config.PxarFileExcludePatterns @@ -2373,11 +2405,25 @@ func (c *Collector) effectiveCorosyncConfigPath() string { return filepath.Join(c.effectivePVEConfigPath(), "corosync.conf") } if filepath.IsAbs(corosyncPath) { - return corosyncPath + // Honor SystemRootPrefix for an absolute override, like effectivePVEConfigPath. + return c.systemPath(corosyncPath) } return filepath.Join(c.effectivePVEConfigPath(), corosyncPath) } +// effectiveVzdumpConfigPath resolves the vzdump.conf source, honoring an optional +// SystemRootPrefix for the default and for absolute overrides (mirroring corosync). +func (c *Collector) effectiveVzdumpConfigPath() string { + vzdumpPath := strings.TrimSpace(c.config.VzdumpConfigPath) + if vzdumpPath == "" { + return c.systemPath("/etc/vzdump.conf") + } + if filepath.IsAbs(vzdumpPath) { + return c.systemPath(vzdumpPath) + } + return filepath.Join(c.effectivePVEConfigPath(), vzdumpPath) +} + func (c *Collector) hasMultiplePVENodes() bool { count, err := c.pveNodesDirCount() return err == nil && count > 1 diff --git a/internal/backup/collector_pve_test.go b/internal/backup/collector_pve_test.go index dac67b0..9d40e10 100644 --- a/internal/backup/collector_pve_test.go +++ b/internal/backup/collector_pve_test.go @@ -696,6 +696,37 @@ func TestPVEJobBricksComprehensive(t *testing.T) { }) } +// TestSystemRootPrefixAppliesToCorosyncAndVzdumpPaths guards #68: the documented +// SYSTEM_ROOT_PREFIX must be honored for the corosync/vzdump config reads, whose +// defaults are absolute (/etc/pve/corosync.conf, /etc/vzdump.conf) and previously +// bypassed the prefix. +func TestSystemRootPrefixAppliesToCorosyncAndVzdumpPaths(t *testing.T) { + t.Run("absolute defaults are prefixed when SystemRootPrefix is set", func(t *testing.T) { + root := "/mnt/fixture" + cfg := GetDefaultCollectorConfig() // corosync/vzdump defaults are absolute + cfg.SystemRootPrefix = root + c := NewCollector(newTestLogger(), cfg, t.TempDir(), "pve", false) + + if got, want := c.effectiveCorosyncConfigPath(), filepath.Join(root, "etc/pve/corosync.conf"); got != want { + t.Errorf("corosync: got %q, want %q", got, want) + } + if got, want := c.effectiveVzdumpConfigPath(), filepath.Join(root, "etc/vzdump.conf"); got != want { + t.Errorf("vzdump: got %q, want %q", got, want) + } + }) + t.Run("no prefix is identity (production default unchanged)", func(t *testing.T) { + cfg := GetDefaultCollectorConfig() // SystemRootPrefix == "" + c := NewCollector(newTestLogger(), cfg, t.TempDir(), "pve", false) + + if got := c.effectiveCorosyncConfigPath(); got != "/etc/pve/corosync.conf" { + t.Errorf("corosync without prefix: got %q, want /etc/pve/corosync.conf", got) + } + if got := c.effectiveVzdumpConfigPath(); got != "/etc/vzdump.conf" { + t.Errorf("vzdump without prefix: got %q, want /etc/vzdump.conf", got) + } + }) +} + // TestPVEScheduleBricks runs the real PVE schedule bricks. func TestPVEScheduleBricks(t *testing.T) { collector := newPVECollector(t) @@ -754,6 +785,10 @@ func TestCollectPVEDirectoriesExcludesDisabledPVEConfigFiles(t *testing.T) { mustWrite(filepath.Join(pveRoot, "lxc", "101.conf"), "ct") mustWrite(filepath.Join(pveRoot, "firewall", "cluster.fw"), "fw") mustWrite(filepath.Join(pveRoot, "nodes", "node1", "host.fw"), "hostfw") + mustWrite(filepath.Join(pveRoot, "priv", "shadow.cfg"), "hash") + mustWrite(filepath.Join(pveRoot, "priv", "token.cfg"), "token") + mustWrite(filepath.Join(pveRoot, "priv", "tfa.cfg"), "tfa") + mustWrite(filepath.Join(pveRoot, "priv", "notifications.cfg"), "notif") clusterPath := filepath.Join(t.TempDir(), "pve-cluster") mustWrite(filepath.Join(clusterPath, "config.db"), "db") @@ -787,6 +822,9 @@ func TestCollectPVEDirectoriesExcludesDisabledPVEConfigFiles(t *testing.T) { filepath.Join("lxc", "101.conf"), filepath.Join("firewall", "cluster.fw"), filepath.Join("nodes", "node1", "host.fw"), + filepath.Join("priv", "shadow.cfg"), + filepath.Join("priv", "token.cfg"), + filepath.Join("priv", "tfa.cfg"), } { _, err := os.Stat(filepath.Join(destPVE, excluded)) if err == nil { @@ -797,12 +835,94 @@ func TestCollectPVEDirectoriesExcludesDisabledPVEConfigFiles(t *testing.T) { } } + // priv/notifications.cfg belongs to the notifications domain, NOT access control: + // the BACKUP_PVE_ACL toggle must not exclude it (proves the exclusion is file-scoped, + // not a priv/** subtree wipe). + if _, err := os.Stat(filepath.Join(destPVE, "priv", "notifications.cfg")); err != nil { + t.Fatalf("expected priv/notifications.cfg retained (not governed by BACKUP_PVE_ACL): %v", err) + } + destDB := collector.targetPathFor(filepath.Join(clusterPath, "config.db")) if _, err := os.Stat(destDB); err == nil { t.Fatalf("expected config.db excluded when BACKUP_CLUSTER_CONFIG=false") } } +// TestPVEConfigSnapshotKeepsPrivWhenACLEnabled guards against over-exclusion: +// with BACKUP_PVE_ACL=true the priv credential files must still be collected. +func TestPVEConfigSnapshotKeepsPrivWhenACLEnabled(t *testing.T) { + collector := newPVECollector(t) + pveRoot := collector.config.PVEConfigPath + + write := func(rel, contents string) { + t.Helper() + path := filepath.Join(pveRoot, rel) + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("mkdir %s: %v", filepath.Dir(path), err) + } + if err := os.WriteFile(path, []byte(contents), 0o600); err != nil { + t.Fatalf("write %s: %v", path, err) + } + } + write("user.cfg", "user") + write(filepath.Join("priv", "shadow.cfg"), "hash") + write(filepath.Join("priv", "token.cfg"), "token") + write(filepath.Join("priv", "tfa.cfg"), "tfa") + + collector.config.BackupPVEACL = true + + runSelectedBricksForTest(t, context.Background(), collector, newPVERecipe(), nil, + brickPVEConfigSnapshot, + ) + + destPVE := collector.targetPathFor(pveRoot) + for _, kept := range []string{ + "user.cfg", + filepath.Join("priv", "shadow.cfg"), + filepath.Join("priv", "token.cfg"), + filepath.Join("priv", "tfa.cfg"), + } { + if _, err := os.Stat(filepath.Join(destPVE, kept)); err != nil { + t.Fatalf("expected %s collected when BACKUP_PVE_ACL=true, got %v", kept, err) + } + } +} + +func TestCollectPVEConfigsPopulatesManifestDisabledForPrivWhenACLOff(t *testing.T) { + collector := newPVECollectorWithDeps(t, CollectorDeps{ + RunCommand: func(context.Context, string, ...string) ([]byte, error) { + return []byte("{}"), nil + }, + LookPath: func(cmd string) (string, error) { + return "/usr/bin/" + cmd, nil + }, + }) + + pveConfigPath := collector.config.PVEConfigPath + if err := os.MkdirAll(filepath.Join(pveConfigPath, "priv"), 0o700); err != nil { + t.Fatalf("mkdir priv: %v", err) + } + if err := os.WriteFile(filepath.Join(pveConfigPath, "priv", "shadow.cfg"), []byte("hash"), 0o600); err != nil { + t.Fatalf("write shadow.cfg: %v", err) + } + collector.config.BackupPVEACL = false + + if err := collector.CollectPVEConfigs(context.Background()); err != nil { + t.Fatalf("CollectPVEConfigs failed: %v", err) + } + + src := filepath.Join(collector.effectivePVEConfigPath(), "priv", "shadow.cfg") + dest := collector.targetPathFor(src) + key := pveManifestKey(collector.tempDir, dest) + entry, ok := collector.pveManifest[key] + if !ok { + t.Fatalf("expected manifest entry for %s (key=%s)", src, key) + } + if entry.Status != StatusDisabled { + t.Fatalf("expected %s status, got %s", StatusDisabled, entry.Status) + } +} + // TestPVEStoragePipeline runs the real PVE storage pipeline bricks. func TestPVEStoragePipeline(t *testing.T) { collector := newPVECollector(t) @@ -980,3 +1100,24 @@ func TestPVECephBricks(t *testing.T) { runSelectedBricksForTest(t, context.Background(), collector, newPVERecipe(), nil, brickPVECephConfigSnapshot, brickPVECephRuntime) }) } + +// TestDefaultPVEBackupPatternsCoverLegacyVariants verifies the default PVE backup +// patterns include legacy vzdump compression variants (.lzo/.xz) so they are not +// silently skipped during sampling/small-copy (issue #65). +func TestDefaultPVEBackupPatternsCoverLegacyVariants(t *testing.T) { + have := make(map[string]bool, len(defaultPVEBackupPatterns)) + for _, p := range defaultPVEBackupPatterns { + have[p] = true + } + for _, want := range []string{"*.vma.lzo", "*.vma.xz", "*.tar.lzo", "*.tar.xz"} { + if !have[want] { + t.Errorf("defaultPVEBackupPatterns missing legacy variant %q", want) + } + } + if !matchPattern("vzdump-qemu-100-2024_01_01.vma.lzo", "*.vma.lzo") { + t.Fatal("expected *.vma.lzo to match a legacy lzo backup file") + } + if !matchPattern("vzdump-lxc-101-2024_01_01.tar.xz", "*.tar.xz") { + t.Fatal("expected *.tar.xz to match a legacy xz backup file") + } +} diff --git a/internal/backup/collector_system.go b/internal/backup/collector_system.go index 78158f3..cd90fc7 100644 --- a/internal/backup/collector_system.go +++ b/internal/backup/collector_system.go @@ -137,6 +137,15 @@ func (c *Collector) CollectSystemInfo(ctx context.Context) error { ensureSystemPath() c.logger.Debug("System PATH verified for command execution") + + // Populate the system_files manifest for the duration of system collection so + // the backup manifest records which system targets were collected (issue #59). + if c.systemManifest == nil { + c.systemManifest = make(map[string]ManifestEntry) + } + c.recordSystemManifest = true + defer func() { c.recordSystemManifest = false }() + state := newCollectionState(c) if err := runRecipe(ctx, newSystemRecipe(), state); err != nil { return err @@ -1364,6 +1373,13 @@ func (c *Collector) collectConfigFile(ctx context.Context) error { func (c *Collector) collectCustomPaths(ctx context.Context) error { c.logger.Debug("Collecting custom paths defined in configuration") + // Operator-supplied paths may be broad (e.g. "/", "/tmp", "/tmp/proxsave") and + // thus contain the staging workspace; prune it from the source walk so the + // in-progress archive is never copied into itself (#56). Other collection + // sources are fixed system paths that never contain tempDir, so the prune is + // scoped to this phase only. + c.collectingCustomPaths = true + defer func() { c.collectingCustomPaths = false }() seen := make(map[string]struct{}) for _, rawPath := range c.config.CustomBackupPaths { @@ -1508,14 +1524,15 @@ func (c *Collector) collectScriptRepository(ctx context.Context) error { if err != nil || rel == "." { return nil } - parts := strings.Split(rel, string(filepath.Separator)) - if len(parts) > 0 { - if parts[0] == "backup" || parts[0] == "log" { - if d.IsDir() { - return filepath.SkipDir - } - return nil + // Skip VCS metadata and runtime/output dirs at ANY depth (not just the top + // level): .git/.svn/.hg carry full history/objects (large and sensitive), and + // backup(s)/log(s) are regenerated output that only bloats the snapshot. + switch d.Name() { + case ".git", ".svn", ".hg", "backup", "backups", "log", "logs": + if d.IsDir() { + return filepath.SkipDir } + return nil } dest := filepath.Join(target, rel) diff --git a/internal/backup/collector_system_test.go b/internal/backup/collector_system_test.go index 9e28159..b7a9556 100644 --- a/internal/backup/collector_system_test.go +++ b/internal/backup/collector_system_test.go @@ -1416,10 +1416,21 @@ func TestCollectScriptRepositoryCopiesAndSkipsRuntimeDirs(t *testing.T) { repo := t.TempDir() collector.config.ScriptRepositoryPath = repo + // Kept: real repo content. writeFileAt(t, filepath.Join(repo, "keep.sh"), "#!/bin/sh\n") writeFileAt(t, filepath.Join(repo, "nested", "config.env"), "A=1\n") + // Skipped (pre-existing): top-level backup/log as a dir and as a file. writeFileAt(t, filepath.Join(repo, "backup", "skip.tar"), "backup\n") writeFileAt(t, filepath.Join(repo, "log", "skip.log"), "log\n") + // Skipped (the #69 fix): .git at any depth, plural runtime dirs, and nested + // backup/log dirs that the old parts[0]-only check let through. + writeFileAt(t, filepath.Join(repo, ".git", "objects", "ab", "cd"), "obj\n") + writeFileAt(t, filepath.Join(repo, ".git", "logs", "HEAD"), "ref\n") + writeFileAt(t, filepath.Join(repo, ".svn", "entries"), "svn\n") + writeFileAt(t, filepath.Join(repo, "backups", "old.tar"), "old\n") + writeFileAt(t, filepath.Join(repo, "logs", "app.log"), "log\n") + writeFileAt(t, filepath.Join(repo, "nested", "backup", "skip"), "x\n") + writeFileAt(t, filepath.Join(repo, "nested", "log", "skip"), "y\n") if err := collector.collectScriptRepository(context.Background()); err != nil { t.Fatalf("collectScriptRepository: %v", err) @@ -1428,12 +1439,23 @@ func TestCollectScriptRepositoryCopiesAndSkipsRuntimeDirs(t *testing.T) { target := collector.proxsaveInfoDir("script-repository", filepath.Base(repo)) assertFileExists(t, filepath.Join(target, "keep.sh")) assertFileExists(t, filepath.Join(target, "nested", "config.env")) - if _, err := os.Stat(filepath.Join(target, "backup", "skip.tar")); !os.IsNotExist(err) { - t.Fatalf("expected backup dir skipped, stat err=%v", err) - } - if _, err := os.Stat(filepath.Join(target, "log", "skip.log")); !os.IsNotExist(err) { - t.Fatalf("expected log dir skipped, stat err=%v", err) - } + + assertAbsent := func(rel string) { + t.Helper() + if _, err := os.Stat(filepath.Join(target, rel)); !os.IsNotExist(err) { + t.Fatalf("expected %q to be skipped, stat err=%v", rel, err) + } + } + assertAbsent(filepath.Join("backup", "skip.tar")) + assertAbsent(filepath.Join("log", "skip.log")) + assertAbsent(".git") + assertAbsent(filepath.Join(".git", "objects", "ab", "cd")) + assertAbsent(filepath.Join(".git", "logs", "HEAD")) + assertAbsent(filepath.Join(".svn", "entries")) + assertAbsent(filepath.Join("backups", "old.tar")) + assertAbsent(filepath.Join("logs", "app.log")) + assertAbsent(filepath.Join("nested", "backup", "skip")) + assertAbsent(filepath.Join("nested", "log", "skip")) } func TestCollectScriptRepositorySkipAndCancelBranches(t *testing.T) { @@ -2900,3 +2922,122 @@ func newTestCollectorWithDeps(t *testing.T, override CollectorDeps) *Collector { tempDir := t.TempDir() return NewCollectorWithDeps(logger, config, tempDir, types.ProxmoxUnknown, false, deps) } + +// TestSystemManifestRecordsTargetsNotNestedFiles verifies system collection +// populates systemManifest at collection-target granularity (issue #59): a direct +// file copy and a directory copy each yield one entry, a missing source is +// recorded as not_found, and files nested inside a copied directory are NOT +// recorded individually. +func TestSystemManifestRecordsTargetsNotNestedFiles(t *testing.T) { + src := t.TempDir() + if err := os.WriteFile(filepath.Join(src, "hostname"), []byte("h"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(filepath.Join(src, "netdir", "sub"), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(src, "netdir", "a.conf"), []byte("a"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(src, "netdir", "sub", "b.conf"), []byte("b"), 0o644); err != nil { + t.Fatal(err) + } + + collector := newTestCollectorWithDeps(t, CollectorDeps{}) + collector.systemManifest = make(map[string]ManifestEntry) + collector.recordSystemManifest = true + + ctx := context.Background() + if err := collector.safeCopyFile(ctx, filepath.Join(src, "hostname"), filepath.Join(collector.tempDir, "etc/hostname"), "Hostname"); err != nil { + t.Fatalf("safeCopyFile hostname: %v", err) + } + if err := collector.safeCopyFile(ctx, filepath.Join(src, "missing"), filepath.Join(collector.tempDir, "etc/missing"), "Missing"); err != nil { + t.Fatalf("safeCopyFile missing: %v", err) + } + if err := collector.safeCopyDir(ctx, filepath.Join(src, "netdir"), filepath.Join(collector.tempDir, "etc/netdir"), "Net dir"); err != nil { + t.Fatalf("safeCopyDir netdir: %v", err) + } + + m := collector.systemManifest + if got := m["etc/hostname"]; got.Status != StatusCollected { + t.Fatalf("etc/hostname: want collected, got %+v", got) + } + if got := m["etc/missing"]; got.Status != StatusNotFound { + t.Fatalf("etc/missing: want not_found, got %+v", got) + } + if got := m["etc/netdir"]; got.Status != StatusCollected { + t.Fatalf("etc/netdir: want collected dir target, got %+v", got) + } + for k := range m { + if strings.HasPrefix(k, "etc/netdir/") { + t.Fatalf("nested file %q must not be recorded (only the dir target)", k) + } + } + if len(m) != 3 { + t.Fatalf("expected 3 system manifest entries (hostname, missing, netdir), got %d: %+v", len(m), m) + } +} + +// TestSafeCopyDirRecordsFailedOnError is the #2 guard: a directory copy that fails +// (here via a canceled context during the walk) must be recorded in the system +// manifest as failed, not left as the up-front "collected" status. +func TestSafeCopyDirRecordsFailedOnError(t *testing.T) { + src := t.TempDir() + if err := os.WriteFile(filepath.Join(src, "a.conf"), []byte("a"), 0o644); err != nil { + t.Fatal(err) + } + + collector := newTestCollectorWithDeps(t, CollectorDeps{}) + collector.systemManifest = make(map[string]ManifestEntry) + collector.recordSystemManifest = true + + // Make ensureDir(dest) fail: put a regular FILE where dest's parent dir would be, + // so MkdirAll(dest) returns ENOTDIR. ctx stays valid (the top-of-function guard + // must not short-circuit), so the failure happens inside the copy itself. + parent := filepath.Join(collector.tempDir, "etc") + if err := os.WriteFile(parent, []byte("blocker"), 0o644); err != nil { + t.Fatal(err) + } + dest := filepath.Join(parent, "netdir") // tempDir/etc/netdir, but tempDir/etc is a file + + if err := collector.safeCopyDir(context.Background(), src, dest, "Net dir"); err == nil { + t.Fatal("safeCopyDir should return an error when ensureDir(dest) fails") + } + if got := collector.systemManifest["etc/netdir"]; got.Status != StatusFailed { + t.Fatalf("etc/netdir: want failed status on a failed copy, got %+v", got) + } +} + +// TestSafeCopyDirSkipsStagingWorkspace verifies a broad source does not copy the +// staging workspace into itself (issue #56): the staging subtree under the source +// must be pruned while the real content is still collected. +func TestSafeCopyDirSkipsStagingWorkspace(t *testing.T) { + collector := newTestCollectorWithDeps(t, CollectorDeps{}) + + srcDir := t.TempDir() + if err := os.WriteFile(filepath.Join(srcDir, "real.conf"), []byte("keep"), 0o600); err != nil { + t.Fatal(err) + } + // The staging workspace lives under the source (the self-recursion case). + staging := filepath.Join(srcDir, "proxsave-staging") + if err := os.MkdirAll(staging, 0o700); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(staging, "ARCHIVE_DATA"), []byte("must not be copied"), 0o600); err != nil { + t.Fatal(err) + } + collector.tempDir = staging + collector.collectingCustomPaths = true // the prune is scoped to custom-path collection + + dest := filepath.Join(staging, "etc", "custom") + if err := collector.safeCopyDir(context.Background(), srcDir, dest, "custom"); err != nil { + t.Fatalf("safeCopyDir: %v", err) + } + + if _, err := os.Stat(filepath.Join(dest, "real.conf")); err != nil { + t.Fatalf("expected real.conf to be collected: %v", err) + } + if _, err := os.Stat(filepath.Join(dest, "proxsave-staging")); !os.IsNotExist(err) { + t.Fatalf("staging workspace must not be copied into itself (#56), stat err=%v", err) + } +} diff --git a/internal/backup/optimizations.go b/internal/backup/optimizations.go index af3f663..0ca55ba 100644 --- a/internal/backup/optimizations.go +++ b/internal/backup/optimizations.go @@ -19,6 +19,19 @@ const ( defaultOptimizedFilePerm = 0o640 ) +// DedupManifestRelPath is where deduplicateFiles records the symlinks it created, +// relative to the staging/archive root. The restore reads it (always extracted) +// to materialize those symlinks back into regular files, so selective restore +// never produces dangling links and full restore preserves file-type fidelity +// (issue #70). +const DedupManifestRelPath = "var/lib/proxsave-info/dedup_manifest.json" + +// DedupManifestEntry records one file that deduplication replaced with a symlink. +type DedupManifestEntry struct { + Path string `json:"path"` // path relative to the archive root, slash-separated + Mode uint32 `json:"mode"` // original regular-file permission bits +} + // OptimizationConfig controls optional preprocessing steps executed before archiving. type OptimizationConfig struct { EnableDeduplication bool @@ -31,10 +44,22 @@ func (c OptimizationConfig) Enabled() bool { return c.EnableDeduplication || c.EnablePrefilter } -// ApplyOptimizations executes the requested optimizations in sequence. -func ApplyOptimizations(ctx context.Context, logger *logging.Logger, root string, cfg OptimizationConfig) error { +// OptimizationResult reports what the optimization stages removed from the staged +// tree. Callers use BytesReclaimed to correct the reported uncompressed-payload size +// (issue #73): dedup and prefilter shrink the tree AFTER the collection stats were +// snapshotted, so the pre-optimization byte total would otherwise inflate the +// compression ratio shown in reports/notifications/metrics. +type OptimizationResult struct { + BytesReclaimed int64 // bytes removed from the staged tree by dedup + prefilter + DuplicatesReplaced int +} + +// ApplyOptimizations executes the requested optimizations in sequence and reports +// how many bytes they reclaimed. +func ApplyOptimizations(ctx context.Context, logger *logging.Logger, root string, cfg OptimizationConfig) (OptimizationResult, error) { + var res OptimizationResult if !cfg.Enabled() { - return nil + return res, nil } logger.Info("Running backup optimizations (dedup=%v prefilter=%v)", @@ -42,34 +67,45 @@ func ApplyOptimizations(ctx context.Context, logger *logging.Logger, root string if cfg.EnableDeduplication { logger.Debug("Starting deduplication stage") - if err := deduplicateFiles(ctx, logger, root); err != nil { - logger.Warning("File deduplication failed: %v", err) - } else { - logger.Debug("Deduplication stage completed") + dups, reclaimed, err := deduplicateFiles(ctx, logger, root) + if err != nil { + // A dedup error means the staging tree may still hold symlinks the restore + // cannot materialize (manifest unwritten, partial revert): fail rather than + // archive a tree that would lose fidelity on restore (issue #70). The + // happy path and a fully-reverted manifest failure both return nil. + return OptimizationResult{}, fmt.Errorf("deduplication: %w", err) } + res.DuplicatesReplaced = dups + res.BytesReclaimed += reclaimed + logger.Debug("Deduplication stage completed") } if cfg.EnablePrefilter { logger.Debug("Starting prefilter stage (max file size %d bytes)", cfg.PrefilterMaxFileSizeBytes) - if err := prefilterFiles(ctx, logger, root, cfg.PrefilterMaxFileSizeBytes); err != nil { + reclaimed, err := prefilterFiles(ctx, logger, root, cfg.PrefilterMaxFileSizeBytes) + if err != nil { logger.Warning("Content prefilter failed: %v", err) } else { + res.BytesReclaimed += reclaimed logger.Debug("Prefilter stage completed") } } - return nil + return res, nil } -func deduplicateFiles(ctx context.Context, logger *logging.Logger, root string) error { +func deduplicateFiles(ctx context.Context, logger *logging.Logger, root string) (int, int64, error) { logger.Debug("Scanning files for deduplication") hashes := make(map[string]string) var duplicates int + var bytesReclaimed int64 + var manifest []DedupManifestEntry + var replaced []dedupReplacement rootFS, err := os.OpenRoot(root) if err != nil { - return fmt.Errorf("open dedup root: %w", err) + return 0, 0, fmt.Errorf("open dedup root: %w", err) } defer func() { _ = rootFS.Close() }() @@ -115,6 +151,16 @@ func deduplicateFiles(ctx context.Context, logger *logging.Logger, root string) return nil } duplicates++ + bytesReclaimed += info.Size() + manifest = append(manifest, DedupManifestEntry{ + Path: filepath.ToSlash(rel), + Mode: uint32(info.Mode().Perm()), + }) + replaced = append(replaced, dedupReplacement{ + duplicate: path, + canonical: existing, + mode: info.Mode().Perm(), + }) logger.Debug("Deduplicated %s → %s", path, existing) } else { hashes[hash] = path @@ -123,11 +169,51 @@ func deduplicateFiles(ctx context.Context, logger *logging.Logger, root string) }) if err != nil { - return fmt.Errorf("deduplication walk failed: %w", err) + return 0, 0, fmt.Errorf("deduplication walk failed: %w", err) + } + + if err := writeDedupManifest(root, manifest); err != nil { + // Without the manifest the restore cannot materialize these symlinks, so an + // unrecorded symlink would ship and break fidelity (issue #70). Revert every + // symlink back to a regular file so the archive degrades to "no dedup this + // run" rather than carrying unrecoverable links. + logger.Warning("Failed to write dedup manifest; reverting %d deduplicated symlink(s) to regular files: %v", len(replaced), err) + reverted := 0 + for _, r := range replaced { + if rerr := revertDedupSymlink(r); rerr != nil { + logger.Warning("Failed to revert deduplicated symlink %s: %v", r.duplicate, rerr) + continue + } + reverted++ + } + if reverted != len(replaced) { + return 0, 0, fmt.Errorf("write dedup manifest: %w (reverted %d/%d symlinks)", err, reverted, len(replaced)) + } + // All symlinks reverted to regular files: nothing was actually reclaimed. + logger.Info("Deduplication aborted (manifest unwritable); %d symlink(s) reverted to regular files", reverted) + return 0, 0, nil } logger.Info("Deduplication completed: %d duplicates replaced", duplicates) - return nil + return duplicates, bytesReclaimed, nil +} + +// writeDedupManifest records the deduplicated symlinks so the restore can +// materialize them back into regular files (issue #70). It is a no-op when no +// files were deduplicated. +func writeDedupManifest(root string, entries []DedupManifestEntry) error { + if len(entries) == 0 { + return nil + } + data, err := json.Marshal(entries) + if err != nil { + return err + } + dest := filepath.Join(root, filepath.FromSlash(DedupManifestRelPath)) + if err := os.MkdirAll(filepath.Dir(dest), 0o700); err != nil { + return err + } + return os.WriteFile(dest, data, 0o600) } func shouldSkipDedupPath(rel string) bool { @@ -158,17 +244,81 @@ func hashFile(root *os.Root, name string) (sum string, err error) { } func replaceWithSymlink(target, duplicate string) error { - if err := os.Remove(duplicate); err != nil { - return err - } rel, err := filepath.Rel(filepath.Dir(duplicate), target) if err != nil { rel = target } - return os.Symlink(rel, duplicate) + // Create the symlink at a UNIQUE temporary name in the same directory, then + // atomically rename it over the duplicate. A unique name (not the fixed + // duplicate+".dedup.tmp") avoids destroying a real staged file that happens to + // carry that suffix, and the rename keeps the replacement fail-closed: on any + // error the original duplicate is left untouched (issues #70/#71). + tmpFile, err := os.CreateTemp(filepath.Dir(duplicate), ".proxsave-dedup-*") + if err != nil { + return err + } + tmp := tmpFile.Name() + _ = tmpFile.Close() + _ = os.Remove(tmp) // os.Symlink needs a non-existent path + if err := os.Symlink(rel, tmp); err != nil { + return err + } + if err := os.Rename(tmp, duplicate); err != nil { + _ = os.Remove(tmp) + return err + } + return nil +} + +// dedupReplacement remembers a symlink dedup created so it can be reverted to a +// regular file if the manifest cannot be written (so an unrecorded symlink, which +// the restore could not materialize, is never shipped). +type dedupReplacement struct { + duplicate string // absolute staged path now holding the symlink + canonical string // absolute staged path of the kept original + mode os.FileMode +} + +// revertDedupSymlink turns one dedup symlink back into a regular copy of its +// canonical. Used when the manifest write fails so the archive carries plain files. +// It writes to a sibling temp then renames over the symlink, so a failed write never +// leaves the duplicate missing (no remove-then-write window). +func revertDedupSymlink(r dedupReplacement) error { + content, err := os.ReadFile(r.canonical) + if err != nil { + return err + } + mode := r.mode.Perm() + if mode == 0 { + mode = 0o600 + } + tmp, err := os.CreateTemp(filepath.Dir(r.duplicate), ".proxsave-dedup-revert-*") + if err != nil { + return err + } + tmpPath := tmp.Name() + if _, err := tmp.Write(content); err != nil { + _ = tmp.Close() + _ = os.Remove(tmpPath) + return err + } + if err := tmp.Chmod(mode); err != nil { + _ = tmp.Close() + _ = os.Remove(tmpPath) + return err + } + if err := tmp.Close(); err != nil { + _ = os.Remove(tmpPath) + return err + } + if err := os.Rename(tmpPath, r.duplicate); err != nil { + _ = os.Remove(tmpPath) + return err + } + return nil } -func prefilterFiles(ctx context.Context, logger *logging.Logger, root string, maxSize int64) error { +func prefilterFiles(ctx context.Context, logger *logging.Logger, root string, maxSize int64) (int64, error) { if maxSize <= 0 { maxSize = defaultPrefilterMaxSizeBytes } @@ -181,6 +331,7 @@ func prefilterFiles(ctx context.Context, logger *logging.Logger, root string, ma skippedSymlink int } var stats prefilterStats + var reclaimed int64 isStructuredConfigPath := func(path string) bool { rel, err := filepath.Rel(root, path) @@ -207,7 +358,7 @@ func prefilterFiles(ctx context.Context, logger *logging.Logger, root string, ma rootFS, err := os.OpenRoot(root) if err != nil { - return fmt.Errorf("open prefilter root: %w", err) + return 0, fmt.Errorf("open prefilter root: %w", err) } defer func() { _ = rootFS.Close() }() @@ -243,34 +394,48 @@ func prefilterFiles(ctx context.Context, logger *logging.Logger, root string, ma } stats.scanned++ + before := info.Size() + changed := false ext := strings.ToLower(filepath.Ext(path)) switch ext { case ".txt", ".log", ".md": - if changed, err := normalizeTextFile(rootFS, rel); err == nil && changed { - stats.optimized++ + if c, err := normalizeTextFile(rootFS, rel); err == nil && c { + changed = true } case ".conf", ".cfg", ".ini": if isStructuredConfigPath(path) { stats.skippedStructured++ return nil } - if changed, err := normalizeConfigFile(rootFS, rel); err == nil && changed { - stats.optimized++ + if c, err := normalizeConfigFile(rootFS, rel); err == nil && c { + changed = true } case ".json": - if changed, err := minifyJSON(rootFS, rel); err == nil && changed { - stats.optimized++ + if isStructuredConfigPath(path) { + stats.skippedStructured++ + return nil + } + if c, err := minifyJSON(rootFS, rel); err == nil && c { + changed = true + } + } + if changed { + stats.optimized++ + // Account for bytes removed (issue #73 ratio correction); re-stat the + // rewritten file (best-effort). + if newInfo, serr := os.Lstat(path); serr == nil && newInfo.Size() < before { + reclaimed += before - newInfo.Size() } } return nil }) if err != nil { - return fmt.Errorf("prefilter walk failed: %w", err) + return 0, fmt.Errorf("prefilter walk failed: %w", err) } - logger.Info("Prefilter completed: optimized=%d scanned=%d skipped_structured=%d skipped_symlink=%d", stats.optimized, stats.scanned, stats.skippedStructured, stats.skippedSymlink) - return nil + logger.Info("Prefilter completed: optimized=%d scanned=%d skipped_structured=%d skipped_symlink=%d reclaimed_bytes=%d", stats.optimized, stats.scanned, stats.skippedStructured, stats.skippedSymlink, reclaimed) + return reclaimed, nil } // normalizeTextFile reads and rewrites name through root, an *os.Root opened on @@ -300,14 +465,15 @@ func minifyJSON(root *os.Root, name string) (bool, error) { if err != nil { return false, err } - var tmp any - if err := json.Unmarshal(data, &tmp); err != nil { - return false, err - } - minified, err := json.Marshal(tmp) - if err != nil { + // json.Compact strips only insignificant whitespace at the token level. Unlike + // an Unmarshal-into-any + Marshal round-trip it preserves number text/precision + // (no >2^53 rounding), key order and duplicate keys, so the payload stays + // byte-faithful aside from whitespace (issue #72). + var buf bytes.Buffer + if err := json.Compact(&buf, data); err != nil { return false, err } + minified := buf.Bytes() if bytes.Equal(bytes.TrimSpace(data), minified) { return false, nil } diff --git a/internal/backup/optimizations_bench_test.go b/internal/backup/optimizations_bench_test.go index 1180a42..ba43409 100644 --- a/internal/backup/optimizations_bench_test.go +++ b/internal/backup/optimizations_bench_test.go @@ -42,7 +42,7 @@ func BenchmarkPrefilterFiles(b *testing.B) { if err := copyDir(template, iterRoot); err != nil { b.Fatalf("copy template: %v", err) } - if err := prefilterFiles(ctx, logger, iterRoot, maxSize); err != nil { + if _, err := prefilterFiles(ctx, logger, iterRoot, maxSize); err != nil { b.Fatalf("prefilterFiles: %v", err) } } diff --git a/internal/backup/optimizations_structured_test.go b/internal/backup/optimizations_structured_test.go index 5d20129..f168864 100644 --- a/internal/backup/optimizations_structured_test.go +++ b/internal/backup/optimizations_structured_test.go @@ -46,7 +46,7 @@ func TestPrefilterSkipsStructuredConfigs(t *testing.T) { // Run prefilter logger := logging.New(types.LogLevelError, false) - if err := prefilterFiles(context.Background(), logger, tmp, 8*1024*1024); err != nil { + if _, err := prefilterFiles(context.Background(), logger, tmp, 8*1024*1024); err != nil { t.Fatalf("prefilterFiles: %v", err) } diff --git a/internal/backup/optimizations_test.go b/internal/backup/optimizations_test.go index c50cf7f..ada857d 100644 --- a/internal/backup/optimizations_test.go +++ b/internal/backup/optimizations_test.go @@ -3,8 +3,10 @@ package backup import ( "bytes" "context" + "encoding/json" "os" "path/filepath" + "strings" "testing" "github.com/tis24dev/proxsave/internal/logging" @@ -51,7 +53,7 @@ func TestApplyOptimizationsRunsAllStages(t *testing.T) { PrefilterMaxFileSizeBytes: 1024, } - if err := ApplyOptimizations(context.Background(), logger, root, cfg); err != nil { + if _, err := ApplyOptimizations(context.Background(), logger, root, cfg); err != nil { t.Fatalf("ApplyOptimizations: %v", err) } @@ -117,7 +119,7 @@ func TestDedupDoesNotReplaceCriticalFilesWithSymlinks(t *testing.T) { cfg := OptimizationConfig{ EnableDeduplication: true, } - if err := ApplyOptimizations(context.Background(), logger, root, cfg); err != nil { + if _, err := ApplyOptimizations(context.Background(), logger, root, cfg); err != nil { t.Fatalf("ApplyOptimizations: %v", err) } @@ -136,3 +138,230 @@ func TestDedupDoesNotReplaceCriticalFilesWithSymlinks(t *testing.T) { t.Fatalf("resolv.conf content mismatch: got %q want %q", got, resolvContent) } } + +// TestApplyOptimizationsFailsFatallyOnDedupError guards the #70 safety contract: an +// unsafe deduplication state must NOT be swallowed to a warning; ApplyOptimizations +// must return the error so the backup run aborts rather than ship a damaged tree. +func TestApplyOptimizationsFailsFatallyOnDedupError(t *testing.T) { + logger := logging.New(types.LogLevelError, false) + // A non-existent dedup root makes deduplicateFiles fail (os.OpenRoot error); the + // happy/fully-reverted paths return nil, so a returned error here can only come + // from an unsafe state that must abort. + _, err := ApplyOptimizations(context.Background(), logger, "/proxsave-nonexistent-root-xyz", OptimizationConfig{EnableDeduplication: true}) + if err == nil { + t.Fatal("ApplyOptimizations must return (not swallow) a deduplication error so the backup aborts") + } + if !strings.Contains(err.Error(), "deduplication") { + t.Fatalf("error should identify the deduplication stage, got: %v", err) + } +} + +// TestDeduplicationRevertsSymlinksWhenManifestUnwritable guards #70: if the dedup +// manifest cannot be written, the symlinks are reverted to regular files so the +// archive never ships unrecorded (unrecoverable) links. +func TestDeduplicationRevertsSymlinksWhenManifestUnwritable(t *testing.T) { + root := t.TempDir() + if err := os.WriteFile(filepath.Join(root, "a.txt"), []byte("same"), 0o640); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(root, "b.txt"), []byte("same"), 0o640); err != nil { + t.Fatal(err) + } + // Make the manifest path a DIRECTORY so writeDedupManifest's WriteFile fails. + manifestAsDir := filepath.Join(root, filepath.FromSlash(DedupManifestRelPath)) + if err := os.MkdirAll(manifestAsDir, 0o755); err != nil { + t.Fatal(err) + } + + logger := logging.New(types.LogLevelError, false) + if _, _, err := deduplicateFiles(context.Background(), logger, root); err != nil { + t.Fatalf("deduplicateFiles should succeed (revert) when the manifest cannot be written: %v", err) + } + + for _, name := range []string{"a.txt", "b.txt"} { + p := filepath.Join(root, name) + info, err := os.Lstat(p) + if err != nil { + t.Fatalf("lstat %s: %v", name, err) + } + if info.Mode()&os.ModeSymlink != 0 { + t.Fatalf("%s must be reverted to a regular file when the manifest is unwritable, got symlink", name) + } + if data, err := os.ReadFile(p); err != nil || string(data) != "same" { + t.Fatalf("%s content lost after revert: %q err=%v", name, data, err) + } + } +} + +// TestReplaceWithSymlinkDoesNotClobberSuffixFile guards #71: using a unique temp +// name must never destroy a real staged file that happens to carry a dedup temp +// suffix. +func TestReplaceWithSymlinkDoesNotClobberSuffixFile(t *testing.T) { + root := t.TempDir() + target := filepath.Join(root, "a.txt") + duplicate := filepath.Join(root, "b.txt") + suffixFile := duplicate + ".dedup.tmp" // a legitimate backed-up file, not our temp + if err := os.WriteFile(target, []byte("data"), 0o640); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(duplicate, []byte("data"), 0o640); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(suffixFile, []byte("precious"), 0o640); err != nil { + t.Fatal(err) + } + + if err := replaceWithSymlink(target, duplicate); err != nil { + t.Fatalf("replaceWithSymlink: %v", err) + } + + if info, err := os.Lstat(duplicate); err != nil || info.Mode()&os.ModeSymlink == 0 { + t.Fatalf("duplicate should be a symlink after dedup: info=%v err=%v", info, err) + } + // The pre-existing .dedup.tmp file must be intact (the old fixed-name temp would + // have removed/clobbered it). + if got, err := os.ReadFile(suffixFile); err != nil || string(got) != "precious" { + t.Fatalf("a real *.dedup.tmp file must not be clobbered: got %q err=%v", got, err) + } +} + +// TestReplaceWithSymlinkFailClosedOnRenameFailure verifies that when the rename +// over the duplicate fails, the original duplicate is left untouched (fail-closed, +// issue #71). +func TestReplaceWithSymlinkFailClosedOnRenameFailure(t *testing.T) { + root := t.TempDir() + target := filepath.Join(root, "a.txt") + duplicate := filepath.Join(root, "dup") + if err := os.WriteFile(target, []byte("data"), 0o640); err != nil { + t.Fatal(err) + } + // Make the duplicate a NON-EMPTY directory so renaming a file over it fails. + if err := os.MkdirAll(filepath.Join(duplicate, "child"), 0o755); err != nil { + t.Fatal(err) + } + + if err := replaceWithSymlink(target, duplicate); err == nil { + t.Fatal("expected replaceWithSymlink to fail when rename over the duplicate cannot succeed") + } + + info, err := os.Lstat(duplicate) + if err != nil { + t.Fatalf("duplicate must still exist after a failed dedup: %v", err) + } + if !info.IsDir() { + t.Fatal("duplicate must be left untouched (still a directory) after a failed dedup") + } + if _, err := os.Stat(filepath.Join(duplicate, "child")); err != nil { + t.Fatalf("duplicate contents must be preserved on failure: %v", err) + } +} + +// TestMinifyJSONIsLossless verifies json.Compact preserves number precision, +// duplicate keys and key order (issue #72) while still stripping whitespace. +func TestMinifyJSONIsLossless(t *testing.T) { + root := t.TempDir() + const name = "data.json" + input := `{ "id": 123456789012345678, "b": 1, "b": 2, "ratio": 1.0 }` + if err := os.WriteFile(filepath.Join(root, name), []byte(input), 0o640); err != nil { + t.Fatal(err) + } + rootFS, err := os.OpenRoot(root) + if err != nil { + t.Fatal(err) + } + defer func() { _ = rootFS.Close() }() + + changed, err := minifyJSON(rootFS, name) + if err != nil { + t.Fatalf("minifyJSON: %v", err) + } + if !changed { + t.Fatal("expected JSON whitespace to be stripped") + } + got, err := os.ReadFile(filepath.Join(root, name)) + if err != nil { + t.Fatal(err) + } + want := `{"id":123456789012345678,"b":1,"b":2,"ratio":1.0}` + if string(got) != want { + t.Fatalf("minifyJSON is not lossless:\n got %q\nwant %q", got, want) + } +} + +// TestPrefilterSkipsStructuredConfigJSON verifies JSON under sensitive config +// directories is left untouched by the prefilter (issue #72 defense-in-depth). +func TestPrefilterSkipsStructuredConfigJSON(t *testing.T) { + root := t.TempDir() + path := filepath.Join(root, "etc", "proxmox-backup", "shadow.json") + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatal(err) + } + original := "{\n \"user\": \"x\"\n}\n" + if err := os.WriteFile(path, []byte(original), 0o640); err != nil { + t.Fatal(err) + } + + logger := logging.New(types.LogLevelError, false) + if _, err := prefilterFiles(context.Background(), logger, root, 1024); err != nil { + t.Fatalf("prefilterFiles: %v", err) + } + + got, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + if string(got) != original { + t.Fatalf("structured-config JSON must not be modified, got %q", got) + } +} + +// TestDeduplicationWritesManifest verifies dedup records each created symlink in +// the manifest the restore uses to materialize them back (issue #70). +func TestDeduplicationWritesManifest(t *testing.T) { + root := t.TempDir() + write := func(rel, content string, mode os.FileMode) { + t.Helper() + p := filepath.Join(root, rel) + if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(p, []byte(content), mode); err != nil { + t.Fatal(err) + } + } + write(filepath.Join("a", "one.cfg"), "same content here", 0o640) + write(filepath.Join("a", "two.cfg"), "same content here", 0o600) + + logger := logging.New(types.LogLevelError, false) + res, err := ApplyOptimizations(context.Background(), logger, root, OptimizationConfig{EnableDeduplication: true}) + if err != nil { + t.Fatalf("ApplyOptimizations: %v", err) + } + // #73: the result reports the reclaimed bytes (the deduplicated duplicate's size) + // so the caller can correct the uncompressed-payload figure / compression ratio. + if res.DuplicatesReplaced != 1 { + t.Fatalf("expected 1 duplicate replaced, got %d", res.DuplicatesReplaced) + } + if want := int64(len("same content here")); res.BytesReclaimed != want { + t.Fatalf("expected BytesReclaimed=%d (the duplicate size), got %d", want, res.BytesReclaimed) + } + + data, err := os.ReadFile(filepath.Join(root, filepath.FromSlash(DedupManifestRelPath))) + if err != nil { + t.Fatalf("read dedup manifest: %v", err) + } + var entries []DedupManifestEntry + if err := json.Unmarshal(data, &entries); err != nil { + t.Fatalf("unmarshal dedup manifest: %v", err) + } + if len(entries) != 1 { + t.Fatalf("expected 1 dedup entry, got %d (%+v)", len(entries), entries) + } + // WalkDir visits a/one.cfg before a/two.cfg, so two.cfg is the one symlinked. + if entries[0].Path != "a/two.cfg" { + t.Fatalf("unexpected dedup path %q", entries[0].Path) + } + if entries[0].Mode != uint32(0o600) { + t.Fatalf("expected recorded mode 0600, got %o", entries[0].Mode) + } +} diff --git a/internal/checks/checks.go b/internal/checks/checks.go index 44b56e3..4dd223a 100644 --- a/internal/checks/checks.go +++ b/internal/checks/checks.go @@ -25,6 +25,7 @@ var closeTestFile = func(f *os.File) error { return f.Close() } var ( osStat = os.Stat + osLstat = os.Lstat osRemove = os.Remove osOpenFile = os.OpenFile osMkdirAll = os.MkdirAll @@ -279,6 +280,11 @@ func sameHost(a, b string) bool { } // CheckLockFile checks for stale lock files and creates a new lock +// CheckCodeBackupInProgress marks a lock-file check that failed because another +// backup is actively running. It is a benign concurrency skip, not a real +// failure, so callers can treat it differently (no failure notification). +const CheckCodeBackupInProgress = "BACKUP_IN_PROGRESS" + func (c *Checker) CheckLockFile() CheckResult { result := CheckResult{ Name: "Lock File", @@ -303,6 +309,7 @@ func (c *Checker) CheckLockFile() CheckResult { age := time.Since(info.ModTime()) formatInProgress := func(age time.Duration, meta lockFileMetadata) string { + result.Code = CheckCodeBackupInProgress parts := []string{fmt.Sprintf("lock age: %v", age)} if meta.PID > 0 { parts = append(parts, fmt.Sprintf("pid=%d", meta.PID)) @@ -377,6 +384,12 @@ func (c *Checker) CheckLockFile() CheckResult { f, err := osOpenFile(lockPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0640) if err != nil { if os.IsExist(err) { + // Lost the atomic O_EXCL create race: another backup created the lock + // between the stat above and this create. Set the in-progress code so + // callers treat it as a benign concurrency skip, exactly like the + // stat-found-a-live-lock path above; otherwise this race path raises a + // spurious failure notification. + result.Code = CheckCodeBackupInProgress result.Message = "Another backup acquired the lock" c.logger.Error("%s", result.Message) return result @@ -630,6 +643,16 @@ func (c *Checker) CheckTempDirectory() CheckResult { c.logger.Debug("Temp directory exists: %s", tempRoot) } + // osStat follows symlinks, so a pre-created /tmp/proxsave pointing at an + // attacker-controlled directory would otherwise pass. Reject a symlinked root + // (issue #54). + if linfo, lerr := osLstat(tempRoot); lerr == nil && linfo.Mode()&os.ModeSymlink != 0 { + result.Code = "SYMLINK_REJECTED" + result.Error = fmt.Errorf("temp path is a symlink - path: %s", tempRoot) + result.Message = result.Error.Error() + return result + } + if !info.IsDir() { result.Code = "NOT_DIRECTORY" result.Error = fmt.Errorf("temp path is not a directory - path: %s", tempRoot) diff --git a/internal/checks/checks_test.go b/internal/checks/checks_test.go index f06f210..6480f7b 100644 --- a/internal/checks/checks_test.go +++ b/internal/checks/checks_test.go @@ -114,6 +114,79 @@ func TestCheckLockFile(t *testing.T) { } } +// TestCheckLockFile_InProgressSetsCode verifies that a fresh (live) lock makes +// CheckLockFile report the BACKUP_IN_PROGRESS code, so callers can treat it as a +// benign concurrency skip rather than a failure. +func TestCheckLockFile_InProgressSetsCode(t *testing.T) { + logger := logging.New(types.LogLevelInfo, false) + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, ".backup.lock") + + cfg := &CheckerConfig{ + BackupPath: tmpDir, + LogPath: tmpDir, + LockDirPath: tmpDir, + LockFilePath: lockPath, + MaxLockAge: 1 * time.Hour, + DryRun: false, + } + checker := NewChecker(logger, cfg) + + // First call creates and acquires a fresh lock (passes). + if result := checker.CheckLockFile(); !result.Passed { + t.Fatalf("first CheckLockFile should pass: %s", result.Message) + } + t.Cleanup(func() { _ = checker.ReleaseLock() }) + + // Second call sees the fresh lock as another backup in progress. + result := checker.CheckLockFile() + if result.Passed { + t.Fatal("second CheckLockFile should fail with a live lock present") + } + if result.Code != CheckCodeBackupInProgress { + t.Fatalf("expected Code=%q, got %q (message=%q)", CheckCodeBackupInProgress, result.Code, result.Message) + } +} + +// TestCheckLockFile_CreateRaceSetsCode verifies that LOSING the atomic O_EXCL +// create race (another backup created the lock between our stat and our create) +// also reports the BACKUP_IN_PROGRESS code, so it is treated as a benign +// concurrency skip rather than a failure notification (matching the +// stat-found-a-live-lock path). +func TestCheckLockFile_CreateRaceSetsCode(t *testing.T) { + logger := logging.New(types.LogLevelInfo, false) + logger.SetOutput(io.Discard) + + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, ".backup.lock") + + // No lock file exists at stat time (fresh dir), so CheckLockFile proceeds to the + // atomic create; force that create to lose the race with EEXIST. + origOpen := osOpenFile + t.Cleanup(func() { osOpenFile = origOpen }) + osOpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) { + return nil, &os.PathError{Op: "open", Path: name, Err: syscall.EEXIST} + } + + cfg := &CheckerConfig{ + BackupPath: tmpDir, + LogPath: tmpDir, + LockDirPath: tmpDir, + LockFilePath: lockPath, + MaxLockAge: time.Hour, + DryRun: false, + } + checker := NewChecker(logger, cfg) + + result := checker.CheckLockFile() + if result.Passed { + t.Fatal("CheckLockFile should fail when the atomic create loses the race") + } + if result.Code != CheckCodeBackupInProgress { + t.Fatalf("create-race must set Code=%q, got %q (message=%q)", CheckCodeBackupInProgress, result.Code, result.Message) + } +} + func TestCheckLockFileStaleLock(t *testing.T) { logger := logging.New(types.LogLevelInfo, false) tmpDir := t.TempDir() diff --git a/internal/config/config.go b/internal/config/config.go index f0ac7e6..f0c6883 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -287,6 +287,7 @@ type Config struct { BackupScriptRepository bool BackupUserHomes bool BackupConfigFile bool + SystemRootPrefix string PVEConfigPath string PBSConfigPath string PVEClusterPath string @@ -854,9 +855,14 @@ func (c *Config) parseSystemSettings() { c.BackupSSHKeys = c.getBool("BACKUP_SSH_KEYS", true) c.BackupZFSConfig = c.getBool("BACKUP_ZFS_CONFIG", true) c.BackupRootHome = c.getBool("BACKUP_ROOT_HOME", true) - c.BackupScriptRepository = c.getBool("BACKUP_SCRIPT_REPOSITORY", true) + // Default false to match the shipped template (backup.env) and the project + // convention; a config missing this key must not silently snapshot /opt/proxsave. + c.BackupScriptRepository = c.getBool("BACKUP_SCRIPT_REPOSITORY", false) c.BackupUserHomes = c.getBool("BACKUP_USER_HOMES", true) c.BackupConfigFile = c.getBool("BACKUP_CONFIG_FILE", true) + // Optional system-root override (chroot/test fixture). Empty or "/" means real + // root; CollectorConfig.Validate rejects a non-absolute value. + c.SystemRootPrefix = strings.TrimSpace(c.getString("SYSTEM_ROOT_PREFIX", "")) c.PBSDatastorePaths = normalizeList(c.getStringSlice("PBS_DATASTORE_PATH", nil)) } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 279724f..663b13c 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1359,3 +1359,36 @@ func TestParseSecuritySettingsSafeProcessesDefaultsEmpty(t *testing.T) { t.Fatalf("SafeProcesses should default to empty, got %#v", cfg.SafeProcesses) } } + +func TestParseSystemSettingsSystemRootPrefix(t *testing.T) { + t.Run("parses and trims the override", func(t *testing.T) { + cfg := &Config{raw: map[string]string{"SYSTEM_ROOT_PREFIX": " /mnt/fixture "}} + cfg.parseSystemSettings() + if cfg.SystemRootPrefix != "/mnt/fixture" { + t.Fatalf("SystemRootPrefix = %q; want /mnt/fixture", cfg.SystemRootPrefix) + } + }) + t.Run("defaults to empty (real root) when unset", func(t *testing.T) { + cfg := &Config{raw: map[string]string{}} + cfg.parseSystemSettings() + if cfg.SystemRootPrefix != "" { + t.Fatalf("SystemRootPrefix should default to empty, got %q", cfg.SystemRootPrefix) + } + }) +} + +func TestParseSystemSettingsScriptRepositoryDefaultsFalse(t *testing.T) { + // A config missing the key must default to false, matching the shipped template + // (#69), so it does not silently snapshot /opt/proxsave. + cfg := &Config{raw: map[string]string{}} + cfg.parseSystemSettings() + if cfg.BackupScriptRepository { + t.Fatalf("BACKUP_SCRIPT_REPOSITORY must default to false, got true") + } + // An explicit opt-in is still honored. + on := &Config{raw: map[string]string{"BACKUP_SCRIPT_REPOSITORY": "true"}} + on.parseSystemSettings() + if !on.BackupScriptRepository { + t.Fatalf("explicit BACKUP_SCRIPT_REPOSITORY=true must enable collection") + } +} diff --git a/internal/config/templates/backup.env b/internal/config/templates/backup.env index f298c9c..4bb57d2 100644 --- a/internal/config/templates/backup.env +++ b/internal/config/templates/backup.env @@ -34,20 +34,35 @@ PORT_WHITELIST= # Format: service:port (e.g. sshd:22,nginx:4 # use "name*" or "regex:^name$" for wildcard/exact control. SUSPICIOUS_PROCESSES="ncat,cryptominer,xmrig,kdevtmpfsi,kinsing,minerd,mr.sh" +# SAFE_KERNEL_PROCESSES and SAFE_BRACKET_PROCESSES below both allowlist the +# "Suspicious kernel-style process: ..." warning. That warning fires for any +# process the host's `ps` reports inside square brackets, e.g. a real kernel +# thread `[kworker/0:1]` or a container worker an unprivileged LXC exposes to the +# host as `[celeryd: celery@paperless:ForkPoolWorker-3057]`. +# IMPORTANT: matching is against the text BETWEEN the brackets (brackets stripped), +# case-insensitive, and a plain entry is an EXACT whole-name match (NOT a prefix). +# To match part of the name use "name*" (prefix) or "regex:pattern" (unanchored). +# So the celery worker above is matched by `celeryd*` or `regex:^celeryd`, but NOT +# by a plain `celeryd` or by an anchored `regex:^celeryd$`. +# # SAFE_KERNEL_PROCESSES: Add safe kernel process patterns (supports exact, wildcard*, regex:pattern) # Built-in defaults include: ksgxd, hwrng, usb-storage, vdev_autotrim, kvm-pit*, and ZFS/DRBD patterns SAFE_KERNEL_PROCESSES="regex:^card[0-9]+-crtc[0-9]+$,regex:^drbd_[wrs]_.+,regex:^kmmpd-drbd[0-9]+$" # SAFE_BRACKET_PROCESSES: Add safe bracket [process] names # Built-in defaults include: systemd, cron, sshd:, rsyslogd, dbus-daemon, ZFS processes, NFS processes +# Example to silence the celery worker above: SAFE_BRACKET_PROCESSES="celeryd*" # SAFE_BRACKET_PROCESSES="" -# SAFE_PROCESSES: Allowlist for the suspicious-process scan. A process is never -# flagged if any token of its command line (or that token's basename) matches an -# entry here, even if it also matches SUSPICIOUS_PROCESSES. Matching is anchored to -# the start of each token: a plain entry matches any token that STARTS WITH it -# (e.g. "ssh" also matches "sshd"), so use "regex:^name$" if you need an exact match. -# "name*" wildcard and "regex:pattern" are also supported. Example: SAFE_PROCESSES="ffmpeg". +# SAFE_PROCESSES: Allowlist for the suspicious-process scan ONLY. It does NOT +# silence the bracketed "kernel-style" warning above; use SAFE_KERNEL_PROCESSES or +# SAFE_BRACKET_PROCESSES for those. A process is never flagged by the +# suspicious-process scan if any token of its command line (or that token's +# basename) matches an entry here, even if it also matches SUSPICIOUS_PROCESSES. +# Matching is anchored to the start of each token: a plain entry matches any token +# that STARTS WITH it (e.g. "ssh" also matches "sshd"), so use "regex:^name$" if you +# need an exact match. "name*" wildcard and "regex:pattern" are also supported. +# Example: SAFE_PROCESSES="ffmpeg". # SAFE_PROCESSES="" # ---------------------------------------------------------------------- @@ -298,7 +313,10 @@ METRICS_PATH=${BASE_DIR}/metrics BACKUP_CLUSTER_CONFIG=true BACKUP_PVE_FIREWALL=true BACKUP_VZDUMP_CONFIG=true -BACKUP_PVE_ACL=true # Access control (users/roles/groups/ACL; realms when configured) +# BACKUP_PVE_ACL=false also excludes the credential files /etc/pve/priv/{shadow,token,tfa}.cfg +# (password hashes, API token secrets, TFA secrets) from the /etc/pve snapshot. NOTE: those secrets +# also live inside the cluster database config.db; to exclude them entirely also set BACKUP_CLUSTER_CONFIG=false. +BACKUP_PVE_ACL=true # Access control (users/roles/groups/ACL + priv credentials; realms when configured) BACKUP_PVE_JOBS=true BACKUP_PVE_SCHEDULES=true BACKUP_PVE_REPLICATION=true @@ -322,7 +340,9 @@ BACKUP_PBS_METRIC_SERVERS=true # metricserver.cfg BACKUP_PBS_TRAFFIC_CONTROL=true # traffic-control.cfg BACKUP_PBS_NOTIFICATIONS=true # notifications.cfg (targets/matchers/endpoints) BACKUP_PBS_NOTIFICATIONS_PRIV=true # notifications-priv.cfg (secrets/credentials for endpoints) -BACKUP_USER_CONFIGS=true +# BACKUP_USER_CONFIGS=false also excludes the credential files token.cfg, shadow.json, token.shadow +# and tfa.json (API token secrets, password hashes, TFA secrets), not just user.cfg/acl.cfg/domains.cfg. +BACKUP_USER_CONFIGS=true # PBS users/ACLs/realms + their credentials (tokens, hashes, TFA) BACKUP_REMOTE_CONFIGS=true BACKUP_SYNC_JOBS=true BACKUP_VERIFICATION_JOBS=true diff --git a/internal/logging/logger.go b/internal/logging/logger.go index 19ca0a6..6c0d173 100644 --- a/internal/logging/logger.go +++ b/internal/logging/logger.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "os" + "sort" "strings" "sync" "time" @@ -23,6 +24,35 @@ type Logger struct { errorCount int64 issueLines []string // Captured WARNING/ERROR/CRITICAL lines for end-of-run summary exitFunc func(int) + secrets []secretForm // registered secret values scrubbed from every log line +} + +// RegisterSecret records a secret value so it is masked out of every subsequent +// log line (stdout, log file, and the end-of-run issue summary), at any level. +// This is a defense-in-depth net on top of source-level redaction; empty/too +// short secrets are ignored, and both raw and URL-encoded forms are covered. +func (l *Logger) RegisterSecret(s string) { + forms := secretReplaceForms([]string{s}) + if len(forms) == 0 { + return + } + l.mu.Lock() + defer l.mu.Unlock() + for _, f := range forms { + dup := false + for _, existing := range l.secrets { + if existing.form == f.form { + dup = true + break + } + } + if !dup { + l.secrets = append(l.secrets, f) + } + } + sort.Slice(l.secrets, func(i, j int) bool { + return len(l.secrets[i].form) > len(l.secrets[j].form) + }) } // New creates a new logger. @@ -151,6 +181,12 @@ func (l *Logger) logWithLabel(level types.LogLevel, label string, colorOverride levelStr = label } message := fmt.Sprintf(format, args...) + if len(l.secrets) > 0 { + // Defense-in-depth: scrub any registered secret value (raw or + // URL-encoded) from the line before it hits stdout, the log file, or + // the issue summary (the log is shipped off-host to secondary/cloud). + message = applySecretForms(message, l.secrets) + } var colorCode string var resetCode string diff --git a/internal/logging/redact.go b/internal/logging/redact.go new file mode 100644 index 0000000..d430e50 --- /dev/null +++ b/internal/logging/redact.go @@ -0,0 +1,103 @@ +package logging + +import ( + "net/url" + "sort" + "strings" +) + +const ( + // secretVisibleSuffix is how many trailing characters of a secret stay + // visible (so an operator can correlate a masked value with their config). + secretVisibleSuffix = 4 + // secretMaskPrefix is a fixed-width asterisk run used as the masked prefix. + // Fixed width so the mask never reveals the real secret length. + secretMaskPrefix = "************" + // secretMinFullReveal: secrets at or below this length are masked entirely + // (no visible suffix), so a short/low-entropy secret is not half-revealed. + secretMinFullReveal = 8 + // secretMinRegister: secrets shorter than this are not redacted at all, to + // avoid masking innocent common substrings and full-revealing tiny values. + secretMinRegister = 6 +) + +// MaskSecret renders a secret as a fixed asterisk prefix plus a short visible +// suffix, e.g. "************wxyz". Secrets of length <= secretMinFullReveal are +// fully masked (no visible suffix); empty input returns "". +func MaskSecret(s string) string { + if s == "" { + return "" + } + r := []rune(s) + if len(r) <= secretMinFullReveal { + return secretMaskPrefix + } + return secretMaskPrefix + string(r[len(r)-secretVisibleSuffix:]) +} + +// secretForm is a concrete string to search for (a secret in raw or +// URL-query-encoded form) together with its precomputed masked replacement. +type secretForm struct { + form string + masked string +} + +// secretReplaceForms expands each secret into the concrete forms that may appear +// in a log line or error string (raw and URL-query-encoded), each paired with +// its MaskSecret rendering. Empty/too-short secrets are skipped. Forms are +// ordered longest-first so a shorter secret cannot partially mask a longer one. +func secretReplaceForms(secrets []string) []secretForm { + seen := make(map[string]struct{}) + var forms []secretForm + add := func(form, masked string) { + if form == "" { + return + } + if _, ok := seen[form]; ok { + return + } + seen[form] = struct{}{} + forms = append(forms, secretForm{form: form, masked: masked}) + } + for _, sec := range secrets { + sec = strings.TrimSpace(sec) + if len([]rune(sec)) < secretMinRegister { + continue + } + masked := MaskSecret(sec) + add(sec, masked) + // The same secret commonly appears URL-query-encoded inside *url.Error + // strings (e.g. a Gotify token in "?token=..."); a verbatim replace of + // the raw value would miss it, so cover the encoded form too. + if enc := url.QueryEscape(sec); enc != sec { + add(enc, masked) + } + } + sort.Slice(forms, func(i, j int) bool { + return len(forms[i].form) > len(forms[j].form) + }) + return forms +} + +// applySecretForms replaces every occurrence of each form in s with its mask. +func applySecretForms(s string, forms []secretForm) string { + if s == "" { + return s + } + for _, f := range forms { + if strings.Contains(s, f.form) { + s = strings.ReplaceAll(s, f.form, f.masked) + } + } + return s +} + +// RedactSecrets replaces every occurrence of each given secret (in raw and +// URL-query-encoded form) in s with its MaskSecret rendering. Use it to redact a +// known secret out of an error/log string at the source before wrapping/logging. +func RedactSecrets(s string, secrets ...string) string { + if s == "" || len(secrets) == 0 { + return s + } + return applySecretForms(s, secretReplaceForms(secrets)) +} diff --git a/internal/logging/redact_test.go b/internal/logging/redact_test.go new file mode 100644 index 0000000..8f66e0e --- /dev/null +++ b/internal/logging/redact_test.go @@ -0,0 +1,91 @@ +package logging + +import ( + "bytes" + "net/url" + "strings" + "testing" + + "github.com/tis24dev/proxsave/internal/types" +) + +func TestMaskSecret(t *testing.T) { + cases := []struct{ in, want string }{ + {"", ""}, + {"short", secretMaskPrefix}, // <= 8 -> full mask + {"12345678", secretMaskPrefix}, // exactly 8 -> full mask + {"0123456789ABCDEF", secretMaskPrefix + "CDEF"}, // last 4 visible + } + for _, c := range cases { + if got := MaskSecret(c.in); got != c.want { + t.Errorf("MaskSecret(%q)=%q want %q", c.in, got, c.want) + } + } + sec := "supersecretvalue123" + if strings.Contains(MaskSecret(sec), sec) { + t.Fatalf("MaskSecret leaked the raw secret") + } +} + +func TestRedactSecrets(t *testing.T) { + sec := "1234567890ABCDEF" + got := RedactSecrets("token="+sec+" tail", sec) + if strings.Contains(got, sec) { + t.Fatalf("raw secret not redacted: %q", got) + } + if !strings.Contains(got, MaskSecret(sec)) { + t.Fatalf("expected mask in %q", got) + } + + // URL-encoded form (how a token appears inside a *url.Error, e.g. Gotify). + raw := "tok+en/val=longenough" + enc := url.QueryEscape(raw) + gotEnc := RedactSecrets(`Post "https://h/m?token=`+enc+`": refused`, raw) + if strings.Contains(gotEnc, enc) { + t.Fatalf("URL-encoded secret not redacted: %q", gotEnc) + } + if strings.Contains(gotEnc, raw) { + t.Fatalf("raw secret not redacted: %q", gotEnc) + } + + // Empty secret is a no-op (must not corrupt the string). + if out := RedactSecrets("hello world", ""); out != "hello world" { + t.Fatalf("empty secret should be a no-op, got %q", out) + } + // Too-short secret is skipped (avoid masking innocent substrings). + if out := RedactSecrets("abcabc", "abc"); out != "abcabc" { + t.Fatalf("short secret should be skipped, got %q", out) + } +} + +func TestLoggerScrubsRegisteredSecret(t *testing.T) { + var buf bytes.Buffer + l := New(types.LogLevelInfo, false) + l.SetOutput(&buf) + + sec := "bottoken0123456789secret" + l.RegisterSecret(sec) + l.Warning("api request failed: https://api/bot%s/x", sec) + + out := buf.String() + if strings.Contains(out, sec) { + t.Fatalf("log leaked the registered secret: %q", out) + } + if !strings.Contains(out, MaskSecret(sec)) { + t.Fatalf("expected masked secret in log: %q", out) + } +} + +func TestLoggerRegisterSecretIgnoresEmptyOrShort(t *testing.T) { + var buf bytes.Buffer + l := New(types.LogLevelInfo, false) + l.SetOutput(&buf) + + l.RegisterSecret("") // ignored + l.RegisterSecret("ab") // too short, ignored + l.Info("a normal line without secrets") + + if !strings.Contains(buf.String(), "a normal line without secrets") { + t.Fatalf("normal line corrupted: %q", buf.String()) + } +} diff --git a/internal/notify/gotify.go b/internal/notify/gotify.go index 44acc3b..d242c5a 100644 --- a/internal/notify/gotify.go +++ b/internal/notify/gotify.go @@ -136,7 +136,9 @@ func (g *GotifyNotifier) Send(ctx context.Context, data *NotificationData) (*Not resp, err := g.client.Do(req) if err != nil { - err = fmt.Errorf("gotify request failed: %w", err) + // The token is in the URL query, URL-encoded inside the *url.Error; + // RedactSecrets masks both the raw and URL-encoded forms. + err = fmt.Errorf("gotify request failed: %s", logging.RedactSecrets(err.Error(), g.config.Token)) g.logger.Warning("WARNING: %v", err) result.Success = false result.Error = err diff --git a/internal/notify/redact_notify_test.go b/internal/notify/redact_notify_test.go new file mode 100644 index 0000000..261d53b --- /dev/null +++ b/internal/notify/redact_notify_test.go @@ -0,0 +1,48 @@ +package notify + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/types" +) + +// TestGotifySendRedactsTokenOnTransportError (H11) verifies the Gotify token is +// not leaked into the transport-error returned by Send, in either its raw or +// URL-encoded form (it rides in the URL query, so *url.Error carries it encoded). +func TestGotifySendRedactsTokenOnTransportError(t *testing.T) { + logger := logging.New(types.LogLevelError, false) + + server := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + serverURL := server.URL + server.Close() // closed -> client.Do fails with connection refused + + const token = "tok+en/val=secret123" + notifier, err := NewGotifyNotifier(GotifyConfig{Enabled: true, ServerURL: serverURL, Token: token}, logger) + if err != nil { + t.Fatalf("new gotify: %v", err) + } + notifier.client = &http.Client{Timeout: 2 * time.Second} + + result, err := notifier.Send(context.Background(), createTestNotificationData()) + if err != nil { + t.Fatalf("Send returned error: %v", err) + } + if result.Success || result.Error == nil { + t.Fatalf("expected a transport failure, got success=%v err=%v", result.Success, result.Error) + } + + msg := result.Error.Error() + if strings.Contains(msg, token) { + t.Fatalf("raw token leaked in error: %q", msg) + } + if enc := url.QueryEscape(token); strings.Contains(msg, enc) { + t.Fatalf("URL-encoded token leaked in error: %q", msg) + } +} diff --git a/internal/notify/telegram.go b/internal/notify/telegram.go index f1cb14e..93a3f16 100644 --- a/internal/notify/telegram.go +++ b/internal/notify/telegram.go @@ -338,7 +338,9 @@ func (t *TelegramNotifier) sendToTelegram(ctx context.Context, botToken, chatID, // Send request resp, err := t.client.Do(req) if err != nil { - return fmt.Errorf("api request failed: %w", err) + // The bot token is embedded in the URL path, which a *url.Error carries + // verbatim; redact it before the error is wrapped/logged/propagated. + return fmt.Errorf("api request failed: %s", logging.RedactSecrets(err.Error(), botToken)) } defer func() { _ = resp.Body.Close() }() diff --git a/internal/notify/webhook.go b/internal/notify/webhook.go index 2470e3e..ed65db5 100644 --- a/internal/notify/webhook.go +++ b/internal/notify/webhook.go @@ -378,9 +378,12 @@ func (w *WebhookNotifier) sendToEndpoint(ctx context.Context, endpoint config.We if ctxErr := ctx.Err(); ctxErr != nil { return ctxErr } - lastErr = fmt.Errorf("request failed: %w", err) - w.logger.Warning("⚠️ Request failed after %dms (attempt %d/%d): %v", - requestDuration.Milliseconds(), attempt+1, maxRetries+1, err) + // The endpoint URL (often itself the secret, e.g. Discord/Slack) is + // carried verbatim by *url.Error; redact it before wrap/log. + redacted := logging.RedactSecrets(err.Error(), endpoint.URL) + lastErr = fmt.Errorf("request failed: %s", redacted) + w.logger.Warning("⚠️ Request failed after %dms (attempt %d/%d): %s", + requestDuration.Milliseconds(), attempt+1, maxRetries+1, redacted) continue } diff --git a/internal/orchestrator/additional_helpers_test.go b/internal/orchestrator/additional_helpers_test.go index 6f28282..0d5e26b 100644 --- a/internal/orchestrator/additional_helpers_test.go +++ b/internal/orchestrator/additional_helpers_test.go @@ -599,6 +599,7 @@ func TestFinalizeAndCloseLogWithoutLogFile(t *testing.T) { type stubNotifierChannel struct { name string called bool + count int logger *logging.Logger warnOnNotify bool errorCount int @@ -613,6 +614,7 @@ func (s *stubNotifierChannel) Name() string { func (s *stubNotifierChannel) Notify(ctx context.Context, stats *BackupStats) error { s.called = true + s.count++ if stats != nil { s.errorCount = stats.ErrorCount s.warningCount = stats.WarningCount @@ -1151,6 +1153,10 @@ func TestCleanupPreviousExecutionArtifacts(t *testing.T) { logger: logger, } + origRoot := workspaceRoot + workspaceRoot = t.TempDir() + t.Cleanup(func() { workspaceRoot = origRoot }) + logDir := t.TempDir() o.logPath = logDir @@ -1185,10 +1191,15 @@ func TestCleanupPreviousExecutionArtifacts(t *testing.T) { } o.tempRegistry = reg - orphanDir := filepath.Join(registryDir, "orphan-temp") - if err := os.MkdirAll(orphanDir, 0o755); err != nil { + // A legitimate orphan workspace under the trusted root, with the marker, so the + // hardened CleanupOrphaned (issue #55) will remove it. + orphanDir := filepath.Join(workspaceRoot, "proxsave-orphan") + if err := os.MkdirAll(orphanDir, 0o700); err != nil { t.Fatalf("create orphan dir: %v", err) } + if err := os.WriteFile(filepath.Join(orphanDir, workspaceMarker), []byte("m"), 0o600); err != nil { + t.Fatalf("write orphan marker: %v", err) + } if err := reg.Register(orphanDir); err != nil { t.Fatalf("register orphan dir: %v", err) } diff --git a/internal/orchestrator/age_setup_workflow.go b/internal/orchestrator/age_setup_workflow.go index 339fcf6..d5a8855 100644 --- a/internal/orchestrator/age_setup_workflow.go +++ b/internal/orchestrator/age_setup_workflow.go @@ -143,7 +143,7 @@ func (o *Orchestrator) runAgeSetupWorkflow(ctx context.Context, candidatePath st return nil, nil, ErrAgeRecipientSetupAborted } - value, err := resolveAgeRecipientDraft(draft) + value, err := o.resolveAgeRecipientDraft(draft, targetPath) if err != nil { if o.logger != nil { o.logger.Warning("Encryption setup: %v", err) @@ -207,7 +207,7 @@ func (o *Orchestrator) runAgeSetupWorkflow(ctx context.Context, candidatePath st }, nil } -func resolveAgeRecipientDraft(draft *AgeRecipientDraft) (string, error) { +func (o *Orchestrator) resolveAgeRecipientDraft(draft *AgeRecipientDraft, recipientPath string) (string, error) { if draft == nil { return "", fmt.Errorf("recipient draft is required") } @@ -228,7 +228,15 @@ func resolveAgeRecipientDraft(draft *AgeRecipientDraft) (string, error) { if err := validatePassphraseStrength([]byte(passphrase)); err != nil { return "", err } - recipient, err := deriveDeterministicRecipientFromPassphrase(passphrase) + // Per-installation random salt (persisted next to the recipient file and + // embedded in archive manifests) instead of a global constant: this + // prevents cross-install precomputation and recipient correlation while + // keeping the asymmetric model (the passphrase is not needed at backup time). + salt, err := o.getOrCreatePassphraseSalt(recipientPath) + if err != nil { + return "", err + } + recipient, err := deriveDeterministicRecipientFromPassphraseWithSalt(passphrase, salt) if err != nil { return "", err } diff --git a/internal/orchestrator/backup_run_helpers.go b/internal/orchestrator/backup_run_helpers.go index b28308f..2015eab 100644 --- a/internal/orchestrator/backup_run_helpers.go +++ b/internal/orchestrator/backup_run_helpers.go @@ -140,18 +140,22 @@ func (o *Orchestrator) logBackupCollectionSummary(collStats *backup.CollectionSt collStats.DirsCreated) } -func (o *Orchestrator) applyBackupOptimizations(ctx context.Context, tempDir string) error { +func (o *Orchestrator) applyBackupOptimizations(ctx context.Context, tempDir string) (backup.OptimizationResult, error) { if !o.optimizationCfg.Enabled() { o.logger.Debug("Skipping optimization step (all features disabled)") - return nil + return backup.OptimizationResult{}, nil } fmt.Println() o.logger.Step("Backup optimizations on collected data") - if err := backup.ApplyOptimizations(ctx, o.logger, tempDir, o.optimizationCfg); err != nil { - o.logger.Warning("Backup optimizations completed with warnings: %v", err) + res, err := backup.ApplyOptimizations(ctx, o.logger, tempDir, o.optimizationCfg) + if err != nil { + // ApplyOptimizations only returns an error for an unsafe deduplication state + // (a tree that would lose fidelity on restore); abort rather than ship it. + // Benign prefilter issues are logged internally and do not surface here. + return backup.OptimizationResult{}, fmt.Errorf("backup optimizations: %w", err) } - return nil + return res, nil } func estimatedBackupSizeGB(bytesCollected int64) float64 { @@ -304,6 +308,7 @@ func (o *Orchestrator) newArchiveManifest(stats *BackupStats, archivePath, check ScriptVersion: stats.ScriptVersion, EncryptionMode: o.archiveEncryptionMode(), ClusterMode: stats.ClusterMode, + PassphraseSalt: o.passphraseSaltForManifest(), } } diff --git a/internal/orchestrator/backup_run_phases.go b/internal/orchestrator/backup_run_phases.go index 4e79ad4..71b0795 100644 --- a/internal/orchestrator/backup_run_phases.go +++ b/internal/orchestrator/backup_run_phases.go @@ -124,8 +124,8 @@ func (o *Orchestrator) finalizeFailedBackupStats(run *backupRunContext, runErr e func (o *Orchestrator) prepareBackupWorkspace(run *backupRunContext, workspace *backupWorkspace) error { o.logger.Debug("Creating temporary directory for collection output") - workspace.tempRoot = filepath.Join("/tmp", "proxsave") - if err := workspace.fs.MkdirAll(workspace.tempRoot, 0o755); err != nil { + workspace.tempRoot = workspaceRoot + if err := ensureSecureTempRoot(workspace.fs, workspace.tempRoot); err != nil { return fmt.Errorf("temp directory creation failed - path: %s: %w", workspace.tempRoot, err) } @@ -144,17 +144,30 @@ func (o *Orchestrator) prepareBackupWorkspace(run *backupRunContext, workspace * } func (o *Orchestrator) cleanupBackupWorkspace(workspace *backupWorkspace) { - if workspace.registry == nil { - if cleanupErr := workspace.fs.RemoveAll(workspace.tempDir); cleanupErr != nil { - o.logger.Warning("Failed to remove temp directory %s: %v", workspace.tempDir, cleanupErr) - } + if workspace.tempDir == "" { + return + } + // Always remove the staging workspace when the run finishes: it holds plaintext + // copies of sensitive files (shadow, SSL/SSH keys, ...) gathered before + // encryption, so it must not be left on disk after a successful (or failed) run + // (issue #53). The registry exists for crash recovery only; previously a + // non-nil registry caused the workspace to be preserved "until the next + // startup", leaving secrets at rest for the whole inter-run window. + if cleanupErr := workspace.fs.RemoveAll(workspace.tempDir); cleanupErr != nil { + // Keep it registered so the next run's orphan sweep retries the removal. + o.logger.Warning("Failed to remove temp directory %s: %v", workspace.tempDir, cleanupErr) return } - o.logger.Debug("Temporary workspace preserved at %s (will be removed at the next startup)", workspace.tempDir) + o.logger.Debug("Removed temporary workspace %s", workspace.tempDir) + if workspace.registry != nil { + if err := workspace.registry.Deregister(workspace.tempDir); err != nil { + o.logger.Debug("Failed to deregister temp directory %s: %v", workspace.tempDir, err) + } + } } func (o *Orchestrator) markBackupWorkspace(workspace *backupWorkspace) error { - markerPath := filepath.Join(workspace.tempDir, ".proxsave-marker") + markerPath := filepath.Join(workspace.tempDir, workspaceMarker) markerContent := fmt.Sprintf( "Created by PID %d on %s UTC\n", os.Getpid(), @@ -207,7 +220,20 @@ func (o *Orchestrator) collectBackupData(run *backupRunContext, workspace *backu return err } - return o.applyBackupOptimizations(run.ctx, workspace.tempDir) + optResult, err := o.applyBackupOptimizations(run.ctx, workspace.tempDir) + if err != nil { + return err + } + // Dedup/prefilter shrank the staged tree AFTER the collection stats were taken; + // correct the uncompressed-payload figure that the compression ratio divides by + // so reports/notifications/metrics reflect what is actually archived (issue #73). + // BytesCollected stays the honest "bytes read during collection" figure. + if optResult.BytesReclaimed > 0 { + if shipped := run.stats.BytesCollected - optResult.BytesReclaimed; shipped >= 0 { + run.stats.UncompressedSize = shipped + } + } + return nil } func (o *Orchestrator) validateCollectedBackupSize(stats *BackupStats) error { diff --git a/internal/orchestrator/backup_run_phases_test.go b/internal/orchestrator/backup_run_phases_test.go index dcfc270..a5c489a 100644 --- a/internal/orchestrator/backup_run_phases_test.go +++ b/internal/orchestrator/backup_run_phases_test.go @@ -3,13 +3,50 @@ package orchestrator import ( "context" "errors" + "os" + "path/filepath" "strings" "testing" "github.com/tis24dev/proxsave/internal/config" + "github.com/tis24dev/proxsave/internal/logging" "github.com/tis24dev/proxsave/internal/types" ) +func TestCleanupBackupWorkspaceRemovesAndDeregisters(t *testing.T) { + logger := logging.New(types.LogLevelError, false) + orch := New(logger, false) + + reg, err := NewTempDirRegistry(logger, filepath.Join(t.TempDir(), "registry.json")) + if err != nil { + t.Fatalf("NewTempDirRegistry: %v", err) + } + + tempDir := t.TempDir() + // Represent plaintext staged secrets that must not survive a finished run. + if err := os.WriteFile(filepath.Join(tempDir, "shadow"), []byte("hash"), 0o600); err != nil { + t.Fatal(err) + } + if err := reg.Register(tempDir); err != nil { + t.Fatalf("register: %v", err) + } + + orch.cleanupBackupWorkspace(&backupWorkspace{registry: reg, fs: osFS{}, tempDir: tempDir}) + + if _, err := os.Stat(tempDir); !os.IsNotExist(err) { + t.Fatalf("workspace must be removed when the run finishes (issue #53), stat err=%v", err) + } + entries, err := reg.loadEntries() + if err != nil { + t.Fatalf("loadEntries: %v", err) + } + for _, e := range entries { + if e.Path == tempDir { + t.Fatalf("workspace must be deregistered after removal; still present in %+v", entries) + } + } +} + func TestCreateBackupArchiveClassifiesAgeRecipientFailureAsEncryption(t *testing.T) { orch := New(newTestLogger(), false) orch.SetConfig(&config.Config{ diff --git a/internal/orchestrator/categories.go b/internal/orchestrator/categories.go index 40b47f7..ac01a89 100644 --- a/internal/orchestrator/categories.go +++ b/internal/orchestrator/categories.go @@ -312,6 +312,12 @@ func GetAllCategories() []Category { "./etc/auto.master", "./etc/auto.master.d/", "./etc/auto.*", + // Conventional LUKS/crypttab keyfile directories. Keyfiles at + // arbitrary absolute paths are collected dynamically but cannot be + // statically matched here; those still need the full/plain restore (#66). + "./etc/keys/", + "./etc/luks-keys/", + "./etc/cryptsetup-keys.d/", }, }, { @@ -372,6 +378,10 @@ func GetAllCategories() []Category { "./etc/cron.d/", "./etc/crontab", "./var/spool/cron/", + "./etc/cron.daily/", + "./etc/cron.hourly/", + "./etc/cron.weekly/", + "./etc/cron.monthly/", }, }, { @@ -395,6 +405,21 @@ func GetAllCategories() []Category { "./etc/nftables.d/", }, }, + { + ID: "accounts", + Name: "System Accounts & Auth (WARNING)", + Description: "OS users/groups/passwords and sudoers (/etc/passwd,group,shadow,gshadow,sudoers). " + + "Applied with a safe merge that preserves the current host root and system accounts; " + + "WARNING: review before applying and prefer a matching/fresh host.", + Type: CategoryTypeCommon, + Paths: []string{ + "./etc/passwd", + "./etc/group", + "./etc/shadow", + "./etc/gshadow", + "./etc/sudoers", + }, + }, { ID: "user_data", Name: "User Data (Home Directories)", diff --git a/internal/orchestrator/categories_coverage_test.go b/internal/orchestrator/categories_coverage_test.go new file mode 100644 index 0000000..18d9adb --- /dev/null +++ b/internal/orchestrator/categories_coverage_test.go @@ -0,0 +1,61 @@ +package orchestrator + +import "testing" + +// TestOrphanPathsNowCovered locks the #66/#67 coverage: previously-orphaned +// collected files now match a system-writable (non-export-only) category. +func TestOrphanPathsNowCovered(t *testing.T) { + all := GetAllCategories() + cases := map[string]string{ + "./etc/passwd": "accounts", + "./etc/group": "accounts", + "./etc/shadow": "accounts", + "./etc/gshadow": "accounts", + "./etc/sudoers": "accounts", + "./etc/cron.daily/logrot": "crontabs", + "./etc/cron.hourly/x": "crontabs", + "./etc/cron.weekly/x": "crontabs", + "./etc/cron.monthly/x": "crontabs", + "./etc/keys/luks.key": "storage_stack", + "./etc/luks-keys/disk": "storage_stack", + "./etc/cryptsetup-keys.d/k": "storage_stack", + } + for p, wantID := range cases { + cat := GetCategoryByID(wantID, all) + if cat == nil { + t.Fatalf("category %q not found", wantID) + } + if !PathMatchesCategory(p, *cat) { + t.Errorf("path %q should match category %q", p, wantID) + } + if cat.ExportOnly { + t.Errorf("category %q for %q must be system-writable (not ExportOnly)", wantID, p) + } + } +} + +// TestAccountsCategoryClassification verifies the sensitive accounts category is +// staged (safe merge apply), not export-only, and confined to Full/Custom modes. +func TestAccountsCategoryClassification(t *testing.T) { + all := GetAllCategories() + cat := GetCategoryByID("accounts", all) + if cat == nil { + t.Fatal("accounts category missing") + } + if cat.ExportOnly { + t.Error("accounts must not be ExportOnly") + } + if !isStagedCategoryID("accounts") { + t.Error("accounts must be staged so it applies via the safe merge step") + } + for _, c := range GetBaseModeCategories() { + if c.ID == "accounts" { + t.Error("accounts must NOT be in Base mode") + } + } + for _, c := range GetStorageModeCategories("dual") { + if c.ID == "accounts" { + t.Error("accounts must NOT be in Storage mode") + } + } +} diff --git a/internal/orchestrator/categories_test.go b/internal/orchestrator/categories_test.go new file mode 100644 index 0000000..7c157b5 --- /dev/null +++ b/internal/orchestrator/categories_test.go @@ -0,0 +1,68 @@ +package orchestrator + +import "testing" + +// TestPVEAccessControlCategoryMatchesRestoreConstants locks the authoritative +// pve_access_control file set. The same list is duplicated as exclusion patterns +// in internal/backup/collector_pve.go (pveACLPrivExcludePatterns) because the +// backup package cannot import this one (import cycle). If this test fails after +// adding/removing a PVE access-control file, update the backup-side exclusion too. +func TestPVEAccessControlCategoryMatchesRestoreConstants(t *testing.T) { + cat := GetCategoryByID("pve_access_control", GetAllCategories()) + if cat == nil { + t.Fatal("pve_access_control category not found") + } + + want := []string{ + "." + pveUserCfgPath, + "." + pveDomainsCfgPath, + "." + pveShadowCfgPath, + "." + pveTokenCfgPath, + "." + pveTFACfgPath, + } + + have := make(map[string]bool, len(cat.Paths)) + for _, p := range cat.Paths { + have[p] = true + } + for _, w := range want { + if !have[w] { + t.Errorf("pve_access_control category missing %q; keep it in sync with restore_access_control.go constants and internal/backup/collector_pve.go pveACLPrivExcludePatterns", w) + } + } + if len(cat.Paths) != len(want) { + t.Errorf("pve_access_control has %d paths, want %d (%v); a new access-control file must also be added to the backup-side exclusion", len(cat.Paths), len(want), cat.Paths) + } +} + +// TestPBSAccessControlCategoryContainsRestoreConstants locks the PBS access-control +// credential files. The same set is duplicated as exclusion patterns in +// internal/backup/collector_pbs.go (pbsUserConfigSecretExcludes) because the backup +// package cannot import this one. Subset check (the category also carries +// informational var/lib/proxsave-info JSON paths). +func TestPBSAccessControlCategoryContainsRestoreConstants(t *testing.T) { + cat := GetCategoryByID("pbs_access_control", GetAllCategories()) + if cat == nil { + t.Fatal("pbs_access_control category not found") + } + + want := []string{ + "." + pbsUserCfgPath, + "." + pbsDomainsCfgPath, + "." + pbsACLCfgPath, + "." + pbsTokenCfgPath, + "." + pbsShadowJSONPath, + "." + pbsTokenShadowPath, + "." + pbsTFAJSONPath, + } + + have := make(map[string]bool, len(cat.Paths)) + for _, p := range cat.Paths { + have[p] = true + } + for _, w := range want { + if !have[w] { + t.Errorf("pbs_access_control category missing %q; keep it in sync with restore_access_control.go constants and internal/backup/collector_pbs.go pbsUserConfigSecretExcludes", w) + } + } +} diff --git a/internal/orchestrator/decrypt.go b/internal/orchestrator/decrypt.go index 55c24ee..5169592 100644 --- a/internal/orchestrator/decrypt.go +++ b/internal/orchestrator/decrypt.go @@ -428,9 +428,9 @@ func promptDestinationDir(ctx context.Context, reader *bufio.Reader, cfg *config func downloadRcloneBackup(ctx context.Context, remotePath string, logger *logging.Logger) (tmpPath string, cleanup func(), err error) { done := logging.DebugStart(logger, "download rclone backup", "remote=%s", remotePath) defer func() { done(err) }() - // Ensure /tmp/proxsave exists - tempRoot := filepath.Join("/tmp", "proxsave") - if err := restoreFS.MkdirAll(tempRoot, 0o755); err != nil { + // Ensure /tmp/proxsave exists and is a safe, root-owned, non-symlink directory. + tempRoot := workspaceRoot + if err := ensureSecureTempRoot(restoreFS, tempRoot); err != nil { return "", nil, fmt.Errorf("failed to create temp directory: %w", err) } @@ -706,10 +706,18 @@ func copyRawArtifactsToWorkdirWithLogger(ctx context.Context, cand *backupCandid func decryptArchiveWithPrompts(ctx context.Context, reader *bufio.Reader, encryptedPath, outputPath string, logger *logging.Logger) error { ui := newCLIWorkflowUI(reader, logger) displayName := filepath.Base(encryptedPath) - return decryptArchiveWithSecretPrompt(ctx, encryptedPath, outputPath, displayName, ui.PromptDecryptSecret) + return decryptArchiveWithSecretPrompt(ctx, encryptedPath, outputPath, displayName, ui.PromptDecryptSecret, nil) } func parseIdentityInput(input string) ([]age.Identity, error) { + return parseIdentityInputWithSalts(input, nil) +} + +// parseIdentityInputWithSalts parses a private key directly, or treats the input +// as a passphrase and derives the candidate identities. extraSalts carries the +// per-installation salt read from the archive manifest so passphrase decryption +// works on any host; the fixed v1/legacy salts are always appended as fallback. +func parseIdentityInputWithSalts(input string, extraSalts []string) ([]age.Identity, error) { if strings.HasPrefix(strings.ToUpper(input), "AGE-SECRET-KEY-") { id, err := age.ParseX25519Identity(strings.ToUpper(input)) if err != nil { @@ -717,7 +725,19 @@ func parseIdentityInput(input string) ([]age.Identity, error) { } return []age.Identity{id}, nil } - return deriveDeterministicIdentitiesFromPassphrase(input) + return deriveDeterministicIdentitiesFromPassphraseWithExtraSalts(input, extraSalts) +} + +// manifestPassphraseSalts returns the per-installation salt recorded in a +// manifest (if any), to be tried first when deriving identities from a passphrase. +func manifestPassphraseSalts(m *backup.Manifest) []string { + if m == nil { + return nil + } + if salt := strings.TrimSpace(m.PassphraseSalt); salt != "" { + return []string{salt} + } + return nil } func decryptWithIdentity(src, dst string, identities ...age.Identity) (err error) { diff --git a/internal/orchestrator/decrypt_prepare_common.go b/internal/orchestrator/decrypt_prepare_common.go index f35813f..1f06dad 100644 --- a/internal/orchestrator/decrypt_prepare_common.go +++ b/internal/orchestrator/decrypt_prepare_common.go @@ -81,8 +81,8 @@ func preparePlainBundleCommon(ctx context.Context, cand *backupCandidate, versio cand.BundlePath = localPath } - tempRoot := filepath.Join("/tmp", "proxsave") - if err := restoreFS.MkdirAll(tempRoot, 0o755); err != nil { + tempRoot := workspaceRoot + if err := ensureSecureTempRoot(restoreFS, tempRoot); err != nil { if rcloneCleanup != nil { rcloneCleanup() } diff --git a/internal/orchestrator/decrypt_workflow_ui.go b/internal/orchestrator/decrypt_workflow_ui.go index aadc971..719770b 100644 --- a/internal/orchestrator/decrypt_workflow_ui.go +++ b/internal/orchestrator/decrypt_workflow_ui.go @@ -160,7 +160,7 @@ func ensureWritablePathWithUI(ctx context.Context, ui DecryptWorkflowUI, targetP } } -func decryptArchiveWithSecretPrompt(ctx context.Context, encryptedPath, outputPath, displayName string, prompt func(ctx context.Context, displayName, previousError string) (string, error)) error { +func decryptArchiveWithSecretPrompt(ctx context.Context, encryptedPath, outputPath, displayName string, prompt func(ctx context.Context, displayName, previousError string) (string, error), extraSalts []string) error { promptError := "" for { secret, err := prompt(ctx, displayName, promptError) @@ -178,7 +178,7 @@ func decryptArchiveWithSecretPrompt(ctx context.Context, encryptedPath, outputPa continue } - identities, err := parseIdentityInput(secret) + identities, err := parseIdentityInputWithSalts(secret, extraSalts) resetString(&secret) if err != nil { promptError = fmt.Sprintf("Invalid key or passphrase: %v", err) @@ -209,8 +209,9 @@ func preparePlainBundleWithUI(ctx context.Context, cand *backupCandidate, versio done := logging.DebugStart(logger, "prepare plain bundle (ui)", "source=%v rclone=%v", cand.Source, cand.IsRclone) defer func() { done(err) }() + extraSalts := manifestPassphraseSalts(cand.Manifest) return preparePlainBundleCommon(ctx, cand, version, logger, func(ctx context.Context, encryptedPath, outputPath, displayName string) error { - return decryptArchiveWithSecretPrompt(ctx, encryptedPath, outputPath, displayName, ui.PromptDecryptSecret) + return decryptArchiveWithSecretPrompt(ctx, encryptedPath, outputPath, displayName, ui.PromptDecryptSecret, extraSalts) }) } diff --git a/internal/orchestrator/deps_test.go b/internal/orchestrator/deps_test.go index 6154e07..62db99b 100644 --- a/internal/orchestrator/deps_test.go +++ b/internal/orchestrator/deps_test.go @@ -24,6 +24,7 @@ type FakeFS struct { MkdirAllErr error MkdirTempErr error OpenFileErr map[string]error + RenameErr map[string]error Ownership map[string]FakeOwnership } @@ -39,6 +40,7 @@ func NewFakeFS() *FakeFS { StatErr: make(map[string]error), StatErrors: make(map[string]error), OpenFileErr: make(map[string]error), + RenameErr: make(map[string]error), Ownership: make(map[string]FakeOwnership), } } @@ -191,6 +193,9 @@ func (f *FakeFS) MkdirTemp(dir, pattern string) (string, error) { } func (f *FakeFS) Rename(oldpath, newpath string) error { + if err, ok := f.RenameErr[filepath.Clean(oldpath)]; ok { + return err + } return os.Rename(f.onDisk(oldpath), f.onDisk(newpath)) } diff --git a/internal/orchestrator/early_error_notification_test.go b/internal/orchestrator/early_error_notification_test.go new file mode 100644 index 0000000..d5b868d --- /dev/null +++ b/internal/orchestrator/early_error_notification_test.go @@ -0,0 +1,70 @@ +package orchestrator + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/tis24dev/proxsave/internal/config" + "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/types" +) + +func newEarlyErrorState() *EarlyErrorState { + return &EarlyErrorState{ + Phase: "storage_init", + Error: errors.New("primary storage mount unavailable"), + ExitCode: types.ExitStorageError, + Timestamp: time.Now(), + } +} + +// TestDispatchEarlyErrorNotification_SendsToRegisteredChannel covers H12: an +// early-init failure must reach the registered notification channels exactly +// once (channels are now registered before the fallible init phases). +func TestDispatchEarlyErrorNotification_SendsToRegisteredChannel(t *testing.T) { + logger := logging.New(types.LogLevelInfo, false) + stub := &stubNotifierChannel{name: "FakeEarly"} + o := &Orchestrator{ + logger: logger, + cfg: &config.Config{}, + notificationChannels: []NotificationChannel{stub}, + } + + stats := o.DispatchEarlyErrorNotification(context.Background(), newEarlyErrorState()) + if stats == nil { + t.Fatal("expected non-nil stats") + } + if stub.count != 1 { + t.Fatalf("expected the registered channel to be notified exactly once, got %d", stub.count) + } + if stats.LocalStatus != "error" || stats.ErrorCount != 1 { + t.Fatalf("expected error stats, got LocalStatus=%q ErrorCount=%d", stats.LocalStatus, stats.ErrorCount) + } + if stub.errorCount != 1 { + t.Fatalf("channel received ErrorCount=%d, want 1", stub.errorCount) + } +} + +// TestDispatchEarlyErrorNotification_DryRunDoesNotSend covers the mandatory +// guardrail: in dry-run no real notification is sent (the normal path already +// gates on !dryRun), but stats are still returned for support/log handling. +func TestDispatchEarlyErrorNotification_DryRunDoesNotSend(t *testing.T) { + logger := logging.New(types.LogLevelInfo, false) + stub := &stubNotifierChannel{name: "FakeEarly"} + o := &Orchestrator{ + logger: logger, + cfg: &config.Config{}, + dryRun: true, + notificationChannels: []NotificationChannel{stub}, + } + + stats := o.DispatchEarlyErrorNotification(context.Background(), newEarlyErrorState()) + if stats == nil { + t.Fatal("dry-run should still return stats") + } + if stub.called { + t.Fatal("dry-run must not send notifications") + } +} diff --git a/internal/orchestrator/encryption.go b/internal/orchestrator/encryption.go index f5171f1..bd93caa 100644 --- a/internal/orchestrator/encryption.go +++ b/internal/orchestrator/encryption.go @@ -27,12 +27,12 @@ var ErrAgeRecipientSetupAborted = errors.New("encryption setup aborted by user") const ( // Note: dual salt for passphrase-derived keys — keep legacy for decrypting older archives. - passphraseRecipientSalt = "proxsave/age-passphrase/v1" - legacyPassphraseRecipientSalt = "proxmox-backup-go/age-passphrase/v1" - passphraseScryptN = 1 << 15 - passphraseScryptR = 8 - passphraseScryptP = 1 - minPassphraseLength = 12 + recipientSaltV1 = "proxsave/age-passphrase/v1" + legacyRecipientSalt = "proxmox-backup-go/age-passphrase/v1" + passphraseScryptN = 1 << 15 + passphraseScryptR = 8 + passphraseScryptP = 1 + minPassphraseLength = 12 ) var weakPassphraseList = []string{ @@ -560,7 +560,7 @@ func DeriveDeterministicRecipientFromPassphrase(passphrase string) (string, erro } func deriveDeterministicRecipientFromPassphrase(passphrase string) (string, error) { - return deriveDeterministicRecipientFromPassphraseWithSalt(passphrase, passphraseRecipientSalt) + return deriveDeterministicRecipientFromPassphraseWithSalt(passphrase, recipientSaltV1) } func deriveDeterministicRecipientFromPassphraseWithSalt(passphrase, salt string) (string, error) { @@ -589,7 +589,7 @@ func clampCurve25519Scalar(k []byte) { } func deriveCurve25519ScalarFromPassphrase(passphrase string) ([]byte, error) { - return deriveCurve25519ScalarFromPassphraseWithSalt(passphrase, passphraseRecipientSalt) + return deriveCurve25519ScalarFromPassphraseWithSalt(passphrase, recipientSaltV1) } func deriveCurve25519ScalarFromPassphraseWithSalt(passphrase, salt string) ([]byte, error) { @@ -602,7 +602,7 @@ func deriveCurve25519ScalarFromPassphraseWithSalt(passphrase, salt string) ([]by } func deriveDeterministicIdentityFromPassphrase(passphrase string) (age.Identity, error) { - return deriveDeterministicIdentityFromPassphraseWithSalt(passphrase, passphraseRecipientSalt) + return deriveDeterministicIdentityFromPassphraseWithSalt(passphrase, recipientSaltV1) } func deriveDeterministicIdentityFromPassphraseWithSalt(passphrase, salt string) (age.Identity, error) { @@ -619,11 +619,33 @@ func deriveDeterministicIdentityFromPassphraseWithSalt(passphrase, salt string) } func deriveDeterministicIdentitiesFromPassphrase(passphrase string) ([]age.Identity, error) { - salts := []string{passphraseRecipientSalt, legacyPassphraseRecipientSalt} - seen := make(map[string]struct{}, len(salts)) + return deriveDeterministicIdentitiesFromPassphraseWithExtraSalts(passphrase, nil) +} + +// deriveDeterministicIdentitiesFromPassphraseWithExtraSalts derives one identity +// per candidate salt. extraSalts (e.g. the per-installation salt embedded in an +// archive manifest) are tried first, followed by the fixed v1/legacy salts so +// archives produced before the per-install salt remain decryptable. Identities +// are deduped by their recipient. +func deriveDeterministicIdentitiesFromPassphraseWithExtraSalts(passphrase string, extraSalts []string) ([]age.Identity, error) { + salts := make([]string, 0, len(extraSalts)+2) + salts = append(salts, extraSalts...) + salts = append(salts, recipientSaltV1, legacyRecipientSalt) + + seenSalt := make(map[string]struct{}, len(salts)) + seenRec := make(map[string]struct{}, len(salts)) ids := make([]age.Identity, 0, len(salts)) for _, salt := range salts { + salt = strings.TrimSpace(salt) + if salt == "" { + continue + } + if _, ok := seenSalt[salt]; ok { + continue + } + seenSalt[salt] = struct{}{} + id, err := deriveDeterministicIdentityFromPassphraseWithSalt(passphrase, salt) if err != nil { return nil, err @@ -632,10 +654,10 @@ func deriveDeterministicIdentitiesFromPassphrase(passphrase string) ([]age.Ident if err != nil { return nil, err } - if _, ok := seen[rec]; ok { + if _, ok := seenRec[rec]; ok { continue } - seen[rec] = struct{}{} + seenRec[rec] = struct{}{} ids = append(ids, id) } return ids, nil diff --git a/internal/orchestrator/extensions.go b/internal/orchestrator/extensions.go index 619fc50..faa22f3 100644 --- a/internal/orchestrator/extensions.go +++ b/internal/orchestrator/extensions.go @@ -230,6 +230,20 @@ func (o *Orchestrator) DispatchEarlyErrorNotification(ctx context.Context, early stats.LogFilePath = logPath } + // Export a Prometheus "fail" metric (status=error) so textfile-based alerting + // fires on early-init failures too; otherwise the textfile keeps the last + // successful run's metrics. Self-gated by MetricsEnabled && !dryRun. + if o.shouldExportBackupMetrics(stats) { + o.ensureBackupStatsTiming(stats) + o.exportPrometheusBackupMetrics(stats) + } + + // Honor dry-run like the normal finalize path: never send real notifications. + if o.dryRun { + o.logger.Info("[DRY RUN] Would send early-error notification: %s", stats.LocalStatusSummary) + return stats + } + // Dispatch notifications with minimal stats. Early errors are already // represented in stats and may not be present in the log file yet. o.dispatchNotifications(ctx, stats) diff --git a/internal/orchestrator/fs_atomic.go b/internal/orchestrator/fs_atomic.go index c4e6f85..8a09fd2 100644 --- a/internal/orchestrator/fs_atomic.go +++ b/internal/orchestrator/fs_atomic.go @@ -184,75 +184,82 @@ func desiredOwnershipForAtomicWrite(destPath string) uidGid { return uidGid{} } -func writeFileAtomic(path string, data []byte, perm os.FileMode) error { - path = filepath.Clean(strings.TrimSpace(path)) - if path == "" || path == "." { - return fmt.Errorf("invalid path") +// syncDir fsyncs a directory so a rename within it becomes durable. Filesystems +// that do not support directory fsync (EINVAL/ENOTSUP) are tolerated. +func syncDir(dir string) error { + dir = filepath.Clean(strings.TrimSpace(dir)) + if dir == "" { + dir = "." + } + + df, err := restoreFS.Open(dir) + if err != nil { + return fmt.Errorf("open dir %s: %w", dir, err) + } + + syncErr := atomicFileSync(df) + closeErr := df.Close() + if syncErr != nil { + if errors.Is(syncErr, syscall.EINVAL) || errors.Is(syncErr, syscall.ENOTSUP) { + return closeErr + } + return fmt.Errorf("fsync dir %s: %w", dir, syncErr) + } + if closeErr != nil { + return fmt.Errorf("close dir %s: %w", dir, closeErr) + } + return nil +} + +// prepareAtomicTempFile is phase 1 of an atomic write: it writes data to a sibling +// temp file of path with the final ownership/permissions applied and flushed, but +// does NOT rename it into place, so no live file is touched yet. On any error before +// the temp is ready it removes the temp. It returns the temp path, the cleaned +// destination path, and the parent directory for commitAtomicTempFile (phase 2). +func prepareAtomicTempFile(path string, data []byte, perm os.FileMode) (tmpPath, cleanPath, dir string, err error) { + cleanPath = filepath.Clean(strings.TrimSpace(path)) + if cleanPath == "" || cleanPath == "." { + return "", "", "", fmt.Errorf("invalid path") } perm = modeBits(perm) if perm == 0 { perm = 0o644 } - dir := filepath.Dir(path) + dir = filepath.Dir(cleanPath) if err := ensureDirExistsWithInheritedMeta(dir); err != nil { - return err + return "", "", "", err } - owner := desiredOwnershipForAtomicWrite(path) + owner := desiredOwnershipForAtomicWrite(cleanPath) - tmpPath := fmt.Sprintf("%s.proxsave.tmp.%d", path, nowRestore().UnixNano()) + tmpPath = fmt.Sprintf("%s.proxsave.tmp.%d", cleanPath, nowRestore().UnixNano()) f, err := restoreFS.OpenFile(tmpPath, os.O_CREATE|os.O_WRONLY|os.O_EXCL|os.O_TRUNC, 0o600) if err != nil { - return err - } - - syncDir := func(dir string) error { - dir = filepath.Clean(strings.TrimSpace(dir)) - if dir == "" { - dir = "." - } - - df, err := restoreFS.Open(dir) - if err != nil { - return fmt.Errorf("open dir %s: %w", dir, err) - } - - syncErr := atomicFileSync(df) - closeErr := df.Close() - if syncErr != nil { - if errors.Is(syncErr, syscall.EINVAL) || errors.Is(syncErr, syscall.ENOTSUP) { - return closeErr - } - return fmt.Errorf("fsync dir %s: %w", dir, syncErr) - } - if closeErr != nil { - return fmt.Errorf("close dir %s: %w", dir, closeErr) - } - return nil + return "", "", "", err } writeErr := func() error { if len(data) == 0 { return nil } - _, err := f.Write(data) - return err + _, werr := f.Write(data) + return werr }() if writeErr == nil { if atomicGeteuid() == 0 && owner.ok { - if err := atomicFileChown(f, owner.uid, owner.gid); err != nil { - writeErr = err + if cerr := atomicFileChown(f, owner.uid, owner.gid); cerr != nil { + writeErr = cerr } } if writeErr == nil { - if err := atomicFileChmod(f, perm); err != nil { - writeErr = err + if cerr := atomicFileChmod(f, perm); cerr != nil { + writeErr = cerr } } if writeErr == nil { - if err := atomicFileSync(f); err != nil { - writeErr = err + if serr := atomicFileSync(f); serr != nil { + writeErr = serr } } } @@ -260,20 +267,111 @@ func writeFileAtomic(path string, data []byte, perm os.FileMode) error { closeErr := f.Close() if writeErr != nil { _ = restoreFS.Remove(tmpPath) - return writeErr + return "", "", "", writeErr } if closeErr != nil { _ = restoreFS.Remove(tmpPath) - return closeErr + return "", "", "", closeErr } + return tmpPath, cleanPath, dir, nil +} - if err := restoreFS.Rename(tmpPath, path); err != nil { +// commitAtomicTempFile is phase 2 of an atomic write: it renames a prepared temp +// file into its final destination and fsyncs the parent directory. The returned +// committed flag lets a batch caller roll back precisely: committed==false means the +// rename failed and the destination is untouched (the temp was removed); committed== +// true with a non-nil error means the rename succeeded (the file is live) but the +// directory fsync failed. +func commitAtomicTempFile(tmpPath, cleanPath, dir string) (committed bool, err error) { + if err := restoreFS.Rename(tmpPath, cleanPath); err != nil { _ = restoreFS.Remove(tmpPath) - return err + return false, err } - if err := syncDir(dir); err != nil { + return true, err + } + return true, nil +} + +func writeFileAtomic(path string, data []byte, perm os.FileMode) error { + tmpPath, cleanPath, dir, err := prepareAtomicTempFile(path, data, perm) + if err != nil { return err } + _, err = commitAtomicTempFile(tmpPath, cleanPath, dir) + return err +} + +// atomicFileWrite is one entry of a writeFilesAtomic batch. original holds the +// current on-disk bytes used to roll the file back if a later file in the batch +// fails to commit. An empty original (the read-back representation of an absent or +// empty file) rolls back to an empty file, which is acceptable for the account DB. +type atomicFileWrite struct { + path string + data []byte + original []byte + perm os.FileMode +} + +// writeFilesAtomic writes a set of files all-or-nothing as far as is achievable +// without a journal. Phase 1 prepares every temp file; if any preparation fails no +// destination is touched, so the common disk-full / read-only / IO failures leave +// every live file unchanged. Phase 2 renames each temp into place; if a commit fails +// the already-committed files are rolled back to their originals. Residual window: a +// crash or power loss BETWEEN renames cannot be made atomic here (that needs a +// journal / recovery marker) — the window is narrowed to the cheap renames because +// all data is already fsynced and closed in phase 1. +func writeFilesAtomic(writes []atomicFileWrite) error { + type prepared struct { + tmpPath string + cleanPath string + dir string + w atomicFileWrite + } + + staged := make([]prepared, 0, len(writes)) + // Phase 1: prepare all temps. No destination file is touched yet. + for _, w := range writes { + tmpPath, cleanPath, dir, err := prepareAtomicTempFile(w.path, w.data, w.perm) + if err != nil { + for _, s := range staged { + _ = restoreFS.Remove(s.tmpPath) + } + return fmt.Errorf("prepare %s: %w", w.path, err) + } + staged = append(staged, prepared{tmpPath: tmpPath, cleanPath: cleanPath, dir: dir, w: w}) + } + + // Phase 2: commit (rename) each in order; roll back on failure. + for i, s := range staged { + committed, err := commitAtomicTempFile(s.tmpPath, s.cleanPath, s.dir) + if err == nil { + continue + } + + // Roll back the files already made live. When committed is true the rename of + // index i succeeded (only the dir fsync failed) so file i is live and must be + // reverted too; when false the rename failed and index i is untouched (its + // temp was already removed by commitAtomicTempFile). + last := i - 1 + if committed { + last = i + } + var rollbackFailed []string + for j := last; j >= 0; j-- { + if rbErr := writeFileAtomic(staged[j].cleanPath, staged[j].w.original, staged[j].w.perm); rbErr != nil { + rollbackFailed = append(rollbackFailed, staged[j].cleanPath) + } + } + // Remove the not-yet-committed temps (index i's temp is already gone). + for k := i + 1; k < len(staged); k++ { + _ = restoreFS.Remove(staged[k].tmpPath) + } + + if len(rollbackFailed) > 0 { + return fmt.Errorf("CRITICAL: commit of %s failed and rollback of %v also failed; the on-disk file set may be inconsistent and manual recovery is required: %w", s.cleanPath, rollbackFailed, err) + } + return fmt.Errorf("commit %s failed; already-written files were rolled back to their originals: %w", s.cleanPath, err) + } return nil } diff --git a/internal/orchestrator/fs_atomic_test.go b/internal/orchestrator/fs_atomic_test.go index 843d06e..cffcb1c 100644 --- a/internal/orchestrator/fs_atomic_test.go +++ b/internal/orchestrator/fs_atomic_test.go @@ -1,13 +1,40 @@ package orchestrator import ( + "errors" "os" "path/filepath" + "strings" "syscall" "testing" "time" ) +// rollbackFailFS wraps FakeFS to fail OpenFile for specific temp paths only AFTER a +// designated rename has been attempted (the forward commit phase has run). This lets a +// test break the rollback writes without also breaking the forward prepare, which share +// the same temp name under a pinned clock. +type rollbackFailFS struct { + *FakeFS + tripOnRename string + tripped bool + failOpenAfterTrip map[string]bool +} + +func (r *rollbackFailFS) Rename(oldpath, newpath string) error { + if filepath.Clean(oldpath) == r.tripOnRename { + r.tripped = true + } + return r.FakeFS.Rename(oldpath, newpath) +} + +func (r *rollbackFailFS) OpenFile(path string, flag int, perm os.FileMode) (*os.File, error) { + if r.tripped && r.failOpenAfterTrip[filepath.Clean(path)] { + return nil, errors.New("forced rollback open failure: " + filepath.Clean(path)) + } + return r.FakeFS.OpenFile(path, flag, perm) +} + type statOverrideFS struct { FS Infos map[string]os.FileInfo @@ -228,6 +255,155 @@ func TestWriteFileAtomic_InheritsGroupFromParentWhenDestMissing(t *testing.T) { } } +// REFUTER: force committed==true (rename succeeds, directory fsync fails) at a middle +// index and assert (a) the live file at that index is reverted, (b) earlier committed +// files are reverted, (c) later untouched files stay original, (d) NO temp leaks, and +// (e) the rollback temp re-create does NOT collide on O_EXCL with the (now-consumed) +// original temp name despite the pinned fixed clock. +func TestWriteFilesAtomic_CommittedTrueDirFsyncFailRollsBackLiveFile(t *testing.T) { + fakeFS := NewFakeFS() + origFS := restoreFS + origTime := restoreTime + origSync := atomicFileSync + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + atomicFileSync = origSync + _ = os.RemoveAll(fakeFS.Root) + }) + + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Unix(10, 0)} + + pA := "/etc/auth/a" + pB := "/etc/auth/b" + pC := "/etc/auth/c" + if err := fakeFS.MkdirAll("/etc/auth", 0o755); err != nil { + t.Fatalf("seed dir: %v", err) + } + for p, orig := range map[string]string{pA: "A-orig\n", pB: "B-orig\n", pC: "C-orig\n"} { + if err := fakeFS.WriteFile(p, []byte(orig), 0o644); err != nil { + t.Fatalf("seed %s: %v", p, err) + } + } + + // Fail ONLY B's dir fsync (the second directory-fsync call): A's commit (call 1) + // succeeds, B's commit renames OK but its dir fsync fails -> committed==true at + // index 1. All later dir fsyncs (calls 3+, performed by the rollback writes) must + // succeed so we exercise the rollback-SUCCESS path, not the failing-device path. + calls := 0 + atomicFileSync = func(f *os.File) error { + if f == nil { + return nil + } + info, err := f.Stat() + if err == nil && info != nil && info.IsDir() { + calls++ + if calls == 2 { + return errors.New("forced dir fsync failure") // B's dir fsync fails -> committed==true at index 1 + } + return nil + } + return nil + } + + writes := []atomicFileWrite{ + {path: pA, data: []byte("A-new\n"), original: []byte("A-orig\n"), perm: 0o644}, + {path: pB, data: []byte("B-new\n"), original: []byte("B-orig\n"), perm: 0o644}, + {path: pC, data: []byte("C-new\n"), original: []byte("C-orig\n"), perm: 0o644}, + } + err := writeFilesAtomic(writes) + if err == nil { + t.Fatal("expected an error when a committed file's dir fsync fails") + } + if strings.Contains(err.Error(), "CRITICAL") { + t.Fatalf("rollback should have succeeded (not CRITICAL): %v", err) + } + + // A committed then was rolled back; B committed (rename ok) then dir-fsync failed, + // so B is live and must ALSO be reverted; C never committed. + for p, want := range map[string]string{pA: "A-orig\n", pB: "B-orig\n", pC: "C-orig\n"} { + got, rerr := fakeFS.ReadFile(p) + if rerr != nil { + t.Fatalf("read %s: %v", p, rerr) + } + if string(got) != want { + t.Errorf("%s = %q, want rolled-back original %q", p, string(got), want) + } + } + + // No temp may leak. Scan the dir for any .proxsave.tmp. residue. + entries, derr := fakeFS.ReadDir("/etc/auth") + if derr != nil { + t.Fatalf("readdir: %v", derr) + } + for _, e := range entries { + if strings.Contains(e.Name(), ".proxsave.tmp.") { + t.Errorf("leaked temp file: %s", e.Name()) + } + } +} + +// REFUTER: rollback-of-rollback. Force a commit failure that triggers rollback, and +// make the rollback writes ALSO fail; assert a CRITICAL error, that ALL rollbacks are +// still attempted (no early return / panic / infinite loop), and the later temps are +// still cleaned. +func TestWriteFilesAtomic_RollbackAlsoFailsReturnsCritical(t *testing.T) { + fakeFS := NewFakeFS() + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + _ = os.RemoveAll(fakeFS.Root) + }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Unix(10, 0)} + + pA := "/etc/auth/a" + pB := "/etc/auth/b" + pC := "/etc/auth/c" + if err := fakeFS.MkdirAll("/etc/auth", 0o755); err != nil { + t.Fatalf("seed dir: %v", err) + } + for _, p := range []string{pA, pB, pC} { + if err := fakeFS.WriteFile(p, []byte("orig\n"), 0o644); err != nil { + t.Fatalf("seed %s: %v", p, err) + } + } + + tmp := func(p string) string { + return filepath.Clean(p) + ".proxsave.tmp." + "10000000000" // time.Unix(10,0).UnixNano() + } + // C's rename fails -> rollback of A and B. Because the clock is pinned, the rollback + // temp name for A/B is IDENTICAL to the name used in the forward prepare/commit, so a + // static OpenFileErr keyed on that name would also break the forward prepare. To fail + // ONLY the rollback writes we wrap the FS and start rejecting OpenFile for A/B's temp + // after the forward commit phase has run (tripped once C's rename is requested). + cf := &rollbackFailFS{FakeFS: fakeFS, failOpenAfterTrip: map[string]bool{tmp(pA): true, tmp(pB): true}} + cf.RenameErr[tmp(pC)] = errors.New("forced C rename failure") + // Trip the gate when C's (failing) rename is attempted. + cf.tripOnRename = tmp(pC) + restoreFS = cf + + writes := []atomicFileWrite{ + {path: pA, data: []byte("A-new\n"), original: []byte("A-orig\n"), perm: 0o644}, + {path: pB, data: []byte("B-new\n"), original: []byte("B-orig\n"), perm: 0o644}, + {path: pC, data: []byte("C-new\n"), original: []byte("C-orig\n"), perm: 0o644}, + } + err := writeFilesAtomic(writes) + if err == nil { + t.Fatal("expected an error") + } + if !strings.Contains(err.Error(), "CRITICAL") { + t.Fatalf("expected CRITICAL error when rollbacks fail, got: %v", err) + } + // BOTH failed rollbacks must be named (proves the loop did not stop at the first). + if !strings.Contains(err.Error(), pA) || !strings.Contains(err.Error(), pB) { + t.Errorf("CRITICAL error should list both failed rollbacks %s and %s, got: %v", pA, pB, err) + } +} + func TestEnsureDirExistsWithInheritedMeta_CreatesNestedDirsWithInheritedMeta(t *testing.T) { fakeFS := NewFakeFS() origFS := restoreFS diff --git a/internal/orchestrator/orchestrator.go b/internal/orchestrator/orchestrator.go index 959391f..e00c93a 100644 --- a/internal/orchestrator/orchestrator.go +++ b/internal/orchestrator/orchestrator.go @@ -4,6 +4,7 @@ import ( "archive/tar" "context" "encoding/json" + "errors" "fmt" "io" "os" @@ -360,6 +361,11 @@ func (o *Orchestrator) SetIdentity(serverID, serverMAC string) { } // RunPreBackupChecks performs all pre-backup validation checks +// ErrBackupInProgress signals that the pre-backup lock check failed because +// another backup is already running. It is a benign concurrency skip (not a +// failure), so the caller exits 0 and does NOT send a failure notification. +var ErrBackupInProgress = errors.New("another backup is already in progress") + func (o *Orchestrator) RunPreBackupChecks(ctx context.Context) error { if o.checker == nil { o.logger.Debug("No checker configured, skipping pre-backup checks") @@ -419,6 +425,9 @@ func (o *Orchestrator) RunPreBackupChecks(ctx context.Context) error { lockResult := o.checker.CheckLockFile() logResult(lockResult) if !lockResult.Passed { + if lockResult.Code == checks.CheckCodeBackupInProgress { + return fmt.Errorf("pre-backup checks failed: %s: %w", lockResult.Message, ErrBackupInProgress) + } return fmt.Errorf("pre-backup checks failed: %s", lockResult.Message) } @@ -1155,6 +1164,7 @@ func applyCollectorOverrides(cc *backup.CollectorConfig, cfg *config.Config) { cc.BackupScriptRepository = cfg.BackupScriptRepository cc.BackupUserHomes = cfg.BackupUserHomes cc.BackupConfigFile = cfg.BackupConfigFile + cc.SystemRootPrefix = cfg.SystemRootPrefix cc.ScriptRepositoryPath = cfg.BaseDir if cfg.PxarDatastoreConcurrency > 0 { cc.PxarDatastoreConcurrency = cfg.PxarDatastoreConcurrency diff --git a/internal/orchestrator/orchestrator_test.go b/internal/orchestrator/orchestrator_test.go index 56e98a3..e5597ff 100644 --- a/internal/orchestrator/orchestrator_test.go +++ b/internal/orchestrator/orchestrator_test.go @@ -648,6 +648,7 @@ func TestApplyCollectorOverridesCopiesConfig(t *testing.T) { BackupScriptRepository: true, BackupUserHomes: true, BackupConfigFile: true, + SystemRootPrefix: "/mnt/testroot", BaseDir: "/opt/proxsave", PxarDatastoreConcurrency: 3, @@ -689,6 +690,9 @@ func TestApplyCollectorOverridesCopiesConfig(t *testing.T) { if cc.ScriptRepositoryPath != cfg.BaseDir { t.Fatalf("ScriptRepositoryPath = %s, want %s", cc.ScriptRepositoryPath, cfg.BaseDir) } + if cc.SystemRootPrefix != cfg.SystemRootPrefix { + t.Fatalf("SystemRootPrefix = %q, want %q", cc.SystemRootPrefix, cfg.SystemRootPrefix) + } if cc.PxarDatastoreConcurrency != cfg.PxarDatastoreConcurrency { t.Fatalf("PxarDatastoreConcurrency not copied correctly") } diff --git a/internal/orchestrator/passphrase_salt.go b/internal/orchestrator/passphrase_salt.go new file mode 100644 index 0000000..02a4443 --- /dev/null +++ b/internal/orchestrator/passphrase_salt.go @@ -0,0 +1,84 @@ +package orchestrator + +import ( + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "os" + "path/filepath" + "strings" +) + +// randomSaltNamespaceV2 namespaces the per-installation random salt used to +// derive a passphrase-based AGE recipient. The "v2" generation replaces the +// fixed salts (recipientSaltV1 / legacyRecipientSalt), which +// remain accepted at decrypt time for backward compatibility with older archives. +const randomSaltNamespaceV2 = "proxsave/age-passphrase/v2:" + +// passphraseSaltFilePath returns the salt file that sits next to a recipient file. +func passphraseSaltFilePath(recipientPath string) string { + return filepath.Join(filepath.Dir(recipientPath), "passphrase.salt") +} + +// getOrCreatePassphraseSalt returns the per-installation passphrase salt stored +// next to recipientPath, generating and persisting a fresh random one if absent. +// The salt is public (it only provides domain separation / anti-precomputation): +// it is stored 0600 next to the recipient and also embedded in each archive +// manifest so the passphrase alone can re-derive the identity on any host. +func (o *Orchestrator) getOrCreatePassphraseSalt(recipientPath string) (string, error) { + if strings.TrimSpace(recipientPath) == "" { + return "", fmt.Errorf("recipient path is required to resolve the passphrase salt") + } + fs := o.filesystem() + saltPath := passphraseSaltFilePath(recipientPath) + + data, err := fs.ReadFile(saltPath) + if err == nil { + if salt := strings.TrimSpace(string(data)); salt != "" { + return salt, nil + } + } else if !errors.Is(err, os.ErrNotExist) { + return "", fmt.Errorf("read passphrase salt %s: %w", saltPath, err) + } + + raw := make([]byte, 16) + if _, err := rand.Read(raw); err != nil { + return "", fmt.Errorf("generate passphrase salt: %w", err) + } + salt := randomSaltNamespaceV2 + hex.EncodeToString(raw) + if err := fs.MkdirAll(filepath.Dir(saltPath), 0o700); err != nil { + return "", fmt.Errorf("create passphrase salt directory: %w", err) + } + if err := writeFileAtomicWithDeps(fs, o.clock, saltPath, []byte(salt+"\n"), 0o600); err != nil { + return "", fmt.Errorf("persist passphrase salt %s: %w", saltPath, err) + } + return salt, nil +} + +// readPassphraseSalt returns the persisted per-installation salt next to +// recipientPath, or "" if it is absent/unreadable. +func (o *Orchestrator) readPassphraseSalt(recipientPath string) string { + if strings.TrimSpace(recipientPath) == "" { + return "" + } + data, err := o.filesystem().ReadFile(passphraseSaltFilePath(recipientPath)) + if err != nil { + return "" + } + return strings.TrimSpace(string(data)) +} + +// passphraseSaltForManifest returns the per-installation salt to embed in an +// archive manifest, or "" when encryption is off or no passphrase salt exists +// (X25519/SSH-only setups, or legacy installs still on the fixed salt). +func (o *Orchestrator) passphraseSaltForManifest() string { + if o == nil || o.cfg == nil || !o.cfg.EncryptArchive { + return "" + } + recipientPath := strings.TrimSpace(o.cfg.AgeRecipientFile) + if recipientPath == "" { + recipientPath = o.defaultAgeRecipientFile() + } + return o.readPassphraseSalt(recipientPath) +} diff --git a/internal/orchestrator/passphrase_salt_test.go b/internal/orchestrator/passphrase_salt_test.go new file mode 100644 index 0000000..21325fd --- /dev/null +++ b/internal/orchestrator/passphrase_salt_test.go @@ -0,0 +1,213 @@ +package orchestrator + +import ( + "bytes" + "io" + "os" + "path/filepath" + "strings" + "testing" + + "filippo.io/age" + "github.com/tis24dev/proxsave/internal/backup" + "github.com/tis24dev/proxsave/internal/config" +) + +const testStrongPassphrase = "Str0ng-Passphrase!" + +func encryptToRecipient(t *testing.T, recipientStr string, plaintext []byte) []byte { + t.Helper() + rec, err := age.ParseX25519Recipient(recipientStr) + if err != nil { + t.Fatalf("parse recipient: %v", err) + } + var buf bytes.Buffer + w, err := age.Encrypt(&buf, rec) + if err != nil { + t.Fatalf("age.Encrypt: %v", err) + } + if _, err := w.Write(plaintext); err != nil { + t.Fatalf("write plaintext: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("close age writer: %v", err) + } + return buf.Bytes() +} + +func TestGetOrCreatePassphraseSalt(t *testing.T) { + recipientPath := filepath.Join(t.TempDir(), "identity", "age", "recipient.txt") + o := &Orchestrator{} + + salt, err := o.getOrCreatePassphraseSalt(recipientPath) + if err != nil { + t.Fatalf("getOrCreatePassphraseSalt: %v", err) + } + if !strings.HasPrefix(salt, randomSaltNamespaceV2) { + t.Fatalf("salt %q missing prefix %q", salt, randomSaltNamespaceV2) + } + if salt == recipientSaltV1 || salt == legacyRecipientSalt { + t.Fatalf("salt collided with a fixed salt: %q", salt) + } + + saltPath := passphraseSaltFilePath(recipientPath) + info, err := os.Stat(saltPath) + if err != nil { + t.Fatalf("salt file not persisted: %v", err) + } + if perm := info.Mode().Perm(); perm != 0o600 { + t.Fatalf("salt file perm = %v, want 0600", perm) + } + + // Idempotent: a second call returns the same persisted salt. + salt2, err := o.getOrCreatePassphraseSalt(recipientPath) + if err != nil { + t.Fatalf("second getOrCreatePassphraseSalt: %v", err) + } + if salt2 != salt { + t.Fatalf("salt not stable across calls: %q vs %q", salt, salt2) + } + if got := o.readPassphraseSalt(recipientPath); got != salt { + t.Fatalf("readPassphraseSalt = %q, want %q", got, salt) + } +} + +func TestPassphraseSaltIsPerInstallation(t *testing.T) { + o1, o2 := &Orchestrator{}, &Orchestrator{} + r1 := filepath.Join(t.TempDir(), "recipient.txt") + r2 := filepath.Join(t.TempDir(), "recipient.txt") + + salt1, err := o1.getOrCreatePassphraseSalt(r1) + if err != nil { + t.Fatal(err) + } + salt2, err := o2.getOrCreatePassphraseSalt(r2) + if err != nil { + t.Fatal(err) + } + if salt1 == salt2 { + t.Fatalf("two installations produced the same salt: %q", salt1) + } + + rec1, err := deriveDeterministicRecipientFromPassphraseWithSalt(testStrongPassphrase, salt1) + if err != nil { + t.Fatal(err) + } + rec2, err := deriveDeterministicRecipientFromPassphraseWithSalt(testStrongPassphrase, salt2) + if err != nil { + t.Fatal(err) + } + if rec1 == rec2 { + t.Fatalf("same passphrase produced the same recipient across installs (correlatable): %q", rec1) + } + + recConst, err := deriveDeterministicRecipientFromPassphrase(testStrongPassphrase) + if err != nil { + t.Fatal(err) + } + if rec1 == recConst { + t.Fatalf("random-salt recipient equals the fixed-salt recipient: %q", rec1) + } +} + +// TestPassphraseRandomSaltRoundTripAndIsolation proves the per-install salt is +// actually required to decrypt: with the salt the archive decrypts, and with the +// fixed/legacy salts alone (no manifest salt) it does not. +func TestPassphraseRandomSaltRoundTripAndIsolation(t *testing.T) { + salt := randomSaltNamespaceV2 + "00112233445566778899aabbccddeeff" + recStr, err := deriveDeterministicRecipientFromPassphraseWithSalt(testStrongPassphrase, salt) + if err != nil { + t.Fatal(err) + } + plaintext := []byte("top secret backup bytes") + ciphertext := encryptToRecipient(t, recStr, plaintext) + + // With the per-install salt (as carried in the manifest) → success. + idsWith, err := parseIdentityInputWithSalts(testStrongPassphrase, []string{salt}) + if err != nil { + t.Fatalf("parseIdentityInputWithSalts: %v", err) + } + r, err := age.Decrypt(bytes.NewReader(ciphertext), idsWith...) + if err != nil { + t.Fatalf("decrypt with per-install salt failed: %v", err) + } + got, err := io.ReadAll(r) + if err != nil { + t.Fatalf("read decrypted: %v", err) + } + if !bytes.Equal(got, plaintext) { + t.Fatalf("decrypted = %q, want %q", got, plaintext) + } + + // Without the salt (only the fixed v1/legacy salts) → must NOT decrypt. + idsConst, err := parseIdentityInput(testStrongPassphrase) + if err != nil { + t.Fatalf("parseIdentityInput: %v", err) + } + if _, err := age.Decrypt(bytes.NewReader(ciphertext), idsConst...); err == nil { + t.Fatalf("archive decrypted WITHOUT the per-install salt; the salt is not actually required") + } +} + +// TestLegacyConstantSaltArchiveStillDecrypts guarantees backward compatibility: +// archives produced before the per-install salt (derived from the fixed salt and +// carrying no manifest salt) keep decrypting via the constant fallbacks. +func TestLegacyConstantSaltArchiveStillDecrypts(t *testing.T) { + recStr, err := deriveDeterministicRecipientFromPassphrase(testStrongPassphrase) + if err != nil { + t.Fatal(err) + } + plaintext := []byte("legacy archive payload") + ciphertext := encryptToRecipient(t, recStr, plaintext) + + ids, err := parseIdentityInput(testStrongPassphrase) // no manifest salt, constants only + if err != nil { + t.Fatalf("parseIdentityInput: %v", err) + } + r, err := age.Decrypt(bytes.NewReader(ciphertext), ids...) + if err != nil { + t.Fatalf("legacy archive failed to decrypt: %v", err) + } + got, err := io.ReadAll(r) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, plaintext) { + t.Fatalf("decrypted = %q, want %q", got, plaintext) + } +} + +func TestManifestPassphraseSalts(t *testing.T) { + if got := manifestPassphraseSalts(nil); got != nil { + t.Fatalf("nil manifest: got %v, want nil", got) + } + if got := manifestPassphraseSalts(&backup.Manifest{}); got != nil { + t.Fatalf("empty salt: got %v, want nil", got) + } + got := manifestPassphraseSalts(&backup.Manifest{PassphraseSalt: " the-salt "}) + if len(got) != 1 || got[0] != "the-salt" { + t.Fatalf("manifestPassphraseSalts = %v, want [the-salt]", got) + } +} + +func TestPassphraseSaltForManifest(t *testing.T) { + recipientPath := filepath.Join(t.TempDir(), "identity", "age", "recipient.txt") + o := &Orchestrator{cfg: &config.Config{EncryptArchive: true, AgeRecipientFile: recipientPath}} + + if got := o.passphraseSaltForManifest(); got != "" { + t.Fatalf("expected empty salt before setup, got %q", got) + } + + salt, err := o.getOrCreatePassphraseSalt(recipientPath) + if err != nil { + t.Fatal(err) + } + if got := o.passphraseSaltForManifest(); got != salt { + t.Fatalf("passphraseSaltForManifest = %q, want %q", got, salt) + } + + o.cfg.EncryptArchive = false + if got := o.passphraseSaltForManifest(); got != "" { + t.Fatalf("expected empty salt when encryption disabled, got %q", got) + } +} diff --git a/internal/orchestrator/pbs_api_apply.go b/internal/orchestrator/pbs_api_apply.go index 9851866..38d51cc 100644 --- a/internal/orchestrator/pbs_api_apply.go +++ b/internal/orchestrator/pbs_api_apply.go @@ -3,6 +3,7 @@ package orchestrator import ( "context" "encoding/json" + "errors" "fmt" "os" "sort" @@ -14,6 +15,23 @@ import ( var pbsAPIApplyGeteuid = os.Geteuid +// errPBSCleanRemoveIncomplete marks a Clean (1:1) restore where one or more stale +// objects absent from the backup could NOT be removed from PBS. The objects are +// left in place (the conservative outcome, nothing destroyed), but the caller +// must report the restore "with warnings" instead of a clean success, and must +// NOT paper over it with the file-based fallback (which would force-rewrite the +// .cfg and drop the object, bypassing PBS's refusal, e.g. for an in-use object). +var errPBSCleanRemoveIncomplete = errors.New("PBS Clean 1:1 incomplete: stale object(s) could not be removed") + +// pbsCleanRemoveResult returns nil when no clean-mode remove failed, otherwise an +// error wrapping errPBSCleanRemoveIncomplete listing the objects left behind. +func pbsCleanRemoveResult(kind string, failures []string) error { + if len(failures) == 0 { + return nil + } + return fmt.Errorf("%s remove failed for %d object(s) [%s]: %w", kind, len(failures), strings.Join(failures, ", "), errPBSCleanRemoveIncomplete) +} + func normalizeProxmoxCfgKey(key string) string { key = strings.ToLower(strings.TrimSpace(key)) key = strings.ReplaceAll(key, "_", "-") @@ -217,6 +235,7 @@ func applyPBSRemoteCfgViaAPI(ctx context.Context, logger *logging.Logger, stageR desired[name] = s } + var removeFailures []string if strict { out, err := runPBSManager(ctx, "remote", "list", "--output-format=json") if err != nil { @@ -232,6 +251,7 @@ func applyPBSRemoteCfgViaAPI(ctx context.Context, logger *logging.Logger, stageR } if _, err := runPBSManager(ctx, "remote", "remove", id); err != nil { logger.Warning("PBS API apply: remote remove %s failed (continuing): %v", id, err) + removeFailures = append(removeFailures, id) } } } @@ -253,7 +273,7 @@ func applyPBSRemoteCfgViaAPI(ctx context.Context, logger *logging.Logger, stageR } } - return nil + return pbsCleanRemoveResult("remote", removeFailures) } func applyPBSS3CfgViaAPI(ctx context.Context, logger *logging.Logger, stageRoot string, strict bool) error { @@ -278,6 +298,7 @@ func applyPBSS3CfgViaAPI(ctx context.Context, logger *logging.Logger, stageRoot desired[id] = s } + var removeFailures []string if strict { out, err := runPBSManager(ctx, "s3", "endpoint", "list", "--output-format=json") if err != nil { @@ -293,6 +314,7 @@ func applyPBSS3CfgViaAPI(ctx context.Context, logger *logging.Logger, stageRoot } if _, err := runPBSManager(ctx, "s3", "endpoint", "remove", id); err != nil { logger.Warning("PBS API apply: s3 endpoint remove %s failed (continuing): %v", id, err) + removeFailures = append(removeFailures, id) } } } @@ -314,7 +336,7 @@ func applyPBSS3CfgViaAPI(ctx context.Context, logger *logging.Logger, stageRoot } } - return nil + return pbsCleanRemoveResult("s3 endpoint", removeFailures) } func applyPBSDatastoreCfgViaAPI(ctx context.Context, logger *logging.Logger, stageRoot string, strict bool) error { @@ -365,6 +387,7 @@ func applyPBSDatastoreCfgViaAPI(ctx context.Context, logger *logging.Logger, sta } } + var removeFailures []string if strict { current := make([]string, 0, len(currentPaths)) for name := range currentPaths { @@ -377,6 +400,7 @@ func applyPBSDatastoreCfgViaAPI(ctx context.Context, logger *logging.Logger, sta } if _, err := runPBSManager(ctx, "datastore", "remove", name); err != nil { logger.Warning("PBS API apply: datastore remove %s failed (continuing): %v", name, err) + removeFailures = append(removeFailures, name) } } } @@ -425,7 +449,7 @@ func applyPBSDatastoreCfgViaAPI(ctx context.Context, logger *logging.Logger, sta } } - return nil + return pbsCleanRemoveResult("datastore", removeFailures) } func applyPBSSyncCfgViaAPI(ctx context.Context, logger *logging.Logger, stageRoot string, strict bool) error { @@ -450,6 +474,7 @@ func applyPBSSyncCfgViaAPI(ctx context.Context, logger *logging.Logger, stageRoo desired[id] = s } + var removeFailures []string if strict { out, err := runPBSManager(ctx, "sync-job", "list", "--output-format=json") if err != nil { @@ -465,6 +490,7 @@ func applyPBSSyncCfgViaAPI(ctx context.Context, logger *logging.Logger, stageRoo } if _, err := runPBSManager(ctx, "sync-job", "remove", id); err != nil { logger.Warning("PBS API apply: sync-job remove %s failed (continuing): %v", id, err) + removeFailures = append(removeFailures, id) } } } @@ -486,7 +512,7 @@ func applyPBSSyncCfgViaAPI(ctx context.Context, logger *logging.Logger, stageRoo } } - return nil + return pbsCleanRemoveResult("sync-job", removeFailures) } func applyPBSVerificationCfgViaAPI(ctx context.Context, logger *logging.Logger, stageRoot string, strict bool) error { @@ -511,6 +537,7 @@ func applyPBSVerificationCfgViaAPI(ctx context.Context, logger *logging.Logger, desired[id] = s } + var removeFailures []string if strict { out, err := runPBSManager(ctx, "verify-job", "list", "--output-format=json") if err != nil { @@ -526,6 +553,7 @@ func applyPBSVerificationCfgViaAPI(ctx context.Context, logger *logging.Logger, } if _, err := runPBSManager(ctx, "verify-job", "remove", id); err != nil { logger.Warning("PBS API apply: verify-job remove %s failed (continuing): %v", id, err) + removeFailures = append(removeFailures, id) } } } @@ -547,7 +575,7 @@ func applyPBSVerificationCfgViaAPI(ctx context.Context, logger *logging.Logger, } } - return nil + return pbsCleanRemoveResult("verify-job", removeFailures) } func applyPBSPruneCfgViaAPI(ctx context.Context, logger *logging.Logger, stageRoot string, strict bool) error { @@ -572,6 +600,7 @@ func applyPBSPruneCfgViaAPI(ctx context.Context, logger *logging.Logger, stageRo desired[id] = s } + var removeFailures []string if strict { out, err := runPBSManager(ctx, "prune-job", "list", "--output-format=json") if err != nil { @@ -587,6 +616,7 @@ func applyPBSPruneCfgViaAPI(ctx context.Context, logger *logging.Logger, stageRo } if _, err := runPBSManager(ctx, "prune-job", "remove", id); err != nil { logger.Warning("PBS API apply: prune-job remove %s failed (continuing): %v", id, err) + removeFailures = append(removeFailures, id) } } } @@ -608,7 +638,7 @@ func applyPBSPruneCfgViaAPI(ctx context.Context, logger *logging.Logger, stageRo } } - return nil + return pbsCleanRemoveResult("prune-job", removeFailures) } func applyPBSTrafficControlCfgViaAPI(ctx context.Context, logger *logging.Logger, stageRoot string, strict bool) error { @@ -633,6 +663,7 @@ func applyPBSTrafficControlCfgViaAPI(ctx context.Context, logger *logging.Logger desired[name] = s } + var removeFailures []string if strict { out, err := runPBSManager(ctx, "traffic-control", "list", "--output-format=json") if err != nil { @@ -648,6 +679,7 @@ func applyPBSTrafficControlCfgViaAPI(ctx context.Context, logger *logging.Logger } if _, err := runPBSManager(ctx, "traffic-control", "remove", name); err != nil { logger.Warning("PBS API apply: traffic-control remove %s failed (continuing): %v", name, err) + removeFailures = append(removeFailures, name) } } } @@ -669,7 +701,7 @@ func applyPBSTrafficControlCfgViaAPI(ctx context.Context, logger *logging.Logger } } - return nil + return pbsCleanRemoveResult("traffic-control", removeFailures) } func applyPBSNodeCfgViaAPI(ctx context.Context, stageRoot string) error { diff --git a/internal/orchestrator/pbs_api_apply_test.go b/internal/orchestrator/pbs_api_apply_test.go index a0ce640..fa566a3 100644 --- a/internal/orchestrator/pbs_api_apply_test.go +++ b/internal/orchestrator/pbs_api_apply_test.go @@ -608,8 +608,10 @@ func TestApplyPBSRemoteCfgViaAPI_StrictCleanupAndCreate(t *testing.T) { "proxmox-backup-manager remote remove old": errors.New("cannot remove old"), } - if err := applyPBSRemoteCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { - t.Fatalf("applyPBSRemoteCfgViaAPI error: %v", err) + // Clean (1:1) mode with a failed remove must surface as errPBSCleanRemoveIncomplete + // (reported "with warnings"), while create/update still run (no abort). + if err := applyPBSRemoteCfgViaAPI(context.Background(), logger, stageRoot, true); !errors.Is(err, errPBSCleanRemoveIncomplete) { + t.Fatalf("expected errPBSCleanRemoveIncomplete (clean 1:1 remove), got: %v", err) } want := []string{ @@ -711,8 +713,8 @@ func TestApplyPBSS3CfgViaAPI_CreateUpdateAndStrictCleanup(t *testing.T) { "proxmox-backup-manager s3 endpoint create e1 --endpoint https://s3.example --access-key access1 --secret-key secret1": errors.New("already exists"), } - if err := applyPBSS3CfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { - t.Fatalf("applyPBSS3CfgViaAPI error: %v", err) + if err := applyPBSS3CfgViaAPI(context.Background(), logger, stageRoot, true); !errors.Is(err, errPBSCleanRemoveIncomplete) { + t.Fatalf("expected errPBSCleanRemoveIncomplete (clean 1:1 remove), got: %v", err) } want := []string{ @@ -839,8 +841,8 @@ func TestApplyPBSDatastoreCfgViaAPI_CurrentPathsFallbacksAndStrictRemoveWarn(t * "proxmox-backup-manager datastore remove id1": errors.New("cannot remove id1"), } - if err := applyPBSDatastoreCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { - t.Fatalf("applyPBSDatastoreCfgViaAPI error: %v", err) + if err := applyPBSDatastoreCfgViaAPI(context.Background(), logger, stageRoot, true); !errors.Is(err, errPBSCleanRemoveIncomplete) { + t.Fatalf("expected errPBSCleanRemoveIncomplete (clean 1:1 remove), got: %v", err) } want := []string{ @@ -1027,8 +1029,8 @@ func TestApplyPBSSyncCfgViaAPI_StrictCleanupAndFallbackUpdate(t *testing.T) { "proxmox-backup-manager sync-job create job1 --remote r1 --store ds1": errors.New("already exists"), } - if err := applyPBSSyncCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { - t.Fatalf("applyPBSSyncCfgViaAPI error: %v", err) + if err := applyPBSSyncCfgViaAPI(context.Background(), logger, stageRoot, true); !errors.Is(err, errPBSCleanRemoveIncomplete) { + t.Fatalf("expected errPBSCleanRemoveIncomplete (clean 1:1 remove), got: %v", err) } want := []string{ @@ -1082,8 +1084,8 @@ func TestApplyPBSVerificationCfgViaAPI_StrictCleanupAndFallbackUpdate(t *testing "proxmox-backup-manager verify-job create v1 --store ds1": errors.New("already exists"), } - if err := applyPBSVerificationCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { - t.Fatalf("applyPBSVerificationCfgViaAPI error: %v", err) + if err := applyPBSVerificationCfgViaAPI(context.Background(), logger, stageRoot, true); !errors.Is(err, errPBSCleanRemoveIncomplete) { + t.Fatalf("expected errPBSCleanRemoveIncomplete (clean 1:1 remove), got: %v", err) } want := []string{ @@ -1139,8 +1141,8 @@ func TestApplyPBSPruneCfgViaAPI_StrictCleanupAndFallbackUpdate(t *testing.T) { "proxmox-backup-manager prune-job create p1 --store ds1 --keep-last 3": errors.New("already exists"), } - if err := applyPBSPruneCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { - t.Fatalf("applyPBSPruneCfgViaAPI error: %v", err) + if err := applyPBSPruneCfgViaAPI(context.Background(), logger, stageRoot, true); !errors.Is(err, errPBSCleanRemoveIncomplete) { + t.Fatalf("expected errPBSCleanRemoveIncomplete (clean 1:1 remove), got: %v", err) } want := []string{ @@ -1171,8 +1173,8 @@ func TestApplyPBSTrafficControlCfgViaAPI_StrictCleanupAndCreate(t *testing.T) { "proxmox-backup-manager traffic-control remove old": errors.New("cannot remove old"), } - if err := applyPBSTrafficControlCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { - t.Fatalf("applyPBSTrafficControlCfgViaAPI error: %v", err) + if err := applyPBSTrafficControlCfgViaAPI(context.Background(), logger, stageRoot, true); !errors.Is(err, errPBSCleanRemoveIncomplete) { + t.Fatalf("expected errPBSCleanRemoveIncomplete (clean 1:1 remove), got: %v", err) } want := []string{ diff --git a/internal/orchestrator/pbs_staged_apply.go b/internal/orchestrator/pbs_staged_apply.go index a8c2455..7940cdf 100644 --- a/internal/orchestrator/pbs_staged_apply.go +++ b/internal/orchestrator/pbs_staged_apply.go @@ -97,7 +97,9 @@ func maybeApplyPBSConfigsFromStage(ctx context.Context, logger *logging.Logger, if apiAvailable { if err := pbsStagedApplyTrafficControlCfgViaAPIFn(ctx, logger, stageRoot, strict); err != nil { logger.Warning("PBS API apply: traffic-control failed: %v", err) - if !pbsFallbackApplied(logger, "traffic-control.cfg", allowFileFallback, func() error { + if errors.Is(err, errPBSCleanRemoveIncomplete) { + failedItems = append(failedItems, "traffic-control.cfg (clean 1:1 incomplete)") + } else if !pbsFallbackApplied(logger, "traffic-control.cfg", allowFileFallback, func() error { return applyPBSConfigFileFromStage(ctx, logger, stageRoot, "etc/proxmox-backup/traffic-control.cfg") }) { failedItems = append(failedItems, "traffic-control.cfg") @@ -130,7 +132,9 @@ func maybeApplyPBSConfigsFromStage(ctx context.Context, logger *logging.Logger, if apiAvailable { if err := pbsStagedApplyS3CfgViaAPIFn(ctx, logger, stageRoot, strict); err != nil { logger.Warning("PBS API apply: s3.cfg failed: %v", err) - if !pbsFallbackApplied(logger, "s3.cfg", allowFileFallback, func() error { + if errors.Is(err, errPBSCleanRemoveIncomplete) { + failedItems = append(failedItems, "s3.cfg (clean 1:1 incomplete)") + } else if !pbsFallbackApplied(logger, "s3.cfg", allowFileFallback, func() error { return applyPBSS3CfgFromStage(ctx, logger, stageRoot) }) { failedItems = append(failedItems, "s3.cfg") @@ -138,7 +142,9 @@ func maybeApplyPBSConfigsFromStage(ctx context.Context, logger *logging.Logger, } if err := pbsStagedApplyDatastoreCfgViaAPIFn(ctx, logger, stageRoot, strict); err != nil { logger.Warning("PBS API apply: datastore.cfg failed: %v", err) - if !pbsFallbackApplied(logger, "datastore.cfg", allowFileFallback, func() error { + if errors.Is(err, errPBSCleanRemoveIncomplete) { + failedItems = append(failedItems, "datastore.cfg (clean 1:1 incomplete)") + } else if !pbsFallbackApplied(logger, "datastore.cfg", allowFileFallback, func() error { return applyPBSDatastoreCfgFromStage(ctx, logger, stageRoot) }) { failedItems = append(failedItems, "datastore.cfg") @@ -162,7 +168,9 @@ func maybeApplyPBSConfigsFromStage(ctx context.Context, logger *logging.Logger, if apiAvailable { if err := pbsStagedApplyRemoteCfgViaAPIFn(ctx, logger, stageRoot, strict); err != nil { logger.Warning("PBS API apply: remote.cfg failed: %v", err) - if !pbsFallbackApplied(logger, "remote.cfg", allowFileFallback, func() error { + if errors.Is(err, errPBSCleanRemoveIncomplete) { + failedItems = append(failedItems, "remote.cfg (clean 1:1 incomplete)") + } else if !pbsFallbackApplied(logger, "remote.cfg", allowFileFallback, func() error { return applyPBSRemoteCfgFromStage(ctx, logger, stageRoot) }) { failedItems = append(failedItems, "remote.cfg") @@ -182,7 +190,9 @@ func maybeApplyPBSConfigsFromStage(ctx context.Context, logger *logging.Logger, if apiAvailable { if err := pbsStagedApplySyncCfgViaAPIFn(ctx, logger, stageRoot, strict); err != nil { logger.Warning("PBS API apply: sync jobs failed: %v", err) - if !pbsFallbackApplied(logger, "job configs", allowFileFallback, func() error { + if errors.Is(err, errPBSCleanRemoveIncomplete) { + failedItems = append(failedItems, "sync.cfg (clean 1:1 incomplete)") + } else if !pbsFallbackApplied(logger, "job configs", allowFileFallback, func() error { return applyPBSJobConfigsFromStage(ctx, logger, stageRoot) }) { failedItems = append(failedItems, "sync.cfg") @@ -190,7 +200,9 @@ func maybeApplyPBSConfigsFromStage(ctx context.Context, logger *logging.Logger, } if err := pbsStagedApplyVerificationCfgViaAPIFn(ctx, logger, stageRoot, strict); err != nil { logger.Warning("PBS API apply: verification jobs failed: %v", err) - if !pbsFallbackApplied(logger, "job configs", allowFileFallback, func() error { + if errors.Is(err, errPBSCleanRemoveIncomplete) { + failedItems = append(failedItems, "verification.cfg (clean 1:1 incomplete)") + } else if !pbsFallbackApplied(logger, "job configs", allowFileFallback, func() error { return applyPBSJobConfigsFromStage(ctx, logger, stageRoot) }) { failedItems = append(failedItems, "verification.cfg") @@ -198,7 +210,9 @@ func maybeApplyPBSConfigsFromStage(ctx context.Context, logger *logging.Logger, } if err := pbsStagedApplyPruneCfgViaAPIFn(ctx, logger, stageRoot, strict); err != nil { logger.Warning("PBS API apply: prune jobs failed: %v", err) - if !pbsFallbackApplied(logger, "job configs", allowFileFallback, func() error { + if errors.Is(err, errPBSCleanRemoveIncomplete) { + failedItems = append(failedItems, "prune.cfg (clean 1:1 incomplete)") + } else if !pbsFallbackApplied(logger, "job configs", allowFileFallback, func() error { return applyPBSJobConfigsFromStage(ctx, logger, stageRoot) }) { failedItems = append(failedItems, "prune.cfg") diff --git a/internal/orchestrator/pbs_staged_apply_maybeapply_test.go b/internal/orchestrator/pbs_staged_apply_maybeapply_test.go index aba4e27..509225d 100644 --- a/internal/orchestrator/pbs_staged_apply_maybeapply_test.go +++ b/internal/orchestrator/pbs_staged_apply_maybeapply_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "os" + "strings" "testing" "github.com/tis24dev/proxsave/internal/logging" @@ -408,3 +409,66 @@ func TestMaybeApplyPBSConfigsFromStage_ApiErrorsTriggerFallbackOnlyInCleanMode(t } } } + +// TestMaybeApplyPBSConfigsFromStage_CleanRemoveIncomplete_SurfacesWithoutFallback +// (H10) checks that when a Clean (1:1) API apply succeeds at create/update but +// could not remove a stale object, the wrapper surfaces it as a failed item (so +// the restore reports "with warnings") WITHOUT invoking the destructive +// file-based fallback that would force-rewrite the .cfg and drop the object. +func TestMaybeApplyPBSConfigsFromStage_CleanRemoveIncomplete_SurfacesWithoutFallback(t *testing.T) { + origFS := restoreFS + origIsReal := pbsStagedApplyIsRealRestoreFSFn + origGeteuid := pbsStagedApplyGeteuidFn + origEnsure := pbsStagedApplyEnsurePBSServicesForAPIFn + origDS := pbsStagedApplyDatastoreCfgViaAPIFn + t.Cleanup(func() { + restoreFS = origFS + pbsStagedApplyIsRealRestoreFSFn = origIsReal + pbsStagedApplyGeteuidFn = origGeteuid + pbsStagedApplyEnsurePBSServicesForAPIFn = origEnsure + pbsStagedApplyDatastoreCfgViaAPIFn = origDS + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + pbsStagedApplyIsRealRestoreFSFn = func(FS) bool { return true } + pbsStagedApplyGeteuidFn = func() int { return 0 } + pbsStagedApplyEnsurePBSServicesForAPIFn = func(context.Context, *logging.Logger) error { return nil } + + // Create/update succeed but a stale datastore could not be removed in Clean mode. + pbsStagedApplyDatastoreCfgViaAPIFn = func(context.Context, *logging.Logger, string, bool) error { + return pbsCleanRemoveResult("datastore", []string{"stale-ds"}) + } + + stageRoot := "/stage" + // datastore.cfg points at a SAFE, empty datastore path so the file fallback, IF + // wrongly invoked, would pass shouldApplyPBSDatastoreBlock and actually write the + // live file. A non-empty path like /tmp is deferred (never written), which would + // make the "file not written" assertion below pass even on a bypass regression. + safeDir := t.TempDir() + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/datastore.cfg", []byte("datastore: DS1\n path "+safeDir+"\n"), 0o640); err != nil { + t.Fatalf("write staged datastore.cfg: %v", err) + } + + plan := &RestorePlan{ + SystemType: SystemTypePBS, + PBSRestoreBehavior: PBSRestoreBehaviorClean, + NormalCategories: []Category{{ID: "datastore_pbs"}}, + } + + err := maybeApplyPBSConfigsFromStage(context.Background(), newTestLogger(), plan, stageRoot, false) + if err == nil { + t.Fatalf("expected a summary error so the restore reports 'with warnings'") + } + if !strings.Contains(err.Error(), "datastore.cfg (clean 1:1 incomplete)") { + t.Fatalf("expected the clean-1:1-incomplete item in the summary, got: %v", err) + } + + // Conservative outcome: the destructive file-based fallback must NOT run, so + // the live datastore.cfg is left untouched (the stale object stays). + if _, statErr := fakeFS.Stat("/etc/proxmox-backup/datastore.cfg"); statErr == nil { + t.Fatalf("file fallback must NOT rewrite /etc/proxmox-backup/datastore.cfg on a clean-remove failure") + } +} diff --git a/internal/orchestrator/restore_accounts.go b/internal/orchestrator/restore_accounts.go new file mode 100644 index 0000000..91ffacc --- /dev/null +++ b/internal/orchestrator/restore_accounts.go @@ -0,0 +1,516 @@ +package orchestrator + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + + "github.com/tis24dev/proxsave/internal/logging" +) + +const ( + etcPasswdPath = "/etc/passwd" + etcGroupPath = "/etc/group" + etcShadowPath = "/etc/shadow" + etcGshadowPath = "/etc/gshadow" + etcSudoersPath = "/etc/sudoers" + + // Accounts/groups whose numeric id is below this threshold belong to the + // host/distro (root@uid0, daemon users, sudo/docker groups, ...). They are + // preserved from the CURRENT host, so restoring accounts never locks out the + // running machine. Only regular accounts (id >= threshold) are imported. + systemAccountIDThreshold = 1000 + + passwdMinFields = 7 + groupMinFields = 4 + shadowMinFields = 2 + + // lockedShadowSuffix, appended to a username, yields a well-formed but locked + // (password-less) shadow entry. Used when an imported passwd user has no shadow + // line in the backup, so passwd and shadow never desync. + lockedShadowSuffix = ":*:::::::" +) + +// maybeApplyAccountsFromStage is the wired, gated entry point for restoring OS +// account files from the staging tree (#67). The merge lives in +// applyAccountsFromStage so it can be unit-tested with an in-memory FS. +func maybeApplyAccountsFromStage(ctx context.Context, logger *logging.Logger, plan *RestorePlan, stageRoot string, dryRun bool) (err error) { + if plan == nil || !plan.HasCategoryID("accounts") { + return nil + } + if strings.TrimSpace(stageRoot) == "" { + logging.DebugStep(logger, "accounts staged apply", "Skipped: staging directory not available") + return nil + } + done := logging.DebugStart(logger, "accounts staged apply", "dryRun=%v stage=%s", dryRun, stageRoot) + defer func() { done(err) }() + + if dryRun { + logger.Info("Dry run enabled: skipping staged system accounts apply") + return nil + } + if !isRealRestoreFS(restoreFS) { + logger.Debug("Skipping staged system accounts apply: non-system filesystem in use") + return nil + } + if accessControlApplyGeteuid() != 0 { + logger.Warning("Skipping staged system accounts apply: requires root privileges") + return nil + } + return applyAccountsFromStage(ctx, logger, stageRoot) +} + +func applyAccountsFromStage(ctx context.Context, logger *logging.Logger, stageRoot string) error { + if ctx == nil { + ctx = context.Background() + } + if err := ctx.Err(); err != nil { + return err + } + + stagedPasswd, hasPasswd, err := readStageFileOptional(stageRoot, "etc/passwd") + if err != nil { + return err + } + if !hasPasswd { + logging.DebugStep(logger, "accounts staged apply", "Skipped: etc/passwd not present in stage") + return nil + } + currentPasswd, err := readCurrentAccountFile(etcPasswdPath) + if err != nil { + return err + } + if strings.TrimSpace(currentPasswd) == "" { + // Never rewrite the account database without the host baseline (would drop + // root and system accounts). Skip rather than risk a lockout. + logger.Warning("Skipping system accounts restore: current /etc/passwd is empty/unreadable") + return nil + } + currentGroup, err := readCurrentAccountFile(etcGroupPath) + if err != nil { + return err + } + if strings.TrimSpace(currentGroup) == "" { + // Same anti-lockout rationale as the empty /etc/passwd guard above: never + // rewrite the group DB without the host baseline (it would drop root and all + // host system groups). Skip rather than risk dropping host group memberships. + logger.Warning("Skipping system accounts restore: current /etc/group is empty/unreadable") + return nil + } + + stagedShadow, _, err := readStageFileOptional(stageRoot, "etc/shadow") + if err != nil { + return err + } + currentShadow, err := readCurrentAccountFile(etcShadowPath) + if err != nil { + return err + } + if strings.TrimSpace(currentShadow) == "" { + // As above: without the host shadow baseline the rewrite would drop root and + // every host account's credentials. Skip to avoid a lockout. + logger.Warning("Skipping system accounts restore: current /etc/shadow is empty/unreadable") + return nil + } + stagedGroup, _, err := readStageFileOptional(stageRoot, "etc/group") + if err != nil { + return err + } + stagedGshadow, _, err := readStageFileOptional(stageRoot, "etc/gshadow") + if err != nil { + return err + } + currentGshadow, err := readCurrentAccountFile(etcGshadowPath) + if err != nil { + return err + } + + // Host identity maps. hostSystemUsers = names of host accounts with uid < threshold + // (incl. the uid 0 entry whatever its name), so a renamed root or a name-clash never + // clobbers a host account. hostGroupGID maps every host group name->gid and hostGIDs + // is the set of ALL host group gids: together they ensure the merge never overwrites + // an existing host group, never reuses a host gid for a new group, and never enrolls + // an imported user into an existing host group (system OR privileged) via primary gid. + hostSystemUsers := lowIDNames(currentPasswd, 2) + hostGroupGID := groupGIDsByName(currentGroup) + hostGIDs := gidValueSet(hostGroupGID) + + importedUsers, mergedPasswd := mergePasswd(currentPasswd, stagedPasswd, hostSystemUsers, hostGIDs) + mergedShadow := mergeShadow(currentShadow, stagedShadow, importedUsers) + importedGroups, mergedGroup := mergeGroup(currentGroup, stagedGroup, hostGroupGID, hostGIDs, importedUsers) + mergedGshadow := mergeGshadow(currentGshadow, stagedGshadow, importedGroups) + + // Write the four auth-DB files all-or-nothing: a failure partway through must never + // leave the host with an inconsistent passwd/shadow/group/gshadow set (lockout + // risk). The current* contents read above are the rollback source. + if err := writeFilesAtomic([]atomicFileWrite{ + {path: etcPasswdPath, data: []byte(mergedPasswd), original: []byte(currentPasswd), perm: 0o644}, + {path: etcShadowPath, data: []byte(mergedShadow), original: []byte(currentShadow), perm: 0o640}, + {path: etcGroupPath, data: []byte(mergedGroup), original: []byte(currentGroup), perm: 0o644}, + {path: etcGshadowPath, data: []byte(mergedGshadow), original: []byte(currentGshadow), perm: 0o640}, + }); err != nil { + return err + } + logger.Info("Restored system accounts: imported %d user(s) and %d group(s); host root, system accounts and group memberships preserved", len(importedUsers), len(importedGroups)) + + return applySudoersFromStage(ctx, logger, stageRoot) +} + +// applySudoersFromStage replaces /etc/sudoers with the EXACT staged bytes only if +// they pass `visudo -c`, otherwise the current file is kept untouched. +func applySudoersFromStage(ctx context.Context, logger *logging.Logger, stageRoot string) error { + stagedPath := filepath.Join(stageRoot, "etc/sudoers") + data, err := restoreFS.ReadFile(stagedPath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return fmt.Errorf("read staged sudoers: %w", err) + } + if strings.TrimSpace(string(data)) == "" { + return nil + } + if _, err := restoreCmd.Run(ctx, "visudo", "-c", "-f", stagedPath); err != nil { + logger.Warning("Skipping /etc/sudoers restore: staged sudoers failed validation (visudo -c): %v", err) + return nil + } + if err := writeFileAtomic(etcSudoersPath, data, 0o440); err != nil { + return err + } + logger.Info("Restored /etc/sudoers (validated with visudo -c)") + return nil +} + +func readCurrentAccountFile(path string) (string, error) { + data, err := restoreFS.ReadFile(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return "", nil + } + return "", err + } + return string(data), nil +} + +// mergePasswd keeps every current entry and imports each backup user that is a +// regular account (uid >= threshold, >= passwdMinFields, not NIS, not "root", and +// whose name does not collide with a host system account). Returns the imported +// names and the merged file. +func mergePasswd(current, backup string, hostSystemUsers map[string]bool, hostGIDs map[uint64]bool) (map[string]bool, string) { + lines := splitNonEmptyLines(current) + index := indexByName(lines) + imported := map[string]bool{} + for _, line := range splitNonEmptyLines(backup) { + if isNISLine(line) { + continue + } + parts := strings.Split(line, ":") + if len(parts) < passwdMinFields { + continue + } + name := parts[0] + if !isValidAccountName(name) || name == "root" || hostSystemUsers[name] { + continue + } + uid, ok := parseAccountID(parts[2]) + if !ok || uid < systemAccountIDThreshold { + continue + } + // Never import a regular user whose PRIMARY group is root (gid 0), an existing + // host group of ANY gid (system like sudo/shadow/disk OR a privileged regular + // group like docker at gid >= 1000), or an overflowed/garbage gid: a passwd + // primary gid grants that group's privileges on its own, bypassing the + // /etc/group member-merge protections. + pgid, ok := parseAccountID(parts[3]) + if !ok || pgid == 0 || hostGIDs[pgid] { + continue + } + imported[name] = true + upsert(&lines, index, name, line) + } + return imported, joinLines(lines) +} + +// mergeShadow keeps every current line and, for each imported user, sets the +// backup shadow line; if the backup lacks a (valid) line, a locked placeholder is +// written so an imported passwd user is never left without a shadow entry. +func mergeShadow(current, backup string, imported map[string]bool) string { + lines := splitNonEmptyLines(current) + index := indexByName(lines) + backupByName := byName(backup) + for _, name := range sortedKeys(imported) { + line, ok := backupByName[name] + if !ok || len(strings.Split(line, ":")) < shadowMinFields { + line = name + lockedShadowSuffix + } + upsert(&lines, index, name, line) + } + return joinLines(lines) +} + +// mergeGroup keeps every current group and NEVER overwrites an existing host group. +// - A backup group whose name is NOT on the host is imported whole, but only when +// it is a regular group (gid >= threshold) and its gid does not collide with any +// host group's gid (a collision would create two groups sharing a gid). +// - A backup group whose name IS on the host is never replaced: its host gid and +// existing members are preserved, and only the imported users are added as +// supplementary members, and only when the backup line is genuinely the same +// group (gid matches the host) and it is not the root group. This stops a backup +// from changing a host gid, dropping host members, or injecting members into a +// host privileged group (sudo/docker/...) via a spoofed or name-only line. +func mergeGroup(current, backup string, hostGroupGID map[string]uint64, hostGIDs map[uint64]bool, importedUsers map[string]bool) (map[string]bool, string) { + lines := splitNonEmptyLines(current) + index := indexByName(lines) + imported := map[string]bool{} + for _, line := range splitNonEmptyLines(backup) { + if isNISLine(line) { + continue + } + parts := strings.Split(line, ":") + if len(parts) < groupMinFields { + continue + } + name := parts[0] + if !isValidAccountName(name) { + continue + } + gid, ok := parseAccountID(parts[2]) + if !ok { + continue + } + hostGID, existsOnHost := hostGroupGID[name] + if !existsOnHost { + // Brand-new backup group: import only a regular group whose gid does not + // collide with an existing host group gid (and never root/system range). + if name == "root" || gid < systemAccountIDThreshold || hostGIDs[gid] { + continue + } + // Restrict members to users we are also importing, so a new backup group + // never silently enrolls an existing host account (mirrors the member + // filtering on the existing-host-group path below). + parts[3] = strings.Join(filterSet(groupMembers(parts), importedUsers), ",") + line = strings.Join(parts, ":") + imported[name] = true + upsert(&lines, index, name, line) + continue + } + // Existing host group: never overwrite. Merge imported members only, and only + // when the backup line is the same group (gid matches) and not the root group. + if gid == 0 || hostGID != gid { + continue + } + if i, ok := index[name]; ok { + add := filterSet(groupMembers(parts), importedUsers) + lines[i] = addGroupMembers(lines[i], add) + } + } + return imported, joinLines(lines) +} + +// mergeGshadow substitutes the backup gshadow line for each imported regular group +// (gshadow is optional/advisory; system-group member merges are reflected in +// /etc/group, which is authoritative). +func mergeGshadow(current, backup string, importedGroups map[string]bool) string { + lines := splitNonEmptyLines(current) + index := indexByName(lines) + backupByName := byName(backup) + for _, name := range sortedKeys(importedGroups) { + if line, ok := backupByName[name]; ok { + upsert(&lines, index, name, line) + } + } + return joinLines(lines) +} + +func lowIDNames(content string, idField int) map[string]bool { + out := map[string]bool{} + for _, line := range splitNonEmptyLines(content) { + if isNISLine(line) { + continue + } + parts := strings.Split(line, ":") + if idField >= len(parts) { + continue + } + id, ok := parseAccountID(parts[idField]) + if ok && id < systemAccountIDThreshold { + out[parts[0]] = true + } + } + return out +} + +func isNISLine(line string) bool { + return strings.HasPrefix(line, "+") || strings.HasPrefix(line, "-") +} + +// parseAccountID parses a uid/gid as a 32-bit unsigned value. There is no +// TrimSpace: a field with surrounding whitespace is non-canonical and we write +// the backup line verbatim, so it is rejected (ParseUint base-10 also rejects +// signs and whitespace). bitSize 32 closes the overflow vector (4294967296 ->0), +// and 0xFFFFFFFF (=(uid_t)-1, the nobody/error sentinel) is rejected explicitly. +func parseAccountID(s string) (uint64, bool) { + id, err := strconv.ParseUint(s, 10, 32) + if err != nil || id == 0xFFFFFFFF { + return 0, false + } + return id, true +} + +// isValidAccountName rejects empty/over-long names and any control byte (incl. +// NUL), DEL, whitespace, field separators (':' ',') or '/', so forged/malformed +// names are never written into the account database and cannot bypass the +// exact-match host-collision check (e.g. "daemon " with a trailing space). +func isValidAccountName(name string) bool { + if name == "" || len(name) > 32 { + return false + } + return strings.IndexFunc(name, func(r rune) bool { + return r < 0x21 || r == 0x7f || r == ':' || r == ',' || r == '/' + }) < 0 +} + +// groupGIDsByName maps each host group name to its numeric gid (used to verify a +// backup group is genuinely the same group before merging members into it). +func groupGIDsByName(content string) map[string]uint64 { + m := map[string]uint64{} + for _, line := range splitNonEmptyLines(content) { + if isNISLine(line) { + continue + } + parts := strings.Split(line, ":") + if len(parts) < 3 { + continue + } + if gid, ok := parseAccountID(parts[2]); ok { + m[parts[0]] = gid + } + } + return m +} + +// gidValueSet returns the set of gid VALUES from a name->gid map (every host group +// gid), used to block gid collisions and primary-gid enrolment into host groups. +func gidValueSet(m map[string]uint64) map[uint64]bool { + out := make(map[uint64]bool, len(m)) + for _, gid := range m { + out[gid] = true + } + return out +} + +func colonName(line string) string { + if i := strings.IndexByte(line, ':'); i >= 0 { + return line[:i] + } + return line +} + +func indexByName(lines []string) map[string]int { + m := make(map[string]int, len(lines)) + for i, l := range lines { + m[colonName(l)] = i + } + return m +} + +func byName(content string) map[string]string { + m := map[string]string{} + for _, l := range splitNonEmptyLines(content) { + m[colonName(l)] = l + } + return m +} + +func upsert(lines *[]string, index map[string]int, name, line string) { + if i, ok := index[name]; ok { + (*lines)[i] = line + return + } + index[name] = len(*lines) + *lines = append(*lines, line) +} + +func sortedKeys(m map[string]bool) []string { + out := make([]string, 0, len(m)) + for k := range m { + out = append(out, k) + } + sort.Strings(out) + return out +} + +func groupMembers(parts []string) []string { + if len(parts) < 4 { + return nil + } + field := strings.TrimSpace(parts[3]) + if field == "" { + return nil + } + return strings.Split(field, ",") +} + +func filterSet(names []string, allow map[string]bool) []string { + var out []string + for _, n := range names { + n = strings.TrimSpace(n) + if n != "" && allow[n] { + out = append(out, n) + } + } + return out +} + +// addGroupMembers adds names (not already present) to the members field (index 3) +// of a colon-separated group line, preserving order and the rest of the line. +func addGroupMembers(line string, add []string) string { + if len(add) == 0 { + return line + } + parts := strings.Split(line, ":") + for len(parts) < 4 { + parts = append(parts, "") + } + seen := map[string]bool{} + var members []string + for _, m := range strings.Split(parts[3], ",") { + m = strings.TrimSpace(m) + if m != "" && !seen[m] { + seen[m] = true + members = append(members, m) + } + } + for _, n := range add { + if !seen[n] { + seen[n] = true + members = append(members, n) + } + } + parts[3] = strings.Join(members, ",") + return strings.Join(parts, ":") +} + +func splitNonEmptyLines(content string) []string { + var out []string + for _, raw := range strings.Split(content, "\n") { + line := strings.TrimRight(raw, "\r") + if strings.TrimSpace(line) == "" { + continue + } + out = append(out, line) + } + return out +} + +func joinLines(lines []string) string { + if len(lines) == 0 { + return "" + } + return strings.Join(lines, "\n") + "\n" +} diff --git a/internal/orchestrator/restore_accounts_test.go b/internal/orchestrator/restore_accounts_test.go new file mode 100644 index 0000000..2f45938 --- /dev/null +++ b/internal/orchestrator/restore_accounts_test.go @@ -0,0 +1,620 @@ +package orchestrator + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +type fakeAccountsCmd struct{ err error } + +func (f fakeAccountsCmd) Run(_ context.Context, _ string, _ ...string) ([]byte, error) { + return nil, f.err +} + +func TestMergePasswdPreservesRootSystemAndProtectsCollisions(t *testing.T) { + // Host: superuser renamed "superadmin" (uid 0), a system service "svc" (uid 200), + // and a regular user "alice" (uid 1000). + current := "superadmin:x:0:0:root:/root:/bin/bash\n" + + "svc:x:200:200:service:/var/svc:/usr/sbin/nologin\n" + + "alice:x:1000:1000::/home/alice:/bin/bash\n" + // Backup tries to clobber the renamed root and the system account with regular + // users of the same name, imports a real user, overwrites alice, and includes a + // NIS line and a truncated line. + backup := "superadmin:x:1500:1500:EVIL:/home/x:/bin/sh\n" + + "svc:x:1600:1600:EVIL:/home/y:/bin/sh\n" + + "bob:x:1001:1001::/home/bob:/bin/bash\n" + + "alice:x:1000:1000::/home/alice:/bin/zsh\n" + + "+::::::\n" + + "zoe:x:1500\n" + + hostSystem := lowIDNames(current, 2) + imported, merged := mergePasswd(current, backup, hostSystem, map[uint64]bool{}) + + if !strings.Contains(merged, "superadmin:x:0:0:root:") || strings.Contains(merged, "EVIL") { + t.Errorf("renamed root (uid 0) must be preserved and never clobbered:\n%s", merged) + } + if !strings.Contains(merged, "svc:x:200:200:service:") { + t.Errorf("host system account 'svc' must be preserved:\n%s", merged) + } + if imported["superadmin"] || imported["svc"] { + t.Errorf("name-colliding system accounts must not be imported: %v", imported) + } + if !strings.Contains(merged, "bob:x:1001:1001:") || !imported["bob"] { + t.Errorf("regular backup user 'bob' must be imported:\n%s", merged) + } + if !strings.Contains(merged, "alice:x:1000:1000::/home/alice:/bin/zsh") { + t.Errorf("existing regular user 'alice' must be overwritten from backup:\n%s", merged) + } + if strings.Contains(merged, "zoe") { + t.Errorf("truncated passwd line (<7 fields) must be rejected:\n%s", merged) + } + if strings.Contains(merged, "+::::::") || imported["+"] { + t.Errorf("NIS compat line must be ignored:\n%s", merged) + } +} + +func TestMergePasswdRejectsOverflowEscalationAndMalformedNames(t *testing.T) { + current := "root:x:0:0:root:/root:/bin/bash\ndaemon:x:1:1:daemon:/usr/sbin:/usr/sbin/nologin\n" + hostSystem := lowIDNames(current, 2) + hostSystemGIDs := map[uint64]bool{0: true, 27: true} // root, sudo + backup := strings.Join([]string{ + "over:x:4294967296:1500::/home/over:/bin/sh", // uid overflow -> wraps to 0 on a 32-bit kernel + "sentinel:x:4294967295:1500::/h:/bin/sh", // uid == (uid_t)-1 sentinel + "gid0:x:1001:0::/home/gid0:/bin/sh", // regular user with PRIMARY group root + "psudo:x:2001:27::/h:/bin/sh", // PRIMARY gid = host sudo (escalation) + "gidover:x:1002:4294967296::/home/g:/bin/sh", // primary gid overflow + "wsp:x: 1700:1700::/h:/bin/sh", // whitespace uid field (non-canonical) + ":x:1003:1003::/home/empty:/bin/sh", // empty name + "daemon :x:1004:1004:EVIL:/root:/bin/sh", // whitespace name (collision-bypass attempt) + "ev/il:x:1900:1900::/h:/bin/sh", // '/' in name + "good:x:1005:1005::/home/good:/bin/bash", // legitimate regular user + }, "\n") + "\n" + + imported, merged := mergePasswd(current, backup, hostSystem, hostSystemGIDs) + + for _, bad := range []string{"over", "sentinel", "gid0", "psudo", "gidover", "EVIL", "ev/il", "4294967295"} { + if strings.Contains(merged, bad) { + t.Errorf("malformed/escalation entry %q must be rejected:\n%s", bad, merged) + } + } + if imported["over"] || imported["sentinel"] || imported["gid0"] || imported["psudo"] || + imported["gidover"] || imported["wsp"] || imported[""] || imported["daemon "] || imported["ev/il"] { + t.Errorf("rejected entries must not be imported: %v", imported) + } + if !imported["good"] || !strings.Contains(merged, "good:x:1005:1005:") { + t.Errorf("legitimate user 'good' must be imported:\n%s", merged) + } + if !strings.Contains(merged, "daemon:x:1:1:daemon:") { + t.Errorf("host 'daemon' must be preserved unchanged:\n%s", merged) + } +} + +func TestMergeGroupRejectsOverflowAndRootGroupMemberMerge(t *testing.T) { + current := "root:x:0:\nsudo:x:27:\n" + hostGroupGID := groupGIDsByName(current) + hostGIDs := gidValueSet(hostGroupGID) + backup := "root:x:0:bob\nover:x:4294967296:bob\nsudo:x:27:bob\nproj:x:1500:bob\n" + + importedGroups, merged := mergeGroup(current, backup, hostGroupGID, hostGIDs, map[string]bool{"bob": true}) + + if strings.Contains(merged, "root:x:0:bob") { + t.Errorf("imported user must NEVER be merged into the root group:\n%s", merged) + } + if strings.Contains(merged, "over") || importedGroups["over"] { + t.Errorf("gid-overflow group must be rejected:\n%s", merged) + } + if !strings.Contains(merged, "sudo:x:27:bob") { + t.Errorf("bob should be merged into the (non-root) sudo group:\n%s", merged) + } + if !strings.Contains(merged, "proj:x:1500:bob") || !importedGroups["proj"] { + t.Errorf("regular group 'proj' should be imported:\n%s", merged) + } +} + +func TestMergeShadowNeverLeavesImportedUserWithoutEntry(t *testing.T) { + current := "root:CURHASH:1::::::\nalice:ALICEHASH:1::::::\n" + // bob has a backup shadow line; carol is imported (passwd) but has NO backup shadow. + backup := "root:EVIL:1::::::\nbob:BOBHASH:1::::::\n" + merged := mergeShadow(current, backup, map[string]bool{"bob": true, "carol": true}) + + if !strings.Contains(merged, "root:CURHASH") || strings.Contains(merged, "EVIL") { + t.Errorf("root shadow must be preserved from host:\n%s", merged) + } + if !strings.Contains(merged, "bob:BOBHASH") { + t.Errorf("bob shadow must come from backup:\n%s", merged) + } + // carol must NOT be missing from shadow (no passwd<->shadow desync): locked placeholder. + if !strings.Contains(merged, "carol:*:::::::") { + t.Errorf("imported user 'carol' without backup shadow must get a locked placeholder, not be absent:\n%s", merged) + } +} + +func TestMergeGroupMergesSystemGroupMembersAndImportsRegular(t *testing.T) { + current := "root:x:0:\nsudo:x:27:\nalice:x:1000:\n" + // Backup adds bob (imported) and alice (not imported here) to sudo, imports a + // regular group, and references a system group the host lacks (docker). + backup := "sudo:x:27:bob,alice\nbob:x:1001:\ndocker:x:998:bob\n" + + hostGroupGID := groupGIDsByName(current) + hostGIDs := gidValueSet(hostGroupGID) + importedGroups, merged := mergeGroup(current, backup, hostGroupGID, hostGIDs, map[string]bool{"bob": true}) + + if !strings.Contains(merged, "sudo:x:27:bob") { + t.Errorf("imported user 'bob' must be merged into the host 'sudo' group, gid preserved:\n%s", merged) + } + if strings.Contains(merged, "alice") && strings.Contains(merged, "sudo:x:27:bob,alice") { + t.Errorf("non-imported user 'alice' must not be added to sudo:\n%s", merged) + } + if !strings.Contains(merged, "bob:x:1001:") || !importedGroups["bob"] { + t.Errorf("regular backup group 'bob' must be imported:\n%s", merged) + } + if strings.Contains(merged, "docker") { + t.Errorf("a system group absent on the host must not be imported (gid clash risk):\n%s", merged) + } +} + +func TestMergeGroupRejectsGidSpoofedSystemGroupInjection(t *testing.T) { + current := "root:x:0:\nsudo:x:27:\n" + hostGroupGID := groupGIDsByName(current) + hostGIDs := gidValueSet(hostGroupGID) + // Backup references the host 'sudo' group by NAME but with a spoofed gid (1234, + // not the host's 27), trying to inject an imported user into the real sudo group. + backup := "sudo:x:1234:mallory\n" + + _, merged := mergeGroup(current, backup, hostGroupGID, hostGIDs, map[string]bool{"mallory": true}) + + if strings.Contains(merged, "mallory") { + t.Errorf("gid-spoofed sudo line must NOT inject a member into the host sudo group:\n%s", merged) + } + if !strings.Contains(merged, "sudo:x:27:") { + t.Errorf("host sudo group must be preserved unchanged:\n%s", merged) + } + + // With the host's REAL gid, a legitimate imported member IS merged. + _, merged2 := mergeGroup(current, "sudo:x:27:realbob\n", hostGroupGID, hostGIDs, map[string]bool{"realbob": true}) + if !strings.Contains(merged2, "sudo:x:27:realbob") { + t.Errorf("gid-matching sudo line should merge the imported member:\n%s", merged2) + } +} + +func TestMergePasswdRejectsPrimaryGidIntoHostSystemGroup(t *testing.T) { + current := "root:x:0:0:root:/root:/bin/bash\n" + hostSystem := lowIDNames(current, 2) + // Set of ALL host group gids: system (root/sudo/shadow) AND a privileged regular + // group (docker at gid 1001). + hostGIDs := map[uint64]bool{0: true, 27: true, 42: true, 1001: true} + backup := "evil:x:1001:27::/home/evil:/bin/bash\n" + // primary gid = sudo + "reader:x:1002:42::/home/reader:/bin/bash\n" + // primary gid = shadow + "dock:x:1004:1001::/home/dock:/bin/bash\n" + // primary gid = host docker (>=1000) + "ok:x:1003:1003::/home/ok:/bin/bash\n" // private primary gid + + imported, merged := mergePasswd(current, backup, hostSystem, hostGIDs) + + if imported["evil"] || strings.Contains(merged, "evil") { + t.Errorf("user with primary gid=sudo must be rejected:\n%s", merged) + } + if imported["reader"] || strings.Contains(merged, "reader") { + t.Errorf("user with primary gid=shadow must be rejected:\n%s", merged) + } + if imported["dock"] || strings.Contains(merged, "dock:x:") { + t.Errorf("user with primary gid = a host privileged group (gid>=1000) must be rejected:\n%s", merged) + } + if !imported["ok"] || !strings.Contains(merged, "ok:x:1003:1003:") { + t.Errorf("user with a private primary gid must be imported:\n%s", merged) + } +} + +func TestMergeGroupNeverOverwritesExistingHostGroup(t *testing.T) { + // Host has a privileged regular group 'docker' at gid 1001 with a real member. + current := "root:x:0:\ndocker:x:1001:realops\n" + hostGroupGID := groupGIDsByName(current) + hostGIDs := gidValueSet(hostGroupGID) + + // Backup tries to overwrite docker: change gid to 5000 and replace members. + _, merged := mergeGroup(current, "docker:x:5000:attacker\n", hostGroupGID, hostGIDs, map[string]bool{"attacker": true}) + if !strings.Contains(merged, "docker:x:1001:realops") { + t.Errorf("existing host group 'docker' (gid+members) must be preserved, not overwritten:\n%s", merged) + } + if strings.Contains(merged, "5000") || strings.Contains(merged, "attacker") { + t.Errorf("backup must not change host group gid or inject members via a gid-mismatched line:\n%s", merged) + } + + // A brand-new backup group whose gid collides with the host docker gid is skipped. + importedGroups, merged2 := mergeGroup(current, "team:x:1001:bob\n", hostGroupGID, hostGIDs, map[string]bool{"bob": true}) + if strings.Contains(merged2, "team") || importedGroups["team"] { + t.Errorf("a new backup group colliding with an existing host gid must be skipped:\n%s", merged2) + } +} + +func TestApplyAccountsFromStageEndToEnd(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + restoreCmd = fakeAccountsCmd{err: nil} + + stage := "/stage" + w := func(p, c string, m os.FileMode) { + t.Helper() + if err := fakeFS.WriteFile(p, []byte(c), m); err != nil { + t.Fatalf("write %s: %v", p, err) + } + } + w(etcPasswdPath, "root:x:0:0:root:/root:/bin/bash\nalice:x:1000:1000::/home/alice:/bin/bash\n", 0o644) + w(etcShadowPath, "root:CURHASH:1::::::\nalice:ALICEHASH:1::::::\n", 0o640) + w(etcGroupPath, "root:x:0:\nsudo:x:27:\n", 0o644) + w(etcGshadowPath, "root:*::\n", 0o640) + w(etcSudoersPath, "root ALL=(ALL) ALL\n", 0o440) + + w(stage+"/etc/passwd", "root:x:0:0:EVIL:/root:/bin/sh\nbob:x:1001:1001::/home/bob:/bin/bash\n", 0o644) + w(stage+"/etc/shadow", "root:EVIL:1::::::\nbob:BOBHASH:1::::::\n", 0o640) + w(stage+"/etc/group", "sudo:x:27:bob\nbob:x:1001:\n", 0o644) + w(stage+"/etc/gshadow", "bob:!::\n", 0o640) + w(stage+"/etc/sudoers", "root ALL=(ALL) ALL\nbob ALL=(ALL) NOPASSWD: ALL\n", 0o440) + + if err := applyAccountsFromStage(context.Background(), newTestLogger(), stage); err != nil { + t.Fatalf("applyAccountsFromStage: %v", err) + } + + passwd := readFake(t, fakeFS, etcPasswdPath) + shadow := readFake(t, fakeFS, etcShadowPath) + group := readFake(t, fakeFS, etcGroupPath) + sudoers := readFake(t, fakeFS, etcSudoersPath) + + if !strings.Contains(passwd, "root:x:0:0:root:") || strings.Contains(passwd, "EVIL") { + t.Errorf("root preserved in passwd:\n%s", passwd) + } + if !strings.Contains(passwd, "bob:x:1001:1001:") { + t.Errorf("bob merged into passwd:\n%s", passwd) + } + if !strings.Contains(shadow, "root:CURHASH") || strings.Contains(shadow, "EVIL") { + t.Errorf("root shadow preserved:\n%s", shadow) + } + if !strings.Contains(shadow, "bob:BOBHASH") { + t.Errorf("bob shadow merged:\n%s", shadow) + } + // passwd<->shadow consistency: every passwd name must have a shadow line. + assertPasswdShadowConsistent(t, passwd, shadow) + if !strings.Contains(group, "sudo:x:27:bob") { + t.Errorf("bob added to sudo group:\n%s", group) + } + if !strings.Contains(sudoers, "bob ALL=(ALL) NOPASSWD") { + t.Errorf("validated sudoers applied exactly:\n%s", sudoers) + } +} + +func TestApplyAccountsDesyncPreventionMissingStagedShadow(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + restoreCmd = fakeAccountsCmd{} + + _ = fakeFS.WriteFile(etcPasswdPath, []byte("root:x:0:0:root:/root:/bin/bash\n"), 0o644) + _ = fakeFS.WriteFile(etcShadowPath, []byte("root:CURHASH:1::::::\n"), 0o640) + _ = fakeFS.WriteFile(etcGroupPath, []byte("root:x:0:\n"), 0o644) + _ = fakeFS.WriteFile(etcGshadowPath, []byte("root:*::\n"), 0o640) + // Stage has a new user in passwd but NO staged shadow at all. + _ = fakeFS.WriteFile("/stage/etc/passwd", []byte("bob:x:1001:1001::/home/bob:/bin/bash\n"), 0o644) + + if err := applyAccountsFromStage(context.Background(), newTestLogger(), "/stage"); err != nil { + t.Fatalf("applyAccountsFromStage: %v", err) + } + passwd := readFake(t, fakeFS, etcPasswdPath) + shadow := readFake(t, fakeFS, etcShadowPath) + if !strings.Contains(passwd, "bob:x:1001:") { + t.Fatalf("bob should be imported:\n%s", passwd) + } + if !strings.Contains(shadow, "bob:*:::::::") { + t.Errorf("bob must have a locked shadow placeholder (no desync), got shadow:\n%s", shadow) + } + assertPasswdShadowConsistent(t, passwd, shadow) +} + +func TestApplyAccountsSkipsWhenCurrentPasswdEmpty(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + _ = fakeFS.WriteFile(etcPasswdPath, []byte("\n"), 0o644) // empty/unreadable host baseline + _ = fakeFS.WriteFile("/stage/etc/passwd", []byte("bob:x:1001:1001::/home/bob:/bin/bash\n"), 0o644) + + if err := applyAccountsFromStage(context.Background(), newTestLogger(), "/stage"); err != nil { + t.Fatalf("applyAccountsFromStage: %v", err) + } + if got := readFake(t, fakeFS, etcPasswdPath); strings.Contains(got, "bob") { + t.Errorf("must not write accounts when current /etc/passwd is empty (anti-lockout), got:\n%s", got) + } +} + +func TestApplyAccountsSkipsWhenCurrentGroupEmpty(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + _ = fakeFS.WriteFile(etcPasswdPath, []byte("root:x:0:0:root:/root:/bin/bash\n"), 0o644) + _ = fakeFS.WriteFile(etcGroupPath, []byte("\n"), 0o644) // empty/unreadable host group baseline + // Non-empty shadow so the empty-group guard is the ONLY thing preventing a rewrite + // (otherwise the empty-shadow guard would mask this case and weaken the anchor). + _ = fakeFS.WriteFile(etcShadowPath, []byte("root:CURHASH:1::::::\n"), 0o640) + _ = fakeFS.WriteFile("/stage/etc/passwd", []byte("bob:x:1001:1001::/home/bob:/bin/bash\n"), 0o644) + _ = fakeFS.WriteFile("/stage/etc/group", []byte("team:x:1001:bob\n"), 0o644) + + if err := applyAccountsFromStage(context.Background(), newTestLogger(), "/stage"); err != nil { + t.Fatalf("applyAccountsFromStage: %v", err) + } + if got := readFake(t, fakeFS, etcPasswdPath); strings.Contains(got, "bob") { + t.Errorf("must not rewrite accounts when current /etc/group is empty (anti-lockout), passwd:\n%s", got) + } +} + +func TestApplyAccountsSkipsWhenCurrentShadowEmpty(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + _ = fakeFS.WriteFile(etcPasswdPath, []byte("root:x:0:0:root:/root:/bin/bash\n"), 0o644) + _ = fakeFS.WriteFile(etcGroupPath, []byte("root:x:0:\n"), 0o644) + _ = fakeFS.WriteFile(etcShadowPath, []byte("\n"), 0o640) // empty/unreadable host shadow baseline + _ = fakeFS.WriteFile("/stage/etc/passwd", []byte("bob:x:1001:1001::/home/bob:/bin/bash\n"), 0o644) + + if err := applyAccountsFromStage(context.Background(), newTestLogger(), "/stage"); err != nil { + t.Fatalf("applyAccountsFromStage: %v", err) + } + if got := readFake(t, fakeFS, etcShadowPath); strings.Contains(got, "bob") { + t.Errorf("must not rewrite accounts when current /etc/shadow is empty (anti-lockout), shadow:\n%s", got) + } +} + +// TestMergeGroupNewGroupDropsNonImportedMembers checks that a brand-new backup +// group does not silently enroll an existing host account: its member list is +// restricted to users actually being imported (mirrors the existing-host-group +// member filtering). +func TestMergeGroupNewGroupDropsNonImportedMembers(t *testing.T) { + current := "root:x:0:\nalice:x:1000:\n" // alice is a host user, NOT being imported + hostGroupGID := groupGIDsByName(current) + hostGIDs := gidValueSet(hostGroupGID) + + importedGroups, merged := mergeGroup(current, "team:x:3000:bob,alice\n", hostGroupGID, hostGIDs, map[string]bool{"bob": true}) + if !importedGroups["team"] { + t.Fatalf("brand-new regular group 'team' should be imported:\n%s", merged) + } + var teamMembers string + for _, line := range strings.Split(merged, "\n") { + if strings.HasPrefix(line, "team:") { + if f := strings.Split(line, ":"); len(f) >= 4 { + teamMembers = f[3] + } + } + } + if teamMembers != "bob" { + t.Errorf("new backup group must keep only the imported member 'bob', got members %q (host user 'alice' must not be enrolled):\n%s", teamMembers, merged) + } +} + +func TestApplySudoersSkipsOnVisudoFailure(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + restoreCmd = fakeAccountsCmd{err: os.ErrInvalid} + + const current = "root ALL=(ALL) ALL\n" + _ = fakeFS.WriteFile(etcSudoersPath, []byte(current), 0o440) + _ = fakeFS.WriteFile("/stage/etc/sudoers", []byte("garbage !!! invalid\n"), 0o440) + + if err := applySudoersFromStage(context.Background(), newTestLogger(), "/stage"); err != nil { + t.Fatalf("applySudoersFromStage should not error on invalid sudoers: %v", err) + } + if got := readFake(t, fakeFS, etcSudoersPath); got != current { + t.Errorf("current /etc/sudoers must be kept when staged fails visudo:\n%s", got) + } +} + +func TestMaybeApplyAccountsFromStageGates(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + restoreCmd = fakeAccountsCmd{} + + const baseline = "root:x:0:0:root:/root:/bin/bash\n" + reset := func() { + _ = fakeFS.WriteFile(etcPasswdPath, []byte(baseline), 0o644) + _ = fakeFS.WriteFile("/stage/etc/passwd", []byte("bob:x:1001:1001::/home/bob:/bin/bash\n"), 0o644) + } + withAccounts := &RestorePlan{StagedCategories: []Category{{ID: "accounts"}}} + notWritten := func(label string) { + t.Helper() + if got := readFake(t, fakeFS, etcPasswdPath); strings.Contains(got, "bob") { + t.Errorf("%s: must not apply accounts, got passwd:\n%s", label, got) + } + } + + reset() + if err := maybeApplyAccountsFromStage(context.Background(), newTestLogger(), nil, "/stage", false); err != nil { + t.Fatalf("nil plan: %v", err) + } + notWritten("nil plan") + + reset() + if err := maybeApplyAccountsFromStage(context.Background(), newTestLogger(), &RestorePlan{}, "/stage", false); err != nil { + t.Fatalf("no accounts category: %v", err) + } + notWritten("no accounts category") + + reset() + if err := maybeApplyAccountsFromStage(context.Background(), newTestLogger(), withAccounts, "/stage", true); err != nil { + t.Fatalf("dryRun: %v", err) + } + notWritten("dryRun") + + reset() + // FakeFS is not a real system FS -> isRealRestoreFS gate must skip the apply. + if err := maybeApplyAccountsFromStage(context.Background(), newTestLogger(), withAccounts, "/stage", false); err != nil { + t.Fatalf("non-real FS: %v", err) + } + notWritten("non-real FS") +} + +func assertPasswdShadowConsistent(t *testing.T, passwd, shadow string) { + t.Helper() + shadowNames := map[string]bool{} + for _, l := range splitNonEmptyLines(shadow) { + shadowNames[colonName(l)] = true + } + for _, l := range splitNonEmptyLines(passwd) { + if name := colonName(l); !shadowNames[name] { + t.Errorf("passwd<->shadow desync: user %q in passwd has no shadow line", name) + } + } +} + +// seedAccountFiles writes the four host auth-DB files (+ a staged set that imports +// bob) into fakeFS and pins the clock so the per-file temp names are predictable. +// Returns the four original contents for byte-for-byte rollback assertions. +func seedAccountFiles(t *testing.T, fakeFS *FakeFS) (passwd0, shadow0, group0, gshadow0 string) { + t.Helper() + passwd0 = "root:x:0:0:root:/root:/bin/bash\n" + shadow0 = "root:CURHASH:1::::::\n" + group0 = "root:x:0:\nsudo:x:27:\n" + gshadow0 = "root:*::\n" + _ = fakeFS.WriteFile(etcPasswdPath, []byte(passwd0), 0o644) + _ = fakeFS.WriteFile(etcShadowPath, []byte(shadow0), 0o640) + _ = fakeFS.WriteFile(etcGroupPath, []byte(group0), 0o644) + _ = fakeFS.WriteFile(etcGshadowPath, []byte(gshadow0), 0o640) + _ = fakeFS.WriteFile("/stage/etc/passwd", []byte("bob:x:1001:1001::/home/bob:/bin/bash\n"), 0o644) + _ = fakeFS.WriteFile("/stage/etc/shadow", []byte("bob:BOBHASH:1::::::\n"), 0o640) + _ = fakeFS.WriteFile("/stage/etc/group", []byte("bob:x:1001:\n"), 0o644) + return passwd0, shadow0, group0, gshadow0 +} + +func accountTempPath(t *testing.T, path string) string { + t.Helper() + return fmt.Sprintf("%s.proxsave.tmp.%d", filepath.Clean(path), nowRestore().UnixNano()) +} + +// TestApplyAccountsPreparePhaseFailureLeavesAllOriginals: a failure while preparing a +// temp file (the common disk-full / read-only / IO case) must leave ALL four live +// auth-DB files untouched (no partial commit) and leave no temp behind. +func TestApplyAccountsPreparePhaseFailureLeavesAllOriginals(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + restoreCmd = fakeAccountsCmd{} + origTime := restoreTime + t.Cleanup(func() { restoreTime = origTime }) + restoreTime = &FakeTime{Current: time.Unix(10, 0)} + + passwd0, shadow0, group0, gshadow0 := seedAccountFiles(t, fakeFS) + + // Fail the THIRD file's (group) temp creation. With a pinned clock the four temps + // share the same nanosecond suffix but differ by path prefix, so this hits group. + fakeFS.OpenFileErr[filepath.Clean(accountTempPath(t, etcGroupPath))] = errors.New("forced temp-create failure") + + if err := applyAccountsFromStage(context.Background(), newTestLogger(), "/stage"); err == nil { + t.Fatal("expected an error when a prepare-phase temp creation fails") + } + + if got := readFake(t, fakeFS, etcPasswdPath); got != passwd0 { + t.Errorf("passwd must be unchanged after a prepare-phase failure, got:\n%s", got) + } + if got := readFake(t, fakeFS, etcShadowPath); got != shadow0 { + t.Errorf("shadow must be unchanged after a prepare-phase failure, got:\n%s", got) + } + if got := readFake(t, fakeFS, etcGroupPath); got != group0 { + t.Errorf("group must be unchanged after a prepare-phase failure, got:\n%s", got) + } + if got := readFake(t, fakeFS, etcGshadowPath); got != gshadow0 { + t.Errorf("gshadow must be unchanged after a prepare-phase failure, got:\n%s", got) + } + // The temps prepared before the failure (passwd, shadow) must have been cleaned up. + for _, p := range []string{etcPasswdPath, etcShadowPath, etcGshadowPath} { + if _, err := fakeFS.Stat(accountTempPath(t, p)); !os.IsNotExist(err) { + t.Errorf("temp for %s should not remain after prepare-phase failure (stat err=%v)", p, err) + } + } +} + +// TestApplyAccountsCommitRollbackRestoresCommittedFiles: if a rename fails partway +// through the commit phase, the files already committed must be rolled back to their +// originals so the auth-DB set stays consistent. +func TestApplyAccountsCommitRollbackRestoresCommittedFiles(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + restoreCmd = fakeAccountsCmd{} + origTime := restoreTime + t.Cleanup(func() { restoreTime = origTime }) + restoreTime = &FakeTime{Current: time.Unix(10, 0)} + + passwd0, shadow0, group0, gshadow0 := seedAccountFiles(t, fakeFS) + + // Fail the group rename (index 2): passwd+shadow commit first, then group's rename + // fails -> the two committed files must be rolled back to their originals. + fakeFS.RenameErr[filepath.Clean(accountTempPath(t, etcGroupPath))] = errors.New("forced rename failure") + + if err := applyAccountsFromStage(context.Background(), newTestLogger(), "/stage"); err == nil { + t.Fatal("expected an error when a commit-phase rename fails") + } + + if got := readFake(t, fakeFS, etcPasswdPath); got != passwd0 { + t.Errorf("passwd must be rolled back to its original after a commit failure, got:\n%s", got) + } + if got := readFake(t, fakeFS, etcShadowPath); got != shadow0 { + t.Errorf("shadow must be rolled back to its original after a commit failure, got:\n%s", got) + } + if got := readFake(t, fakeFS, etcGroupPath); got != group0 { + t.Errorf("group (failed commit) must be untouched, got:\n%s", got) + } + if got := readFake(t, fakeFS, etcGshadowPath); got != gshadow0 { + t.Errorf("gshadow (never committed) must be untouched, got:\n%s", got) + } + for _, p := range []string{etcPasswdPath, etcShadowPath, etcGroupPath, etcGshadowPath} { + if _, err := fakeFS.Stat(accountTempPath(t, p)); !os.IsNotExist(err) { + t.Errorf("temp for %s should not remain after commit rollback (stat err=%v)", p, err) + } + } +} + +func readFake(t *testing.T, fs *FakeFS, path string) string { + t.Helper() + data, err := fs.ReadFile(path) + if err != nil { + t.Fatalf("read %s: %v", path, err) + } + return string(data) +} diff --git a/internal/orchestrator/restore_archive_extract.go b/internal/orchestrator/restore_archive_extract.go index e5fb043..c90d60e 100644 --- a/internal/orchestrator/restore_archive_extract.go +++ b/internal/orchestrator/restore_archive_extract.go @@ -4,11 +4,15 @@ package orchestrator import ( "archive/tar" "context" + "encoding/json" "fmt" "io" "os" + "path" "path/filepath" + "strings" + "github.com/tis24dev/proxsave/internal/backup" "github.com/tis24dev/proxsave/internal/logging" ) @@ -68,6 +72,20 @@ func extractArchiveNative(ctx context.Context, opts restoreArchiveOptions) (err extractionLog.writeSummary(stats) logRestoreExtractionSummary(opts, stats) + // Turn deduplicated symlinks back into regular files by rebuilding them from the + // archive, so selective restore never leaves a dangling link and full restore + // preserves the original file type (issue #70). Safe on every extraction: it + // never deletes and is a no-op when no dedup manifest is present. + if err := materializeDedupSymlinks(ctx, opts.archivePath, opts.destRoot, opts.logger); err != nil { + // On the staged path (failOnPartialExtraction) an incompletely reconstructed + // dedup tree must not be applied to the live system; elsewhere it is a + // recoverable warning and the (kept) manifest lets a re-run finish. + if opts.failOnPartialExtraction { + return err + } + opts.logger.Warning("Dedup materialization incomplete: %v", err) + } + // When the caller cannot safely act on a partial result (the staged restore // path, which would otherwise apply an incomplete tree of PVE/PBS/network/ // secret config to the live system), surface per-entry extraction failures as @@ -152,6 +170,17 @@ func processRestoreArchiveEntries(ctx context.Context, tarReader *tar.Reader, op if skipRestoreArchiveEntry(header, opts, selectiveMode, extractionLog, &stats) { continue } + // A hardlink aliases an existing on-disk file; in a selective restore its + // target must belong to a selected category. A cross-category hardlink + // (e.g. an in-category name aliasing /etc/shadow) is never legitimate, so + // refuse it. Symlinks are intentionally NOT constrained this way: their + // targets legitimately point outside the category. + if selectiveMode && header.Typeflag == tar.TypeLink && !restoreEntryMatchesCategories(header.Linkname, opts.categories) { + opts.logger.Warning("Refusing hardlink %s: target %s is outside the selected categories", header.Name, header.Linkname) + stats.filesFailed++ + extractionLog.recordSkipped(header.Name, "hardlink target outside selected categories") + continue + } if err := extractTarEntry(tarReader, header, opts.destRoot, opts.logger); err != nil { opts.logger.Warning("Failed to extract %s: %v", header.Name, err) stats.filesFailed++ @@ -167,7 +196,17 @@ func processRestoreArchiveEntries(ctx context.Context, tarReader *tar.Reader, op return stats, nil } +func isDedupManifestEntry(name string) bool { + clean := strings.TrimPrefix(filepath.ToSlash(filepath.Clean(name)), "/") + return clean == backup.DedupManifestRelPath +} + func skipRestoreArchiveEntry(header *tar.Header, opts restoreArchiveOptions, selectiveMode bool, extractionLog *restoreExtractionLog, stats *restoreExtractionStats) bool { + // The dedup manifest is always extracted, regardless of selected categories, so + // the post-extraction pass can materialize deduplicated symlinks (issue #70). + if isDedupManifestEntry(header.Name) { + return false + } if opts.skipFn != nil && opts.skipFn(header.Name) { stats.filesSkipped++ extractionLog.recordSkipped(header.Name, "skipped by restore policy") @@ -259,3 +298,208 @@ func logRestoreExtractionSummary(opts restoreArchiveOptions, stats restoreExtrac opts.logger.Info("Detailed restore log: %s", opts.logFilePath) } } + +type materializeTarget struct { + path string // absolute duplicate path under destRoot (currently a symlink) + mode os.FileMode +} + +// materializeDedupSymlinks reads the dedup manifest written at backup time and +// replaces each recorded symlink with a regular file rebuilt from the BACKUP ARCHIVE +// content. Reading the canonical bytes from the archive (never from the possibly +// stale on-disk/live target, and never deleting the symlink) is what makes a +// selective/staged restore safe: a selected duplicate is reconstructed even when its +// dedup canonical's category was not selected or its on-disk copy failed to extract, +// and it never picks up stale live content (issue #70). It is a no-op when no +// manifest is present (deduplication was off or found no duplicates). +func materializeDedupSymlinks(ctx context.Context, archivePath, destRoot string, logger *logging.Logger) error { + manifestTarget, _, err := sanitizeRestoreEntryTargetWithFS(restoreFS, destRoot, backup.DedupManifestRelPath) + if err != nil { + return nil + } + data, err := restoreFS.ReadFile(manifestTarget) + if err != nil { + return nil // no dedup manifest: nothing to materialize + } + var entries []backup.DedupManifestEntry + if err := json.Unmarshal(data, &entries); err != nil { + // Corrupt manifest: nothing can be materialized, but do not leave the garbage + // (force-extracted under var/lib/proxsave-info) lingering on the restored system. + logger.Warning("Dedup manifest unreadable; skipping symlink materialization: %v", err) + removeDedupManifest(manifestTarget) + return nil + } + + // Map each canonical archive path to the extracted duplicate symlinks that need + // its content. Only duplicates actually present on disk are considered. + needByCanonical := map[string][]materializeTarget{} + for _, entry := range entries { + if strings.TrimSpace(entry.Path) == "" { + continue + } + target, _, err := sanitizeRestoreEntryTargetWithFS(restoreFS, destRoot, entry.Path) + if err != nil { + continue + } + info, err := restoreFS.Lstat(target) + if err != nil { + continue // duplicate not extracted: its own category was not selected + } + if info.Mode()&os.ModeSymlink == 0 { + continue // already a regular file (the target was extracted too) + } + linkTarget, err := restoreFS.Readlink(target) + if err != nil { + continue + } + // Canonical archive-relative path, derived purely lexically as the inverse of + // the relative link replaceWithSymlink wrote at backup time. + canonicalRel := dedupCleanArchivePath(path.Join(path.Dir(filepath.ToSlash(entry.Path)), filepath.ToSlash(linkTarget))) + needByCanonical[canonicalRel] = append(needByCanonical[canonicalRel], materializeTarget{ + path: target, + mode: os.FileMode(entry.Mode).Perm(), + }) + } + + if len(needByCanonical) > 0 { + materialized, missing, completed := materializeFromArchive(ctx, archivePath, needByCanonical, logger) + if !completed { + // The archive scan was cut short (context canceled / open or read error): + // keep the manifest so a re-run can finish, rather than dropping it and + // stranding un-materialized symlinks with no way to recover. Surface it so a + // staged restore that cannot tolerate a partial result fails closed instead + // of applying an incompletely reconstructed tree (BH-002). + logger.Warning("Dedup: materialization did not complete (%d rebuilt so far); keeping the manifest for a retry", materialized) + return fmt.Errorf("dedup materialization incomplete: %d file(s) rebuilt before the archive scan stopped; manifest kept for retry", materialized) + } + if materialized > 0 || missing > 0 { + logger.Info("Dedup: materialized %d deduplicated file(s) from the archive; %d left as link(s) due to missing canonical content", materialized, missing) + } + } + + // Drop the manifest (and the now-empty proxsave-info dir we may have force-created) + // so it does not linger on the restored system. + removeDedupManifest(manifestTarget) + return nil +} + +// removeDedupManifest deletes the materialized-then-consumed dedup manifest and, if +// force-extraction created an otherwise-empty var/lib/proxsave-info directory on the +// destination, removes that too (Remove on a non-empty dir fails and is a no-op). +func removeDedupManifest(manifestTarget string) { + _ = restoreFS.Remove(manifestTarget) + _ = restoreFS.Remove(filepath.Dir(manifestTarget)) +} + +// dedupCleanArchivePath normalizes a name to the archive-relative slash form used +// for manifest/target matching (no leading "./" or "/"). +func dedupCleanArchivePath(name string) string { + return strings.TrimPrefix(path.Clean(filepath.ToSlash(name)), "/") +} + +// materializeFromArchive streams the (already decrypted) archive once and rebuilds +// each pending duplicate from its canonical's bytes, reading one canonical at a time +// (bounded memory). A duplicate whose canonical is absent from the archive is left +// as a symlink (never deleted). It returns how many were materialized, how many were +// left as links (canonical genuinely missing from the archive), and whether the scan +// ran to completion. completed is false when the archive could not be opened/read or +// the scan was canceled mid-way; the caller then keeps the manifest for a retry +// instead of dropping it and stranding un-materialized symlinks. +func materializeFromArchive(ctx context.Context, archivePath string, needByCanonical map[string][]materializeTarget, logger *logging.Logger) (materialized, missing int, completed bool) { + file, err := restoreFS.Open(archivePath) + if err != nil { + logger.Warning("Dedup: could not open the archive to rebuild deduplicated files; left as links: %v", err) + return 0, 0, false + } + defer func() { _ = file.Close() }() + reader, err := createDecompressionReader(ctx, file, archivePath) + if err != nil { + logger.Warning("Dedup: could not read the archive to rebuild deduplicated files; left as links: %v", err) + return 0, 0, false + } + defer func() { _ = reader.Close() }() + + found := map[string]bool{} + writeOK := true + tr := tar.NewReader(reader) + for len(found) < len(needByCanonical) { + if err := ctx.Err(); err != nil { + logger.Warning("Dedup: archive scan canceled while rebuilding deduplicated files: %v", err) + return materialized, 0, false + } + header, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + logger.Warning("Dedup: error reading the archive while rebuilding deduplicated files: %v", err) + return materialized, 0, false + } + if header.Typeflag != tar.TypeReg { + continue + } + name := dedupCleanArchivePath(header.Name) + dups, ok := needByCanonical[name] + if !ok { + continue + } + content, err := io.ReadAll(tr) + if err != nil { + logger.Warning("Dedup: failed to read canonical %q from the archive: %v", name, err) + continue // leave its duplicates as links (counted as missing below) + } + found[name] = true + for _, d := range dups { + if werr := writeMaterializedFile(d.path, content, d.mode); werr != nil { + logger.Warning("Dedup: failed to materialize %s from archive: %v", name, werr) + writeOK = false // a transient write failure: keep the manifest for a retry + continue + } + materialized++ + } + } + + for name, dups := range needByCanonical { + if !found[name] { + logger.Warning("Dedup: canonical %q is missing from the archive; %d file(s) left as symlink(s)", name, len(dups)) + missing += len(dups) + } + } + // completed=false on a write failure so the caller keeps the manifest and a re-run + // can finish the still-symlinked duplicate(s); a genuinely missing canonical + // (corrupt backup, not retryable) does not block manifest cleanup. + return materialized, missing, writeOK +} + +// writeMaterializedFile atomically replaces a path (typically a dedup symlink) with +// a regular file holding content, via a sibling temp + rename so a crash never +// leaves the path missing. +func writeMaterializedFile(target string, content []byte, mode os.FileMode) error { + if mode == 0 { + mode = 0o600 + } + tmp, err := restoreFS.CreateTemp(filepath.Dir(target), restoreTempPattern) + if err != nil { + return fmt.Errorf("create temp: %w", err) + } + tmpPath := tmp.Name() + if _, err := tmp.Write(content); err != nil { + _ = tmp.Close() + _ = restoreFS.Remove(tmpPath) + return fmt.Errorf("write temp: %w", err) + } + if err := atomicFileChmod(tmp, mode.Perm()); err != nil { + _ = tmp.Close() + _ = restoreFS.Remove(tmpPath) + return fmt.Errorf("chmod temp: %w", err) + } + if err := tmp.Close(); err != nil { + _ = restoreFS.Remove(tmpPath) + return fmt.Errorf("close temp: %w", err) + } + if err := restoreFS.Rename(tmpPath, target); err != nil { + _ = restoreFS.Remove(tmpPath) + return fmt.Errorf("replace symlink with file: %w", err) + } + return nil +} diff --git a/internal/orchestrator/restore_dedup_materialize_test.go b/internal/orchestrator/restore_dedup_materialize_test.go new file mode 100644 index 0000000..f842a30 --- /dev/null +++ b/internal/orchestrator/restore_dedup_materialize_test.go @@ -0,0 +1,267 @@ +package orchestrator + +import ( + "archive/tar" + "context" + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/tis24dev/proxsave/internal/backup" + "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/types" +) + +func TestIsDedupManifestEntry(t *testing.T) { + cases := map[string]bool{ + "./" + backup.DedupManifestRelPath: true, + backup.DedupManifestRelPath: true, + "/" + backup.DedupManifestRelPath: true, + "etc/pve/user.cfg": false, + "var/lib/proxsave-info/commands/pve/pve_users.json": false, + } + for in, want := range cases { + if got := isDedupManifestEntry(in); got != want { + t.Errorf("isDedupManifestEntry(%q) = %v, want %v", in, got, want) + } + } +} + +func TestSkipRestoreArchiveEntryAlwaysExtractsDedupManifest(t *testing.T) { + hdr := &tar.Header{Name: "./" + backup.DedupManifestRelPath} + // Selective restore with a category the manifest path does not match: the dedup + // manifest is force-extracted regardless so materialization can run (issue #70). + opts := restoreArchiveOptions{categories: []Category{{ID: "pve_access_control"}}} + var stats restoreExtractionStats + if skipRestoreArchiveEntry(hdr, opts, true, &restoreExtractionLog{}, &stats) { + t.Fatal("dedup manifest must never be skipped, even in selective restore") + } +} + +func TestMaterializeDedupSymlinksFullRestore(t *testing.T) { + origFS := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = origFS }) + + root := t.TempDir() + archive := writeTarArchiveForTest(t, root, map[string]string{"a/one.cfg": "payload"}) + + destRoot := t.TempDir() + if err := os.MkdirAll(filepath.Join(destRoot, "a"), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(destRoot, "a", "one.cfg"), []byte("payload"), 0o640); err != nil { + t.Fatal(err) + } + if err := os.Symlink("one.cfg", filepath.Join(destRoot, "a", "two.cfg")); err != nil { + t.Fatal(err) + } + writeDedupManifestForTest(t, destRoot, []backup.DedupManifestEntry{{Path: "a/two.cfg", Mode: 0o640}}) + + materializeDedupSymlinks(context.Background(), archive, destRoot, logging.New(types.LogLevelError, false)) + + two := filepath.Join(destRoot, "a", "two.cfg") + info, err := os.Lstat(two) + if err != nil { + t.Fatalf("lstat materialized file: %v", err) + } + if info.Mode()&os.ModeSymlink != 0 { + t.Fatal("expected a/two.cfg to be a regular file after materialization, got symlink") + } + if info.Mode().Perm() != 0o640 { + t.Fatalf("expected mode 0640, got %o", info.Mode().Perm()) + } + if data, err := os.ReadFile(two); err != nil || string(data) != "payload" { + t.Fatalf("materialized content mismatch: %q err=%v", data, err) + } + if _, err := os.Stat(filepath.Join(destRoot, filepath.FromSlash(backup.DedupManifestRelPath))); !os.IsNotExist(err) { + t.Fatalf("dedup manifest should be removed after materialization, stat err=%v", err) + } +} + +// TestMaterializeDedupCrossCategoryRebuildsFromArchive is the #70 regression guard: +// a selective restore that selects the symlinked DUPLICATE but not its canonical +// TARGET's category must rebuild the duplicate from the archive content, NOT delete +// the user-selected file. +func TestMaterializeDedupCrossCategoryRebuildsFromArchive(t *testing.T) { + origFS := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = origFS }) + + root := t.TempDir() + // The archive holds the canonical a/one.cfg (category A, NOT selected). + archive := writeTarArchiveForTest(t, root, map[string]string{"a/one.cfg": "payload"}) + + destRoot := t.TempDir() + // Simulate selective extraction of ONLY category B: the duplicate symlink exists, + // but its canonical a/one.cfg was not extracted. + if err := os.MkdirAll(filepath.Join(destRoot, "b"), 0o755); err != nil { + t.Fatal(err) + } + if err := os.Symlink("../a/one.cfg", filepath.Join(destRoot, "b", "two.cfg")); err != nil { + t.Fatal(err) + } + writeDedupManifestForTest(t, destRoot, []backup.DedupManifestEntry{{Path: "b/two.cfg", Mode: 0o640}}) + + materializeDedupSymlinks(context.Background(), archive, destRoot, logging.New(types.LogLevelError, false)) + + two := filepath.Join(destRoot, "b", "two.cfg") + info, err := os.Lstat(two) + if err != nil { + t.Fatalf("selected duplicate must NOT be deleted on cross-category selective restore: %v", err) + } + if info.Mode()&os.ModeSymlink != 0 { + t.Fatal("expected b/two.cfg rebuilt as a regular file from the archive, got symlink") + } + if data, err := os.ReadFile(two); err != nil || string(data) != "payload" { + t.Fatalf("rebuilt content mismatch: %q err=%v", data, err) + } +} + +// TestMaterializeDedupReturnsErrorWhenIncomplete is the #8 guard: when the archive +// scan cannot complete (here a canceled context with a duplicate still to rebuild), +// materializeDedupSymlinks must RETURN an error so a staged restore that cannot +// tolerate a partial result fails closed instead of applying it, and the manifest is +// kept for a retry. +func TestMaterializeDedupReturnsErrorWhenIncomplete(t *testing.T) { + origFS := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = origFS }) + + root := t.TempDir() + archive := writeTarArchiveForTest(t, root, map[string]string{"a/one.cfg": "payload"}) + + destRoot := t.TempDir() + if err := os.MkdirAll(filepath.Join(destRoot, "b"), 0o755); err != nil { + t.Fatal(err) + } + if err := os.Symlink("../a/one.cfg", filepath.Join(destRoot, "b", "two.cfg")); err != nil { + t.Fatal(err) + } + writeDedupManifestForTest(t, destRoot, []backup.DedupManifestEntry{{Path: "b/two.cfg", Mode: 0o640}}) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // canceled before the archive scan: materialization cannot complete + + if err := materializeDedupSymlinks(ctx, archive, destRoot, logging.New(types.LogLevelError, false)); err == nil { + t.Fatal("materializeDedupSymlinks must return an error when materialization is incomplete (canceled scan)") + } + // The manifest must be kept (not removed) so a re-run can finish. + if _, statErr := os.Stat(filepath.Join(destRoot, filepath.FromSlash(backup.DedupManifestRelPath))); statErr != nil { + t.Fatalf("dedup manifest should be kept for retry on incomplete materialization, stat err=%v", statErr) + } +} + +// TestMaterializeDedupMissingCanonicalKeepsSymlink: if the canonical is genuinely +// absent from the archive (corrupt backup), the symlink is kept (no deletion of the +// user-selected file). +func TestMaterializeDedupMissingCanonicalKeepsSymlink(t *testing.T) { + origFS := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = origFS }) + + root := t.TempDir() + archive := writeTarArchiveForTest(t, root, map[string]string{"unrelated.cfg": "x"}) + + destRoot := t.TempDir() + if err := os.MkdirAll(filepath.Join(destRoot, "b"), 0o755); err != nil { + t.Fatal(err) + } + if err := os.Symlink("../a/one.cfg", filepath.Join(destRoot, "b", "two.cfg")); err != nil { + t.Fatal(err) + } + writeDedupManifestForTest(t, destRoot, []backup.DedupManifestEntry{{Path: "b/two.cfg", Mode: 0o640}}) + + materializeDedupSymlinks(context.Background(), archive, destRoot, logging.New(types.LogLevelError, false)) + + info, err := os.Lstat(filepath.Join(destRoot, "b", "two.cfg")) + if err != nil { + t.Fatalf("symlink must be kept when the canonical is missing from the archive, not deleted: %v", err) + } + if info.Mode()&os.ModeSymlink == 0 { + t.Fatal("expected the symlink to remain when the canonical is unavailable, got a regular file") + } +} + +// TestMaterializeDedupUsesArchiveNotStaleDisk guards the HIGH fast-path defect: even +// when a (stale) file exists on disk at the canonical path, the duplicate must be +// rebuilt from the ARCHIVE bytes, never from the on-disk/live content. +func TestMaterializeDedupUsesArchiveNotStaleDisk(t *testing.T) { + origFS := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = origFS }) + + root := t.TempDir() + archive := writeTarArchiveForTest(t, root, map[string]string{"a/one.cfg": "FRESH-from-archive"}) + + destRoot := t.TempDir() + if err := os.MkdirAll(filepath.Join(destRoot, "a"), 0o755); err != nil { + t.Fatal(err) + } + // A STALE canonical exists on disk (e.g. the live pre-restore file, or a failed + // extraction left the old bytes). + if err := os.WriteFile(filepath.Join(destRoot, "a", "one.cfg"), []byte("STALE-on-disk"), 0o640); err != nil { + t.Fatal(err) + } + if err := os.Symlink("one.cfg", filepath.Join(destRoot, "a", "two.cfg")); err != nil { + t.Fatal(err) + } + writeDedupManifestForTest(t, destRoot, []backup.DedupManifestEntry{{Path: "a/two.cfg", Mode: 0o640}}) + + materializeDedupSymlinks(context.Background(), archive, destRoot, logging.New(types.LogLevelError, false)) + + data, err := os.ReadFile(filepath.Join(destRoot, "a", "two.cfg")) + if err != nil { + t.Fatalf("read materialized: %v", err) + } + if string(data) != "FRESH-from-archive" { + t.Fatalf("duplicate must be rebuilt from the archive, not from stale disk content: got %q", data) + } +} + +func writeDedupManifestForTest(t *testing.T, destRoot string, entries []backup.DedupManifestEntry) { + t.Helper() + data, err := json.Marshal(entries) + if err != nil { + t.Fatal(err) + } + dest := filepath.Join(destRoot, filepath.FromSlash(backup.DedupManifestRelPath)) + if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(dest, data, 0o600); err != nil { + t.Fatal(err) + } +} + +// writeTarArchiveForTest writes a plain (uncompressed) .tar with the given regular +// files (archive-relative paths -> content) and returns its path. +func writeTarArchiveForTest(t *testing.T, dir string, files map[string]string) string { + t.Helper() + p := filepath.Join(dir, "backup.tar") + f, err := os.Create(p) + if err != nil { + t.Fatal(err) + } + defer func() { _ = f.Close() }() + tw := tar.NewWriter(f) + for name, content := range files { + hdr := &tar.Header{ + Name: "./" + name, + Mode: 0o640, + Size: int64(len(content)), + Typeflag: tar.TypeReg, + } + if err := tw.WriteHeader(hdr); err != nil { + t.Fatal(err) + } + if _, err := tw.Write([]byte(content)); err != nil { + t.Fatal(err) + } + } + if err := tw.Close(); err != nil { + t.Fatal(err) + } + return p +} diff --git a/internal/orchestrator/restore_firewall.go b/internal/orchestrator/restore_firewall.go index 0c2c689..6f44791 100644 --- a/internal/orchestrator/restore_firewall.go +++ b/internal/orchestrator/restore_firewall.go @@ -245,12 +245,17 @@ func maybeApplyPVEFirewallWithUI( return nil } - if err := firewallRestartService(ctx); err != nil { - logger.Warning("PVE firewall restore: reload/restart failed: %v", err) + restartErr := firewallRestartService(ctx) + if restartErr != nil { + logger.Warning("PVE firewall restore: reload/restart failed: %v", restartErr) } if rollbackHandle == nil { - logger.Info("PVE firewall restore applied (no rollback timer armed).") + if restartErr != nil { + logger.Warning("PVE firewall restore applied but the service restart failed and no rollback timer is armed; verify firewall state and access from the local console/IPMI.") + } else { + logger.Info("PVE firewall restore applied (no rollback timer armed).") + } return nil } @@ -266,6 +271,17 @@ func maybeApplyPVEFirewallWithUI( "Keep firewall changes?", int(remaining.Seconds()), ) + if restartErr != nil { + // Surface the failed restart at the decision point: the new rules may not be + // active yet and the live firewall state is uncertain. The default stays + // "Rollback" (defaultYes=false); an explicit Keep is still honored. + commitMessage = fmt.Sprintf( + "WARNING: the firewall service restart FAILED (%v).\n"+ + "The new rules may not be active yet and the live firewall state is uncertain.\n"+ + "If unsure, choose Rollback.\n\n", + restartErr, + ) + commitMessage + } commit, err := ui.ConfirmAction(ctx, "Commit firewall changes", commitMessage, "Keep", "Rollback", remaining, false) if err != nil { if errors.Is(err, input.ErrInputAborted) || errors.Is(err, context.Canceled) { diff --git a/internal/orchestrator/restore_firewall_additional_test.go b/internal/orchestrator/restore_firewall_additional_test.go index 0a09410..80b513f 100644 --- a/internal/orchestrator/restore_firewall_additional_test.go +++ b/internal/orchestrator/restore_firewall_additional_test.go @@ -203,11 +203,13 @@ type scriptedConfirmAction struct { type scriptedRestoreWorkflowUI struct { *fakeRestoreWorkflowUI - script []scriptedConfirmAction - calls int + script []scriptedConfirmAction + calls int + messages []string } func (s *scriptedRestoreWorkflowUI) ConfirmAction(ctx context.Context, title, message, yesLabel, noLabel string, timeout time.Duration, defaultYes bool) (bool, error) { + s.messages = append(s.messages, message) if s.calls >= len(s.script) { return false, fmt.Errorf("unexpected ConfirmAction call %d (title=%q)", s.calls+1, strings.TrimSpace(title)) } @@ -1349,6 +1351,17 @@ func TestMaybeApplyPVEFirewallWithUI_AdditionalBranches(t *testing.T) { if _, err := fakeFS.Stat(markerPath); err == nil || !os.IsNotExist(err) { t.Fatalf("expected rollback marker removed; stat err=%v", err) } + // H09: the failed restart must be surfaced at the commit decision point so + // the operator does not keep changes blindly (the default is still Rollback). + bannerShown := false + for _, m := range ui.messages { + if strings.Contains(m, "restart FAILED") { + bannerShown = true + } + } + if !bannerShown { + t.Fatalf("expected the commit prompt to warn about the failed restart; messages=%v", ui.messages) + } }) t.Run("commit context canceled returns canceled error", func(t *testing.T) { diff --git a/internal/orchestrator/restore_workflow_ui_extract.go b/internal/orchestrator/restore_workflow_ui_extract.go index 1ab1786..3093b98 100644 --- a/internal/orchestrator/restore_workflow_ui_extract.go +++ b/internal/orchestrator/restore_workflow_ui_extract.go @@ -306,6 +306,9 @@ func (w *restoreUIWorkflowRun) applyStagedCategories() error { }}, {name: "PVE SDN staged apply", run: func() error { return maybeApplyPVESDNFromStage(w.ctx, w.logger, w.plan, w.stageRoot, w.cfg.DryRun) }}, {name: "Access control staged apply", run: w.applyAccessControlFromStage}, + {name: "System accounts staged apply", run: func() error { + return maybeApplyAccountsFromStage(w.ctx, w.logger, w.plan, w.stageRoot, w.cfg.DryRun) + }}, {name: "Notifications staged apply", run: func() error { return maybeApplyNotificationsFromStage(w.ctx, w.logger, w.plan, w.stageRoot, w.cfg.DryRun) }}, diff --git a/internal/orchestrator/staging.go b/internal/orchestrator/staging.go index b3ebdaf..054a34f 100644 --- a/internal/orchestrator/staging.go +++ b/internal/orchestrator/staging.go @@ -23,6 +23,7 @@ func isStagedCategoryID(id string) bool { "pbs_notifications", "pve_access_control", "pbs_access_control", + "accounts", "pve_firewall", "pve_ha", "pve_sdn": diff --git a/internal/orchestrator/temp_registry.go b/internal/orchestrator/temp_registry.go index f0b7126..1f69d8f 100644 --- a/internal/orchestrator/temp_registry.go +++ b/internal/orchestrator/temp_registry.go @@ -17,8 +17,70 @@ const ( defaultRegistryEnvVar = "PROXMOX_TEMP_REGISTRY_PATH" defaultRegistryPath = "/var/run/proxsave/temp-dirs.json" registryFallbackDir = "proxsave" + workspaceMarker = ".proxsave-marker" ) +// workspaceRoot is the shared root under which all ProxSave temp workspaces +// are created (MkdirTemp children). CleanupOrphaned only removes paths contained +// here, and the backup/decrypt paths validate it before use. It is a var (not a +// const) so tests can point it at a scratch directory. +var workspaceRoot = "/tmp/proxsave" + +// ensureSecureTempRoot validates (and creates if missing) the shared temp root so +// it cannot be hijacked by an attacker who pre-creates /tmp/proxsave as a symlink +// or a world-writable / non-root-owned directory before ProxSave runs (issue #54). +func ensureSecureTempRoot(fsys FS, path string) error { + info, err := fsys.Lstat(path) + if err != nil { + if os.IsNotExist(err) { + return fsys.MkdirAll(path, 0o700) + } + return fmt.Errorf("stat temp root %s: %w", path, err) + } + if info == nil { + // Defensive: a well-behaved FS returns a non-nil FileInfo on success; if it + // does not, fall back to ensuring the directory exists. + return fsys.MkdirAll(path, 0o700) + } + if info.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("refusing to use temp root %s: it is a symlink", path) + } + if !info.IsDir() { + return fmt.Errorf("refusing to use temp root %s: not a directory", path) + } + if info.Mode().Perm()&0o022 != 0 { + return fmt.Errorf("refusing to use temp root %s: group/world-writable (mode %#o)", path, info.Mode().Perm()) + } + if st, ok := info.Sys().(*syscall.Stat_t); ok { + if st.Uid != 0 && int(st.Uid) != os.Geteuid() { + return fmt.Errorf("refusing to use temp root %s: owned by uid %d, not root/self", path, st.Uid) + } + } + return nil +} + +// workspacePathIsRemovable reports whether path is a genuine ProxSave temp +// workspace that CleanupOrphaned may RemoveAll: it must be a non-symlink +// directory contained directly under workspaceRoot and carry the marker file +// written before a workspace is registered (issue #55). This prevents a poisoned +// registry (or a controlled PROXMOX_TEMP_REGISTRY_PATH) from deleting arbitrary +// paths. +func workspacePathIsRemovable(path string) bool { + clean := filepath.Clean(path) + root := filepath.Clean(workspaceRoot) + if clean == root || !strings.HasPrefix(clean, root+string(os.PathSeparator)) { + return false + } + info, err := os.Lstat(clean) + if err != nil || info.Mode()&os.ModeSymlink != 0 || !info.IsDir() { + return false + } + if _, err := os.Lstat(filepath.Join(clean, workspaceMarker)); err != nil { + return false + } + return true +} + type tempDirRecord struct { Path string `json:"path"` PID int `json:"pid"` @@ -102,6 +164,13 @@ func (r *TempDirRegistry) CleanupOrphaned(maxAge time.Duration) (int, error) { alive := processAlive(entry.PID) if stale || !alive { + if !workspacePathIsRemovable(entry.Path) { + if r.logger != nil { + r.logger.Warning("Refusing to remove registry entry %s: not a ProxSave workspace under %s; dropping untrusted entry", entry.Path, workspaceRoot) + } + // Drop the untrusted entry without touching the filesystem path. + continue + } if r.logger != nil { r.logger.Debug("Cleaning orphaned temp dir %s (pid=%d)...", entry.Path, entry.PID) } diff --git a/internal/orchestrator/temp_registry_test.go b/internal/orchestrator/temp_registry_test.go index 37eafc4..135e798 100644 --- a/internal/orchestrator/temp_registry_test.go +++ b/internal/orchestrator/temp_registry_test.go @@ -52,31 +52,44 @@ func TestTempDirRegistryRegisterAndDeregister(t *testing.T) { } func TestTempDirRegistryCleanupOrphaned(t *testing.T) { + origRoot := workspaceRoot + workspaceRoot = t.TempDir() + t.Cleanup(func() { workspaceRoot = origRoot }) + regPath := filepath.Join(t.TempDir(), "temp-dirs.json") registry, err := NewTempDirRegistry(newTestLogger(), regPath) if err != nil { t.Fatalf("NewTempDirRegistry failed: %v", err) } - staleDir := filepath.Join(t.TempDir(), "stale") - if err := os.MkdirAll(staleDir, 0o755); err != nil { + // A legitimate workspace: under the trusted root and carrying the marker. + staleDir := filepath.Join(workspaceRoot, "proxsave-stale") + if err := os.MkdirAll(staleDir, 0o700); err != nil { t.Fatalf("mkdir stale dir: %v", err) } + if err := os.WriteFile(filepath.Join(staleDir, workspaceMarker), []byte("m"), 0o600); err != nil { + t.Fatalf("write marker: %v", err) + } + // A poisoned entry pointing outside the trusted root: must NOT be deleted (#55). + outsideDir := filepath.Join(t.TempDir(), "outside") + if err := os.MkdirAll(outsideDir, 0o700); err != nil { + t.Fatalf("mkdir outside dir: %v", err) + } - if err := registry.Register(staleDir); err != nil { - t.Fatalf("register stale dir: %v", err) + for _, dir := range []string{staleDir, outsideDir} { + if err := registry.Register(dir); err != nil { + t.Fatalf("register %s: %v", dir, err) + } } entries, err := registry.loadEntries() if err != nil { t.Fatalf("load entries: %v", err) } - if len(entries) != 1 { - t.Fatalf("expected 1 entry, got %d", len(entries)) + for i := range entries { + entries[i].CreatedAt = time.Now().Add(-48 * time.Hour) + entries[i].PID = -1 } - - entries[0].CreatedAt = time.Now().Add(-48 * time.Hour) - entries[0].PID = -1 if err := registry.saveEntries(entries); err != nil { t.Fatalf("save entries: %v", err) } @@ -86,11 +99,15 @@ func TestTempDirRegistryCleanupOrphaned(t *testing.T) { t.Fatalf("cleanup orphaned: %v", err) } if cleaned != 1 { - t.Fatalf("expected 1 directory cleaned, got %d", cleaned) + t.Fatalf("expected exactly 1 directory cleaned (the contained workspace), got %d", cleaned) } if _, err := os.Stat(staleDir); !os.IsNotExist(err) { - t.Fatalf("expected stale dir to be removed, err=%v", err) + t.Fatalf("expected contained workspace to be removed, err=%v", err) + } + // The out-of-root entry must be left on disk (only dropped from the registry). + if _, err := os.Stat(outsideDir); err != nil { + t.Fatalf("out-of-root dir must NOT be removed by CleanupOrphaned: %v", err) } entries, err = registry.loadEntries() @@ -98,6 +115,53 @@ func TestTempDirRegistryCleanupOrphaned(t *testing.T) { t.Fatalf("load entries: %v", err) } if len(entries) != 0 { - t.Fatalf("expected registry to be empty after cleanup, got %d", len(entries)) + t.Fatalf("expected registry empty after cleanup (workspace removed, untrusted entry dropped), got %d: %+v", len(entries), entries) } } + +func TestEnsureSecureTempRoot(t *testing.T) { + t.Run("creates missing root 0700", func(t *testing.T) { + root := filepath.Join(t.TempDir(), "proxsave") + if err := ensureSecureTempRoot(osFS{}, root); err != nil { + t.Fatalf("ensureSecureTempRoot: %v", err) + } + info, err := os.Lstat(root) + if err != nil { + t.Fatalf("lstat: %v", err) + } + if info.Mode().Perm() != 0o700 { + t.Fatalf("expected created root mode 0700, got %#o", info.Mode().Perm()) + } + }) + + t.Run("accepts existing root-owned 0755 dir", func(t *testing.T) { + root := t.TempDir() // created 0700 by default; relax to 0755 + if err := os.Chmod(root, 0o755); err != nil { + t.Fatal(err) + } + if err := ensureSecureTempRoot(osFS{}, root); err != nil { + t.Fatalf("expected existing 0755 dir accepted, got %v", err) + } + }) + + t.Run("rejects symlink", func(t *testing.T) { + realDir := t.TempDir() + link := filepath.Join(t.TempDir(), "proxsave-link") + if err := os.Symlink(realDir, link); err != nil { + t.Fatal(err) + } + if err := ensureSecureTempRoot(osFS{}, link); err == nil { + t.Fatal("expected ensureSecureTempRoot to reject a symlinked temp root (issue #54)") + } + }) + + t.Run("rejects world-writable dir", func(t *testing.T) { + root := t.TempDir() + if err := os.Chmod(root, 0o777); err != nil { + t.Fatal(err) + } + if err := ensureSecureTempRoot(osFS{}, root); err == nil { + t.Fatal("expected ensureSecureTempRoot to reject a world-writable temp root (issue #54)") + } + }) +}