@@ -109,8 +109,6 @@ static std::string buildCompilerArgs(const MatrixParams &Params,
109109 SS << " -DELEM_TYPE=uint" ;
110110 break ;
111111 }
112- if (Params.EmulateTest )
113- SS << " -DEMULATE_TEST" ;
114112 if (Params.Enable16Bit )
115113 SS << " -enable-16bit-types" ;
116114 if (ExtraDefines)
@@ -283,12 +281,6 @@ static VariantCompType makeExpected(ComponentType CompType, MatrixDim M,
283281 }
284282}
285283
286- static void logCompiledButSkipping () {
287- hlsl_test::LogCommentFmt (
288- L" Shader compiled OK; skipping execution (no SM 6.10 device)" );
289- WEX::Logging::Log::Result (WEX::Logging::TestResults::Skipped);
290- }
291-
292284class DxilConf_SM610_LinAlg {
293285public:
294286 BEGIN_TEST_CLASS (DxilConf_SM610_LinAlg)
@@ -317,36 +309,16 @@ class DxilConf_SM610_LinAlg {
317309 TEST_METHOD (ElementAccess_Wave_16x16_F16);
318310
319311private:
320- D3D_SHADER_MODEL createDevice ();
321-
322312 CComPtr<ID3D12Device> D3DDevice;
323313 dxc::SpecificDllLoader DxcSupport;
324314 bool VerboseLogging = false ;
325- bool EmulateTest = false ;
326315 bool Initialized = false ;
327- bool CompileOnly = false ;
328316 std::optional<D3D12SDKSelector> D3D12SDK;
329317
330318 WEX::TestExecution::SetVerifyOutput VerifyOutput{
331319 WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures};
332320};
333321
334- // / Attempts to create a device. If shaders are being emulated then a SM6.8
335- // / device is attempted. Otherwise a SM6.10 device is attempted
336- D3D_SHADER_MODEL DxilConf_SM610_LinAlg::createDevice () {
337- if (EmulateTest) {
338- if (D3D12SDK->createDevice (&D3DDevice, D3D_SHADER_MODEL_6_8, false ))
339- return D3D_SHADER_MODEL_6_8;
340-
341- return D3D_SHADER_MODEL_NONE;
342- }
343-
344- if (D3D12SDK->createDevice (&D3DDevice, D3D_SHADER_MODEL_6_10, false ))
345- return D3D_SHADER_MODEL_6_10;
346-
347- return D3D_SHADER_MODEL_NONE;
348- }
349-
350322bool DxilConf_SM610_LinAlg::setupClass () {
351323 if (!Initialized) {
352324 Initialized = true ;
@@ -355,28 +327,18 @@ bool DxilConf_SM610_LinAlg::setupClass() {
355327 D3D12SDK = D3D12SDKSelector ();
356328 WEX::TestExecution::RuntimeParameters::TryGetValue (L" VerboseLogging" ,
357329 VerboseLogging);
358- WEX::TestExecution::RuntimeParameters::TryGetValue (L" EmulateTest" ,
359- EmulateTest);
360- D3D_SHADER_MODEL SupportedSM = createDevice ();
361-
362- if (EmulateTest) {
363- hlsl_test::LogWarningFmt (L" EmulateTest flag set. Tests are NOT REAL" );
364- if (SupportedSM != D3D_SHADER_MODEL_6_8) {
365- hlsl_test::LogErrorFmt (
366- L" Device creation failed. Expected a driver supporting SM6.8" );
367- return false ;
368- }
369- }
370330
331+ if (!D3D12SDK->createDevice (&D3DDevice, D3D_SHADER_MODEL_6_10, false )) {
371332#ifdef _HLK_CONF
372- if (SupportedSM != D3D_SHADER_MODEL_6_10) {
373333 hlsl_test::LogErrorFmt (
374334 L" Device creation failed. Expected a driver supporting SM6.10" );
335+ #else
336+ hlsl_test::LogWarningFmt (
337+ L" Device creation failed. Expected a driver supporting SM6.10" );
338+ WEX::Logging::Log::Result (WEX::Logging::TestResults::Skipped);
339+ #endif
375340 return false ;
376341 }
377- #endif
378-
379- CompileOnly = SupportedSM == D3D_SHADER_MODEL_NONE;
380342 }
381343
382344 return true ;
@@ -388,27 +350,24 @@ bool DxilConf_SM610_LinAlg::setupMethod() {
388350 if (D3DDevice && D3DDevice->GetDeviceRemovedReason () == S_OK)
389351 return true ;
390352
391- // Device is expected to be null. No point in recreating it
392- if (CompileOnly)
393- return true ;
394-
395353 hlsl_test::LogCommentFmt (L" Device was lost!" );
396354 D3DDevice.Release ();
397355
398356 hlsl_test::LogCommentFmt (L" Recreating device" );
399357
400- // !CompileOnly implies we expect it to succeeded
401- return createDevice () != D3D_SHADER_MODEL_NONE;
358+ return D3D12SDK->createDevice (&D3DDevice, D3D_SHADER_MODEL_6_10, false );
402359}
403360
404361static const char LoadStoreShader[] = R"(
405362 RWByteAddressBuffer Input : register(u0);
406363 RWByteAddressBuffer Output : register(u1);
407364
408- #ifndef EMULATE_TEST
409365 [WaveSize(4, 64)]
410366 [numthreads(NUMTHREADS, 1, 1)]
411- void main() {
367+ void main(uint threadID : SV_GroupIndex) {
368+ if (WaveReadLaneFirst(threadID) != 0)
369+ return;
370+
412371 __builtin_LinAlgMatrix
413372 [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
414373 Mat;
@@ -417,45 +376,26 @@ static const char LoadStoreShader[] = R"(
417376 __builtin_LinAlg_MatrixStoreToDescriptor(
418377 Mat, Output, OFFSET, STRIDE, LAYOUT, 128);
419378 }
420- #else
421- [numthreads(NUMTHREADS, 1, 1)]
422- void main() {
423- for (uint I = 0; I < M_DIM*N_DIM; ++I) {
424- Output.Store<ELEM_TYPE>(I*ELEM_SIZE, Input.Load<ELEM_TYPE>(I*ELEM_SIZE));
425- }
426- }
427- #endif
428379)" ;
429380
430381static void runLoadStoreRoundtrip (ID3D12Device *Device,
431382 dxc::SpecificDllLoader &DxcSupport,
432- const MatrixParams &Params, bool Verbose,
433- bool CompileOnly) {
383+ const MatrixParams &Params, bool Verbose) {
434384 const size_t NumElements = Params.totalElements ();
435385 const size_t BufferSize = Params.totalBytes ();
436386
437- std::string Target = " cs_6_10" ;
438- if (Params.EmulateTest )
439- Target = " cs_6_8" ;
440-
441387 // TODO: these should be varied by test to ensure full coverage
442388 std::stringstream ExtraDefs;
443389 ExtraDefs << " -DOFFSET=" << 0 ;
444390
445391 std::string Args = buildCompilerArgs (Params, ExtraDefs.str ().c_str ());
446392
447- // Always verify the shader compiles.
448- compileShader (DxcSupport, LoadStoreShader, Target.c_str (), Args, Verbose);
449-
450- if (CompileOnly) {
451- logCompiledButSkipping ();
452- return ;
453- }
393+ compileShader (DxcSupport, LoadStoreShader, " cs_6_10" , Args, Verbose);
454394
455395 auto Expected = makeExpected (Params.CompType , Params.M , Params.N , 1 );
456396
457397 // Construct the ShaderOp: two UAV buffers, load from one, store to other.
458- auto Op = createComputeOp (LoadStoreShader, Target. c_str () , " UAV(u0), UAV(u1)" ,
398+ auto Op = createComputeOp (LoadStoreShader, " cs_6_10 " , " UAV(u0), UAV(u1)" ,
459399 Args.c_str ());
460400 addUAVBuffer (Op.get (), " Input" , BufferSize, false , " byname" );
461401 addUAVBuffer (Op.get (), " Output" , BufferSize, true );
@@ -488,64 +428,46 @@ void DxilConf_SM610_LinAlg::LoadStoreRoundtrip_Wave_16x16_F16() {
488428 Params.Layout = LinalgMatrixLayout::RowMajor;
489429 Params.NumThreads = 64 ;
490430 Params.Enable16Bit = true ;
491- Params.EmulateTest = EmulateTest;
492- runLoadStoreRoundtrip (D3DDevice, DxcSupport, Params, VerboseLogging,
493- CompileOnly);
431+ runLoadStoreRoundtrip (D3DDevice, DxcSupport, Params, VerboseLogging);
494432}
495433
496434static const char SplatStoreShader[] = R"(
497435 RWByteAddressBuffer Output : register(u0);
498436
499- #ifndef EMULATE_TEST
500437 [WaveSize(4, 64)]
501438 [numthreads(NUMTHREADS, 1, 1)]
502- void main() {
439+ void main(uint threadID : SV_GroupIndex) {
440+ if (WaveReadLaneFirst(threadID) != 0)
441+ return;
442+
503443 __builtin_LinAlgMatrix
504444 [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
505445 Mat;
506446 __builtin_LinAlg_FillMatrix(Mat, FILL_VALUE);
507447 __builtin_LinAlg_MatrixStoreToDescriptor(
508448 Mat, Output, 0, STRIDE, LAYOUT, 128);
509449 }
510- #else
511- [numthreads(NUMTHREADS, 1, 1)]
512- void main() {
513- ELEM_TYPE fill = FILL_VALUE;
514- for (uint I = 0; I < M_DIM*N_DIM; ++I) {
515- Output.Store<ELEM_TYPE>(I*ELEM_SIZE, fill);
516- }
517- }
518- #endif
519450)" ;
520451
521452static void runSplatStore (ID3D12Device *Device,
522453 dxc::SpecificDllLoader &DxcSupport,
523454 const MatrixParams &Params, float FillValue,
524- bool Verbose, bool CompileOnly ) {
455+ bool Verbose) {
525456 const size_t NumElements = Params.totalElements ();
526457 const size_t BufferSize = Params.totalBytes ();
527- std::string Target = " cs_6_10" ;
528- if (Params.EmulateTest )
529- Target = " cs_6_8" ;
530458
531459 std::stringstream ExtraDefs;
532460 ExtraDefs << " -DFILL_VALUE=" << FillValue;
533461
534462 std::string Args = buildCompilerArgs (Params, ExtraDefs.str ().c_str ());
535463
536- // Always verify the shader compiles.
537- compileShader (DxcSupport, SplatStoreShader, Target.c_str (), Args, Verbose);
538-
539- if (CompileOnly) {
540- logCompiledButSkipping ();
541- return ;
542- }
464+ compileShader (DxcSupport, SplatStoreShader, " cs_6_10" , Args, Verbose);
543465
544466 auto Expected =
545467 makeExpected (Params.CompType , Params.M , Params.N , FillValue, false );
546468
547- auto Op = createComputeOp (SplatStoreShader, Target. c_str (), " UAV(u0) " ,
548- Args.c_str ());
469+ auto Op =
470+ createComputeOp (SplatStoreShader, " cs_6_10 " , " UAV(u0) " , Args.c_str ());
549471 addUAVBuffer (Op.get (), " Output" , BufferSize, true );
550472 addRootUAV (Op.get (), 0 , " Output" );
551473
@@ -568,9 +490,7 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() {
568490 Params.Layout = LinalgMatrixLayout::RowMajor;
569491 Params.NumThreads = 64 ;
570492 Params.Enable16Bit = true ;
571- Params.EmulateTest = EmulateTest;
572- runSplatStore (D3DDevice, DxcSupport, Params, 42 .0f , VerboseLogging,
573- CompileOnly);
493+ runSplatStore (D3DDevice, DxcSupport, Params, 42 .0f , VerboseLogging);
574494}
575495
576496static const char ElementAccessShader[] = R"(
@@ -583,10 +503,12 @@ static const char ElementAccessShader[] = R"(
583503 return (coord.y * N_DIM + coord.x) * ELEM_SIZE;
584504 }
585505
586- #ifndef EMULATE_TEST
587506 [WaveSize(4, 64)]
588507 [numthreads(NUMTHREADS, 1, 1)]
589- void main(uint threadIndex : SV_GroupIndex) {
508+ void main(uint threadID : SV_GroupIndex) {
509+ if (WaveReadLaneFirst(threadID) != 0)
510+ return;
511+
590512 __builtin_LinAlgMatrix
591513 [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
592514 Mat;
@@ -604,30 +526,15 @@ static const char ElementAccessShader[] = R"(
604526
605527 // Save the matrix length that this thread saw. The length is written
606528 // to the output right after the matrix, offset by the thread index
607- uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadIndex * sizeof(uint));
529+ uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadID * sizeof(uint));
608530 uint Len = __builtin_LinAlg_MatrixLength(Mat);
609531 Output.Store<uint>(LenIdx, Len);
610532 }
611- #else
612- [numthreads(NUMTHREADS, 1, 1)]
613- void main(uint threadIndex : SV_GroupIndex) {
614- uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadIndex * sizeof(uint));
615- Output.Store<uint>(LenIdx, M_DIM * N_DIM / NUMTHREADS);
616-
617- if (threadIndex != 0)
618- return;
619-
620- for (uint I = 0; I < M_DIM*N_DIM; ++I) {
621- Output.Store<ELEM_TYPE>(I*ELEM_SIZE, Input.Load<ELEM_TYPE>(I*ELEM_SIZE));
622- }
623- }
624- #endif
625533)" ;
626534
627535static void runElementAccess (ID3D12Device *Device,
628536 dxc::SpecificDllLoader &DxcSupport,
629- const MatrixParams &Params, bool Verbose,
630- bool CompileOnly) {
537+ const MatrixParams &Params, bool Verbose) {
631538 const size_t NumElements = Params.totalElements ();
632539 const size_t NumThreads = Params.NumThreads ;
633540 const size_t InputBufSize = Params.totalBytes ();
@@ -639,24 +546,15 @@ static void runElementAccess(ID3D12Device *Device,
639546 const size_t OutputBufSize =
640547 NumElements * ElementSize + NumThreads * sizeof (uint32_t );
641548
642- std::string Target = " cs_6_10" ;
643- if (Params.EmulateTest )
644- Target = " cs_6_8" ;
645-
646549 std::stringstream ExtraDefs;
647550 std::string Args = buildCompilerArgs (Params, ExtraDefs.str ().c_str ());
648551
649- compileShader (DxcSupport, ElementAccessShader, Target.c_str (), Args, Verbose);
650-
651- if (CompileOnly) {
652- logCompiledButSkipping ();
653- return ;
654- }
552+ compileShader (DxcSupport, ElementAccessShader, " cs_6_10" , Args, Verbose);
655553
656554 auto Expected = makeExpected (Params.CompType , Params.M , Params.N , 1 );
657555
658- auto Op = createComputeOp (ElementAccessShader, Target. c_str () ,
659- " UAV(u0), UAV(u1) " , Args.c_str ());
556+ auto Op = createComputeOp (ElementAccessShader, " cs_6_10 " , " UAV(u0), UAV(u1) " ,
557+ Args.c_str ());
660558 addUAVBuffer (Op.get (), " Input" , InputBufSize, false , " byname" );
661559 addUAVBuffer (Op.get (), " Output" , OutputBufSize, true );
662560 addRootUAV (Op.get (), 0 , " Input" );
@@ -701,8 +599,7 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_16x16_F16() {
701599 Params.Layout = LinalgMatrixLayout::RowMajor;
702600 Params.NumThreads = 64 ;
703601 Params.Enable16Bit = true ;
704- Params.EmulateTest = EmulateTest;
705- runElementAccess (D3DDevice, DxcSupport, Params, VerboseLogging, CompileOnly);
602+ runElementAccess (D3DDevice, DxcSupport, Params, VerboseLogging);
706603}
707604
708605} // namespace LinAlg
0 commit comments