Skip to content

Commit 58163b0

Browse files
authored
Add support for specifying overload arg index in extension function (#3510)
Any extension function that includes the specical string "$o" in its name needs to have the $o replaced with the type name of the overload. Previously we used a default heuristic to select the overload type from a function argument. This commit adds support for explicitly setting the argument to use for the overload name by appending a ":<ArgIndex>" to the overload marker. For example, using a name like "my_special_function.$o:3" would take the overload type from the third function argument.
1 parent 9eb7f4f commit 58163b0

2 files changed

Lines changed: 115 additions & 11 deletions

File tree

lib/HLSL/HLOperationLowerExtension.cpp

Lines changed: 77 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,13 +1142,21 @@ class ExtensionName {
11421142
return name.size() > 0;
11431143
}
11441144

1145+
typedef unsigned OverloadArgIndex;
1146+
static constexpr OverloadArgIndex DefaultOverloadIndex = std::numeric_limits<OverloadArgIndex>::max();
1147+
11451148
// Choose the (return value or argument) type that determines the overload type
11461149
// for the intrinsic call.
1147-
// For now we take the return type as the overload. If the return is void we
1148-
// take the first (non-opcode) argument as the overload type. We could extend the
1149-
// $o sytnax in the extension name to explicitly specify the overload slot (e.g.
1150-
// $o:3 would say the overload type is determined by parameter 3.
1151-
static Type *SelectOverloadSlot(CallInst *CI) {
1150+
// If the overload arg index was explicitly specified (see ParseOverloadArgIndex)
1151+
// then we use that arg to pick the overload name. Otherwise we pick a default
1152+
// where we take the return type as the overload. If the return is void we
1153+
// take the first (non-opcode) argument as the overload type.
1154+
static Type *SelectOverloadSlot(CallInst *CI, OverloadArgIndex ArgIndex) {
1155+
if (ArgIndex != DefaultOverloadIndex)
1156+
{
1157+
return CI->getArgOperand(ArgIndex)->getType();
1158+
}
1159+
11521160
Type *ty = CI->getType();
11531161
if (ty->isVoidTy()) {
11541162
if (CI->getNumArgOperands() > 1)
@@ -1158,8 +1166,8 @@ class ExtensionName {
11581166
return ty;
11591167
}
11601168

1161-
static Type *GetOverloadType(CallInst *CI) {
1162-
Type *ty = SelectOverloadSlot(CI);
1169+
static Type *GetOverloadType(CallInst *CI, OverloadArgIndex ArgIndex) {
1170+
Type *ty = SelectOverloadSlot(CI, ArgIndex);
11631171
if (ty->isVectorTy())
11641172
ty = ty->getVectorElementType();
11651173

@@ -1174,19 +1182,77 @@ class ExtensionName {
11741182
return typeName;
11751183
}
11761184

1177-
static std::string GetOverloadTypeName(CallInst *CI) {
1178-
Type *ty = GetOverloadType(CI);
1185+
static std::string GetOverloadTypeName(CallInst *CI, OverloadArgIndex ArgIndex) {
1186+
Type *ty = GetOverloadType(CI, ArgIndex);
11791187
return GetTypeName(ty);
11801188
}
11811189

1190+
// Parse the arg index out of the overload marker (if any).
1191+
//
1192+
// The function names use a $o to indicate that the function is overloaded
1193+
// and we should replace $o with the overload type. The extension name can
1194+
// explicitly set which arg to use for the overload type by adding a colon
1195+
// and a number after the $o (e.g. $o:3 would say the overload type is
1196+
// determined by parameter 3).
1197+
//
1198+
// If we find an arg index after the overload marker we update the size
1199+
// of the marker to include the full parsed string size so that it can
1200+
// be replaced with the selected overload type.
1201+
//
1202+
static OverloadArgIndex ParseOverloadArgIndex(
1203+
const std::string& functionName,
1204+
size_t OverloadMarkerStartIndex,
1205+
size_t *pOverloadMarkerSize)
1206+
{
1207+
assert(OverloadMarkerStartIndex != std::string::npos);
1208+
size_t StartIndex = OverloadMarkerStartIndex + *pOverloadMarkerSize;
1209+
1210+
// Check if we have anything after the overload marker to parse.
1211+
if (StartIndex >= functionName.size())
1212+
{
1213+
return DefaultOverloadIndex;
1214+
}
1215+
1216+
// Does it start with a ':' ?
1217+
if (functionName[StartIndex] != ':')
1218+
{
1219+
return DefaultOverloadIndex;
1220+
}
1221+
1222+
// Skip past the :
1223+
++StartIndex;
1224+
1225+
// Collect all the digits.
1226+
std::string Digits;
1227+
Digits.reserve(functionName.size() - StartIndex);
1228+
for (size_t i = StartIndex; i < functionName.size(); ++i)
1229+
{
1230+
char c = functionName[i];
1231+
if (!isdigit(c))
1232+
{
1233+
break;
1234+
}
1235+
Digits.push_back(c);
1236+
}
1237+
1238+
if (Digits.empty())
1239+
{
1240+
return DefaultOverloadIndex;
1241+
}
1242+
1243+
*pOverloadMarkerSize = *pOverloadMarkerSize + std::strlen(":") + Digits.size();
1244+
return std::stoi(Digits);
1245+
}
1246+
11821247
// Find the occurence of the overload marker $o and replace it the the overload type name.
11831248
static void ReplaceOverloadMarkerWithTypeName(std::string &functionName, CallInst *CI) {
11841249
const char *OverloadMarker = "$o";
1185-
const size_t OverloadMarkerLength = 2;
1250+
size_t OverloadMarkerLength = 2;
11861251

11871252
size_t pos = functionName.find(OverloadMarker);
11881253
if (pos != std::string::npos) {
1189-
std::string typeName = GetOverloadTypeName(CI);
1254+
OverloadArgIndex ArgIndex = ParseOverloadArgIndex(functionName, pos, &OverloadMarkerLength);
1255+
std::string typeName = GetOverloadTypeName(CI, ArgIndex);
11901256
functionName.replace(pos, OverloadMarkerLength, typeName);
11911257
}
11921258
}

tools/clang/unittests/HLSL/ExtensionTest.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,15 @@ static const HLSL_INTRINSIC_ARGUMENT TestMyTexture2DOp[] = {
145145
{ "val", AR_QUAL_IN, 1, LITEMPLATE_VECTOR, 1, LICOMPTYPE_UINT, 1, 2},
146146
};
147147

148+
// float = test_overload(float a, uint b, double c)
149+
static const HLSL_INTRINSIC_ARGUMENT TestOverloadArgs[] = {
150+
{ "test_overload", AR_QUAL_OUT, 0, LITEMPLATE_SCALAR, 0, LICOMPTYPE_NUMERIC, 1, IA_C },
151+
{ "a", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_FLOAT, 1, IA_C },
152+
{ "b", AR_QUAL_IN, 2, LITEMPLATE_ANY, 2, LICOMPTYPE_UINT, 1, IA_C },
153+
{ "c", AR_QUAL_IN, 3, LITEMPLATE_SCALAR, 3, LICOMPTYPE_DOUBLE, 1, IA_C },
154+
};
155+
156+
148157
struct Intrinsic {
149158
LPCWSTR hlslName;
150159
const char *dxilName;
@@ -175,6 +184,9 @@ Intrinsic Intrinsics[] = {
175184
// counterpart for testing purposes.
176185
{L"test_unsigned","test_unsigned", "n", { static_cast<unsigned>(hlsl::IntrinsicOp::IOP_min), false, true, false, -1, countof(TestUnsigned), TestUnsigned}},
177186
{L"wave_proc", DEFAULT_NAME, "r", { 16, false, true, true, -1, countof(WaveProcArgs), WaveProcArgs }},
187+
{L"test_o_1", "test_o_1.$o:1", "r", { 18, false, true, true, -1, countof(TestOverloadArgs), TestOverloadArgs }},
188+
{L"test_o_2", "test_o_2.$o:2", "r", { 19, false, true, true, -1, countof(TestOverloadArgs), TestOverloadArgs }},
189+
{L"test_o_3", "test_o_3.$o:3", "r", { 20, false, true, true, -1, countof(TestOverloadArgs), TestOverloadArgs }},
178190
};
179191

180192
Intrinsic BufferIntrinsics[] = {
@@ -530,6 +542,7 @@ class ExtensionTest : public ::testing::Test {
530542
TEST_METHOD(ResourceExtensionIntrinsicCustomLowering1)
531543
TEST_METHOD(ResourceExtensionIntrinsicCustomLowering2)
532544
TEST_METHOD(ResourceExtensionIntrinsicCustomLowering3)
545+
TEST_METHOD(CustomOverloadArg1)
533546
};
534547

535548
TEST_F(ExtensionTest, DefineWhenRegisteredThenPreserved) {
@@ -1182,3 +1195,28 @@ TEST_F(ExtensionTest, ResourceExtensionIntrinsicCustomLowering3) {
11821195
};
11831196
CheckMsgs(disassembly.c_str(), disassembly.length(), expected, 1, true);
11841197
}
1198+
1199+
TEST_F(ExtensionTest, CustomOverloadArg1) {
1200+
// Test that we pick the overload name based on the first arg.
1201+
Compiler c(m_dllSupport);
1202+
c.RegisterIntrinsicTable(new TestIntrinsicTable());
1203+
auto result = c.Compile(
1204+
"float main() : SV_Target {\n"
1205+
" float o1 = test_o_1(1.0f, 2u, 4.0);\n"
1206+
" float o2 = test_o_2(1.0f, 2u, 4.0);\n"
1207+
" float o3 = test_o_3(1.0f, 2u, 4.0);\n"
1208+
" return o1 + o2 + o3;\n"
1209+
"}\n",
1210+
{ L"/Vd" }, {}
1211+
);
1212+
CheckOperationResultMsgs(result, {}, true, false);
1213+
std::string disassembly = c.Disassemble();
1214+
1215+
// The function name should match the first arg (float)
1216+
LPCSTR expected[] = {
1217+
"call float @test_o_1.float(i32 18, float 1.000000e+00, i32 2, double 4.000000e+00)",
1218+
"call float @test_o_2.i32(i32 18, float 1.000000e+00, i32 2, double 4.000000e+00)",
1219+
"call float @test_o_3.double(i32 18, float 1.000000e+00, i32 2, double 4.000000e+00)",
1220+
};
1221+
CheckMsgs(disassembly.c_str(), disassembly.length(), expected, 1, false);
1222+
}

0 commit comments

Comments
 (0)