diff --git a/src/core/include/math/hal/intnat/transformnat-impl.h b/src/core/include/math/hal/intnat/transformnat-impl.h index 543d6b152..ce77dbffd 100644 --- a/src/core/include/math/hal/intnat/transformnat-impl.h +++ b/src/core/include/math/hal/intnat/transformnat-impl.h @@ -160,16 +160,8 @@ void NumberTheoreticTransformNat::ForwardTransformIterative(const VecTy omegaFactor = omega.ModMul(oddVal, modulus, mu); evenVal = (*result)[indexEven]; - oddVal = evenVal; - oddVal += omegaFactor; - if (oddVal >= modulus) { - oddVal -= modulus; - } - - if (evenVal < omegaFactor) { - evenVal += modulus; - } - evenVal -= omegaFactor; + oddVal = evenVal.ModAddFast(omegaFactor, modulus); + evenVal = evenVal.ModSubFast(omegaFactor, modulus); (*result)[indexEven] = oddVal; (*result)[indexOdd] = evenVal; @@ -220,15 +212,8 @@ void NumberTheoreticTransformNat::ForwardTransformToBitReverseInPlace(c omegaFactor = (*element)[indexHi]; omegaFactor.ModMulFastEq(omega, modulus, mu); - hiVal = loVal + omegaFactor; - if (hiVal >= modulus) { - hiVal -= modulus; - } - - if (loVal < omegaFactor) { - loVal += modulus; - } - loVal -= omegaFactor; + hiVal = loVal.ModAddFast(omegaFactor, modulus); + loVal = loVal.ModSubFast(omegaFactor, modulus); (*element)[indexLo] = hiVal; (*element)[indexHi] = loVal; @@ -254,7 +239,7 @@ void NumberTheoreticTransformNat::ForwardTransformToBitReverse(const Ve result->SetModulus(modulus); uint32_t i, m, j1, j2, indexOmega, indexLo, indexHi; - IntType omega, omegaFactor, loVal, hiVal, zero(0); + IntType omega, omegaFactor, loVal, hiVal; for (i = 0; i < n; ++i) { (*result)[i] = element[i]; @@ -272,25 +257,12 @@ void NumberTheoreticTransformNat::ForwardTransformToBitReverse(const Ve indexHi = indexLo + t; loVal = (*result)[indexLo]; omegaFactor = (*result)[indexHi]; - if (omegaFactor != zero) { - omegaFactor.ModMulFastEq(omega, modulus, mu); - - hiVal = loVal + omegaFactor; - if (hiVal >= modulus) { - hiVal -= modulus; - } - - if (loVal < omegaFactor) { - loVal += modulus; - } - loVal -= omegaFactor; - - (*result)[indexLo] = hiVal; - (*result)[indexHi] = loVal; - } - else { - (*result)[indexHi] = loVal; - } + // Unconditional compute avoids the data-dependent `!= zero` skip. + omegaFactor.ModMulFastEq(omega, modulus, mu); + hiVal = loVal.ModAddFast(omegaFactor, modulus); + loVal = loVal.ModSubFast(omegaFactor, modulus); + (*result)[indexLo] = hiVal; + (*result)[indexHi] = loVal; } } t >>= 1; @@ -329,22 +301,8 @@ void NumberTheoreticTransformNat::ForwardTransformToBitReverseInPlace(c auto omegaFactor{(*element)[j1 + t]}; omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega); auto loVal{(*element)[j1 + 0]}; -#if defined(__GNUC__) && !defined(__clang__) - auto hiVal{loVal + omegaFactor}; - if (hiVal >= modulus) - hiVal -= modulus; - if (loVal < omegaFactor) - loVal += modulus; - loVal -= omegaFactor; - (*element)[j1 + 0] = hiVal; - (*element)[j1 + t] = loVal; -#else - // fixes Clang slowdown issue, but requires lowVal be less than modulus - (*element)[j1 + 0] += omegaFactor - (omegaFactor >= (modulus - loVal) ? modulus : 0); - if (omegaFactor > loVal) - loVal += modulus; - (*element)[j1 + t] = loVal - omegaFactor; -#endif + (*element)[j1 + 0] = loVal.ModAddFast(omegaFactor, modulus); + (*element)[j1 + t] = loVal.ModSubFast(omegaFactor, modulus); } } } @@ -355,21 +313,8 @@ void NumberTheoreticTransformNat::ForwardTransformToBitReverseInPlace(c auto preconOmega{preconRootOfUnityTable[(i >> 1) + n]}; omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega); auto loVal{(*element)[i + 0]}; -#if defined(__GNUC__) && !defined(__clang__) - auto hiVal{loVal + omegaFactor}; - if (hiVal >= modulus) - hiVal -= modulus; - if (loVal < omegaFactor) - loVal += modulus; - loVal -= omegaFactor; - (*element)[i + 0] = hiVal; - (*element)[i + 1] = loVal; -#else - (*element)[i + 0] += omegaFactor - (omegaFactor >= (modulus - loVal) ? modulus : 0); - if (omegaFactor > loVal) - loVal += modulus; - (*element)[i + 1] = loVal - omegaFactor; -#endif + (*element)[i + 0] = loVal.ModAddFast(omegaFactor, modulus); + (*element)[i + 1] = loVal.ModSubFast(omegaFactor, modulus); } } @@ -394,7 +339,7 @@ void NumberTheoreticTransformNat::ForwardTransformToBitReverse(const Ve uint32_t indexOmega, indexHi; NativeInteger preconOmega; - IntType omega, omegaFactor, loVal, hiVal, zero(0); + IntType omega, omegaFactor, loVal, hiVal; uint32_t t = (n >> 1); uint32_t logt1 = GetMSB(t); @@ -410,25 +355,11 @@ void NumberTheoreticTransformNat::ForwardTransformToBitReverse(const Ve indexHi = indexLo + t; loVal = (*result)[indexLo]; omegaFactor = (*result)[indexHi]; - if (omegaFactor != zero) { - omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega); - - hiVal = loVal + omegaFactor; - if (hiVal >= modulus) { - hiVal -= modulus; - } - - if (loVal < omegaFactor) { - loVal += modulus; - } - loVal -= omegaFactor; - - (*result)[indexLo] = hiVal; - (*result)[indexHi] = loVal; - } - else { - (*result)[indexHi] = loVal; - } + omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega); + hiVal = loVal.ModAddFast(omegaFactor, modulus); + loVal = loVal.ModSubFast(omegaFactor, modulus); + (*result)[indexLo] = hiVal; + (*result)[indexHi] = loVal; } } } @@ -461,17 +392,8 @@ void NumberTheoreticTransformNat::InverseTransformFromBitReverseInPlace hiVal = (*element)[indexHi]; loVal = (*element)[indexLo]; - omegaFactor = loVal; - if (omegaFactor < hiVal) { - omegaFactor += modulus; - } - - omegaFactor -= hiVal; - - loVal += hiVal; - if (loVal >= modulus) { - loVal -= modulus; - } + omegaFactor = loVal.ModSubFast(hiVal, modulus); + loVal = loVal.ModAddFast(hiVal, modulus); omegaFactor.ModMulFastEq(omega, modulus, mu); @@ -545,24 +467,11 @@ void NumberTheoreticTransformNat::InverseTransformFromBitReverseInPlace auto preconOmega{preconRootOfUnityInverseTable[(i + n) >> 1]}; auto loVal{(*element)[i + 0]}; auto hiVal{(*element)[i + 1]}; -#if defined(__GNUC__) && !defined(__clang__) - auto omegaFactor{loVal}; - if (omegaFactor < hiVal) - omegaFactor += modulus; - omegaFactor -= hiVal; - loVal += hiVal; - if (loVal >= modulus) - loVal -= modulus; + auto omegaFactor{loVal.ModSubFast(hiVal, modulus)}; + loVal = loVal.ModAddFast(hiVal, modulus); omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega); (*element)[i + 0] = loVal; (*element)[i + 1] = omegaFactor; -#else - auto omegaFactor{loVal + (hiVal > loVal ? modulus : 0) - hiVal}; - loVal += hiVal - (hiVal >= (modulus - loVal) ? modulus : 0); - (*element)[i + 0] = loVal; - omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega); - (*element)[i + 1] = omegaFactor; -#endif } } // inner stages @@ -573,23 +482,11 @@ void NumberTheoreticTransformNat::InverseTransformFromBitReverseInPlace for (uint32_t j1{i << logt}, j2{j1 + t}; j1 < j2; ++j1) { auto loVal{(*element)[j1 + 0]}; auto hiVal{(*element)[j1 + t]}; -#if defined(__GNUC__) && !defined(__clang__) - auto omegaFactor{loVal}; - if (omegaFactor < hiVal) - omegaFactor += modulus; - omegaFactor -= hiVal; - loVal += hiVal; - if (loVal >= modulus) - loVal -= modulus; + auto omegaFactor{loVal.ModSubFast(hiVal, modulus)}; + loVal = loVal.ModAddFast(hiVal, modulus); omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega); (*element)[j1 + 0] = loVal; (*element)[j1 + t] = omegaFactor; -#else - (*element)[j1 + 0] += hiVal - (hiVal >= (modulus - loVal) ? modulus : 0); - auto omegaFactor = loVal + (hiVal > loVal ? modulus : 0) - hiVal; - omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega); - (*element)[j1 + t] = omegaFactor; -#endif } } } @@ -601,23 +498,11 @@ void NumberTheoreticTransformNat::InverseTransformFromBitReverseInPlace for (uint32_t j1{0}; j1 < j2; ++j1) { auto loVal{(*element)[j1]}; auto hiVal{(*element)[j1 + j2]}; -#if defined(__GNUC__) && !defined(__clang__) - auto omegaFactor{loVal}; - if (omegaFactor < hiVal) - omegaFactor += modulus; - omegaFactor -= hiVal; - loVal += hiVal; - if (loVal >= modulus) - loVal -= modulus; + auto omegaFactor{loVal.ModSubFast(hiVal, modulus)}; + loVal = loVal.ModAddFast(hiVal, modulus); omegaFactor.ModMulFastConstEq(omega1Inv, modulus, preconOmega1Inv); (*element)[j1 + 0] = loVal; (*element)[j1 + j2] = omegaFactor; -#else - (*element)[j1] += hiVal - (hiVal >= (modulus - loVal) ? modulus : 0); - auto omegaFactor = loVal + (hiVal > loVal ? modulus : 0) - hiVal; - omegaFactor.ModMulFastConstEq(omega1Inv, modulus, preconOmega1Inv); - (*element)[j1 + j2] = omegaFactor; -#endif } // perform remaining n/2 scalar multiplies by (n inverse) for (uint32_t i = 0; i < j2; ++i) diff --git a/src/core/include/math/hal/intnat/ubintnat.h b/src/core/include/math/hal/intnat/ubintnat.h index 2454b5ac1..c2a861de0 100644 --- a/src/core/include/math/hal/intnat/ubintnat.h +++ b/src/core/include/math/hal/intnat/ubintnat.h @@ -41,6 +41,7 @@ #include "math/hal/integer.h" #include "math/nbtheory.h" +#include "utils/constanttime.h" #include "utils/debug.h" #include "utils/exception.h" #include "utils/inttypes.h" @@ -735,11 +736,7 @@ class NativeIntegerT final : public lbcrypto::BigIntegerInterface= mv) - r -= mv; - return {r}; + return {lbcrypto::ct::SubIfGE(static_cast(m_value + b.m_value), modulus.m_value)}; } /** * Modulus addition where operands are < modulus. In-place variant. @@ -749,10 +746,7 @@ class NativeIntegerT final : public lbcrypto::BigIntegerInterface= mv) - m_value -= mv; + m_value = lbcrypto::ct::SubIfGE(static_cast(m_value + b.m_value), modulus.m_value); return *this; } @@ -909,9 +903,7 @@ class NativeIntegerT final : public lbcrypto::BigIntegerInterface(m_value * b.m_value - q * modulus.m_value); - return {yprime >= 0 ? yprime : yprime + modulus.m_value}; + NativeInt q = MultDHi(m_value, bInv.m_value) + 1; + SignedNativeInt y = static_cast(m_value * b.m_value - q * modulus.m_value); + return {static_cast(lbcrypto::ct::AddIfNeg(y, static_cast(modulus.m_value)))}; } /** @@ -1479,9 +1470,9 @@ class NativeIntegerT final : public lbcrypto::BigIntegerInterface(m_value * b.m_value - q * modulus.m_value); - m_value = static_cast(yprime >= 0 ? yprime : yprime + modulus.m_value); + NativeInt q = MultDHi(m_value, bInv.m_value) + 1; + SignedNativeInt y = static_cast(m_value * b.m_value - q * modulus.m_value); + m_value = static_cast(lbcrypto::ct::AddIfNeg(y, static_cast(modulus.m_value))); return *this; } diff --git a/src/core/include/utils/constanttime.h b/src/core/include/utils/constanttime.h new file mode 100644 index 000000000..963ff8069 --- /dev/null +++ b/src/core/include/utils/constanttime.h @@ -0,0 +1,138 @@ +//================================================================================== +// BSD 2-Clause License +// +// Copyright (c) 2014-2026, NJIT, Duality Technologies Inc. and other contributors +// +// All rights reserved. +// +// Author TPOC: contact@openfhe.org +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +//================================================================================== + +/* + Branch-free helpers for modular arithmetic on the decrypt hot path. + + These routines are drop-in replacements for the common conditional + Barrett / sign-correction idioms that appear throughout the integer + backend. The intent is to remove data-dependent branches (and thus + a known class of timing side channels) from every primitive that + is reachable from Decrypt. + + Methodology follows the PermNet-RM constant-time encoder + (https://github.com/BAder82t/PermNet-RM, Issaei 2026): branchless + bitmask select plus a per-call compiler barrier that survives + -O0..-Ofast on GCC/Clang. + + Each helper is a short inline template so it compiles to the same + three or four instructions the optimizer would emit for the + original ternary on typical x86-64 / arm64 builds, but *without* + depending on optimizer heuristics for the branchlessness. +*/ + +#ifndef LBCRYPTO_UTILS_CONSTANTTIME_H +#define LBCRYPTO_UTILS_CONSTANTTIME_H + +#include +#include + +namespace lbcrypto { +namespace ct { + +// Compiler barrier: opaque the given scalar to the optimizer so that +// aggressive passes cannot reconstruct a data-dependent branch via +// value-range propagation. Expands to nothing on compilers that do +// not support the inline-asm form; on those compilers the branchless +// expression survives on its own at -O0..-O3 today, and the guard is +// here as a future-proofing reminder. +#if defined(__GNUC__) || defined(__clang__) +#define OPENFHE_CT_OPAQUE(x) __asm__ volatile("" : "+r"(x)) +#else +#define OPENFHE_CT_OPAQUE(x) ((void)0) +#endif + +// Return x + m if x is negative, x otherwise, without a branch. +// Requires S to be a signed integral type; m is interpreted in the +// signed domain so callers passing an unsigned m should ensure +// m < 2^(bits(S)-1) (true for all RLWE moduli in OpenFHE). +template +inline S AddIfNeg(S x, S m) noexcept { + static_assert(std::is_signed_v, "AddIfNeg requires a signed type"); + constexpr int kSignShift = static_cast(sizeof(S) * 8 - 1); + // Arithmetic right shift replicates the sign bit across the word. + const S mask = x >> kSignShift; // -1 if x < 0, else 0 + S y = x + (mask & m); + OPENFHE_CT_OPAQUE(y); + return y; +} + +// Return x - m if x >= m, x otherwise, without a branch. +// Precondition: x < 2*m, m < 2^(bits(U)-1). All OpenFHE RLWE moduli +// satisfy the latter; the former is guaranteed by every caller of +// this routine (the canonical Barrett / Montgomery post-reduction). +template +inline U SubIfGE(U x, U m) noexcept { + static_assert(std::is_unsigned_v, "SubIfGE requires an unsigned type"); + constexpr int kTopBit = static_cast(sizeof(U) * 8 - 1); + const U diff = x - m; // underflows (wraps) iff x < m + // Top bit of diff is 1 iff underflow happened iff x < m. + const U under = diff >> kTopBit; + const U mask = U(0) - under; // all-ones if underflow, else zero + U y = diff + (mask & m); + OPENFHE_CT_OPAQUE(y); + return y; +} + +// Return a - b mod m, without a branch. +// Precondition: a, b in [0, m), m < 2^(bits(U)-1). +template +inline U ModSubFast(U a, U b, U m) noexcept { + static_assert(std::is_unsigned_v, "ModSubFast requires an unsigned type"); + constexpr int kTopBit = static_cast(sizeof(U) * 8 - 1); + const U diff = a - b; // underflows iff a < b + const U under = diff >> kTopBit; + const U mask = U(0) - under; + U y = diff + (mask & m); + OPENFHE_CT_OPAQUE(y); + return y; +} + +// Return x - m if x > halfQ (the centered-lift threshold), x otherwise. +// Used for signed modular reduction of plaintext-encoded values. +// Precondition: x in [0, m), halfQ = m / 2. +template +inline U SubIfAboveHalf(U x, U m, U halfQ) noexcept { + static_assert(std::is_unsigned_v, "SubIfAboveHalf requires an unsigned type"); + constexpr int kTopBit = static_cast(sizeof(U) * 8 - 1); + // x > halfQ <=> halfQ - x underflows <=> (halfQ - x) >> top is 1. + const U diff = halfQ - x; + const U above = diff >> kTopBit; // 1 if x > halfQ, else 0 + const U mask = U(0) - above; // all-ones if x > halfQ + U y = x - (mask & m); + OPENFHE_CT_OPAQUE(y); + return y; +} + +} // namespace ct +} // namespace lbcrypto + +#endif // LBCRYPTO_UTILS_CONSTANTTIME_H