A Convenient Hack for Tensor Shape Assertions

2024 May 31

If you're working with multidimensional tensors (eg. in numpy or pytorch), a helpful pattern is often to use pattern matching to get the sizes of various dimensions. Like this:batch, chan, w, h = x.shape. And sometimes you already know some of these dimensions, and want to assert that they have the correct values. Here is a convenient way to do that. Define the following class and single instance of it:
class _MustBe:
  """ class for asserting that a dimension must have a certain value.
      the class itself is private, one should import a particular object,
      "must_be" in order to use the functionality. example code:
      `batch, chan, must_be[32], must_be[32] = image.shape` """
  def __setitem__(self, key, value):
    assert key == value, "must_be[%d] does not match dimension %d" % (key, value)
must_be = _MustBe()
This hack overrides index assignment and replaces it with an assertion. To use, import must_be from the file where you defined it. Now you can do stuff like this:
batch, must_be[3] = v.shape
must_be[batch], l, n = A.shape
must_be[batch], must_be[n], m = B.shape
...