1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| import torch a = torch.arange(12) a = a.reshape((2, 2, 3)) print(a) print(a.shape) b = torch.arange(16).reshape(4, 4) print(b) print(b.sum(axis=0)) print(b.sum(axis=1)) print(b.sum()) c = torch.arange(1, 17).reshape(4, 4) print(c) print(b * c) print(sum(b * c)) print(b @ c) print(torch.mm(b, c))
|