[딥러닝 일지] 오프라인에서 파이토치 모델 불러오기

Joonas' Note

[딥러닝 일지] 오프라인에서 파이토치 모델 불러오기 본문

AI/딥러닝

[딥러닝 일지] 오프라인에서 파이토치 모델 불러오기

2022. 3. 29. 22:35 joonas 읽는데 3분
  • 오류 메시지
  • 해결 방법 1
  • 해결 방법 2

이전 글 - [딥러닝 일지] Conv2d 알아보기


오류 메시지

VGG 같은 모델을 사용하기 위해 허브에서 불러올 때 아래처럼 연결되지 않는 경우가 있다.

import torchvision
model = torchvision.models.vgg16_bn(pretrained=True)
Downloading: "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth" to /root/.cache/torch/hub/checkpoints/vgg16_bn-6c64b313.pth
---------------------------------------------------------------------------
gaierror                                  Traceback (most recent call last)
/opt/conda/lib/python3.7/urllib/request.py in do_open(self, http_class, req, **http_conn_args)
   1349                 h.request(req.get_method(), req.selector, req.data, headers,
-> 1350                           encode_chunked=req.has_header('Transfer-encoding'))
   1351             except OSError as err: # timeout error

/opt/conda/lib/python3.7/http/client.py in request(self, method, url, body, headers, encode_chunked)
   1280         """Send a complete request to the server."""
-> 1281         self._send_request(method, url, body, headers, encode_chunked)
   1282 

/opt/conda/lib/python3.7/http/client.py in _send_request(self, method, url, body, headers, encode_chunked)
   1326             body = _encode(body, 'body')
-> 1327         self.endheaders(body, encode_chunked=encode_chunked)
   1328 

/opt/conda/lib/python3.7/http/client.py in endheaders(self, message_body, encode_chunked)
   1275             raise CannotSendHeader()
-> 1276         self._send_output(message_body, encode_chunked=encode_chunked)
   1277 

/opt/conda/lib/python3.7/http/client.py in _send_output(self, message_body, encode_chunked)
   1035         del self._buffer[:]
-> 1036         self.send(msg)
   1037 

/opt/conda/lib/python3.7/http/client.py in send(self, data)
    975             if self.auto_open:
--> 976                 self.connect()
    977             else:

/opt/conda/lib/python3.7/http/client.py in connect(self)
   1442 
-> 1443             super().connect()
   1444 

/opt/conda/lib/python3.7/http/client.py in connect(self)
    947         self.sock = self._create_connection(
--> 948             (self.host,self.port), self.timeout, self.source_address)
    949         self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

/opt/conda/lib/python3.7/socket.py in create_connection(address, timeout, source_address)
    706     err = None
--> 707     for res in getaddrinfo(host, port, 0, SOCK_STREAM):
    708         af, socktype, proto, canonname, sa = res

/opt/conda/lib/python3.7/socket.py in getaddrinfo(host, port, family, type, proto, flags)
    751     addrlist = []
--> 752     for res in _socket.getaddrinfo(host, port, family, type, proto, flags):
    753         af, socktype, proto, canonname, sa = res

gaierror: [Errno -3] Temporary failure in name resolution

During handling of the above exception, another exception occurred:

URLError                                  Traceback (most recent call last)
/tmp/ipykernel_33/2336501538.py in <module>
      1 import collections
      2 
----> 3 model = torchvision.models.vgg16_bn(pretrained=True)
      4 model.to(device)

/opt/conda/lib/python3.7/site-packages/torchvision/models/vgg.py in vgg16_bn(pretrained, progress, **kwargs)
    166         progress (bool): If True, displays a progress bar of the download to stderr
    167     """
--> 168     return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
    169 
    170 

/opt/conda/lib/python3.7/site-packages/torchvision/models/vgg.py in _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs)
     98     if pretrained:
     99         state_dict = load_state_dict_from_url(model_urls[arch],
--> 100                                               progress=progress)
    101         model.load_state_dict(state_dict)
    102     return model

/opt/conda/lib/python3.7/site-packages/torch/hub.py in load_state_dict_from_url(url, model_dir, map_location, progress, check_hash, file_name)
    569             r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
    570             hash_prefix = r.group(1) if r else None
--> 571         download_url_to_file(url, cached_file, hash_prefix, progress=progress)
    572 
    573     if _is_legacy_zip_format(cached_file):

/opt/conda/lib/python3.7/site-packages/torch/hub.py in download_url_to_file(url, dst, hash_prefix, progress)
    435     # certificates in older Python
    436     req = Request(url, headers={"User-Agent": "torch.hub"})
--> 437     u = urlopen(req)
    438     meta = u.info()
    439     if hasattr(meta, 'getheaders'):

/opt/conda/lib/python3.7/urllib/request.py in urlopen(url, data, timeout, cafile, capath, cadefault, context)
    220     else:
    221         opener = _opener
--> 222     return opener.open(url, data, timeout)
    223 
    224 def install_opener(opener):

/opt/conda/lib/python3.7/urllib/request.py in open(self, fullurl, data, timeout)
    523             req = meth(req)
    524 
--> 525         response = self._open(req, data)
    526 
    527         # post-process response

/opt/conda/lib/python3.7/urllib/request.py in _open(self, req, data)
    541         protocol = req.type
    542         result = self._call_chain(self.handle_open, protocol, protocol +
--> 543                                   '_open', req)
    544         if result:
    545             return result

/opt/conda/lib/python3.7/urllib/request.py in _call_chain(self, chain, kind, meth_name, *args)
    501         for handler in handlers:
    502             func = getattr(handler, meth_name)
--> 503             result = func(*args)
    504             if result is not None:
    505                 return result

/opt/conda/lib/python3.7/urllib/request.py in https_open(self, req)
   1391         def https_open(self, req):
   1392             return self.do_open(http.client.HTTPSConnection, req,
-> 1393                 context=self._context, check_hostname=self._check_hostname)
   1394 
   1395         https_request = AbstractHTTPHandler.do_request_

/opt/conda/lib/python3.7/urllib/request.py in do_open(self, http_class, req, **http_conn_args)
   1350                           encode_chunked=req.has_header('Transfer-encoding'))
   1351             except OSError as err: # timeout error
-> 1352                 raise URLError(err)
   1353             r = h.getresponse()
   1354         except:

URLError: <urlopen error [Errno -3] Temporary failure in name resolution>

해결 방법 1

Kaggle이나 Colab을 사용중이라면, 인터넷과 연결되지 않은 상태인 지 먼저 확인해본다.

kaggle 설정

해결 방법 2

사내 인트라넷에 세팅된 주피터 노트북과 같은 경우에는, 위와 같이 해결할 수 없을 수도 있다. (실제로 겪음)

에러 메시지를 보면 "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth" 와 같이, 어디서 다운로드를 받아오려 했는 지 알려준다.

해당 URL에서 pth 파일을 직접 다운로드 받아서, torch.load 처럼 모델을 로드해도 되는데 더 편한 방법이 있다.

실행하는 스크립트와 같은 위치에, 위에서 다운로드 받은 모델 pth 파일을 그대로 복사하고 아래처럼 설정한다.

os.environ['TORCH_HOME'] = './'

이렇게 하면 torchvision.models 모듈에서 모델을 읽을 때, 지정한 위치부터 확인하기 때문에 저장된 모델을 로드한다.

Comments