0

How do I create an arbitrary unitary gate for an arbitrary number of qubits in Cirq?

Note: A solution for creating a single qubit gate is answered by Thomas W from this post. How do I create my own unitary matrices that I can apply to a circuit in Cirq?

glS
  • 24,708
  • 5
  • 34
  • 108
user3886914
  • 339
  • 2
  • 7

1 Answers1

1

You have 2 options depending on what your needs are:

  1. Use cirq.MatrixGate.
  • Pros: you can easily instantiate the gate based on a unitary matrix.
  • Cons: you can't customize it that easily.
  1. Create your own Gate class.
  • Pros: full flexibility, you can customize the features, diagram info, etc.
  • Cons: it's a bit more involved.

Using cirq.MatrixGate

cirq.MatrixGate lets you create a gate based on an arbitrary unitary. Nothing else is required. In the example below a 2 qubit QFT gate is created.

    QFT2 = np.array([[1, 1, 1, 1],
                     [1, 1j, -1, -1j],
                     [1, -1, 1, -1],
                     [1, -1j, -1, 1j]]) * 0.5
    my_qft2 = cirq.MatrixGate(QFT2)
    print(cirq.Circuit(my_qft2(q[0],q[1])))

Resulting in:

      ┌                                       ┐
      │ 0.5+0.j   0.5+0.j   0.5+0.j   0.5+0.j │
0: ───│ 0.5+0.j   0. +0.5j -0.5+0.j   0. -0.5j│───
      │ 0.5+0.j  -0.5+0.j   0.5+0.j  -0.5+0.j │
      │ 0.5+0.j   0. -0.5j -0.5+0.j   0. +0.5j│
      └                                       ┘
      │
1: ───#2──────────────────────────────────────────

Creating your own gate class

You will have to create a subclass of cirq.Gate and implement one of the following methods mandatorily: _num_qubits_, _qid_shape or num_qubits.

Then you can create magic methods to enrich your class with cirq protocols, like cirq.SupportsUnitary (implement the _unitary_ method) or cirq.SupportsCircuitDiagramInfo (implement the _circuit_diagram_info_ method).

An example implementing X gates applied to arbitrary number of qubits:

    from typing import Union, Iterable
import numpy as np

import cirq
from cirq.type_workarounds import NotImplementedType


class MultiXGate(cirq.Gate):
    def __init__(self, num_qubits):
        if num_qubits <= 0:
            raise ValueError("num_qubits should be > 0")
        self.num_qubits = num_qubits

    # this is mandatory (or alternatively, _qid_shape_ or num_qubits)
    def _num_qubits_(self) -> int:
        return self.num_qubits

    ## These are not mandatory but pretty important

    def _circuit_diagram_info_(self, _) -> Union[str, Iterable[str],
                                                 cirq.CircuitDiagramInfo]:
        return ["multi-X"] * self.num_qubits

    def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
        x = cirq.unitary(cirq.X)
        return cirq.kron(*([x] * self.num_qubits))


if __name__ == '__main__':
    q = cirq.LineQubit.range(4)
    circuit = cirq.Circuit(MultiXGate(len(q))(*q))
    print(circuit)
    c2 = cirq.Circuit([cirq.X(qubit) for qubit in q])
    assert np.allclose(circuit.unitary(), c2.unitary())

This will print the following:

0: ───multi-X───
      │
1: ───multi-X───
      │
2: ───multi-X───
      │
3: ───multi-X───
Balint Pato
  • 971
  • 5
  • 13