딥러닝

[HuggingFace] cannot import name 'PartialState' from 'accelerate' 오류 해결 방법

chanmuzi 2023. 5. 29. 13:40

간단한 딥러닝 모델을 구현하기 위해서 코드를 작성하던 도중, 이전에 본 적 없었던 에러를 만나게 됐습니다.

 

에러 메세지는 다음과 같습니다.

"""

cannot import name 'PartialState' from 'accelerate' (/opt/conda/lib/python3.10/site-packages/accelerate/__init__.py)

"""

 

실행한 코드는 다음과 같습니다.

from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer
from torch.utils.data import Dataset

model_name = "bert-base-uncased"
learning_rate = 1e-5
max_seq_length = 128
batch_size = 16
num_epochs = 5

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=1)
 

 

정확한 원인을 파악하지는 못했는데, 아마 관련 자료들이 전부 최신인 걸 봐서는 진짜 torch 업데이트로 인한 버전 이슈인가 싶습니다.

저도 로컬에서 torch를 2.0으로 최근 업데이트 했었고, 다른 가상환경에서도 관련 에러가 발생했을 때의 torch version이 2.0이더라구요.

 

어쨌든 이 문제는 굉장히 간단하게 해결할 수 있습니다.

!pip uninstall -y transformers accelerate
!pip install transformers accelerate

transforemrs, accelerate 두 개를 모두 삭제 후 재설치 해주시면 끝입니다.

(주피터 노트북 환경에서 uninstall을 하면 yes/no를 선택해야 하는 상황이 있는데, 이때 -y 옵션을 주게 되면 바로 uninstall이 진행됩니다)

물론 다른 라이브러리들과의 호환성을 위해서 torch의 버전을 낮춰주는 것도 방법이 될 수 있는데 저는 시도를 해보지 않았습니다.