|
15 | 15 |
|
16 | 16 | #include "clang/Sema/SemaHLSL.h" |
17 | 17 | #include "VkConstantsTables.h" |
| 18 | +#include "dxc/DXIL/DxilConstants.h" |
18 | 19 | #include "dxc/DXIL/DxilFunctionProps.h" |
19 | 20 | #include "dxc/DXIL/DxilShaderModel.h" |
20 | 21 | #include "dxc/DXIL/DxilUtil.h" |
@@ -11681,74 +11682,50 @@ static const unsigned kMatVecMulMatrixStrideIdx = 12; |
11681 | 11682 | // MatVecAdd |
11682 | 11683 | const unsigned kMatVecMulAddBiasInterpretation = 15; |
11683 | 11684 |
|
11684 | | -enum MatrixLayout { |
11685 | | - MATRIX_LAYOUT_ROW_MAJOR = 0, |
11686 | | - MATRIX_LAYOUT_COLUMN_MAJOR = 1, |
11687 | | - MATRIX_LAYOUT_MUL_OPTIMAL = 2, |
11688 | | - MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL = 3 |
11689 | | -}; |
11690 | | - |
11691 | | -bool IsValidMatrixLayoutForMulandMulAddOps(unsigned Layout) { |
11692 | | - return Layout <= static_cast<unsigned>( |
11693 | | - MatrixLayout::MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL); |
| 11685 | +static bool IsValidMatrixLayoutForMulandMulAddOps(unsigned Layout) { |
| 11686 | + return Layout <= |
| 11687 | + static_cast<unsigned>(DXIL::LinalgMatrixLayout::OuterProductOptimal); |
11694 | 11688 | } |
11695 | 11689 |
|
11696 | | -bool IsOptimalTypeMatrixLayout(unsigned Layout) { |
11697 | | - return (Layout == (static_cast<unsigned>( |
11698 | | - MatrixLayout::MATRIX_LAYOUT_MUL_OPTIMAL)) || |
11699 | | - (Layout == (static_cast<unsigned>( |
11700 | | - MatrixLayout::MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL)))); |
| 11690 | +static bool IsOptimalTypeMatrixLayout(unsigned Layout) { |
| 11691 | + return ( |
| 11692 | + Layout == (static_cast<unsigned>(DXIL::LinalgMatrixLayout::MulOptimal)) || |
| 11693 | + (Layout == |
| 11694 | + (static_cast<unsigned>(DXIL::LinalgMatrixLayout::OuterProductOptimal)))); |
11701 | 11695 | } |
11702 | 11696 |
|
11703 | | -bool IsValidTransposeForMatrixLayout(unsigned Layout, bool Transposed) { |
11704 | | - switch (static_cast<MatrixLayout>(Layout)) { |
11705 | | - case MatrixLayout::MATRIX_LAYOUT_ROW_MAJOR: |
11706 | | - case MatrixLayout::MATRIX_LAYOUT_COLUMN_MAJOR: |
| 11697 | +static bool IsValidTransposeForMatrixLayout(unsigned Layout, bool Transposed) { |
| 11698 | + switch (static_cast<DXIL::LinalgMatrixLayout>(Layout)) { |
| 11699 | + case DXIL::LinalgMatrixLayout::RowMajor: |
| 11700 | + case DXIL::LinalgMatrixLayout::ColumnMajor: |
11707 | 11701 | return !Transposed; |
11708 | 11702 |
|
11709 | 11703 | default: |
11710 | 11704 | return true; |
11711 | 11705 | } |
11712 | 11706 | } |
11713 | 11707 |
|
11714 | | -enum DataType { |
11715 | | - DATA_TYPE_SINT16 = 2, // ComponentType::I16 |
11716 | | - DATA_TYPE_UINT16 = 3, // ComponentType::U16 |
11717 | | - DATA_TYPE_SINT32 = 4, // ComponentType::I32 |
11718 | | - DATA_TYPE_UINT32 = 5, // ComponentType::U32 |
11719 | | - DATA_TYPE_FLOAT16 = 8, // ComponentType::F16 |
11720 | | - DATA_TYPE_FLOAT32 = 9, // ComponentType::F32 |
11721 | | - DATA_TYPE_SINT8_T4_PACKED = 17, // ComponentType::PackedS8x32 |
11722 | | - DATA_TYPE_UINT8_T4_PACKED = 18, // ComponentType::PackedU8x32 |
11723 | | - DATA_TYPE_UINT8 = 19, // ComponentType::U8 |
11724 | | - DATA_TYPE_SINT8 = 20, // ComponentType::I8 |
11725 | | - DATA_TYPE_FLOAT8_E4M3 = 21, // ComponentType::F8_E4M3 |
11726 | | - // (1 sign, 4 exp, 3 mantissa bits) |
11727 | | - DATA_TYPE_FLOAT8_E5M2 = 22, // ComponentType::F8_E5M2 |
11728 | | - // (1 sign, 5 exp, 2 mantissa bits) |
11729 | | -}; |
11730 | | - |
11731 | | -bool IsPackedType(unsigned type) { |
11732 | | - return (type == static_cast<unsigned>(DATA_TYPE_SINT8_T4_PACKED) || |
11733 | | - type == static_cast<unsigned>(DATA_TYPE_UINT8_T4_PACKED)); |
| 11708 | +static bool IsPackedType(unsigned type) { |
| 11709 | + return (type == static_cast<unsigned>(DXIL::ComponentType::PackedS8x32) || |
| 11710 | + type == static_cast<unsigned>(DXIL::ComponentType::PackedU8x32)); |
11734 | 11711 | } |
11735 | 11712 |
|
11736 | 11713 | static bool IsValidLinalgTypeInterpretation(uint32_t Input, bool InRegister) { |
11737 | 11714 |
|
11738 | 11715 | switch (Input) { |
11739 | | - case DATA_TYPE_SINT16: |
11740 | | - case DATA_TYPE_UINT16: |
11741 | | - case DATA_TYPE_SINT32: |
11742 | | - case DATA_TYPE_UINT32: |
11743 | | - case DATA_TYPE_FLOAT16: |
11744 | | - case DATA_TYPE_FLOAT32: |
11745 | | - case DATA_TYPE_UINT8: |
11746 | | - case DATA_TYPE_SINT8: |
11747 | | - case DATA_TYPE_FLOAT8_E4M3: |
11748 | | - case DATA_TYPE_FLOAT8_E5M2: |
| 11716 | + case DXIL::ComponentType::I16: |
| 11717 | + case DXIL::ComponentType::U16: |
| 11718 | + case DXIL::ComponentType::I32: |
| 11719 | + case DXIL::ComponentType::U32: |
| 11720 | + case DXIL::ComponentType::F16: |
| 11721 | + case DXIL::ComponentType::F32: |
| 11722 | + case DXIL::ComponentType::U8: |
| 11723 | + case DXIL::ComponentType::I8: |
| 11724 | + case DXIL::ComponentType::F8_E4M3: |
| 11725 | + case DXIL::ComponentType::F8_E5M2: |
11749 | 11726 | return true; |
11750 | | - case DATA_TYPE_SINT8_T4_PACKED: |
11751 | | - case DATA_TYPE_UINT8_T4_PACKED: |
| 11727 | + case DXIL::ComponentType::PackedS8x32: |
| 11728 | + case DXIL::ComponentType::PackedU8x32: |
11752 | 11729 | return InRegister; |
11753 | 11730 | default: |
11754 | 11731 | return false; |
@@ -12038,9 +12015,9 @@ static void CheckCommonMulandMulAddParameters(Sema &S, CallExpr *CE, |
12038 | 12015 | diag::err_hlsl_linalg_matrix_layout_invalid) |
12039 | 12016 | << std::to_string(MatrixLayoutValue) |
12040 | 12017 | << std::to_string( |
12041 | | - static_cast<unsigned>(MatrixLayout::MATRIX_LAYOUT_ROW_MAJOR)) |
| 12018 | + static_cast<unsigned>(DXIL::LinalgMatrixLayout::RowMajor)) |
12042 | 12019 | << std::to_string(static_cast<unsigned>( |
12043 | | - MatrixLayout::MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL)); |
| 12020 | + DXIL::LinalgMatrixLayout::OuterProductOptimal)); |
12044 | 12021 | return; |
12045 | 12022 | } |
12046 | 12023 | } else { |
@@ -12184,14 +12161,13 @@ static void CheckOuterProductAccumulateCall(Sema &S, FunctionDecl *FD, |
12184 | 12161 | if (MatrixLayoutExpr->isIntegerConstantExpr(MatrixLayoutExprVal, S.Context)) { |
12185 | 12162 | MatrixLayoutValue = MatrixLayoutExprVal.getLimitedValue(); |
12186 | 12163 | if (MatrixLayoutValue != |
12187 | | - static_cast<unsigned>( |
12188 | | - MatrixLayout::MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL)) { |
| 12164 | + static_cast<unsigned>(DXIL::LinalgMatrixLayout::OuterProductOptimal)) { |
12189 | 12165 | S.Diags.Report( |
12190 | 12166 | MatrixLayoutExpr->getExprLoc(), |
12191 | 12167 | diag:: |
12192 | 12168 | err_hlsl_linalg_outer_prod_acc_matrix_layout_must_be_outer_prod_acc_optimal) |
12193 | 12169 | << std::to_string(static_cast<unsigned>( |
12194 | | - MatrixLayout::MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL)); |
| 12170 | + DXIL::LinalgMatrixLayout::OuterProductOptimal)); |
12195 | 12171 | return; |
12196 | 12172 | } |
12197 | 12173 | } else { |
|
0 commit comments