diff --git a/dialects/arith/src/ops/unary.rs b/dialects/arith/src/ops/unary.rs index 373daf039..4d6255eac 100644 --- a/dialects/arith/src/ops/unary.rs +++ b/dialects/arith/src/ops/unary.rs @@ -28,7 +28,7 @@ macro_rules! infer_return_ty_for_unary_op { /// Increment #[operation ( dialect = ArithDialect, - traits(UnaryOp, SameOperandsAndResultType), + traits(UnaryOp, SameTypeOperands, SameOperandsAndResultType), implements(InferTypeOpInterface, MemoryEffectOpInterface) )] pub struct Incr { @@ -44,7 +44,7 @@ has_no_effects!(Incr); /// Negation #[operation ( dialect = ArithDialect, - traits(UnaryOp, SameOperandsAndResultType), + traits(UnaryOp, SameTypeOperands, SameOperandsAndResultType), implements(InferTypeOpInterface, MemoryEffectOpInterface) )] pub struct Neg { @@ -60,7 +60,7 @@ has_no_effects!(Neg); /// Modular inverse #[operation ( dialect = ArithDialect, - traits(UnaryOp, SameOperandsAndResultType), + traits(UnaryOp, SameTypeOperands, SameOperandsAndResultType), implements(InferTypeOpInterface, MemoryEffectOpInterface) )] pub struct Inv { @@ -76,7 +76,7 @@ has_no_effects!(Inv); /// log2(operand) #[operation ( dialect = ArithDialect, - traits(UnaryOp, SameOperandsAndResultType), + traits(UnaryOp, SameTypeOperands, SameOperandsAndResultType), implements(InferTypeOpInterface, MemoryEffectOpInterface) )] pub struct Ilog2 { @@ -92,7 +92,7 @@ has_no_effects!(Ilog2); /// pow2(operand) #[operation ( dialect = ArithDialect, - traits(UnaryOp, SameOperandsAndResultType), + traits(UnaryOp, SameTypeOperands, SameOperandsAndResultType), implements(InferTypeOpInterface, MemoryEffectOpInterface) )] pub struct Pow2 { @@ -108,7 +108,7 @@ has_no_effects!(Pow2); /// Logical NOT #[operation ( dialect = ArithDialect, - traits(UnaryOp, SameOperandsAndResultType), + traits(UnaryOp, SameTypeOperands, SameOperandsAndResultType), implements(InferTypeOpInterface, MemoryEffectOpInterface) )] @@ -125,7 +125,7 @@ has_no_effects!(Not); /// Bitwise NOT #[operation ( dialect = ArithDialect, - traits(UnaryOp, SameOperandsAndResultType), + traits(UnaryOp, SameTypeOperands, SameOperandsAndResultType), implements(InferTypeOpInterface, MemoryEffectOpInterface) )] pub struct Bnot { diff --git a/dialects/hir/src/ops/primop.rs b/dialects/hir/src/ops/primop.rs index ad89c4108..7d9c72a66 100644 --- a/dialects/hir/src/ops/primop.rs +++ b/dialects/hir/src/ops/primop.rs @@ -4,7 +4,7 @@ use crate::HirDialect; #[operation( dialect = HirDialect, - traits(SameOperandsAndResultType), + traits(SameTypeOperands, SameOperandsAndResultType), implements(InferTypeOpInterface, MemoryEffectOpInterface) )] pub struct MemGrow { diff --git a/dialects/hir/src/ops/spills.rs b/dialects/hir/src/ops/spills.rs index bd2eb93a4..b684ef3ff 100644 --- a/dialects/hir/src/ops/spills.rs +++ b/dialects/hir/src/ops/spills.rs @@ -5,7 +5,7 @@ use crate::HirDialect; #[operation( dialect = HirDialect, - traits(SameOperandsAndResultType), + traits(SameTypeOperands, SameOperandsAndResultType), implements(MemoryEffectOpInterface, SpillLike) )] pub struct Spill { @@ -34,7 +34,7 @@ impl EffectOpInterface for Spill { #[operation( dialect = HirDialect, - traits(SameOperandsAndResultType), + traits(SameTypeOperands, SameOperandsAndResultType), implements(InferTypeOpInterface, MemoryEffectOpInterface, ReloadLike) )] pub struct Reload { diff --git a/hir/src/derive.rs b/hir/src/derive.rs index 3fddbcfd4..71f059c25 100644 --- a/hir/src/derive.rs +++ b/hir/src/derive.rs @@ -1,11 +1,16 @@ pub use midenc_hir_macros::operation; /// This macro is used to generate the boilerplate for operation trait implementations. +/// Super traits have to be declared as a comma separated list of traits, instead of the traditional +/// "+" separated list of traits. +/// Example: +/// +/// pub trait SomeTrait: SuperTraitA, SuperTraitB {} #[macro_export] macro_rules! derive { ( $(#[$outer:meta])* - $vis:vis trait $OpTrait:ident { + $vis:vis trait $OpTrait:ident $(:)? $( $ParentTrait:ident ),* $(,)? { $( $OpTraitItem:item )* @@ -21,7 +26,7 @@ macro_rules! derive { ) => { $crate::__derive_op_trait! { $(#[$outer])* - $vis trait $OpTrait { + $vis trait $OpTrait : $( $ParentTrait , )* { $( $OpTraitItem:item )* @@ -65,7 +70,7 @@ macro_rules! derive { macro_rules! __derive_op_trait { ( $(#[$outer:meta])* - $vis:vis trait $OpTrait:ident { + $vis:vis trait $OpTrait:ident $(:)? $( $ParentTrait:ident ),* $(,)? { $( $OpTraitItem:item )* @@ -78,7 +83,7 @@ macro_rules! __derive_op_trait { } ) => { $(#[$outer])* - $vis trait $OpTrait { + $vis trait $OpTrait : $( $ParentTrait + )* { $( $OpTraitItem )* @@ -87,12 +92,19 @@ macro_rules! __derive_op_trait { impl $crate::Verify for T { #[inline] fn verify(&self, context: &$crate::Context) -> Result<(), $crate::Report> { + $( + <$crate::Operation as $crate::Verify>::verify(self.as_operation(), context)?; + )* <$crate::Operation as $crate::Verify>::verify(self.as_operation(), context) } } impl $crate::Verify for $crate::Operation { fn should_verify(&self, _context: &$crate::Context) -> bool { + $( + self.implements::() + && + )* self.implements::() } @@ -136,7 +148,8 @@ mod tests { use crate::{ attributes::Overflow, - dialects::test::{self, Add}, + dialects::test::{self, Add, InvalidOpsWithReturn}, + pass::{Nesting, PassManager}, Builder, BuilderExt, Context, Op, Operation, Report, Spanned, }; @@ -191,25 +204,64 @@ mod tests { )); } - #[ignore = "until https://github.com/0xMiden/compiler/issues/378 is fixed"] #[test] #[should_panic = "expected 'u32', got 'i64'"] fn derived_op_verifier_test() { use crate::{SourceSpan, Type}; let context = Rc::new(Context::default()); + let block = context.create_block_with_params([Type::U32, Type::I64]); + + context.get_or_register_dialect::(); + context.registered_dialects(); + let (lhs, invalid_rhs) = { let block = block.borrow(); let lhs = block.get_argument(0).upcast::(); let rhs = block.get_argument(1).upcast::(); (lhs, rhs) }; - let mut builder = context.builder(); + + let mut builder = context.clone().builder(); builder.set_insertion_point_to_end(block); // Try to create instance of AddOp with mismatched operand types let op_builder = builder.create::(SourceSpan::default()); let op = op_builder(lhs, invalid_rhs, Overflow::Wrapping); - let _op = op.unwrap(); + let op = op.unwrap(); + + // Construct a pass manager with the default pass pipeline + let mut pm = PassManager::on::(context.clone(), Nesting::Implicit); + // Run pass pipeline + pm.run(op.as_operation_ref()).unwrap(); + } + + /// Fails if [`InvalidOpsWithReturn`] is created successfully. [`InvalidOpsWithReturn`] is a + /// struct that has differing types in its result and arguments, despite implementing the + /// [`SameOperandsAndResultType`] trait. + #[test] + #[should_panic = "expected 'i32', got 'u64'"] + fn same_operands_and_result_type_verifier_test() { + use crate::{SourceSpan, Type}; + + let context = Rc::new(Context::default()); + let block = context.create_block_with_params([Type::I32, Type::I32]); + let (lhs, rhs) = { + let block = block.borrow(); + let lhs = block.get_argument(0).upcast::(); + let rhs = block.get_argument(1).upcast::(); + (lhs, rhs) + }; + let mut builder = context.clone().builder(); + builder.set_insertion_point_to_end(block); + // Try to create instance of AddOp with mismatched operand types + let op_builder = builder.create::(SourceSpan::default()); + let op = op_builder(lhs, rhs); + let op = op.unwrap(); + + // Construct a pass manager with the default pass pipeline + let mut pm = PassManager::on::(context.clone(), Nesting::Implicit); + // Run pass pipeline + pm.run(op.as_operation_ref()).unwrap(); } } diff --git a/hir/src/dialects/test.rs b/hir/src/dialects/test.rs index 40b40b0eb..64d698f27 100644 --- a/hir/src/dialects/test.rs +++ b/hir/src/dialects/test.rs @@ -115,6 +115,7 @@ impl DialectRegistration for TestDialect { fn register_operations(info: &mut DialectInfo) { info.register_operation::(); + info.register_operation::(); info.register_operation::(); info.register_operation::(); info.register_operation::(); diff --git a/hir/src/dialects/test/ops/binary.rs b/hir/src/dialects/test/ops/binary.rs index 534b1be75..e48954c03 100644 --- a/hir/src/dialects/test/ops/binary.rs +++ b/hir/src/dialects/test/ops/binary.rs @@ -74,3 +74,25 @@ impl InferTypeOpInterface for Shl { Ok(()) } } + +/// Invalid operation that breaks the SameOperandsAndResultType trait (used for testing). +#[operation( + dialect = TestDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] +pub struct InvalidOpsWithReturn { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyUnsignedInteger, +} + +impl InferTypeOpInterface for InvalidOpsWithReturn { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + self.result_mut().set_type(Type::U64); + Ok(()) + } +} diff --git a/hir/src/ir/operation/builder.rs b/hir/src/ir/operation/builder.rs index 251de660a..dbd5b66e0 100644 --- a/hir/src/ir/operation/builder.rs +++ b/hir/src/ir/operation/builder.rs @@ -256,12 +256,6 @@ where unsafe { UnsafeIntrusiveEntityRef::from_raw(op.container().cast()) } }; - // Run op-specific verification - { - let op: super::EntityRef = op.borrow(); - op.verify(self.builder.context())?; - } - // Insert op at current insertion point, if set if self.builder.insertion_point().is_valid() { self.builder.insert(self.op); diff --git a/hir/src/ir/traits/types.rs b/hir/src/ir/traits/types.rs index 11305176f..ff1462e47 100644 --- a/hir/src/ir/traits/types.rs +++ b/hir/src/ir/traits/types.rs @@ -4,7 +4,7 @@ use core::fmt; use midenc_hir_type::PointerType; use midenc_session::diagnostics::Severity; -use crate::{derive, Context, Op, Operation, Report, Type}; +use crate::{derive, ir::value::Value, Context, Op, Operation, Report, Type}; /// OpInterface to compute the return type(s) of an operation. pub trait InferTypeOpInterface: Op { @@ -24,10 +24,10 @@ derive! { pub trait SameTypeOperands {} verify { - fn operands_are_the_same_type(op: &Operation, _context: &Context) -> Result<(), Report> { + fn operands_are_the_same_type(op: &Operation, context: &Context) -> Result<(), Report> { let mut operands = op.operands().iter(); if let Some(first_operand) = operands.next() { - let (_expected_ty, _set_by) = { + let (expected_ty, set_by) = { let operand = first_operand.borrow(); let value = operand.value(); (value.ty().clone(), value.span()) @@ -36,29 +36,29 @@ derive! { for operand in operands { let operand = operand.borrow(); let value = operand.value(); - let _value_ty = value.ty(); - // if value_ty != &expected_ty { - // return Err(context - // .session - // .diagnostics - // .diagnostic(Severity::Error) - // .with_message(::alloc::format!("invalid operation {}", op.name())) - // .with_primary_label( - // op.span(), - // "this operation expects all operands to be of the same type" - // ) - // .with_secondary_label( - // set_by, - // "inferred the expected type from this value" - // ) - // .with_secondary_label( - // value.span(), - // "which differs from this value" - // ) - // .with_help(format!("expected '{expected_ty}', got '{value_ty}'")) - // .into_report() - // ); - // } + let value_ty = value.ty(); + if value_ty != &expected_ty { + return Err(context + .session() + .diagnostics + .diagnostic(Severity::Error) + .with_message(::alloc::format!("invalid operation {}", op.name())) + .with_primary_label( + op.span, + "this operation expects all operands to be of the same type" + ) + .with_secondary_label( + set_by, + "inferred the expected type from this value" + ) + .with_secondary_label( + value.span(), + "which differs from this value" + ) + .with_help(format!("expected '{expected_ty}', got '{value_ty}'")) + .into_report() + ); + } } } @@ -68,12 +68,67 @@ derive! { } derive! { - /// Op expects all operands and results to be of the same type + /// Op expects all operands and results to be of the same type. + /// NOTE: Operations that implements this trait must also explicitely implement + /// [`SameTypeOperands`]. This can be achieved by using the "traits" filed in the [#operation] + /// macro. + /// Like so: /// - /// TODO(pauls): Implement verification for this. Ideally we could require `SameTypeOperands` - /// as a super trait, check the operands using its implementation, and then check the results - /// separately - pub trait SameOperandsAndResultType {} + /// #[operation ( + /// dialect = ArithDialect, + /// traits(UnaryOp, SameTypeOperands, SameOperandsAndResultType), + /// implements(InferTypeOpInterface, MemoryEffectOpInterface) + /// )] + /// pub struct SomeOp { + /// (...) + /// } + pub trait SameOperandsAndResultType: SameTypeOperands {} + + verify { + fn operands_and_result_are_the_same_type(op: &Operation, context: &Context) -> Result<(), Report> { + let mut operands = op.operands().iter(); + if let Some(first_operand) = operands.next() { + let (expected_ty, set_by) = { + let operand = first_operand.borrow(); + let value = operand.value(); + (value.ty().clone(), value.span()) + }; + + let results = op.results().iter(); + + for result in results { + let result = result.borrow(); + let value = result.as_value_ref().borrow(); + let result_ty = result.ty(); + + if result_ty != &expected_ty { + return Err(context + .session() + .diagnostics + .diagnostic(Severity::Error) + .with_message(::alloc::format!("invalid operation result {}", op.name())) + .with_primary_label( + op.span, + "this operation expects the operands and the results to be of the same type" + ) + .with_secondary_label( + set_by, + "inferred the expected type from this value" + ) + .with_secondary_label( + value.span(), + "which differs from this value" + ) + .with_help(format!("expected '{expected_ty}', got '{result_ty}'")) + .into_report() + ); + } + }; + } + + Ok(()) + } + } } /// An operation trait that indicates it expects a variable number of operands, matching the given diff --git a/hir/src/pass/manager.rs b/hir/src/pass/manager.rs index c0e2b9ad5..5ab019b01 100644 --- a/hir/src/pass/manager.rs +++ b/hir/src/pass/manager.rs @@ -845,6 +845,13 @@ impl OpToOpPassAdaptor { instrumentor: Option>, parent_info: Option<&PipelineParentInfo>, ) -> Result<(), Report> { + + if verify { + // We run an initial recursive verification, since this is the first verification done + // to the operations + Self::verify(&op, true)?; + } + assert!( instrumentor.is_none() || parent_info.is_some(), "expected parent info if instrumentor is provided"