2023-09-29

Does PyTorch support stride_tricks as in numpy.lib.stride_tricks.as_strided?

It is possible to make cool things by changing the strides of an array in Numpy like this:

import numpy as np
from numpy.lib.stride_tricks import as_strided

a = np.arange(15).reshape(3,5)

print(a)
# [[ 0  1  2  3  4]
#  [ 5  6  7  8  9]
#  [10 11 12 13 14]]

b = as_strided(a, shape=(3,3,3), strides=(a.strides[-1],)+a.strides)

print(b)
# [[[ 0  1  2]
#   [ 5  6  7]
#   [10 11 12]]

#  [[ 1  2  3]
#   [ 6  7  8]
#   [11 12 13]]

#  [[ 2  3  4]
#   [ 7  8  9]
#   [12 13 14]]]


# Get 3x3 sums of a, for example
print(b.sum(axis=(1,2)))
# [54 63 72]

I searched a similar method in PyTorch and found as_strided, but it does not support strides which makes an element have multiple indices referring to it, as the warning says:

The constructed view of the storage must only refer to elements within the storage or a runtime error will be thrown, and if the view is “overlapped” (with multiple indices referring to the same element in memory) its behavior is undefined.

In particular it says that the behavior is undefined for the example above where elements have multiple indices.

Is there a way to make this work (with documented, specified behavior)? If not, then why PyTorch does not support this?



No comments:

Post a Comment