@@ -185,6 +185,7 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
185185 op_name = instruction ["opname" ]
186186 fn_name = op_name [2 ].lower () + op_name [3 :]
187187 result_types = []
188+ exts = instruction ["extensions" ] if "extensions" in instruction else []
188189
189190 if "capabilities" in instruction and len (instruction ["capabilities" ]) > 0 :
190191 for cap in instruction ["capabilities" ]:
@@ -223,56 +224,55 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
223224 case "Bit" :
224225 if len (result_types ) == 0 : conds .append ("(is_signed_v<T> || is_unsigned_v<T>)" )
225226
226- if "operands" in instruction :
227- operands = instruction ["operands" ]
228- if operands [0 ]["kind" ] == "IdResultType" :
229- operands = operands [2 :]
230- if len (result_types ) == 0 :
231- if options .result_ty == None :
232- result_types = ["T" ]
233- else :
234- result_types = [options .result_ty ]
235- else :
236- assert len (result_types ) == 0
237- result_types = ["void" ]
238-
239- for rt in result_types :
240- overload_caps = caps .copy ()
241- match rt :
242- case "uint16_t" | "int16_t" : overload_caps .append ("Int16" )
243- case "uint64_t" | "int64_t" : overload_caps .append ("Int64" )
244- case "float16_t" : overload_caps .append ("Float16" )
245- case "float64_t" : overload_caps .append ("Float64" )
246-
247- for cap in overload_caps or [None ]:
248- final_fn_name = fn_name + "_" + cap if (len (overload_caps ) > 1 ) else fn_name
249- final_templates = templates .copy ()
250-
251- if (not "typename T" in final_templates ) and (rt == "T" ):
252- final_templates = ["typename T" ] + final_templates
253-
254- if len (overload_caps ) > 0 :
255- if (("Float16" in cap and rt != "float16_t" ) or
256- ("Float32" in cap and rt != "float32_t" ) or
257- ("Float64" in cap and rt != "float64_t" ) or
258- ("Int16" in cap and rt != "int16_t" and rt != "uint16_t" ) or
259- ("Int64" in cap and rt != "int64_t" and rt != "uint64_t" )): continue
260-
261- if "Vector" in cap :
262- rt = "vector<" + rt + ", N> "
263- final_templates .append ("typename N" )
227+ if "operands" in instruction and instruction ["operands" ][0 ]["kind" ] == "IdResultType" :
228+ if len (result_types ) == 0 :
229+ if options .result_ty == None :
230+ result_types = ["T" ]
231+ else :
232+ result_types = [options .result_ty ]
233+ else :
234+ assert len (result_types ) == 0
235+ result_types = ["void" ]
236+
237+ for rt in result_types :
238+ overload_caps = caps .copy ()
239+ match rt :
240+ case "uint16_t" | "int16_t" : overload_caps .append ("Int16" )
241+ case "uint64_t" | "int64_t" : overload_caps .append ("Int64" )
242+ case "float16_t" : overload_caps .append ("Float16" )
243+ case "float64_t" : overload_caps .append ("Float64" )
244+
245+ for cap in overload_caps or [None ]:
246+ final_fn_name = fn_name + "_" + cap if (len (overload_caps ) > 1 ) else fn_name
247+ final_templates = templates .copy ()
248+
249+ if (not "typename T" in final_templates ) and (rt == "T" ):
250+ final_templates = ["typename T" ] + final_templates
251+
252+ if len (overload_caps ) > 0 :
253+ if (("Float16" in cap and rt != "float16_t" ) or
254+ ("Float32" in cap and rt != "float32_t" ) or
255+ ("Float64" in cap and rt != "float64_t" ) or
256+ ("Int16" in cap and rt != "int16_t" and rt != "uint16_t" ) or
257+ ("Int64" in cap and rt != "int64_t" and rt != "uint64_t" )): continue
264258
265- op_ty = "T"
266- if options .op_ty != None :
267- op_ty = options .op_ty
268- elif rt != "void" :
269- op_ty = rt
270-
271- args = []
272- for operand in operands :
259+ if "Vector" in cap :
260+ rt = "vector<" + rt + ", N> "
261+ final_templates .append ("typename N" )
262+
263+ op_ty = "T"
264+ if options .op_ty != None :
265+ op_ty = options .op_ty
266+ elif rt != "void" :
267+ op_ty = rt
268+
269+ args = []
270+ if "operands" in instruction :
271+ for operand in instruction ["operands" ]:
273272 operand_name = operand ["name" ].strip ("'" ) if "name" in operand else None
274273 operand_name = operand_name [0 ].lower () + operand_name [1 :] if (operand_name != None ) else ""
275274 match operand ["kind" ]:
275+ case "IdResult" | "IdResultType" : continue
276276 case "IdRef" :
277277 match operand ["name" ]:
278278 case "'Pointer'" :
@@ -295,34 +295,38 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
295295 case "'Predicate'" : args .append ("bool " + operand_name )
296296 case "'ClusterSize'" :
297297 if "quantifier" in operand and operand ["quantifier" ] == "?" : continue # TODO: overload
298- else : return # TODO
299- case _: return # TODO
298+ else : return ignore ( op_name ) # TODO
299+ case _: return ignore ( op_name ) # TODO
300300 case "IdScope" : args .append ("uint32_t " + operand_name .lower () + "Scope" )
301301 case "IdMemorySemantics" : args .append (" uint32_t " + operand_name )
302302 case "GroupOperation" : args .append ("[[vk::ext_literal]] uint32_t " + operand_name )
303303 case "MemoryAccess" :
304304 assert len (overload_caps ) <= 1
305305 if options .shape != Shape .BDA :
306- writeInst (writer , final_templates , cap , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t memoryAccess" ])
307- writeInst (writer , final_templates , cap , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam" ])
308- writeInst (writer , final_templates + ["uint32_t alignment" ], cap , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002" , "[[vk::ext_literal]] uint32_t __alignment = alignment" ])
309- case _: return # TODO
306+ writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t memoryAccess" ])
307+ writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam" ])
308+ writeInst (writer , final_templates + ["uint32_t alignment" ], cap , exts , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002" , "[[vk::ext_literal]] uint32_t __alignment = alignment" ])
309+ case _: return ignore ( op_name ) # TODO
310310
311- writeInst (writer , final_templates , cap , op_name , final_fn_name , conds , rt , args )
311+ writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , rt , args )
312312
313313
314- def writeInst (writer : io .TextIOWrapper , templates , cap , op_name , fn_name , conds , result_type , args ):
314+ def writeInst (writer : io .TextIOWrapper , templates , cap , exts , op_name , fn_name , conds , result_type , args ):
315315 if len (templates ) > 0 :
316316 writer .write ("template<" + ", " .join (templates ) + ">\n " )
317- if ( cap != None ) :
317+ if cap != None :
318318 writer .write ("[[vk::ext_capability(spv::Capability" + cap + ")]]\n " )
319+ for ext in exts :
320+ writer .write ("[[vk::ext_extension(\" " + ext + "\" )]]\n " )
319321 writer .write ("[[vk::ext_instruction(spv::" + op_name + ")]]\n " )
320322 if len (conds ) > 0 :
321323 writer .write ("enable_if_t<" + " && " .join (conds ) + ", " + result_type + ">" )
322324 else :
323325 writer .write (result_type )
324326 writer .write (" " + fn_name + "(" + ", " .join (args ) + ");\n \n " )
325327
328+ def ignore (op_name ):
329+ print ("\033 [93mWARNING\033 [0m: instruction " + op_name + " ignored" )
326330
327331if __name__ == "__main__" :
328332 script_dir_path = os .path .abspath (os .path .dirname (__file__ ))
0 commit comments