Skip to content

Commit e89ab35

Browse files
committed
[SM6.10][Bugfix] MultiplyAdd - Convert Bias vector to output vector type
If the Bias vector type on MultiplytAdd is different than the output vector type, convert it first to the output vector type before calling `dx.op.linAlgMatVecMulAdd` op. The interpretation for the bias vector is always going to be set to the output vector interpretation on the op. To accommodate this, MultiplyAdd functions have been split into two variants - one where bias interpretation matches the output vector interpretation and no conversion is necessary, and second where the types don't match and the bias vector is first converted to the output vector type before passing it into the __builtin_LinAlg_MatrixVectorMultiplyAdd. Adds`hlsl::__detail::TypeTraits` struct to enable mapping of HLSL scalar types to component type enum values. Fixes #8390
1 parent d831cb4 commit e89ab35

2 files changed

Lines changed: 192 additions & 31 deletions

File tree

tools/clang/lib/Headers/hlsl/dx/linalg.h

Lines changed: 138 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,19 @@ template <ComponentEnum CompTy> struct ComponentTypeTraits {
165165
static const uint ElementsPerScalar = 4;
166166
};
167167

168+
template <typename T> struct TypeTraits {
169+
static const ComponentEnum CompType =
170+
(ComponentEnum)dxil::ComponentType::Invalid;
171+
};
172+
168173
#define __MATRIX_SCALAR_COMPONENT_MAPPING(enum_val, type) \
169174
template <> struct ComponentTypeTraits<enum_val> { \
170175
using Type = type; \
171176
static const bool IsNativeScalar = true; \
172177
static const uint ElementsPerScalar = 1; \
178+
}; \
179+
template <> struct TypeTraits<type> { \
180+
static const ComponentEnum CompType = enum_val; \
173181
};
174182

175183
#if __HLSL_ENABLE_16_BIT
@@ -498,14 +506,60 @@ Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
498506
template <typename OutputElTy, typename InputElTy, typename BiasElTy,
499507
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
500508
// clang-format off
501-
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value, vector<OutputElTy, M> >::type
509+
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value &&
510+
__detail::TypeTraits<BiasElTy>::CompType ==
511+
__detail::TypeTraits<OutputElTy>::CompType,
512+
vector<OutputElTy, M> >::type
502513
// clang-format on
503514
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
504515
vector<InputElTy, K> Vec, vector<BiasElTy, M> Bias) {
505516
vector<OutputElTy, M> Result;
506-
__builtin_LinAlg_MatrixVectorMultiplyAdd(Result, MatrixA.__handle,
507-
hlsl::is_signed<OutputElTy>::value,
508-
Vec, MatrixDT, Bias, MatrixDT);
517+
__builtin_LinAlg_MatrixVectorMultiplyAdd(
518+
Result, MatrixA.__handle, hlsl::is_signed<OutputElTy>::value, Vec,
519+
__detail::TypeTraits<InputElTy>::CompType, Bias, __detail::TypeTraits<OutputElTy>::CompType);
520+
return Result;
521+
}
522+
523+
template <typename OutputElTy, typename InputElTy, typename BiasElTy,
524+
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
525+
// clang-format off
526+
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value &&
527+
__detail::TypeTraits<BiasElTy>::CompType !=
528+
__detail::TypeTraits<OutputElTy>::CompType,
529+
vector<OutputElTy, M> >::type
530+
// clang-format on
531+
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
532+
vector<InputElTy, K> Vec, vector<BiasElTy, M> Bias) {
533+
vector<OutputElTy, M> BiasVecConv;
534+
__builtin_LinAlg_Convert(BiasVecConv, Bias,
535+
__detail::TypeTraits<BiasElTy>::CompType,
536+
__detail::TypeTraits<OutputElTy>::CompType);
537+
vector<OutputElTy, M> Result;
538+
__builtin_LinAlg_MatrixVectorMultiplyAdd(
539+
Result, MatrixA.__handle, hlsl::is_signed<OutputElTy>::value, Vec,
540+
__detail::TypeTraits<InputElTy>::CompType, BiasVecConv,
541+
__detail::TypeTraits<OutputElTy>::CompType);
542+
return Result;
543+
}
544+
545+
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
546+
typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
547+
ComponentEnum MatrixDT>
548+
// clang-format off
549+
typename hlsl::enable_if<
550+
VecK == __detail::ScalarCountFromPackedComponents<InputInterp, K>::Value &&
551+
__detail::TypeTraits<BiasElTy>::CompType ==
552+
__detail::TypeTraits<OutputElTy>::CompType,
553+
vector<OutputElTy, M> >::type
554+
// clang-format on
555+
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
556+
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
557+
vector<BiasElTy, M> Bias) {
558+
vector<OutputElTy, M> Result;
559+
__builtin_LinAlg_MatrixVectorMultiplyAdd(
560+
Result, MatrixA.__handle, hlsl::is_signed<OutputElTy>::value,
561+
InterpVec.Data, InterpVec.Interpretation, Bias,
562+
__detail::TypeTraits<OutputElTy>::CompType);
509563
return Result;
510564
}
511565

@@ -514,55 +568,121 @@ template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
514568
ComponentEnum MatrixDT>
515569
// clang-format off
516570
typename hlsl::enable_if<
517-
VecK == __detail::ScalarCountFromPackedComponents<InputInterp, K>::Value,
571+
VecK == __detail::ScalarCountFromPackedComponents<InputInterp, K>::Value &&
572+
__detail::TypeTraits<BiasElTy>::CompType !=
573+
__detail::TypeTraits<OutputElTy>::CompType,
518574
vector<OutputElTy, M> >::type
519575
// clang-format on
520576
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
521577
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
522578
vector<BiasElTy, M> Bias) {
579+
580+
vector<OutputElTy, M> BiasVecConv;
581+
__builtin_LinAlg_Convert(BiasVecConv, Bias,
582+
__detail::TypeTraits<BiasElTy>::CompType,
583+
__detail::TypeTraits<OutputElTy>::CompType);
584+
523585
vector<OutputElTy, M> Result;
524586
__builtin_LinAlg_MatrixVectorMultiplyAdd(
525587
Result, MatrixA.__handle, hlsl::is_signed<OutputElTy>::value,
526-
InterpVec.Data, InterpVec.Interpretation, Bias, MatrixDT);
588+
InterpVec.Data, InterpVec.Interpretation, BiasVecConv,
589+
__detail::TypeTraits<OutputElTy>::CompType);
590+
return Result;
591+
}
592+
593+
template <typename OutputElTy, typename InputElTy, ComponentEnum BiasInterp,
594+
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
595+
// clang-format off
596+
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value &&
597+
__detail::TypeTraits<OutputElTy>::CompType == BiasInterp,
598+
vector<OutputElTy, M> >::type
599+
// clang-format on
600+
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
601+
vector<InputElTy, K> Vec, VectorRef<BiasInterp, M> BiasRef) {
602+
using BiasOutputVecTy = vector<OutputElTy, M>;
603+
BiasOutputVecTy BiasVec =
604+
BiasRef.Buf.template Load<BiasOutputVecTy>(BiasRef.Offset);
605+
606+
BiasOutputVecTy Result;
607+
__builtin_LinAlg_MatrixVectorMultiplyAdd(Result, MatrixA.__handle,
608+
hlsl::is_signed<OutputElTy>::value,
609+
Vec, MatrixDT, BiasVec, BiasInterp);
527610
return Result;
528611
}
529612

530-
template <typename OutputElTy, typename InputElTy, ComponentEnum BiasElTy,
613+
template <typename OutputElTy, typename InputElTy, ComponentEnum BiasInterp,
531614
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
532615
// clang-format off
533-
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value,
616+
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value &&
617+
__detail::TypeTraits<OutputElTy>::CompType != BiasInterp,
534618
vector<OutputElTy, M> >::type
535619
// clang-format on
536620
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
537-
vector<InputElTy, K> Vec, VectorRef<BiasElTy, M> BiasRef) {
621+
vector<InputElTy, K> Vec, VectorRef<BiasInterp, M> BiasRef) {
538622
using BiasVecTy =
539-
vector<typename __detail::ComponentTypeTraits<BiasElTy>::Type, M>;
623+
vector<typename __detail::ComponentTypeTraits<BiasInterp>::Type,
624+
__detail::ScalarCountFromPackedComponents<BiasInterp, M>::Value>;
540625
BiasVecTy BiasVec = BiasRef.Buf.template Load<BiasVecTy>(BiasRef.Offset);
626+
627+
vector<OutputElTy, M> BiasVecConv;
628+
ComponentEnum OutputCompType = __detail::TypeTraits<OutputElTy>::CompType;
629+
__builtin_LinAlg_Convert(BiasVecConv, BiasVec, BiasInterp, OutputCompType);
630+
541631
vector<OutputElTy, M> Result;
542-
__builtin_LinAlg_MatrixVectorMultiplyAdd(Result, MatrixA.__handle,
543-
hlsl::is_signed<OutputElTy>::value,
544-
Vec, MatrixDT, BiasVec, BiasElTy);
632+
__builtin_LinAlg_MatrixVectorMultiplyAdd(
633+
Result, MatrixA.__handle, hlsl::is_signed<OutputElTy>::value, Vec,
634+
__detail::TypeTraits<InputElTy>::CompType, BiasVecConv, OutputCompType);
545635
return Result;
546636
}
547637

548638
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
549-
ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
639+
ComponentEnum BiasInterp, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
550640
ComponentEnum MatrixDT>
551641
// clang-format off
552642
typename hlsl::enable_if<
553-
VecK == __detail::ScalarCountFromPackedComponents<InputInterp, K>::Value,
643+
VecK == __detail::ScalarCountFromPackedComponents<InputInterp, K>::Value &&
644+
__detail::TypeTraits<OutputElTy>::CompType == BiasInterp,
554645
vector<OutputElTy, M> >::type
555646
// clang-format on
556647
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
557648
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
558-
VectorRef<BiasElTy, M> BiasRef) {
649+
VectorRef<BiasInterp, M> BiasRef) {
650+
using BiasOutputVecTy = vector<OutputElTy, M>;
651+
BiasOutputVecTy BiasVec =
652+
BiasRef.Buf.template Load<BiasOutputVecTy>(BiasRef.Offset);
653+
654+
vector<OutputElTy, M> Result;
655+
__builtin_LinAlg_MatrixVectorMultiplyAdd(
656+
Result, MatrixA.__handle, hlsl::is_signed<OutputElTy>::value,
657+
InterpVec.Data, InterpVec.Interpretation, BiasVec, BiasInterp);
658+
return Result;
659+
}
660+
661+
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
662+
ComponentEnum BiasInterp, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
663+
ComponentEnum MatrixDT>
664+
// clang-format off
665+
typename hlsl::enable_if<
666+
VecK == __detail::ScalarCountFromPackedComponents<InputInterp, K>::Value &&
667+
__detail::TypeTraits<OutputElTy>::CompType != BiasInterp,
668+
vector<OutputElTy, M> >::type
669+
// clang-format on
670+
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
671+
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
672+
VectorRef<BiasInterp, M> BiasRef) {
559673
using BiasVecTy =
560-
vector<typename __detail::ComponentTypeTraits<BiasElTy>::Type, M>;
674+
vector<typename __detail::ComponentTypeTraits<BiasInterp>::Type,
675+
__detail::ScalarCountFromPackedComponents<BiasInterp, M>::Value>;
561676
BiasVecTy BiasVec = BiasRef.Buf.template Load<BiasVecTy>(BiasRef.Offset);
677+
678+
ComponentEnum OutputCompType = __detail::TypeTraits<OutputElTy>::CompType;
679+
vector<OutputElTy, M> BiasVecConv;
680+
__builtin_LinAlg_Convert(BiasVecConv, BiasVec, BiasInterp, OutputCompType);
681+
562682
vector<OutputElTy, M> Result;
563683
__builtin_LinAlg_MatrixVectorMultiplyAdd(
564684
Result, MatrixA.__handle, hlsl::is_signed<OutputElTy>::value,
565-
InterpVec.Data, InterpVec.Interpretation, BiasVec, BiasElTy);
685+
InterpVec.Data, InterpVec.Interpretation, BiasVecConv, OutputCompType);
566686
return Result;
567687
}
568688

0 commit comments

Comments
 (0)