@@ -67,10 +67,10 @@ struct LinAlgHeaderIncludeHandler : public IDxcIncludeHandler {
6767namespace CoopVecHelpers {
6868
6969template <typename EltTy>
70- static std::vector<uint8_t > CreateAllOnesInputMatrix (uint32_t Width,
71- uint32_t Height) {
70+ static std::vector<uint8_t > CreateAllOnesInputMatrix (size_t Width,
71+ size_t Height) {
7272 std::vector<EltTy> InputMatrix (Width * Height);
73- for (uint32_t i = 0 ; i < Width * Height; i++) {
73+ for (size_t i = 0 ; i < Width * Height; i++) {
7474 if constexpr (std::is_same_v<EltTy, uint8_t > ||
7575 std::is_same_v<EltTy, int8_t >) {
7676 InputMatrix[i] = 1 ;
@@ -92,15 +92,15 @@ static std::vector<uint8_t> CreateAllOnesInputMatrix(uint32_t Width,
9292}
9393
9494template <typename EltTy>
95- static std::vector<uint8_t > CreateInputVector (uint32_t NumThreads,
96- uint32_t EltsPerThread) {
95+ static std::vector<uint8_t > CreateInputVector (size_t NumThreads,
96+ size_t EltsPerThread) {
9797 std::vector<EltTy> InputVector (NumThreads * EltsPerThread);
9898 std::fill (InputVector.begin (), InputVector.end (), EltTy (0 ));
9999 if (EltsPerThread < 2 ) {
100100 WEX::Logging::Log::Error (L" EltsPerThread must be at least 2" );
101101 return std::vector<uint8_t >();
102102 }
103- for (uint32_t TID = 0 ; TID < NumThreads; TID++) {
103+ for (size_t TID = 0 ; TID < NumThreads; TID++) {
104104 if constexpr (std::is_same_v<EltTy, uint8_t > ||
105105 std::is_same_v<EltTy, int8_t >) {
106106 InputVector[TID * EltsPerThread + 0 ] = 1 ;
@@ -125,7 +125,7 @@ static std::vector<uint8_t> CreateInputVector(uint32_t NumThreads,
125125}
126126
127127template <typename EltTy>
128- static std::vector<uint8_t > CreateInputBias (uint32_t NumElts) {
128+ static std::vector<uint8_t > CreateInputBias (size_t NumElts) {
129129 std::vector<EltTy> InputBias (NumElts);
130130 if constexpr (std::is_same_v<EltTy, uint8_t > ||
131131 std::is_same_v<EltTy, int8_t >) {
@@ -248,7 +248,7 @@ static std::wstring MatrixLayoutToHlslLayoutString(
248248
249249// This multiplier is used to compute the row/column stride for a matrix
250250// given it's element size.
251- static int
251+ static size_t
252252GetStrideMultiplierForMatrixDataType (D3D12_LINEAR_ALGEBRA_DATATYPE DataType) {
253253 switch (DataType) {
254254 case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED:
@@ -271,7 +271,7 @@ GetStrideMultiplierForMatrixDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) {
271271 }
272272}
273273
274- static int GetNumPackedElementsForInputDataType (
274+ static size_t GetNumPackedElementsForInputDataType (
275275 D3D12_LINEAR_ALGEBRA_DATATYPE InputInterpretation) {
276276 // Int8 packed types are the only ones that have more than 1 element per
277277 // shader variable
@@ -724,7 +724,7 @@ struct TestVector {
724724
725725 // Create destination matrix info
726726 ConvertInfo.DestInfo .DestSize = 0 ; // Will be populated by driver
727- int DestEltSize = 0 ;
727+ UINT DestEltSize = 0 ;
728728 switch (DestDataType) {
729729 case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
730730 case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED:
@@ -798,14 +798,14 @@ struct TestVector {
798798 sizeof (float ));
799799
800800 if (IsMatrixFP32) {
801- for (int VecIdx = 0 ; VecIdx < InputVector.getNumVectors (); ++VecIdx) {
801+ for (size_t VecIdx = 0 ; VecIdx < InputVector.getNumVectors (); ++VecIdx) {
802802 const DirectX::PackedVector::HALF *InputBiasFP16 =
803803 Bias.getVector <DirectX::PackedVector::HALF>(0 );
804- for (int OutputIdx = 0 ; OutputIdx < Matrix.getNumVectors ();
804+ for (size_t OutputIdx = 0 ; OutputIdx < Matrix.getNumVectors ();
805805 ++OutputIdx) {
806806 float Acc = 0 ;
807807
808- for (int InputIdx = 0 ; InputIdx < Matrix.getVectorSize ();
808+ for (size_t InputIdx = 0 ; InputIdx < Matrix.getVectorSize ();
809809 ++InputIdx) {
810810 float InputElem;
811811 if (InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) {
@@ -829,13 +829,13 @@ struct TestVector {
829829 }
830830 }
831831 } else if (MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8) {
832- for (int VecIdx = 0 ; VecIdx < InputVector.getNumVectors (); ++VecIdx) {
832+ for (size_t VecIdx = 0 ; VecIdx < InputVector.getNumVectors (); ++VecIdx) {
833833 const int32_t *InputBiasI32 = Bias.getVector <int32_t >(0 );
834- for (int OutputIdx = 0 ; OutputIdx < Matrix.getNumVectors ();
834+ for (size_t OutputIdx = 0 ; OutputIdx < Matrix.getNumVectors ();
835835 ++OutputIdx) {
836836 int Acc = 0 ;
837837
838- for (int InputIdx = 0 ; InputIdx < Matrix.getVectorSize ();
838+ for (size_t InputIdx = 0 ; InputIdx < Matrix.getVectorSize ();
839839 ++InputIdx) {
840840 int InputElem;
841841 if (InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) {
0 commit comments