@@ -26,7 +26,7 @@ use crate::{
2626 } ,
2727 dx12:: {
2828 borrow_optional_interface_temporarily, shader_compilation, suballocation,
29- DynamicStorageBufferOffsets , Event ,
29+ DynamicStorageBufferOffsets , Event , ShaderCacheKey , ShaderCacheValue ,
3030 } ,
3131 AccelerationStructureEntries , TlasInstance ,
3232} ;
@@ -203,6 +203,7 @@ impl super::Device {
203203 null_rtv_handle,
204204 mem_allocator,
205205 compiler_container,
206+ shader_cache : Default :: default ( ) ,
206207 counters : Default :: default ( ) ,
207208 } )
208209 }
@@ -304,50 +305,85 @@ impl super::Device {
304305 } ;
305306
306307 //TODO: reuse the writer
307- let mut source = String :: new ( ) ;
308- let mut writer = hlsl:: Writer :: new ( & mut source, naga_options, & pipeline_options) ;
309- let reflection_info = {
308+ let ( source, entry_point) = {
309+ let mut source = String :: new ( ) ;
310+ let mut writer = hlsl:: Writer :: new ( & mut source, naga_options, & pipeline_options) ;
311+
310312 profiling:: scope!( "naga::back::hlsl::write" ) ;
311- writer
313+ let mut reflection_info = writer
312314 . write ( & module, & info, frag_ep. as_ref ( ) )
313- . map_err ( |e| crate :: PipelineError :: Linkage ( stage_bit, format ! ( "HLSL: {e:?}" ) ) ) ?
315+ . map_err ( |e| crate :: PipelineError :: Linkage ( stage_bit, format ! ( "HLSL: {e:?}" ) ) ) ?;
316+
317+ assert_eq ! ( reflection_info. entry_point_names. len( ) , 1 ) ;
318+
319+ let entry_point = reflection_info
320+ . entry_point_names
321+ . pop ( )
322+ . unwrap ( )
323+ . map_err ( |e| crate :: PipelineError :: Linkage ( stage_bit, format ! ( "{e}" ) ) ) ?;
324+
325+ ( source, entry_point)
314326 } ;
315327
328+ log:: info!(
329+ "Naga generated shader for {:?} at {:?}:\n {}" ,
330+ entry_point,
331+ naga_stage,
332+ source
333+ ) ;
334+
335+ let key = ShaderCacheKey {
336+ source,
337+ entry_point,
338+ stage : naga_stage,
339+ shader_model : naga_options. shader_model ,
340+ } ;
341+
342+ {
343+ let mut shader_cache = self . shader_cache . lock ( ) ;
344+ let nr_of_shaders_compiled = shader_cache. nr_of_shaders_compiled ;
345+ if let Some ( value) = shader_cache. entries . get_mut ( & key) {
346+ value. last_used = nr_of_shaders_compiled;
347+ return Ok ( value. shader . clone ( ) ) ;
348+ }
349+ }
350+
351+ let source_name = stage. module . raw_name . as_deref ( ) ;
352+
316353 let full_stage = format ! (
317354 "{}_{}" ,
318355 naga_stage. to_hlsl_str( ) ,
319356 naga_options. shader_model. to_str( )
320357 ) ;
321358
322- let raw_ep = reflection_info. entry_point_names [ 0 ]
323- . as_ref ( )
324- . map_err ( |e| crate :: PipelineError :: Linkage ( stage_bit, format ! ( "{e}" ) ) ) ?;
325-
326- let source_name = stage. module . raw_name . as_deref ( ) ;
327-
328- let result = self . compiler_container . compile (
359+ let compiled_shader = self . compiler_container . compile (
329360 self ,
330- & source,
361+ & key . source ,
331362 source_name,
332- raw_ep ,
363+ & key . entry_point ,
333364 stage_bit,
334365 & full_stage,
335- ) ;
366+ ) ? ;
336367
337- let log_level = if result. is_ok ( ) {
338- log:: Level :: Info
339- } else {
340- log:: Level :: Error
341- } ;
368+ {
369+ let mut shader_cache = self . shader_cache . lock ( ) ;
370+ shader_cache. nr_of_shaders_compiled += 1 ;
371+ let nr_of_shaders_compiled = shader_cache. nr_of_shaders_compiled ;
372+ let value = ShaderCacheValue {
373+ last_used : nr_of_shaders_compiled,
374+ shader : compiled_shader. clone ( ) ,
375+ } ;
376+ shader_cache. entries . insert ( key, value) ;
342377
343- log:: log!(
344- log_level,
345- "Naga generated shader for {:?} at {:?}:\n {}" ,
346- raw_ep,
347- naga_stage,
348- source
349- ) ;
350- result
378+ // Retain all entries that have been used since we compiled the last 100 shaders.
379+ if shader_cache. entries . len ( ) > 200 {
380+ shader_cache
381+ . entries
382+ . retain ( |_, v| v. last_used >= nr_of_shaders_compiled - 100 ) ;
383+ }
384+ }
385+
386+ Ok ( compiled_shader)
351387 }
352388
353389 pub fn raw_device ( & self ) -> & Direct3D12 :: ID3D12Device {
@@ -1818,11 +1854,6 @@ impl crate::Device for super::Device {
18181854 }
18191855 . map_err ( |err| crate :: PipelineError :: Linkage ( shader_stages, err. to_string ( ) ) ) ?;
18201856
1821- unsafe { blob_vs. destroy ( ) } ;
1822- if let Some ( blob_fs) = blob_fs {
1823- unsafe { blob_fs. destroy ( ) } ;
1824- } ;
1825-
18261857 if let Some ( label) = desc. label {
18271858 raw. set_name ( label) ?;
18281859 }
@@ -1880,8 +1911,6 @@ impl crate::Device for super::Device {
18801911 }
18811912 } ;
18821913
1883- unsafe { blob_cs. destroy ( ) } ;
1884-
18851914 let raw: Direct3D12 :: ID3D12PipelineState = pair. map_err ( |err| {
18861915 crate :: PipelineError :: Linkage ( wgt:: ShaderStages :: COMPUTE , err. to_string ( ) )
18871916 } ) ?;
0 commit comments