• Tistory
    • 태그
    • 위치로그
    • 방명록
    • 관리자
    • 글쓰기
Carousel 01
Carousel 02
Previous Next

[Pytorch] multiGPU 에서 학습한 모델을 singleGPU에서 사용하자2

전공관련/Deep Learning 2019. 12. 17. 17:35




[Pytorch] multiGPU 에서 학습한 모델을 singleGPU에서 사용하자

지난난 글에서 사용한한 방법은 저장 단계에서 적용해야 하는 문제가 있어서 

이번에는 다시 state_dict를 읽어서 key에서 module을 제거하고 다시 넣어주는 방법을 사용.


우선 state_dict의 key에서 module 을 제거한 새로운 dict를 반환하는 함수 선언

def get_state_dict(origin_dict):
    old_keys = origin_dict.keys()
    new_dict = {}

    for ii in old_keys:
        temp_key = str(ii)
        if temp_key[0:7] == "module.":
            new_key = temp_key[7:]
        else:
            new_key = temp_key

        new_dict[new_key] = origin_dict[temp_key]
    return new_dict

 

이후 네트워크 가중치를 읽을때 위 함수를 통해 다시 읽어주면 문제없이 사용 가능하다.

# 위쪽은 생략

net = CNN_Network(num_classes=1000)
net.to(device)
net.eval()

checkpoint = torch.load('model_path/model_file.pth.tar', map_location=device)
checkpoint_dict = get_state_dict(checkpoint["state_dict"])
net.load_state_dict(checkpoint_dict)

 

 

(21.01.06. 수정)

위의 get_state_dict를 람다식을 이용하여 깔끔하게 구현된 코드를 확인하여 기록.

(출처 : github.com/biubug6/Pytorch_Retinaface / RetinaFace)

 

def remove_prefix(state_dict, prefix):
    ''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
    print('remove prefix \'{}\''.format(prefix))
    f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
    return {f(key): value for key, value in state_dict.items()}
    
# 사용할때는
pretrained_dict = remove_prefix(pretrained_dict, 'module.')
저작자표시 (새창열림)

'전공관련 > Deep Learning' 카테고리의 다른 글

[Onnx] pytorch model을 onnx로 변환하여 사용하자  (1) 2020.02.26
[Pytorch] Custom Dataloader를 사용하자  (0) 2019.12.23
[Pytorch] pytorch 와 tensorboard를 같이 써보자.  (0) 2019.11.15
[Pytorch] multiGPU 에서 학습한 모델을 singleGPU에서 사용하자  (0) 2019.03.18
[Pytorch] tensor의 차원을 바꿔보자  (0) 2019.03.15
블로그 이미지

매직블럭

작은 지식들 그리고 기억 한조각

,

카테고리

  • 살다보니.. (449)
    • 주절거림 (3)
    • 취미생활 (36)
      • 지식과 지혜 (3)
      • 풍경이 되어 (4)
      • Memories (17)
      • 엥겔지수를 높여라 (2)
    • mathematics (6)
      • Matrix Computation (2)
      • RandomProcesses (3)
    • English.. (8)
    • Programming (147)
      • C, C++, MFC (51)
      • C# (1)
      • OpenCV (17)
      • Python (58)
      • Git, Docker (3)
      • Matlab (4)
      • Windows (3)
      • Kinect V2 (2)
      • 기타 etc. (8)
    • 전공관련 (80)
      • Algorithm (6)
      • Deep Learning (54)
      • 실습 프로그램 (4)
      • 주워들은 용어정리 (8)
      • 기타 etc. (8)
    • Computer (118)
      • Utility (21)
      • Windows (31)
      • Mac (4)
      • Ubuntu, Linux (58)
      • NAS (2)
      • Embedded, Mobile (2)
    • IT, Device (41)
      • 제품 사용기, 개봉기 (14)
      • 스마트 체험단 신청 (27)
    • Wish List (3)
    • TISTORY TIP (5)
    • 미분류. 수정중 (1)

태그목록

  • 매트랩
  • Deep Learning
  • CStdioFile
  • 갤럭시노트3
  • 큐슈
  • 에누리닷컴
  • 딥러닝
  • LIBSVM
  • ReadString
  • review
  • 포르투갈
  • random variable
  • 크롬
  • 칼로리 대폭발
  • SVM
  • 후쿠오카
  • DSLR
  • portugal
  • 일본
  • Computer Tip
  • ColorMeRad
  • utility
  • matlab
  • 스마트체험단
  • Convolutional Neural Networks
  • function
  • matlab function
  • 매트랩 함수
  • 오봉자싸롱
  • DeepLearning

달력

«   2025/07   »
일 월 화 수 목 금 토
1 2 3 4 5
6 7 8 9 10 11 12
13 14 15 16 17 18 19
20 21 22 23 24 25 26
27 28 29 30 31
07-12 16:29

LATEST FROM OUR BLOG

RSS 구독하기

BLOG VISITORS

  • Total :
  • Today :
  • Yesterday :

Copyright © 2015 Socialdev. All Rights Reserved.

티스토리툴바