@@ -195,6 +195,8 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
195195 if options .shape == Shape .PTR_TEMPLATE :
196196 templates .append ("typename P" )
197197 conds .append ("is_spirv_type_v<P>" )
198+ elif options .shape == Shape .BDA :
199+ caps .append ("PhysicalStorageBufferAddresses" )
198200
199201 # split upper case words
200202 matches = [(m .group (1 ), m .span (1 )) for m in re .finditer (r'([A-Z])[A-Z][a-z]' , fn_name )]
@@ -242,74 +244,74 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
242244 case "float16_t" : overload_caps .append ("Float16" )
243245 case "float64_t" : overload_caps .append ("Float64" )
244246
245- op_ty = "T"
246- if options . op_ty != None :
247- op_ty = options . op_ty
248- elif rt != "void" :
249- op_ty = rt
250-
251- if ( not "typename T" in templates ) and ( rt == "T" ):
252- templates = [ "typename T" ] + templates
253-
254- args = []
255- for operand in operands :
256- operand_name = operand [ "name" ]. strip ( "'" ) if "name" in operand else None
257- operand_name = operand_name [ 0 ]. lower () + operand_name [ 1 :] if ( operand_name != None ) else ""
258- match operand [ "kind" ]:
259- case "IdRef" :
260- match operand [ "name" ]:
261- case "'Pointer'" :
262- if options . shape == Shape . PTR_TEMPLATE :
263- args . append ( "P " + operand_name )
264- elif options .shape == Shape . BDA :
265- if ( not "typename T" in templates ) and ( rt == "T" or op_ty == "T" ):
266- templates = [ "typename T" ] + templates
267- overload_caps . append ( "PhysicalStorageBufferAddresses" )
268- args . append ( "pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name )
269- else :
270- if ( not "typename T" in templates ) and ( rt == "T" or op_ty == "T" ) :
271- templates = [ "typename T" ] + templates
272- args . append ( "[[vk::ext_reference]] " + op_ty + " " + operand_name )
273- case "'Value'" | "'Object'" | "'Comparator'" | "'Base'" | "'Insert'" :
274- if ( not "typename T" in templates ) and ( rt == "T" or op_ty == "T" ) :
275- templates = [ "typename T" ] + templates
276- args . append ( op_ty + " " + operand_name )
277- case "'Offset'" | "'Count'" | "'Id'" | "'Index'" | "'Mask'" | "'Delta'" :
278- args .append ("uint32_t " + operand_name )
279- case "'Predicate'" : args . append ( "bool " + operand_name )
280- case "'ClusterSize'" :
281- if "quantifier" in operand and operand [ "quantifier" ] == "?" : continue # TODO: overload
282- else : return # TODO
283- case _: return # TODO
284- case "IdScope" : args . append ( "uint32_t " + operand_name . lower () + "Scope" )
285- case "IdMemorySemantics" : args . append ( " uint32_t " + operand_name )
286- case "GroupOperation" : args .append ("[[vk::ext_literal ]] uint32_t " + operand_name )
287- case "MemoryAccess " :
288- if options . shape != Shape . BDA :
289- writeInst ( writer , templates , overload_caps , op_name , fn_name , conds , rt , args + [ "[[vk::ext_literal]] uint32_t memoryAccess" ])
290- writeInst ( writer , templates , overload_caps , op_name , fn_name , conds , rt , args + [ "[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam" ] )
291- writeInst ( writer , templates + [ "uint32_t alignment" ], overload_caps , op_name , fn_name , conds , rt , args + [ "[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002" , "[[vk::ext_literal]] uint32_t __alignment = alignment" ])
292- case _: return # TODO
293-
294- writeInst ( writer , templates , overload_caps , op_name , fn_name , conds , rt , args )
295-
296-
297- def writeInst ( writer : io . TextIOWrapper , templates , caps , op_name , fn_name , conds , result_type , args ):
298- if len ( caps ) > 0 :
299- for cap in caps :
300- if (( "Float16" in cap and result_type != "float16_t" ) or
301- ( "Float32" in cap and result_type != "float32_t" ) or
302- ( "Float64" in cap and result_type != "float64_t" ) or
303- ( "Int16" in cap and result_type != "int16_t" and result_type != "uint16_t" ) or
304- ( "Int64" in cap and result_type != "int64_t" and result_type != "uint64_t" )): continue
305-
306- final_fn_name = fn_name
307- if ( len ( caps ) > 1 ): final_fn_name = fn_name + "_" + cap
308- writeInstInner ( writer , templates , cap , op_name , final_fn_name , conds , result_type , args )
309- else :
310- writeInstInner ( writer , templates , None , op_name , fn_name , conds , result_type , args )
311-
312- def writeInstInner (writer : io .TextIOWrapper , templates , cap , op_name , fn_name , conds , result_type , args ):
247+ for cap in overload_caps or [ None ]:
248+ final_fn_name = fn_name + "_" + cap if ( len ( overload_caps ) > 1 ) else fn_name
249+ final_templates = templates . copy ()
250+
251+ if ( not "typename T" in final_templates ) and ( rt == "T" ):
252+ final_templates = [ "typename T" ] + final_templates
253+
254+ if len ( overload_caps ) > 0 :
255+ if (( "Float16" in cap and rt != "float16_t" ) or
256+ ( "Float32" in cap and rt != "float32_t" ) or
257+ ( "Float64" in cap and rt != "float64_t" ) or
258+ ( "Int16" in cap and rt != "int16_t" and rt != "uint16_t" ) or
259+ ( "Int64" in cap and rt != "int64_t" and rt != "uint64_t" )): continue
260+
261+ if "Vector" in cap :
262+ rt = "vector<" + rt + ", N> "
263+ final_templates . append ( "typename N" )
264+
265+ op_ty = "T"
266+ if options .op_ty != None :
267+ op_ty = options . op_ty
268+ elif rt != "void" :
269+ op_ty = rt
270+
271+ args = []
272+ for operand in operands :
273+ operand_name = operand [ "name" ]. strip ( "'" ) if "name" in operand else None
274+ operand_name = operand_name [ 0 ]. lower () + operand_name [ 1 :] if ( operand_name != None ) else ""
275+ match operand [ "kind" ] :
276+ case "IdRef" :
277+ match operand [ "name" ]:
278+ case "'Pointer'" :
279+ if options . shape == Shape . PTR_TEMPLATE :
280+ args .append ("P " + operand_name )
281+ elif options . shape == Shape . BDA :
282+ if ( not "typename T" in final_templates ) and ( rt == "T" or op_ty == "T" ) :
283+ final_templates = [ "typename T" ] + final_templates
284+ args . append ( "pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name )
285+ else :
286+ if ( not "typename T" in final_templates ) and ( rt == "T" or op_ty == "T" ):
287+ final_templates = [ "typename T" ] + final_templates
288+ args .append ("[[vk::ext_reference ]] " + op_ty + " " + operand_name )
289+ case "'Value'" | "'Object'" | "'Comparator'" | "'Base'" | "'Insert' " :
290+ if ( not "typename T" in final_templates ) and ( rt == "T" or op_ty == "T" ) :
291+ final_templates = [ "typename T" ] + final_templates
292+ args . append ( op_ty + " " + operand_name )
293+ case "'Offset'" | "'Count'" | "'Id'" | "'Index'" | "'Mask'" | "'Delta'" :
294+ args . append ( "uint32_t " + operand_name )
295+ case "'Predicate'" : args . append ( "bool " + operand_name )
296+ case "'ClusterSize'" :
297+ if "quantifier" in operand and operand [ "quantifier" ] == "?" : continue # TODO: overload
298+ else : return # TODO
299+ case _: return # TODO
300+ case "IdScope" : args . append ( "uint32_t " + operand_name . lower () + "Scope" )
301+ case "IdMemorySemantics" : args . append ( " uint32_t " + operand_name )
302+ case "GroupOperation" : args . append ( "[[vk::ext_literal]] uint32_t " + operand_name )
303+ case "MemoryAccess" :
304+ assert len ( overload_caps ) <= 1
305+ if options . shape != Shape . BDA :
306+ writeInst ( writer , final_templates , cap , op_name , final_fn_name , conds , rt , args + [ "[[vk::ext_literal]] uint32_t memoryAccess" ])
307+ writeInst ( writer , final_templates , cap , op_name , final_fn_name , conds , rt , args + [ "[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam" ])
308+ writeInst ( writer , final_templates + [ "uint32_t alignment" ], cap , op_name , final_fn_name , conds , rt , args + [ "[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002" , "[[vk::ext_literal]] uint32_t __alignment = alignment" ])
309+ case _: return # TODO
310+
311+ writeInst ( writer , final_templates , cap , op_name , final_fn_name , conds , rt , args )
312+
313+
314+ def writeInst (writer : io .TextIOWrapper , templates , cap , op_name , fn_name , conds , result_type , args ):
313315 if len (templates ) > 0 :
314316 writer .write ("template<" + ", " .join (templates ) + ">\n " )
315317 if (cap != None ):
0 commit comments