2828namespace spirv
2929{
3030
31- //! General Decls
32- template<uint32_t StorageClass, typename T>
33- using pointer_t = vk::SpirvOpaqueType<spv::OpTypePointer, vk::Literal< vk::integral_constant<uint32_t, StorageClass> >, T>;
34-
3531// The holy operation that makes addrof possible
3632template<uint32_t StorageClass, typename T>
3733[[vk::ext_instruction(spv::OpCopyObject)]]
3834pointer_t<StorageClass, T> copyObject([[vk::ext_reference]] T value);
3935
40- //! Std 450 Extended set operations
36+ // TODO: Generate extended instructions
37+ //! Std 450 Extended set instructions
4138template<typename SquareMatrix>
4239[[vk::ext_instruction(34, /* GLSLstd450MatrixInverse */, "GLSL.std.450")]]
4340SquareMatrix matrixInverse(NBL_CONST_REF_ARG(SquareMatrix) mat);
@@ -88,37 +85,58 @@ def gen(grammer_path, output_path):
8885
8986 writer .write ("\n //! Builtins\n namespace builtin\n {\n " )
9087 for b in builtins :
91- builtin_type = None
88+ b_name = b ["enumerant" ]
89+ b_type = None
90+ b_cap = None
9291 is_output = False
93- builtin_name = b ["enumerant" ]
94- match builtin_name :
95- case "HelperInvocation" : builtin_type = "bool"
96- case "VertexIndex" : builtin_type = "uint32_t"
97- case "InstanceIndex" : builtin_type = "uint32_t"
98- case "NumWorkgroups" : builtin_type = "uint32_t3"
99- case "WorkgroupId" : builtin_type = "uint32_t3"
100- case "LocalInvocationId" : builtin_type = "uint32_t3"
101- case "GlobalInvocationId" : builtin_type = "uint32_t3"
102- case "LocalInvocationIndex" : builtin_type = "uint32_t"
103- case "SubgroupEqMask" : builtin_type = "uint32_t4"
104- case "SubgroupGeMask" : builtin_type = "uint32_t4"
105- case "SubgroupGtMask" : builtin_type = "uint32_t4"
106- case "SubgroupLeMask" : builtin_type = "uint32_t4"
107- case "SubgroupLtMask" : builtin_type = "uint32_t4"
108- case "SubgroupSize" : builtin_type = "uint32_t"
109- case "NumSubgroups" : builtin_type = "uint32_t"
110- case "SubgroupId" : builtin_type = "uint32_t"
111- case "SubgroupLocalInvocationId" : builtin_type = "uint32_t"
92+ match b_name :
93+ case "HelperInvocation" : b_type = "bool"
94+ case "VertexIndex" : b_type = "uint32_t"
95+ case "InstanceIndex" : b_type = "uint32_t"
96+ case "NumWorkgroups" : b_type = "uint32_t3"
97+ case "WorkgroupId" : b_type = "uint32_t3"
98+ case "LocalInvocationId" : b_type = "uint32_t3"
99+ case "GlobalInvocationId" : b_type = "uint32_t3"
100+ case "LocalInvocationIndex" : b_type = "uint32_t"
101+ case "SubgroupEqMask" :
102+ b_type = "uint32_t4"
103+ b_cap = "GroupNonUniformBallot"
104+ case "SubgroupGeMask" :
105+ b_type = "uint32_t4"
106+ b_cap = "GroupNonUniformBallot"
107+ case "SubgroupGtMask" :
108+ b_type = "uint32_t4"
109+ b_cap = "GroupNonUniformBallot"
110+ case "SubgroupLeMask" :
111+ b_type = "uint32_t4"
112+ b_cap = "GroupNonUniformBallot"
113+ case "SubgroupLtMask" :
114+ b_type = "uint32_t4"
115+ b_cap = "GroupNonUniformBallot"
116+ case "SubgroupSize" :
117+ b_type = "uint32_t"
118+ b_cap = "GroupNonUniform"
119+ case "NumSubgroups" :
120+ b_type = "uint32_t"
121+ b_cap = "GroupNonUniform"
122+ case "SubgroupId" :
123+ b_type = "uint32_t"
124+ b_cap = "GroupNonUniform"
125+ case "SubgroupLocalInvocationId" :
126+ b_type = "uint32_t"
127+ b_cap = "GroupNonUniform"
112128 case "Position" :
113- builtin_type = "float32_t4"
129+ b_type = "float32_t4"
114130 is_output = True
115131 case _: continue
132+ if b_cap != None :
133+ writer .write ("[[vk::ext_capability(spv::Capability" + b_cap + ")]]\n " )
116134 if is_output :
117- writer .write ("[[vk::ext_builtin_output(spv::BuiltIn" + builtin_name + ")]]\n " )
118- writer .write ("static " + builtin_type + " " + builtin_name + ";\n " )
135+ writer .write ("[[vk::ext_builtin_output(spv::BuiltIn" + b_name + ")]]\n " )
136+ writer .write ("static " + b_type + " " + b_name + ";\n " )
119137 else :
120- writer .write ("[[vk::ext_builtin_input(spv::BuiltIn" + builtin_name + ")]]\n " )
121- writer .write ("static const " + builtin_type + " " + builtin_name + ";\n " )
138+ writer .write ("[[vk::ext_builtin_input(spv::BuiltIn" + b_name + ")]]\n " )
139+ writer .write ("static const " + b_type + " " + b_name + ";\n \n " )
122140 writer .write ("}\n " )
123141
124142 writer .write ("\n //! Execution Modes\n namespace execution_mode\n {" )
@@ -142,28 +160,28 @@ def gen(grammer_path, output_path):
142160
143161 match instruction ["class" ]:
144162 case "Atomic" :
145- processInst (writer , instruction , InstOptions () )
146- processInst (writer , instruction , InstOptions ( shape = Shape .PTR_TEMPLATE ) )
163+ processInst (writer , instruction )
164+ processInst (writer , instruction , Shape .PTR_TEMPLATE )
147165 case "Memory" :
148- processInst (writer , instruction , InstOptions ( shape = Shape .PTR_TEMPLATE ) )
149- processInst (writer , instruction , InstOptions ( shape = Shape .BDA ) )
166+ processInst (writer , instruction , Shape .PTR_TEMPLATE )
167+ processInst (writer , instruction , Shape .BDA )
150168 case "Barrier" | "Bit" :
151- processInst (writer , instruction , InstOptions () )
169+ processInst (writer , instruction )
152170 case "Reserved" :
153171 match instruction ["opname" ]:
154172 case "OpBeginInvocationInterlockEXT" | "OpEndInvocationInterlockEXT" :
155- processInst (writer , instruction , InstOptions () )
173+ processInst (writer , instruction )
156174 case "Non-Uniform" :
157175 match instruction ["opname" ]:
158176 case "OpGroupNonUniformElect" | "OpGroupNonUniformAll" | "OpGroupNonUniformAny" | "OpGroupNonUniformAllEqual" :
159- processInst (writer , instruction , InstOptions ( result_ty = "bool" ) )
177+ processInst (writer , instruction , result_ty = "bool" )
160178 case "OpGroupNonUniformBallot" :
161- processInst (writer , instruction , InstOptions ( result_ty = "uint32_t4" ,op_ty = "bool" ) )
179+ processInst (writer , instruction , result_ty = "uint32_t4" ,prefered_op_ty = "bool" )
162180 case "OpGroupNonUniformInverseBallot" | "OpGroupNonUniformBallotBitExtract" :
163- processInst (writer , instruction , InstOptions ( result_ty = "bool" ,op_ty = "uint32_t4" ) )
181+ processInst (writer , instruction , result_ty = "bool" ,prefered_op_ty = "uint32_t4" )
164182 case "OpGroupNonUniformBallotBitCount" | "OpGroupNonUniformBallotFindLSB" | "OpGroupNonUniformBallotFindMSB" :
165- processInst (writer , instruction , InstOptions ( result_ty = "uint32_t" ,op_ty = "uint32_t4" ) )
166- case _: processInst (writer , instruction , InstOptions () )
183+ processInst (writer , instruction , result_ty = "uint32_t" ,prefered_op_ty = "uint32_t4" )
184+ case _: processInst (writer , instruction )
167185 case _: continue # TODO
168186
169187 writer .write (foot )
@@ -173,12 +191,11 @@ class Shape(Enum):
173191 PTR_TEMPLATE = 1 , # TODO: this is a DXC Workaround
174192 BDA = 2 , # PhysicalStorageBuffer Result Type
175193
176- class InstOptions (NamedTuple ):
177- shape : Shape = Shape .DEFAULT
178- result_ty : Optional [str ] = None
179- op_ty : Optional [str ] = None
180-
181- def processInst (writer : io .TextIOWrapper , instruction , options : InstOptions ):
194+ def processInst (writer : io .TextIOWrapper ,
195+ instruction ,
196+ shape : Shape = Shape .DEFAULT ,
197+ result_ty : Optional [str ] = None ,
198+ prefered_op_ty : Optional [str ] = None ):
182199 templates = []
183200 caps = []
184201 conds = []
@@ -193,10 +210,10 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
193210 if cap == "Shader" : continue
194211 caps .append (cap )
195212
196- if options . shape == Shape .PTR_TEMPLATE :
213+ if shape == Shape .PTR_TEMPLATE :
197214 templates .append ("typename P" )
198215 conds .append ("is_spirv_type_v<P>" )
199- elif options . shape == Shape .BDA :
216+ elif shape == Shape .BDA :
200217 caps .append ("PhysicalStorageBufferAddresses" )
201218
202219 # split upper case words
@@ -226,10 +243,10 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
226243
227244 if "operands" in instruction and instruction ["operands" ][0 ]["kind" ] == "IdResultType" :
228245 if len (result_types ) == 0 :
229- if options . result_ty == None :
246+ if result_ty == None :
230247 result_types = ["T" ]
231248 else :
232- result_types = [options . result_ty ]
249+ result_types = [result_ty ]
233250 else :
234251 assert len (result_types ) == 0
235252 result_types = ["void" ]
@@ -261,8 +278,8 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
261278 final_templates .append ("typename N" )
262279
263280 op_ty = "T"
264- if options . op_ty != None :
265- op_ty = options . op_ty
281+ if prefered_op_ty != None :
282+ op_ty = prefered_op_ty
266283 elif rt != "void" :
267284 op_ty = rt
268285
@@ -276,9 +293,9 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
276293 case "IdRef" :
277294 match operand ["name" ]:
278295 case "'Pointer'" :
279- if options . shape == Shape .PTR_TEMPLATE :
296+ if shape == Shape .PTR_TEMPLATE :
280297 args .append ("P " + operand_name )
281- elif options . shape == Shape .BDA :
298+ elif shape == Shape .BDA :
282299 if (not "typename T" in final_templates ) and (rt == "T" or op_ty == "T" ):
283300 final_templates = ["typename T" ] + final_templates
284301 args .append ("pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name )
@@ -302,7 +319,7 @@ def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
302319 case "GroupOperation" : args .append ("[[vk::ext_literal]] uint32_t " + operand_name )
303320 case "MemoryAccess" :
304321 assert len (overload_caps ) <= 1
305- if options . shape != Shape .BDA :
322+ if shape != Shape .BDA :
306323 writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t memoryAccess" ])
307324 writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam" ])
308325 writeInst (writer , final_templates + ["uint32_t alignment" ], cap , exts , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002" , "[[vk::ext_literal]] uint32_t __alignment = alignment" ])
@@ -326,7 +343,7 @@ def writeInst(writer: io.TextIOWrapper, templates, cap, exts, op_name, fn_name,
326343 writer .write (" " + fn_name + "(" + ", " .join (args ) + ");\n \n " )
327344
328345def ignore (op_name ):
329- print ("\033 [93mWARNING \033 [0m: instruction " + op_name + " ignored" )
346+ print ("\033 [94mIGNORED \033 [0m: " + op_name )
330347
331348if __name__ == "__main__" :
332349 script_dir_path = os .path .abspath (os .path .dirname (__file__ ))
0 commit comments