@@ -11762,7 +11762,7 @@ static bool IsValidVectorAndMatrixDimensions(Sema &S, CallExpr *CE,
1176211762 unsigned OutputVectorSize,
1176311763 unsigned MatrixK, unsigned MatrixM,
1176411764 bool isInputPacked) {
11765- // Check is output vector size is equals to matrix dimension M
11765+ // Check if output vector size equals to matrix dimension M
1176611766 if (OutputVectorSize != MatrixM) {
1176711767 Expr *OutputVector = CE->getArg(kMatVecMulOutputVectorIdx);
1176811768 S.Diags.Report(
@@ -11772,6 +11772,10 @@ static bool IsValidVectorAndMatrixDimensions(Sema &S, CallExpr *CE,
1177211772 return false;
1177311773 }
1177411774
11775+ // Check if input vector size equals to matrix dimension K in the unpacked
11776+ // case.
11777+ // Check if input vector size equals the smallest number that can hold
11778+ // matrix dimension K values
1177511779 const unsigned PackingFactor = isInputPacked ? 4 : 1;
1177611780 unsigned MinInputVectorSize = (MatrixK + PackingFactor - 1) / PackingFactor;
1177711781 if (InputVectorSize != MinInputVectorSize) {
@@ -11808,8 +11812,9 @@ static void CheckCommonMulandMulAddParameters(Sema &S, CallExpr *CE,
1180811812 << "IsOutputUnsigned";
1180911813 return;
1181011814 }
11811- // Check if IsOutputUnsigned flag matches output vector type.
11812- // Must be true for unsigned int outputs, false for signed int/float outputs.
11815+
11816+ // Check if output vector is unsigned int, signed int or float
11817+ // Check if the isUnsigned flag is set correctly
1181311818 Expr *OutputVector = CE->getArg(kMatVecMulOutputVectorIdx);
1181411819 unsigned OutputVectorSizeValue = 0;
1181511820 if (IsHLSLVecType(OutputVector->getType())) {
@@ -11824,7 +11829,9 @@ static void CheckCommonMulandMulAddParameters(Sema &S, CallExpr *CE,
1182411829 << "Output Vector";
1182511830 return;
1182611831 }
11827-
11832+ // Check if IsOutputUnsigned flag matches output vector type.
11833+ // Must be true for unsigned int outputs, false for signed int/float
11834+ // outputs.
1182811835 if (IsOutputUnsignedFlagValue &&
1182911836 !OutputVectorTypePtr->isUnsignedIntegerType()) {
1183011837 DXASSERT_NOMSG(OutputVectorTypePtr->isSignedIntegerType() ||
@@ -11844,7 +11851,7 @@ static void CheckCommonMulandMulAddParameters(Sema &S, CallExpr *CE,
1184411851 }
1184511852 }
1184611853
11847- // Find InputVectorType and Size and check IsUnsigned
11854+ // Check if isInputUnsigned parameter is a constant
1184811855 bool IsInputUnsignedFlagValue = false;
1184911856 Expr *IsInputUnsignedExpr = CE->getArg(kMatVecMulIsInputUnsignedIdx);
1185011857 llvm::APSInt IsInputUnsignedExprVal;
@@ -11858,6 +11865,7 @@ static void CheckCommonMulandMulAddParameters(Sema &S, CallExpr *CE,
1185811865 return;
1185911866 }
1186011867
11868+ // Check if input vector32/16bit is unsigned int, signed int or float
1186111869 Expr *InputVector = CE->getArg(kMatVecMulInputVectorIdx);
1186211870 unsigned InputVectorSizeValue = 0;
1186311871 if (IsHLSLVecType(InputVector->getType())) {
@@ -11884,6 +11892,7 @@ static void CheckCommonMulandMulAddParameters(Sema &S, CallExpr *CE,
1188411892 return;
1188511893 }
1188611894
11895+ // Check if the isUnsigned flag is set correctly
1188711896 if (IsInputUnsignedFlagValue &&
1188811897 !InputVectorTypePtr->isUnsignedIntegerType()) {
1188911898 DXASSERT_NOMSG(InputVectorTypePtr->isSignedIntegerType() ||
@@ -11996,6 +12005,7 @@ static void CheckCommonMulandMulAddParameters(Sema &S, CallExpr *CE,
1199612005 }
1199712006
1199812007 // Get MatrixInterpretation, check if it is constant
12008+ // Make sure it is a valid value
1199912009 Expr *MatrixInterpretationExpr =
1200012010 CE->getArg(kMatVecMulMatrixInterpretationIdx);
1200112011 llvm::APSInt MatrixInterpretationExprVal;
@@ -12019,7 +12029,7 @@ static void CheckCommonMulandMulAddParameters(Sema &S, CallExpr *CE,
1201912029 return;
1202012030 }
1202112031
12022- // Get MatrixLayout, check if it is constant
12032+ // Get MatrixLayout, check if it is constant and valid value
1202312033 Expr *MatrixLayoutExpr = CE->getArg(kMatVecMulMatrixLayoutIdx);
1202412034 llvm::APSInt MatrixLayoutExprVal;
1202512035 unsigned MatrixLayoutValue = 0;
0 commit comments