Skip to content

Commit 7e0c57d

Browse files
committed
Improve SIMD binary op lowering and add vector tests
1 parent 0667aec commit 7e0c57d

3 files changed

Lines changed: 176 additions & 18 deletions

File tree

crates/compiler/src/cranelift_backend.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6709,15 +6709,20 @@ impl CraneliftBackend {
67096709
impl HirType {
67106710
#[allow(dead_code)]
67116711
fn is_float(&self) -> bool {
6712-
matches!(self, HirType::F32 | HirType::F64)
6712+
match self {
6713+
HirType::F32 | HirType::F64 => true,
6714+
HirType::Vector(elem_ty, _) => elem_ty.is_float(),
6715+
_ => false,
6716+
}
67136717
}
67146718

67156719
#[allow(dead_code)]
67166720
fn is_signed(&self) -> bool {
6717-
matches!(
6718-
self,
6719-
HirType::I8 | HirType::I16 | HirType::I32 | HirType::I64 | HirType::I128
6720-
)
6721+
match self {
6722+
HirType::I8 | HirType::I16 | HirType::I32 | HirType::I64 | HirType::I128 => true,
6723+
HirType::Vector(elem_ty, _) => elem_ty.is_signed(),
6724+
_ => false,
6725+
}
67216726
}
67226727
}
67236728

crates/compiler/src/ssa.rs

Lines changed: 87 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1883,7 +1883,7 @@ impl SsaBuilder {
18831883
let right_val = self.translate_expression(block_id, right)?;
18841884
let result_type = self.convert_type(&expr.ty);
18851885

1886-
let hir_op = self.convert_binary_op(op);
1886+
let hir_op = self.convert_binary_op(op, &left_with_type.ty);
18871887

18881888
// For comparisons, use the operand type (not Bool result type) for the instruction
18891889
let inst_type = match hir_op {
@@ -4350,28 +4350,102 @@ impl SsaBuilder {
43504350
fn convert_binary_op(
43514351
&self,
43524352
op: &zyntax_typed_ast::typed_ast::BinaryOp,
4353+
operand_ty: &Type,
43534354
) -> crate::hir::BinaryOp {
43544355
use crate::hir::BinaryOp as HirOp;
43554356
use zyntax_typed_ast::typed_ast::BinaryOp as FrontendOp;
43564357

4358+
let operand_hir_ty = self.convert_type(operand_ty);
4359+
let is_float_like = match &operand_hir_ty {
4360+
HirType::F32 | HirType::F64 => true,
4361+
HirType::Vector(elem_ty, _) => matches!(&**elem_ty, HirType::F32 | HirType::F64),
4362+
_ => false,
4363+
};
4364+
43574365
match op {
4358-
FrontendOp::Add => HirOp::Add,
4359-
FrontendOp::Sub => HirOp::Sub,
4360-
FrontendOp::Mul => HirOp::Mul,
4366+
FrontendOp::Add => {
4367+
if is_float_like {
4368+
HirOp::FAdd
4369+
} else {
4370+
HirOp::Add
4371+
}
4372+
}
4373+
FrontendOp::Sub => {
4374+
if is_float_like {
4375+
HirOp::FSub
4376+
} else {
4377+
HirOp::Sub
4378+
}
4379+
}
4380+
FrontendOp::Mul => {
4381+
if is_float_like {
4382+
HirOp::FMul
4383+
} else {
4384+
HirOp::Mul
4385+
}
4386+
}
43614387
FrontendOp::MatMul => HirOp::Mul,
4362-
FrontendOp::Div => HirOp::Div,
4363-
FrontendOp::Rem => HirOp::Rem,
4388+
FrontendOp::Div => {
4389+
if is_float_like {
4390+
HirOp::FDiv
4391+
} else {
4392+
HirOp::Div
4393+
}
4394+
}
4395+
FrontendOp::Rem => {
4396+
if is_float_like {
4397+
HirOp::FRem
4398+
} else {
4399+
HirOp::Rem
4400+
}
4401+
}
43644402
FrontendOp::BitAnd => HirOp::And,
43654403
FrontendOp::BitOr => HirOp::Or,
43664404
FrontendOp::BitXor => HirOp::Xor,
43674405
FrontendOp::Shl => HirOp::Shl,
43684406
FrontendOp::Shr => HirOp::Shr,
4369-
FrontendOp::Eq => HirOp::Eq,
4370-
FrontendOp::Ne => HirOp::Ne,
4371-
FrontendOp::Lt => HirOp::Lt,
4372-
FrontendOp::Le => HirOp::Le,
4373-
FrontendOp::Gt => HirOp::Gt,
4374-
FrontendOp::Ge => HirOp::Ge,
4407+
FrontendOp::Eq => {
4408+
if is_float_like {
4409+
HirOp::FEq
4410+
} else {
4411+
HirOp::Eq
4412+
}
4413+
}
4414+
FrontendOp::Ne => {
4415+
if is_float_like {
4416+
HirOp::FNe
4417+
} else {
4418+
HirOp::Ne
4419+
}
4420+
}
4421+
FrontendOp::Lt => {
4422+
if is_float_like {
4423+
HirOp::FLt
4424+
} else {
4425+
HirOp::Lt
4426+
}
4427+
}
4428+
FrontendOp::Le => {
4429+
if is_float_like {
4430+
HirOp::FLe
4431+
} else {
4432+
HirOp::Le
4433+
}
4434+
}
4435+
FrontendOp::Gt => {
4436+
if is_float_like {
4437+
HirOp::FGt
4438+
} else {
4439+
HirOp::Gt
4440+
}
4441+
}
4442+
FrontendOp::Ge => {
4443+
if is_float_like {
4444+
HirOp::FGe
4445+
} else {
4446+
HirOp::Ge
4447+
}
4448+
}
43754449
_ => HirOp::Add, // Default
43764450
}
43774451
}
@@ -6493,7 +6567,7 @@ impl SsaBuilder {
64936567
));
64946568
}
64956569

6496-
let hir_op = self.convert_binary_op(&bin.op);
6570+
let hir_op = self.convert_binary_op(&bin.op, &bin.left.ty);
64976571
let result_id = HirId::new();
64986572
func.values.insert(
64996573
result_id,

crates/compiler/tests/cranelift_backend_tests.rs

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,26 @@ fn test_arithmetic_operations() {
206206
.expect("Failed to compile division function");
207207
}
208208

209+
/// Test SIMD vector arithmetic compilation (integer lanes)
210+
#[test]
211+
fn test_vector_i32x4_add_compilation() {
212+
let mut backend = CraneliftBackend::new().expect("Failed to create backend");
213+
let func = create_vector_arithmetic_function("vec_i32x4_add", BinaryOp::Add, HirType::I32, 4);
214+
backend
215+
.compile_function(func.id, &func)
216+
.expect("Failed to compile i32x4 add function");
217+
}
218+
219+
/// Test SIMD vector arithmetic compilation (floating lanes)
220+
#[test]
221+
fn test_vector_f32x4_add_compilation() {
222+
let mut backend = CraneliftBackend::new().expect("Failed to create backend");
223+
let func = create_vector_arithmetic_function("vec_f32x4_add", BinaryOp::FAdd, HirType::F32, 4);
224+
backend
225+
.compile_function(func.id, &func)
226+
.expect("Failed to compile f32x4 add function");
227+
}
228+
209229
/// Test comparison operations
210230
#[test]
211231
fn test_comparison_operations() {
@@ -359,6 +379,65 @@ fn create_comparison_function(name: &str, op: BinaryOp) -> HirFunction {
359379
func
360380
}
361381

382+
/// Helper function to create vector arithmetic operations
383+
fn create_vector_arithmetic_function(
384+
name: &str,
385+
op: BinaryOp,
386+
elem_ty: HirType,
387+
lanes: u32,
388+
) -> HirFunction {
389+
let name = create_test_string(name);
390+
let vec_ty = HirType::Vector(Box::new(elem_ty), lanes);
391+
392+
let sig = HirFunctionSignature {
393+
params: vec![
394+
HirParam {
395+
id: HirId::new(),
396+
name: create_test_string("a"),
397+
ty: vec_ty.clone(),
398+
attributes: ParamAttributes::default(),
399+
},
400+
HirParam {
401+
id: HirId::new(),
402+
name: create_test_string("b"),
403+
ty: vec_ty.clone(),
404+
attributes: ParamAttributes::default(),
405+
},
406+
],
407+
returns: vec![vec_ty.clone()],
408+
type_params: vec![],
409+
const_params: vec![],
410+
lifetime_params: vec![],
411+
is_variadic: false,
412+
is_async: false,
413+
effects: vec![],
414+
is_pure: false,
415+
};
416+
417+
let mut func = HirFunction::new(name, sig);
418+
419+
let entry_block_id = func.entry_block;
420+
let param_a = func.create_value(vec_ty.clone(), HirValueKind::Parameter(0));
421+
let param_b = func.create_value(vec_ty.clone(), HirValueKind::Parameter(1));
422+
423+
let result = func.create_value(vec_ty.clone(), HirValueKind::Instruction);
424+
let inst = HirInstruction::Binary {
425+
op,
426+
result,
427+
ty: vec_ty,
428+
left: param_a,
429+
right: param_b,
430+
};
431+
432+
let block = func.blocks.get_mut(&entry_block_id).unwrap();
433+
block.add_instruction(inst);
434+
block.set_terminator(HirTerminator::Return {
435+
values: vec![result],
436+
});
437+
438+
func
439+
}
440+
362441
/// Create a function with control flow (if-else)
363442
fn create_control_flow_function() -> HirFunction {
364443
let name = create_test_string("abs");

0 commit comments

Comments
 (0)