MLIR Pass Rewriting

I am trying to implement a CNOT-propagation Pass in MLIR in Catalyst. An example .mlir is as follows:

module {
  func.func @my_circuit( ) -> (!quantum.bit, !quantum.bit) {
    %0 = quantum.alloc( 2) : !quantum.reg
    %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
    %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit
    %3 = quantum.custom "PauliX"() %1 : !quantum.bit
    %out_qubits:2 = quantum.custom "CNOT"() %3, %2 : !quantum.bit, !quantum.bit
    return %out_qubits#0, %out_qubits#1 : !quantum.bit, !quantum.bit
  }
}

// Expected Output
//module {
//  func.func @my_circuit( ) -> (!quantum.bit, !quantum.bit) {
//    %0 = quantum.alloc( 2) : !quantum.reg
//    %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
//    %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit
//    %out_qubits:2 = quantum.custom "CNOT"() %1, %2 : !quantum.bit, !quantum.bit
//    %3 = quantum.custom "PauliX"() %out_qubits#0 : !quantum.bit
//    %4 = quantum.custom "PauliX"() %out_qubits#1 : !quantum.bit
//    return %3, %4 : !quantum.bit, !quantum.bit
//  }
//}

The idea here is that X gate on control propagates to target after CNOT.

Here’s the rewrite logic I have:

auto cnotOp = cast<quantum::CustomOp>(op);
Operation *definingOp = cnotOp.getInQubits().front().getDefiningOp();
auto xOp = cast<quantum::CustomOp>(definingOp);

mlir::Location opLoc = op->getLoc();
// Create new CNOT operation with original non-X input
SmallVector<mlir::Value> cnotInQubits;
cnotInQubits.push_back(xOp.getInQubits().front()); // Use input to X gate
cnotInQubits.push_back(cnotOp.getInQubits().back());

auto newCnotOp = rewriter.create<quantum::CustomOp>(
	    	opLoc,
	    	cnotOp.getOutQubits().getTypes(),
	    	ValueRange{},
	    	cnotOp.getParams(),
	    	cnotInQubits,
	    	"CNOT",
	    	nullptr,
	    	ValueRange{},
	    	ValueRange{});

// Create X gates operating on CNOT outputs
auto xOp1 = rewriter.create<quantum::CustomOp>(
	    	opLoc,
	        newCnotOp.getOutQubits().front().getType(),
	    	ValueRange{},
        	xOp.getParams(),
	        newCnotOp.getOutQubits().front(),
	       "PauliX",
	        nullptr,
	        ValueRange{},
	        ValueRange{});

auto xOp2 = rewriter.create<quantum::CustomOp>(
	    	opLoc,
	    	newCnotOp.getOutQubits().back().getType(),
	    	ValueRange{},
	    	xOp.getParams(),
	        newCnotOp.getOutQubits().back(),
	        "PauliX",
	        nullptr,
	        ValueRange{},
	        ValueRange{});

SmallVector<mlir::Value> newOp;
newOp.push_back(newCnotOp.getOutQubits().front());
newOp.push_back(xOp1.getOutQubits().front());
newOp.push_back(xOp2.getOutQubits().front());
rewriter.replaceOp(cnotOp, newOp);
rewriter.eraseOp(xOp); // Remove original X gate

return success();

The trace output I get is:

catalyst-cli: /home/patel/projects/catalyst/mlir/llvm-project/mlir/lib/IR/PatternMatch.cpp:135: virtual void mlir::RewriterBase::replaceOp(mlir::Operation *, mlir::ValueRange): Assertion `op->getNumResults() == newValues.size() && "incorrect # of replacement values"' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.	Program arguments: ./build/bin/catalyst-cli ../c3.mlir --tool=opt --catalyst-pipeline=pipe(cnot-propagation{func-name=my_circuit}) --mlir-print-ir-after-all --mlir-print-stacktrace-on-diagnostic
Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
0  catalyst-cli 0x00005ec2a24a60f7
1  catalyst-cli 0x00005ec2a24a3d1e
2  catalyst-cli 0x00005ec2a24a677a
3  libc.so.6    0x000072eb78842520
4  libc.so.6    0x000072eb788969fc pthread_kill + 300
5  libc.so.6    0x000072eb78842476 raise + 22
6  libc.so.6    0x000072eb788287f3 abort + 211
7  libc.so.6    0x000072eb7882871b
8  libc.so.6    0x000072eb78839e96
9  catalyst-cli 0x00005ec2a23d9a26
10 catalyst-cli 0x00005ec29e4ac952
11 catalyst-cli 0x00005ec2a225d9a9
12 catalyst-cli 0x00005ec2a225a2cf
13 catalyst-cli 0x00005ec2a22307a8
14 catalyst-cli 0x00005ec2a222cf7c
15 catalyst-cli 0x00005ec29e4ab83b
16 catalyst-cli 0x00005ec2a2282904
17 catalyst-cli 0x00005ec2a2282f31
18 catalyst-cli 0x00005ec2a22854b2
19 catalyst-cli 0x00005ec29cf6a68e
20 catalyst-cli 0x00005ec29cf6b916
21 catalyst-cli 0x00005ec29cf6f84c
22 libc.so.6    0x000072eb78829d90
23 libc.so.6    0x000072eb78829e40 __libc_start_main + 128
24 catalyst-cli 0x00005ec29cf65ee5
[1]    1897019 IOT instruction (core dumped)  ./build/bin/catalyst-cli ../c3.mlir --tool=opt  --mlir-print-ir-after-all 

I am very new to MLIR and LLVM in general so my logic for rewrite operation might be unnecessary complicated – in which case, feel free to post any suggestions.

Hi @patelvyom, thanks for your interest in Catalyst!

It looks like your rewrite triggered an MLIR assertion. From your log:

mlir::RewriterBase::replaceOp(mlir::Operation *, mlir::ValueRange): Assertion `op->getNumResults() == newValues.size() && "incorrect # of replacement values"' failed.

CNOT is an operation on 2 qubits, but your code

SmallVector<mlir::Value> newOp;
newOp.push_back(newCnotOp.getOutQubits().front());
newOp.push_back(xOp1.getOutQubits().front());
newOp.push_back(xOp2.getOutQubits().front());
rewriter.replaceOp(cnotOp, newOp);

is trying to replace it with 3 values.

@David_Ittah Thanks for your response!

I tried just replacing the CNOT operation which results in a type-casting error:

// SmallVector<mlir::Value> newOp;
// newOp.push_back(newCnotOp.getOutQubits().front());
// newOp.push_back(xOp1.getOutQubits().front());
// newOp.push_back(xOp2.getOutQubits().front());
rewriter.replaceOp(cnotOp, newCnotOp);
catalyst-cli: /home/patel/projects/catalyst/mlir/llvm-project/llvm/include/llvm/Support/Casting.h:578: decltype(auto) llvm::cast(From *) [To = catalyst::quantum::CustomOp, From = mlir::Operation]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.

Do you have any hints? I am trying to replace an X_1 gate followed by CNOT_12 with CNOT_12 X_1 X_2.

Update: I was able to fix the type-casting error by replacing cast with dyn_cast_or_null. However, now I am getting a segmentation fault with no clear error message.

Here’s latest source code for reference:

mlir::LogicalResult matchAndRewrite(CustomOp op, mlir::PatternRewriter &rewriter) const override
{
	LLVM_DEBUG(dbgs() << "Rewriting the following operation:\n" << op << "\n");
	StringRef opGateName = op.getGateName();
	if (opGateName != "CNOT"){
	    return failure();
	}

	auto parentOp = dyn_cast_or_null<CustomOp>(op.getInQubits().front().getDefiningOp());
	StringRef parentOpGateName = parentOp.getGateName();
	if (!PropagationOps.contains(parentOpGateName)){
	    return failure();
	}
	if (op.getInQubits().size() != 2 || parentOp.getOutQubits().size() != 1){
	    return failure();
	}

	mlir::Value inCtrlQubit = op.getInQubits().front();
	mlir::Value inTargQubit = op.getInQubits().back();
	mlir::Value parentOutQubit = parentOp.getOutQubits().front();

	bool foundCtrlMatch = (inCtrlQubit == parentOutQubit) ? true : false;
	bool foundNonCtrlMatch = (inTargQubit == parentOutQubit) ? true : false;

	// If neither control nor target match with parent's output, pattern doesn't match
	if (!foundNonCtrlMatch && !foundCtrlMatch){
	    return failure();
	}
	
	dbgs() << "CNOT propagation pattern matched\n";
	auto cnotOp = op;
	auto xOp = dyn_cast_or_null<quantum::CustomOp>(op.getInQubits().front().getDefiningOp());
	TypeRange outQubitsTypes = cnotOp.getOutQubits().getTypes();
	mlir::Location opLoc = op.getLoc();
	// Create new CNOT operation with original non-X input
	SmallVector<mlir::Value> cnotInQubits;
	cnotInQubits.push_back(xOp.getInQubits().front()); // Use input to X gate
	cnotInQubits.push_back(cnotOp.getInQubits().back());

	auto newCnotOp = rewriter.create<quantum::CustomOp>(
		            opLoc,
		            outQubitsTypes,
		            ValueRange{},
		            cnotOp.getParams(),
		            ValueRange(cnotInQubits),
		            "CNOT",
		            nullptr,
		            ValueRange{},
		            ValueRange{});
	 op.replaceAllUsesWith(newCnotOp);
 return success();
}

Error:

CNOT propagation pattern matched
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.	Program arguments: ./mlir/build/bin/catalyst-cli c3.mlir --tool=opt --catalyst-pipeline=pipe(cnot-propagation) --mlir-print-ir-after-all
Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
0  catalyst-cli 0x00006110070ba077 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) + 39
1  catalyst-cli 0x00006110070b7c2e llvm::sys::RunSignalHandlers() + 238
2  catalyst-cli 0x00006110070ba72a
3  libc.so.6    0x00007abbd3042520
4  catalyst-cli 0x0000611001e31794 catalyst::quantum::CustomOp::getGateName() + 4
5  catalyst-cli 0x000061100255ea99
6  catalyst-cli 0x0000611006dacc79
7  catalyst-cli 0x0000611006da958f mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>) + 911
8  catalyst-cli 0x0000611006d7ca88
9  catalyst-cli 0x0000611006d7925c mlir::applyPatternsAndFoldGreedily(mlir::Region&, mlir::FrozenRewritePatternSet const&, mlir::GreedyRewriteConfig, bool*) + 2092
10 catalyst-cli 0x000061100255e29b
11 catalyst-cli 0x0000611006dd1f74 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 676
12 catalyst-cli 0x0000611006dd2711 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 337
13 catalyst-cli 0x0000611006dd4e02 mlir::PassManager::run(mlir::Operation*) + 914
14 catalyst-cli 0x0000611000eaa309 runPipeline(mlir::PassManager&, catalyst::driver::CompilerOptions const&, catalyst::driver::CompilerOutput&, catalyst::driver::Pipeline&, bool, mlir::ModuleOp) + 105
15 catalyst-cli 0x0000611000eaab50 runLowering(catalyst::driver::CompilerOptions const&, mlir::MLIRContext*, mlir::ModuleOp, catalyst::driver::CompilerOutput&, mlir::TimingScope&) + 1344
16 catalyst-cli 0x0000611000eabc54 QuantumDriverMain(catalyst::driver::CompilerOptions const&, catalyst::driver::CompilerOutput&, mlir::DialectRegistry&) + 2884
17 catalyst-cli 0x0000611000eafb85 QuantumDriverMainFromCL(int, char**) + 3349
18 libc.so.6    0x00007abbd3029d90
19 libc.so.6    0x00007abbd3029e40 __libc_start_main + 128
20 catalyst-cli 0x0000611000ea5f25 _start + 37
[1]    2319695 segmentation fault (core dumped)  ./mlir/build/bin/catalyst-cli c3.mlir --tool=opt  --mlir-print-ir-after-all

There are different types of casts available in MLIR.

  • cast<OpTy>(Operation *) asserts that the type of the Operation* is what you’re trying to cast to (i.e. OpTy)
  • dyn_cast<OpTy>(Operation *) doesn’t make this assertion, but instead returns nullptr if the the type is not the one expect
  • dyn_cast_or_null<OpTy>(Operation *) is like dyn_cast, but also accepts nullptr as a valid input (this allows you to chain these types of functions together which potentially return null)

For the last two forms, you should only use them if you are checking the return type, otherwise you are potentially working with a nullptr leading to segfaults like yours.

So in short the problem is that you’re casting an op to something it isn’t.

Not sure what you’re entire code looks like, but if you were iterating over arbitrary operations you could do something like this:

if (auto gate = dyn_cast<quantum::CustomOp>(op)) {
  if (gate.getGateName() != "CNOT")
    return;

  auto parent = gate.getInQubits().front().getDefiningOp();
  if (auto parentGate = dyn_cast<quantum::CustomOp>(parent)) {
    if (parentGate.getGateName() != "PauliX")
      return;

    // now you know you have a successful match
  }
}