DDP 환경을 구성하고 scratch 부터 학습할때는 잘 동작하던 코드에서
backbone에 pretrained weight를 load 하니 바로 OOM이 발생하는 상황을 만나버렸다.
초기 상황 구성은 아래와 같았다.
# OOM이 발생한 코드
model = get_model()
checkpoint_path = r"~~~"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)
model = torch.nn.parallel.DistributedDataParallel(~~~)
단순히 초기 weight 만 읽어왔을 뿐인데 OOM이 터져버리니 당황스럽긴 한데
메모리 동작하는 꼴을 보니 0번 GPU 가 메모리 터져서 발생한 상황인 듯 했다.
추정키로는 각 GPU에 복사되어야 할 정보들이 0번 GPU로 먼저 다 올라오면서
감당이 안되고 OOM 이 발생한 것 같다.
이런 상황을 해결하기 위해서는 각 GPU로 map_location을 지정하여 직접 load 해 주면 문제가 해결된다.
# 수정된 코드
model = get_model()
checkpoint_path = r"~~~"
checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(args.local_ranl))
model.load_state_dict(checkpoint)
model = torch.nn.parallel.DistributedDataParallel(~~~)
'전공관련 > Deep Learning' 카테고리의 다른 글
[Pytorch] Sequential 모듈 내 레이어에 접근하자 (0) | 2023.09.06 |
---|---|
[Pytorch] No audio I/O backend is available. 에러를 해결하자. (0) | 2023.08.02 |
[Pytorch] Boolean value of Tensor with more than one value is ambiguous 에러를 해결하자. (0) | 2023.06.07 |
[MXNet] 데이터 리스트를 만들고 rec 파일로 만들어 보자 (0) | 2022.04.01 |
[Pytorch] pycharm 환경에서 torch.distributed.launch를 실행하자 (1) | 2021.07.13 |