A Convenient Hack for Tensor Shape Assertions
2024 May 31
If you're working with multidimensional tensors (eg. in numpy or pytorch), a helpful pattern is often to use pattern matching to get the sizes of various dimensions. Like this:
batch, chan, w, h = x.shape
. And sometimes you already know some of these dimensions, and want to assert that they have the correct values. Here is a convenient way to do that. Define the following class and single instance of it:class _MustBe:
""" class for asserting that a dimension must have a certain value.
the class itself is private, one should import a particular object,
"must_be" in order to use the functionality. example code:
`batch, chan, must_be[32], must_be[32] = image.shape` """
def __setitem__(self, key, value):
assert key == value, "must_be[%d] does not match dimension %d" % (key, value)
must_be = _MustBe()
This hack overrides index assignment and replaces it with an assertion. To use, import
must_be
from the file where you defined it. Now you can do stuff like this:batch, must_be[3] = v.shape
must_be[batch], l, n = A.shape
must_be[batch], must_be[n], m = B.shape
...
update, 2024 Sep 29:
Sometimes it's convenient to be able to match multiple dimensions at once.
class _MustBe:
""" class for asserting that a dimension must have a certain value.
the class itself is private, one should import a particular object,
"must_be" in order to use the functionality. example code:
`batch, chan, mustbe[32], mustbe[32] = image.shape`
`*must_be[batch, 20, 20], chan = tens.shape` """
def __setitem__(self, key, value):
if isinstance(key, tuple):
assert key == tuple(value), "must_be[%s] does not match dimension %s" % (str(key), str(value))
else:
assert key == value, "must_be[%d] does not match dimension %d" % (key, value)
must_be = _MustBe()