전공관련/Deep Learning
[Pytorch] tensor의 차원을 바꿔보자
매직블럭
2019. 3. 15. 16:14
사용하다 보면 tensor의 차원을 바꿔야 할 경우가 있다.
ex) single data로 테스트를 하려는데 nn.module의 입력은 4차원 텐서일 경우 등.
이럴 경우 차원을 늘리거나 줄여야 할 때는 torch.squeeze() 와 torch.unsqueeze() 함수를 이용한다.
# a => torch.Size([1, 128, 128])
a = torch.squeeze(a) # a => torch.Size([128, 128])
a = torch.unsqueeze(a, 0) # a => torch.Size([1, 128, 128])
a = torch.unsqueeze(a, 0) # a => torch.Size([1, 1, 128, 128])
a = torch.unsqueeze(a, 3) # a => torch.Size([1, 1, 128, 1, 128])
a = torch.squeeze(a, 3) # a => torch.Size([1, 1, 128, 128])