@@ -165,11 +165,19 @@ template <ComponentEnum CompTy> struct ComponentTypeTraits {
165165 static const uint ElementsPerScalar = 4 ;
166166};
167167
168+ template <typename T> struct TypeTraits {
169+ static const ComponentEnum CompType =
170+ (ComponentEnum)dxil::ComponentType::Invalid;
171+ };
172+
168173#define __MATRIX_SCALAR_COMPONENT_MAPPING (enum_val, type ) \
169174 template <> struct ComponentTypeTraits <enum_val> { \
170175 using Type = type; \
171176 static const bool IsNativeScalar = true ; \
172177 static const uint ElementsPerScalar = 1 ; \
178+ }; \
179+ template <> struct TypeTraits <type> { \
180+ static const ComponentEnum CompType = enum_val; \
173181 };
174182
175183#if __HLSL_ENABLE_16_BIT
@@ -498,14 +506,60 @@ Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
498506template <typename OutputElTy, typename InputElTy, typename BiasElTy,
499507 SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
500508// clang-format off
501- typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value, vector<OutputElTy, M> >::type
509+ typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value &&
510+ __detail::TypeTraits<BiasElTy>::CompType ==
511+ __detail::TypeTraits<OutputElTy>::CompType,
512+ vector<OutputElTy, M> >::type
502513// clang-format on
503514MultiplyAdd (Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
504515 vector<InputElTy, K> Vec, vector<BiasElTy, M> Bias) {
505516 vector<OutputElTy, M> Result;
506- __builtin_LinAlg_MatrixVectorMultiplyAdd (Result, MatrixA.__handle ,
507- hlsl::is_signed<OutputElTy>::value,
508- Vec, MatrixDT, Bias, MatrixDT);
517+ __builtin_LinAlg_MatrixVectorMultiplyAdd (
518+ Result, MatrixA.__handle , hlsl::is_signed<OutputElTy>::value, Vec,
519+ __detail::TypeTraits<InputElTy>::CompType, Bias, __detail::TypeTraits<OutputElTy>::CompType);
520+ return Result;
521+ }
522+
523+ template <typename OutputElTy, typename InputElTy, typename BiasElTy,
524+ SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
525+ // clang-format off
526+ typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value &&
527+ __detail::TypeTraits<BiasElTy>::CompType !=
528+ __detail::TypeTraits<OutputElTy>::CompType,
529+ vector<OutputElTy, M> >::type
530+ // clang-format on
531+ MultiplyAdd (Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
532+ vector<InputElTy, K> Vec, vector<BiasElTy, M> Bias) {
533+ vector<OutputElTy, M> BiasVecConv;
534+ __builtin_LinAlg_Convert (BiasVecConv, Bias,
535+ __detail::TypeTraits<BiasElTy>::CompType,
536+ __detail::TypeTraits<OutputElTy>::CompType);
537+ vector<OutputElTy, M> Result;
538+ __builtin_LinAlg_MatrixVectorMultiplyAdd (
539+ Result, MatrixA.__handle , hlsl::is_signed<OutputElTy>::value, Vec,
540+ __detail::TypeTraits<InputElTy>::CompType, BiasVecConv,
541+ __detail::TypeTraits<OutputElTy>::CompType);
542+ return Result;
543+ }
544+
545+ template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
546+ typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
547+ ComponentEnum MatrixDT>
548+ // clang-format off
549+ typename hlsl::enable_if<
550+ VecK == __detail::ScalarCountFromPackedComponents<InputInterp, K>::Value &&
551+ __detail::TypeTraits<BiasElTy>::CompType ==
552+ __detail::TypeTraits<OutputElTy>::CompType,
553+ vector<OutputElTy, M> >::type
554+ // clang-format on
555+ MultiplyAdd (Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
556+ InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
557+ vector<BiasElTy, M> Bias) {
558+ vector<OutputElTy, M> Result;
559+ __builtin_LinAlg_MatrixVectorMultiplyAdd (
560+ Result, MatrixA.__handle , hlsl::is_signed<OutputElTy>::value,
561+ InterpVec.Data , InterpVec.Interpretation , Bias,
562+ __detail::TypeTraits<OutputElTy>::CompType);
509563 return Result;
510564}
511565
@@ -514,55 +568,121 @@ template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
514568 ComponentEnum MatrixDT>
515569// clang-format off
516570typename hlsl::enable_if<
517- VecK == __detail::ScalarCountFromPackedComponents<InputInterp, K>::Value,
571+ VecK == __detail::ScalarCountFromPackedComponents<InputInterp, K>::Value &&
572+ __detail::TypeTraits<BiasElTy>::CompType !=
573+ __detail::TypeTraits<OutputElTy>::CompType,
518574 vector<OutputElTy, M> >::type
519575// clang-format on
520576MultiplyAdd (Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
521577 InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
522578 vector<BiasElTy, M> Bias) {
579+
580+ vector<OutputElTy, M> BiasVecConv;
581+ __builtin_LinAlg_Convert (BiasVecConv, Bias,
582+ __detail::TypeTraits<BiasElTy>::CompType,
583+ __detail::TypeTraits<OutputElTy>::CompType);
584+
523585 vector<OutputElTy, M> Result;
524586 __builtin_LinAlg_MatrixVectorMultiplyAdd (
525587 Result, MatrixA.__handle , hlsl::is_signed<OutputElTy>::value,
526- InterpVec.Data , InterpVec.Interpretation , Bias, MatrixDT);
588+ InterpVec.Data , InterpVec.Interpretation , BiasVecConv,
589+ __detail::TypeTraits<OutputElTy>::CompType);
590+ return Result;
591+ }
592+
593+ template <typename OutputElTy, typename InputElTy, ComponentEnum BiasInterp,
594+ SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
595+ // clang-format off
596+ typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value &&
597+ __detail::TypeTraits<OutputElTy>::CompType == BiasInterp,
598+ vector<OutputElTy, M> >::type
599+ // clang-format on
600+ MultiplyAdd (Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
601+ vector<InputElTy, K> Vec, VectorRef<BiasInterp, M> BiasRef) {
602+ using BiasOutputVecTy = vector<OutputElTy, M>;
603+ BiasOutputVecTy BiasVec =
604+ BiasRef.Buf .template Load <BiasOutputVecTy>(BiasRef.Offset );
605+
606+ BiasOutputVecTy Result;
607+ __builtin_LinAlg_MatrixVectorMultiplyAdd (Result, MatrixA.__handle ,
608+ hlsl::is_signed<OutputElTy>::value,
609+ Vec, MatrixDT, BiasVec, BiasInterp);
527610 return Result;
528611}
529612
530- template <typename OutputElTy, typename InputElTy, ComponentEnum BiasElTy ,
613+ template <typename OutputElTy, typename InputElTy, ComponentEnum BiasInterp ,
531614 SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
532615// clang-format off
533- typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value,
616+ typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value &&
617+ __detail::TypeTraits<OutputElTy>::CompType != BiasInterp,
534618 vector<OutputElTy, M> >::type
535619// clang-format on
536620MultiplyAdd (Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
537- vector<InputElTy, K> Vec, VectorRef<BiasElTy , M> BiasRef) {
621+ vector<InputElTy, K> Vec, VectorRef<BiasInterp , M> BiasRef) {
538622 using BiasVecTy =
539- vector<typename __detail::ComponentTypeTraits<BiasElTy>::Type, M>;
623+ vector<typename __detail::ComponentTypeTraits<BiasInterp>::Type,
624+ __detail::ScalarCountFromPackedComponents<BiasInterp, M>::Value>;
540625 BiasVecTy BiasVec = BiasRef.Buf .template Load <BiasVecTy>(BiasRef.Offset );
626+
627+ vector<OutputElTy, M> BiasVecConv;
628+ ComponentEnum OutputCompType = __detail::TypeTraits<OutputElTy>::CompType;
629+ __builtin_LinAlg_Convert (BiasVecConv, BiasVec, BiasInterp, OutputCompType);
630+
541631 vector<OutputElTy, M> Result;
542- __builtin_LinAlg_MatrixVectorMultiplyAdd (Result, MatrixA. __handle ,
543- hlsl::is_signed<OutputElTy>::value,
544- Vec, MatrixDT, BiasVec, BiasElTy );
632+ __builtin_LinAlg_MatrixVectorMultiplyAdd (
633+ Result, MatrixA. __handle , hlsl::is_signed<OutputElTy>::value, Vec ,
634+ __detail::TypeTraits<InputElTy>::CompType, BiasVecConv, OutputCompType );
545635 return Result;
546636}
547637
548638template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
549- ComponentEnum BiasElTy , SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
639+ ComponentEnum BiasInterp , SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
550640 ComponentEnum MatrixDT>
551641// clang-format off
552642typename hlsl::enable_if<
553- VecK == __detail::ScalarCountFromPackedComponents<InputInterp, K>::Value,
643+ VecK == __detail::ScalarCountFromPackedComponents<InputInterp, K>::Value &&
644+ __detail::TypeTraits<OutputElTy>::CompType == BiasInterp,
554645 vector<OutputElTy, M> >::type
555646// clang-format on
556647MultiplyAdd (Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
557648 InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
558- VectorRef<BiasElTy, M> BiasRef) {
649+ VectorRef<BiasInterp, M> BiasRef) {
650+ using BiasOutputVecTy = vector<OutputElTy, M>;
651+ BiasOutputVecTy BiasVec =
652+ BiasRef.Buf .template Load <BiasOutputVecTy>(BiasRef.Offset );
653+
654+ vector<OutputElTy, M> Result;
655+ __builtin_LinAlg_MatrixVectorMultiplyAdd (
656+ Result, MatrixA.__handle , hlsl::is_signed<OutputElTy>::value,
657+ InterpVec.Data , InterpVec.Interpretation , BiasVec, BiasInterp);
658+ return Result;
659+ }
660+
661+ template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
662+ ComponentEnum BiasInterp, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
663+ ComponentEnum MatrixDT>
664+ // clang-format off
665+ typename hlsl::enable_if<
666+ VecK == __detail::ScalarCountFromPackedComponents<InputInterp, K>::Value &&
667+ __detail::TypeTraits<OutputElTy>::CompType != BiasInterp,
668+ vector<OutputElTy, M> >::type
669+ // clang-format on
670+ MultiplyAdd (Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
671+ InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
672+ VectorRef<BiasInterp, M> BiasRef) {
559673 using BiasVecTy =
560- vector<typename __detail::ComponentTypeTraits<BiasElTy>::Type, M>;
674+ vector<typename __detail::ComponentTypeTraits<BiasInterp>::Type,
675+ __detail::ScalarCountFromPackedComponents<BiasInterp, M>::Value>;
561676 BiasVecTy BiasVec = BiasRef.Buf .template Load <BiasVecTy>(BiasRef.Offset );
677+
678+ ComponentEnum OutputCompType = __detail::TypeTraits<OutputElTy>::CompType;
679+ vector<OutputElTy, M> BiasVecConv;
680+ __builtin_LinAlg_Convert (BiasVecConv, BiasVec, BiasInterp, OutputCompType);
681+
562682 vector<OutputElTy, M> Result;
563683 __builtin_LinAlg_MatrixVectorMultiplyAdd (
564684 Result, MatrixA.__handle , hlsl::is_signed<OutputElTy>::value,
565- InterpVec.Data , InterpVec.Interpretation , BiasVec, BiasElTy );
685+ InterpVec.Data , InterpVec.Interpretation , BiasVecConv, OutputCompType );
566686 return Result;
567687}
568688
0 commit comments