Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 29 additions & 144 deletions src/core/include/math/hal/intnat/transformnat-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,16 +160,8 @@ void NumberTheoreticTransformNat<VecType>::ForwardTransformIterative(const VecTy

omegaFactor = omega.ModMul(oddVal, modulus, mu);
evenVal = (*result)[indexEven];
oddVal = evenVal;
oddVal += omegaFactor;
if (oddVal >= modulus) {
oddVal -= modulus;
}

if (evenVal < omegaFactor) {
evenVal += modulus;
}
evenVal -= omegaFactor;
oddVal = evenVal.ModAddFast(omegaFactor, modulus);
evenVal = evenVal.ModSubFast(omegaFactor, modulus);

(*result)[indexEven] = oddVal;
(*result)[indexOdd] = evenVal;
Expand Down Expand Up @@ -220,15 +212,8 @@ void NumberTheoreticTransformNat<VecType>::ForwardTransformToBitReverseInPlace(c
omegaFactor = (*element)[indexHi];
omegaFactor.ModMulFastEq(omega, modulus, mu);

hiVal = loVal + omegaFactor;
if (hiVal >= modulus) {
hiVal -= modulus;
}

if (loVal < omegaFactor) {
loVal += modulus;
}
loVal -= omegaFactor;
hiVal = loVal.ModAddFast(omegaFactor, modulus);
loVal = loVal.ModSubFast(omegaFactor, modulus);

(*element)[indexLo] = hiVal;
(*element)[indexHi] = loVal;
Expand All @@ -254,7 +239,7 @@ void NumberTheoreticTransformNat<VecType>::ForwardTransformToBitReverse(const Ve
result->SetModulus(modulus);

uint32_t i, m, j1, j2, indexOmega, indexLo, indexHi;
IntType omega, omegaFactor, loVal, hiVal, zero(0);
IntType omega, omegaFactor, loVal, hiVal;

for (i = 0; i < n; ++i) {
(*result)[i] = element[i];
Expand All @@ -272,25 +257,12 @@ void NumberTheoreticTransformNat<VecType>::ForwardTransformToBitReverse(const Ve
indexHi = indexLo + t;
loVal = (*result)[indexLo];
omegaFactor = (*result)[indexHi];
if (omegaFactor != zero) {
omegaFactor.ModMulFastEq(omega, modulus, mu);

hiVal = loVal + omegaFactor;
if (hiVal >= modulus) {
hiVal -= modulus;
}

if (loVal < omegaFactor) {
loVal += modulus;
}
loVal -= omegaFactor;

(*result)[indexLo] = hiVal;
(*result)[indexHi] = loVal;
}
else {
(*result)[indexHi] = loVal;
}
// Unconditional compute avoids the data-dependent `!= zero` skip.
omegaFactor.ModMulFastEq(omega, modulus, mu);
hiVal = loVal.ModAddFast(omegaFactor, modulus);
loVal = loVal.ModSubFast(omegaFactor, modulus);
(*result)[indexLo] = hiVal;
(*result)[indexHi] = loVal;
}
}
t >>= 1;
Expand Down Expand Up @@ -329,22 +301,8 @@ void NumberTheoreticTransformNat<VecType>::ForwardTransformToBitReverseInPlace(c
auto omegaFactor{(*element)[j1 + t]};
omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega);
auto loVal{(*element)[j1 + 0]};
#if defined(__GNUC__) && !defined(__clang__)
auto hiVal{loVal + omegaFactor};
if (hiVal >= modulus)
hiVal -= modulus;
if (loVal < omegaFactor)
loVal += modulus;
loVal -= omegaFactor;
(*element)[j1 + 0] = hiVal;
(*element)[j1 + t] = loVal;
#else
// fixes Clang slowdown issue, but requires lowVal be less than modulus
(*element)[j1 + 0] += omegaFactor - (omegaFactor >= (modulus - loVal) ? modulus : 0);
if (omegaFactor > loVal)
loVal += modulus;
(*element)[j1 + t] = loVal - omegaFactor;
#endif
(*element)[j1 + 0] = loVal.ModAddFast(omegaFactor, modulus);
(*element)[j1 + t] = loVal.ModSubFast(omegaFactor, modulus);
}
}
}
Expand All @@ -355,21 +313,8 @@ void NumberTheoreticTransformNat<VecType>::ForwardTransformToBitReverseInPlace(c
auto preconOmega{preconRootOfUnityTable[(i >> 1) + n]};
omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega);
auto loVal{(*element)[i + 0]};
#if defined(__GNUC__) && !defined(__clang__)
auto hiVal{loVal + omegaFactor};
if (hiVal >= modulus)
hiVal -= modulus;
if (loVal < omegaFactor)
loVal += modulus;
loVal -= omegaFactor;
(*element)[i + 0] = hiVal;
(*element)[i + 1] = loVal;
#else
(*element)[i + 0] += omegaFactor - (omegaFactor >= (modulus - loVal) ? modulus : 0);
if (omegaFactor > loVal)
loVal += modulus;
(*element)[i + 1] = loVal - omegaFactor;
#endif
(*element)[i + 0] = loVal.ModAddFast(omegaFactor, modulus);
(*element)[i + 1] = loVal.ModSubFast(omegaFactor, modulus);
}
}

Expand All @@ -394,7 +339,7 @@ void NumberTheoreticTransformNat<VecType>::ForwardTransformToBitReverse(const Ve

uint32_t indexOmega, indexHi;
NativeInteger preconOmega;
IntType omega, omegaFactor, loVal, hiVal, zero(0);
IntType omega, omegaFactor, loVal, hiVal;

uint32_t t = (n >> 1);
uint32_t logt1 = GetMSB(t);
Expand All @@ -410,25 +355,11 @@ void NumberTheoreticTransformNat<VecType>::ForwardTransformToBitReverse(const Ve
indexHi = indexLo + t;
loVal = (*result)[indexLo];
omegaFactor = (*result)[indexHi];
if (omegaFactor != zero) {
omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega);

hiVal = loVal + omegaFactor;
if (hiVal >= modulus) {
hiVal -= modulus;
}

if (loVal < omegaFactor) {
loVal += modulus;
}
loVal -= omegaFactor;

(*result)[indexLo] = hiVal;
(*result)[indexHi] = loVal;
}
else {
(*result)[indexHi] = loVal;
}
omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega);
hiVal = loVal.ModAddFast(omegaFactor, modulus);
loVal = loVal.ModSubFast(omegaFactor, modulus);
(*result)[indexLo] = hiVal;
(*result)[indexHi] = loVal;
}
}
}
Expand Down Expand Up @@ -461,17 +392,8 @@ void NumberTheoreticTransformNat<VecType>::InverseTransformFromBitReverseInPlace
hiVal = (*element)[indexHi];
loVal = (*element)[indexLo];

omegaFactor = loVal;
if (omegaFactor < hiVal) {
omegaFactor += modulus;
}

omegaFactor -= hiVal;

loVal += hiVal;
if (loVal >= modulus) {
loVal -= modulus;
}
omegaFactor = loVal.ModSubFast(hiVal, modulus);
loVal = loVal.ModAddFast(hiVal, modulus);

omegaFactor.ModMulFastEq(omega, modulus, mu);

Expand Down Expand Up @@ -545,24 +467,11 @@ void NumberTheoreticTransformNat<VecType>::InverseTransformFromBitReverseInPlace
auto preconOmega{preconRootOfUnityInverseTable[(i + n) >> 1]};
auto loVal{(*element)[i + 0]};
auto hiVal{(*element)[i + 1]};
#if defined(__GNUC__) && !defined(__clang__)
auto omegaFactor{loVal};
if (omegaFactor < hiVal)
omegaFactor += modulus;
omegaFactor -= hiVal;
loVal += hiVal;
if (loVal >= modulus)
loVal -= modulus;
auto omegaFactor{loVal.ModSubFast(hiVal, modulus)};
loVal = loVal.ModAddFast(hiVal, modulus);
omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega);
(*element)[i + 0] = loVal;
(*element)[i + 1] = omegaFactor;
#else
auto omegaFactor{loVal + (hiVal > loVal ? modulus : 0) - hiVal};
loVal += hiVal - (hiVal >= (modulus - loVal) ? modulus : 0);
(*element)[i + 0] = loVal;
omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega);
(*element)[i + 1] = omegaFactor;
#endif
}
}
// inner stages
Expand All @@ -573,23 +482,11 @@ void NumberTheoreticTransformNat<VecType>::InverseTransformFromBitReverseInPlace
for (uint32_t j1{i << logt}, j2{j1 + t}; j1 < j2; ++j1) {
auto loVal{(*element)[j1 + 0]};
auto hiVal{(*element)[j1 + t]};
#if defined(__GNUC__) && !defined(__clang__)
auto omegaFactor{loVal};
if (omegaFactor < hiVal)
omegaFactor += modulus;
omegaFactor -= hiVal;
loVal += hiVal;
if (loVal >= modulus)
loVal -= modulus;
auto omegaFactor{loVal.ModSubFast(hiVal, modulus)};
loVal = loVal.ModAddFast(hiVal, modulus);
omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega);
(*element)[j1 + 0] = loVal;
(*element)[j1 + t] = omegaFactor;
#else
(*element)[j1 + 0] += hiVal - (hiVal >= (modulus - loVal) ? modulus : 0);
auto omegaFactor = loVal + (hiVal > loVal ? modulus : 0) - hiVal;
omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega);
(*element)[j1 + t] = omegaFactor;
#endif
}
}
}
Expand All @@ -601,23 +498,11 @@ void NumberTheoreticTransformNat<VecType>::InverseTransformFromBitReverseInPlace
for (uint32_t j1{0}; j1 < j2; ++j1) {
auto loVal{(*element)[j1]};
auto hiVal{(*element)[j1 + j2]};
#if defined(__GNUC__) && !defined(__clang__)
auto omegaFactor{loVal};
if (omegaFactor < hiVal)
omegaFactor += modulus;
omegaFactor -= hiVal;
loVal += hiVal;
if (loVal >= modulus)
loVal -= modulus;
auto omegaFactor{loVal.ModSubFast(hiVal, modulus)};
loVal = loVal.ModAddFast(hiVal, modulus);
omegaFactor.ModMulFastConstEq(omega1Inv, modulus, preconOmega1Inv);
(*element)[j1 + 0] = loVal;
(*element)[j1 + j2] = omegaFactor;
#else
(*element)[j1] += hiVal - (hiVal >= (modulus - loVal) ? modulus : 0);
auto omegaFactor = loVal + (hiVal > loVal ? modulus : 0) - hiVal;
omegaFactor.ModMulFastConstEq(omega1Inv, modulus, preconOmega1Inv);
(*element)[j1 + j2] = omegaFactor;
#endif
}
// perform remaining n/2 scalar multiplies by (n inverse)
for (uint32_t i = 0; i < j2; ++i)
Expand Down
33 changes: 12 additions & 21 deletions src/core/include/math/hal/intnat/ubintnat.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "math/hal/integer.h"
#include "math/nbtheory.h"

#include "utils/constanttime.h"
#include "utils/debug.h"
#include "utils/exception.h"
#include "utils/inttypes.h"
Expand Down Expand Up @@ -735,11 +736,7 @@ class NativeIntegerT final : public lbcrypto::BigIntegerInterface<NativeIntegerT
* @return is the result of the modulus addition operation.
*/
NativeIntegerT ModAddFast(const NativeIntegerT& b, const NativeIntegerT& modulus) const {
auto r{m_value + b.m_value};
auto& mv{modulus.m_value};
if (r >= mv)
r -= mv;
return {r};
return {lbcrypto::ct::SubIfGE(static_cast<NativeInt>(m_value + b.m_value), modulus.m_value)};
}
/**
* Modulus addition where operands are < modulus. In-place variant.
Expand All @@ -749,10 +746,7 @@ class NativeIntegerT final : public lbcrypto::BigIntegerInterface<NativeIntegerT
* @return is the result of the modulus addition operation.
*/
NativeIntegerT& ModAddFastEq(const NativeIntegerT& b, const NativeIntegerT& modulus) {
auto& mv{modulus.m_value};
m_value += b.m_value;
if (m_value >= mv)
m_value -= mv;
m_value = lbcrypto::ct::SubIfGE(static_cast<NativeInt>(m_value + b.m_value), modulus.m_value);
return *this;
}

Expand Down Expand Up @@ -909,9 +903,7 @@ class NativeIntegerT final : public lbcrypto::BigIntegerInterface<NativeIntegerT
* @return is the result of the modulus subtraction operation.
*/
NativeIntegerT ModSubFast(const NativeIntegerT& b, const NativeIntegerT& modulus) const {
if (m_value < b.m_value)
return {m_value + modulus.m_value - b.m_value};
return {m_value - b.m_value};
return {lbcrypto::ct::ModSubFast(m_value, b.m_value, modulus.m_value)};
}

/**
Expand All @@ -922,9 +914,8 @@ class NativeIntegerT final : public lbcrypto::BigIntegerInterface<NativeIntegerT
* @return is the result of the modulus subtraction operation.
*/
NativeIntegerT& ModSubFastEq(const NativeIntegerT& b, const NativeIntegerT& modulus) {
if (m_value < b.m_value)
return *this = m_value + modulus.m_value - b.m_value;
return *this = m_value - b.m_value;
m_value = lbcrypto::ct::ModSubFast(m_value, b.m_value, modulus.m_value);
return *this;
}

/**
Expand Down Expand Up @@ -1463,9 +1454,9 @@ class NativeIntegerT final : public lbcrypto::BigIntegerInterface<NativeIntegerT
*/
NativeIntegerT ModMulFastConst(const NativeIntegerT& b, const NativeIntegerT& modulus,
const NativeIntegerT& bInv) const {
NativeInt q = MultDHi(m_value, bInv.m_value) + 1;
auto yprime = static_cast<SignedNativeInt>(m_value * b.m_value - q * modulus.m_value);
return {yprime >= 0 ? yprime : yprime + modulus.m_value};
NativeInt q = MultDHi(m_value, bInv.m_value) + 1;
SignedNativeInt y = static_cast<SignedNativeInt>(m_value * b.m_value - q * modulus.m_value);
return {static_cast<NativeInt>(lbcrypto::ct::AddIfNeg(y, static_cast<SignedNativeInt>(modulus.m_value)))};
}

/**
Expand All @@ -1479,9 +1470,9 @@ class NativeIntegerT final : public lbcrypto::BigIntegerInterface<NativeIntegerT
*/
NativeIntegerT& ModMulFastConstEq(const NativeIntegerT& b, const NativeIntegerT& modulus,
const NativeIntegerT& bInv) {
NativeInt q = MultDHi(m_value, bInv.m_value) + 1;
auto yprime = static_cast<SignedNativeInt>(m_value * b.m_value - q * modulus.m_value);
m_value = static_cast<NativeInt>(yprime >= 0 ? yprime : yprime + modulus.m_value);
NativeInt q = MultDHi(m_value, bInv.m_value) + 1;
SignedNativeInt y = static_cast<SignedNativeInt>(m_value * b.m_value - q * modulus.m_value);
m_value = static_cast<NativeInt>(lbcrypto::ct::AddIfNeg(y, static_cast<SignedNativeInt>(modulus.m_value)));
return *this;
}

Expand Down
Loading
Loading