Skip to content

Commit e4ddf78

Browse files
authored
Handle implicit operator overloading (#4063)
Currently, for a binary operator, we check whether the RHS expr defines an operator using the operator overloading or not regardless the binary operator to handle. If the RHS expr defines such operator, we build the operator using the defined one. However, it can result in trying to build an operator that is not actually defined by the RHS expr. This commit lets it decide whether it has to build a new operator or use some pre-defined one based on the type of the binary operator to handle and the operator defined by the RHS expr.
1 parent 110a5be commit e4ddf78

8 files changed

Lines changed: 170 additions & 46 deletions

File tree

tools/clang/include/clang/AST/HlslTypes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,11 @@ bool GetIntrinsicOp(const clang::FunctionDecl *FD, unsigned &opcode,
431431
llvm::StringRef &group);
432432
bool GetIntrinsicLowering(const clang::FunctionDecl *FD, llvm::StringRef &S);
433433

434+
bool IsUserDefinedRecordType(clang::QualType type);
435+
bool DoesTypeDefineOverloadedOperator(clang::QualType typeWithOperator,
436+
clang::OverloadedOperatorKind opc,
437+
clang::QualType paramType);
438+
434439
/// <summary>Adds a function declaration to the specified class record.</summary>
435440
/// <param name="context">ASTContext that owns declarations.</param>
436441
/// <param name="recordDecl">Record declaration in which to add function.</param>

tools/clang/lib/AST/HlslTypes.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,49 @@ bool IsHLSLSubobjectType(clang::QualType type) {
580580
return GetHLSLSubobjectKind(type, kind, hgType);
581581
}
582582

583+
bool IsUserDefinedRecordType(clang::QualType type) {
584+
if (const auto *rt = type->getAs<RecordType>()) {
585+
// HLSL specific types
586+
if (hlsl::IsHLSLResourceType(type) || hlsl::IsHLSLVecMatType(type) ||
587+
isa<ExtVectorType>(type.getTypePtr()) || type->isBuiltinType() ||
588+
type->isArrayType()) {
589+
return false;
590+
}
591+
592+
// SubpassInput or SubpassInputMS type
593+
if (rt->getDecl()->getName() == "SubpassInput" ||
594+
rt->getDecl()->getName() == "SubpassInputMS") {
595+
return false;
596+
}
597+
return true;
598+
}
599+
600+
return false;
601+
}
602+
603+
bool DoesTypeDefineOverloadedOperator(clang::QualType typeWithOperator,
604+
clang::OverloadedOperatorKind opc,
605+
clang::QualType paramType) {
606+
if (const RecordType *recordType = typeWithOperator->getAs<RecordType>()) {
607+
if (const CXXRecordDecl *cxxRecordDecl =
608+
dyn_cast<CXXRecordDecl>(recordType->getDecl())) {
609+
for (const auto *method : cxxRecordDecl->methods()) {
610+
if (!method->isUserProvided() || method->getNumParams() != 1)
611+
continue;
612+
// It must be an implicit assignment.
613+
if (opc == OO_Equal &&
614+
typeWithOperator != method->getParamDecl(0)->getOriginalType() &&
615+
typeWithOperator == paramType) {
616+
continue;
617+
}
618+
if (method->getOverloadedOperator() == opc)
619+
return true;
620+
}
621+
}
622+
}
623+
return false;
624+
}
625+
583626
bool GetHLSLSubobjectKind(clang::QualType type, DXIL::SubobjectKind &subobjectKind, DXIL::HitGroupType &hgType) {
584627
hgType = (DXIL::HitGroupType)(-1);
585628
type = type.getCanonicalType();

tools/clang/lib/Sema/SemaExpr.cpp

Lines changed: 10 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,49 +1553,6 @@ static ExprResult BuildCookedLiteralOperatorCall(Sema &S, Scope *Scope,
15531553
return S.BuildLiteralOperatorCall(R, OpNameInfo, Args, LitEndLoc);
15541554
}
15551555

1556-
// HLSL Change Starts
1557-
static bool IsUserDefinedRecordType(QualType type) {
1558-
if (const auto *rt = type->getAs<RecordType>()) {
1559-
// HLSL specific types
1560-
if (hlsl::IsHLSLResourceType(type) || hlsl::IsHLSLVecMatType(type) ||
1561-
isa<ExtVectorType>(type.getTypePtr()) || type->isBuiltinType() ||
1562-
type->isArrayType()) {
1563-
return false;
1564-
}
1565-
1566-
// SubpassInput or SubpassInputMS type
1567-
if (rt->getDecl()->getName() == "SubpassInput" ||
1568-
rt->getDecl()->getName() == "SubpassInputMS") {
1569-
return false;
1570-
}
1571-
return true;
1572-
}
1573-
1574-
return false;
1575-
}
1576-
1577-
static bool DoesTypeDefineOverloadedOperator(QualType type) {
1578-
if (const RecordType *recordType = type->getAs<RecordType>()) {
1579-
if (const CXXRecordDecl *cxxRecordDecl =
1580-
dyn_cast<CXXRecordDecl>(recordType->getDecl())) {
1581-
bool found_overloaded_operator = false;
1582-
for (const auto *method : cxxRecordDecl->methods()) {
1583-
if (method->getOverloadedOperator() != OO_None)
1584-
found_overloaded_operator = true;
1585-
}
1586-
return found_overloaded_operator;
1587-
}
1588-
}
1589-
return false;
1590-
}
1591-
1592-
static bool IsUserDefinedRecordTypeWithOverloadedOperator(QualType type) {
1593-
if (!IsUserDefinedRecordType(type))
1594-
return false;
1595-
return DoesTypeDefineOverloadedOperator(type);
1596-
}
1597-
// HLSL Change Ends
1598-
15991556
/// ActOnStringLiteral - The specified tokens were lexed as pasted string
16001557
/// fragments (e.g. "foo" "bar" L"baz"). The result string has to handle string
16011558
/// concatenation ([C99 5.1.1.2, translation phase #6]), so it may come from
@@ -10435,7 +10392,12 @@ ExprResult Sema::CreateBuiltinBinOp(SourceLocation OpLoc,
1043510392

1043610393
// HLSL Change Starts
1043710394
// Handle HLSL binary operands differently
10438-
if (getLangOpts().HLSL) {
10395+
if (getLangOpts().HLSL &&
10396+
(!getLangOpts().EnableOperatorOverloading ||
10397+
!hlsl::IsUserDefinedRecordType(LHSExpr->getType())) ||
10398+
!hlsl::DoesTypeDefineOverloadedOperator(
10399+
LHSExpr->getType(), clang::BinaryOperator::getOverloadedOperator(Opc),
10400+
RHSExpr->getType())) {
1043910401
hlsl::CheckBinOpForHLSL(*this, OpLoc, Opc, LHS, RHS, ResultTy, CompLHSTy, CompResultTy);
1044010402
if (!ResultTy.isNull() && Opc == BO_Comma) {
1044110403
// In C/C++, the RHS value kind should propagate. In HLSL, it should yield an r-value.
@@ -10921,8 +10883,10 @@ ExprResult Sema::BuildBinOp(Scope *S, SourceLocation OpLoc,
1092110883
// methods or not.
1092210884
if (getLangOpts().CPlusPlus &&
1092310885
(!getLangOpts().HLSL || getLangOpts().EnableOperatorOverloading) &&
10924-
IsUserDefinedRecordTypeWithOverloadedOperator(LHSExpr->getType()) &&
10925-
IsUserDefinedRecordType(RHSExpr->getType())) {
10886+
hlsl::IsUserDefinedRecordType(LHSExpr->getType()) &&
10887+
hlsl::DoesTypeDefineOverloadedOperator(
10888+
LHSExpr->getType(), clang::BinaryOperator::getOverloadedOperator(Opc),
10889+
RHSExpr->getType())) {
1092610890
// If either expression is type-dependent, always build an
1092710891
// overloaded op.
1092810892
if (LHSExpr->isTypeDependent() || RHSExpr->isTypeDependent())
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: %clang_cc1 -fsyntax-only -Wno-unused-value -ffreestanding -verify -enable-operator-overloading %s
2+
// RUN: %clang_cc1 -fsyntax-only -Wno-unused-value -ffreestanding -verify -HV 2021 %s
3+
4+
// This test checks that when we use undefined overloaded operator
5+
// dxcompiler generates error and no crashes are observed.
6+
7+
struct S1 {
8+
float a;
9+
10+
float operator+(float x) {
11+
return a + x;
12+
}
13+
};
14+
15+
struct S2 {
16+
S1 s1;
17+
};
18+
19+
void main(float4 pos: SV_Position) {
20+
S1 s1;
21+
S2 s2;
22+
pos.x = s2.s1 + 0.1;
23+
pos.x = s2.s1 + s1; // expected-error {{invalid operands to binary expression ('S1' and 'S1')}}
24+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: %dxc -E main -T ps_6_0 -fcgl -enable-operator-overloading %s | FileCheck %s
2+
3+
// CHECK: define internal void {{.*}}(%struct.MyArray* %this, %struct.MyArray* noalias sret %agg.result, %struct.MyArray* %RHS)
4+
5+
#define MAX_SIZE 100
6+
7+
struct MyArray {
8+
float4 A[MAX_SIZE];
9+
10+
void splat(float f) {
11+
for (int i = 0; i < MAX_SIZE; i++)
12+
A[i] = f;
13+
};
14+
15+
MyArray operator+(MyArray RHS) {
16+
MyArray OutArray;
17+
for (int i = 0; i < MAX_SIZE; i++)
18+
OutArray.A[i] = A[i] + RHS.A[i];
19+
return OutArray;
20+
};
21+
};
22+
23+
24+
float4 main(float4 col1: COLOR0, float4 col2: COLOR2, int ix : I) : SV_Target {
25+
MyArray A, B;
26+
B.splat(col1);
27+
A = A + B;
28+
return A.A[ix];
29+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: %dxc -E main -T ps_6_0 -HV 2021 %s | FileCheck %s
2+
3+
// CHECK: define void @main()
4+
5+
struct S1 {
6+
float a;
7+
};
8+
9+
struct S2 {
10+
S1 s1;
11+
Texture2D<float> tex;
12+
};
13+
14+
void main(float4 pos: SV_Position) {
15+
S1 s1;
16+
S2 s2;
17+
s2.s1 = s1;
18+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: %dxc -E main -T ps_6_0 -HV 2021 %s | FileCheck %s
2+
3+
// CHECK: define void @main()
4+
5+
struct S {
6+
float a;
7+
8+
void operator=(float x) {
9+
a = x;
10+
}
11+
12+
float operator+(float x) {
13+
return a + x;
14+
}
15+
};
16+
17+
struct Number {
18+
int n;
19+
20+
void operator=(float x) {
21+
n = x;
22+
}
23+
};
24+
25+
int main(float4 pos: SV_Position) : SV_Target {
26+
S s1;
27+
S s2;
28+
s1 = s2;
29+
s1 = 0.2;
30+
s1 = s1 + 0.1;
31+
32+
Number a = {pos.x};
33+
Number b = {pos.y};
34+
a = pos.x;
35+
return a.n;
36+
}

tools/clang/unittests/HLSL/VerifierTest.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class VerifierTest TEST_CLASS_DERIVATION {
4747
TEST_METHOD(RunCppErrors)
4848
TEST_METHOD(RunCppErrorsHV2015)
4949
TEST_METHOD(RunOperatorOverloadingForNewDelete)
50+
TEST_METHOD(RunOperatorOverloadingNotDefinedBinaryOp)
5051
TEST_METHOD(RunCXX11Attributes)
5152
TEST_METHOD(RunEnums)
5253
TEST_METHOD(RunFunctions)
@@ -198,6 +199,10 @@ TEST_F(VerifierTest, RunOperatorOverloadingForNewDelete) {
198199
CheckVerifiesHLSL(L"overloading-new-delete-errors.hlsl");
199200
}
200201

202+
TEST_F(VerifierTest, RunOperatorOverloadingNotDefinedBinaryOp) {
203+
CheckVerifiesHLSL(L"use-undefined-overloaded-operator.hlsl");
204+
}
205+
201206
TEST_F(VerifierTest, RunCXX11Attributes) {
202207
CheckVerifiesHLSL(L"cxx11-attributes.hlsl");
203208
}

0 commit comments

Comments
 (0)