인공지능 대학원/AI 영상처리
CIFAR-10로 CNN 이미지 분류
열쩡왔쩡
2025. 3. 31. 14:56
# 필요한 라이브러리 불러오기
import torch # 딥러닝 기본 라이브러리
from torch import nn # 신경망을 쉽게 만들 수 있는 도구
from torch.utils.data import DataLoader # 데이터를 배치(batch) 단위로 불러오는 도구
import torchvision # 이미지 데이터셋 관련 도구 모음
from torchvision import datasets, transforms
import matplotlib.pyplot as plt # 그래프 그릴 때 사용
import numpy as np # 수학 계산을 쉽게 해주는 도구
# GPU가 가능하면 GPU 사용하고, 아니면 CPU 사용
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"사용 중인 디바이스: {device}")
하이퍼파라미터 설정 (모델 훈련에 영향을 주는 값들)
learning_rate = 0.001 # 모델이 학습할 때 얼마나 빠르게 배울지 결정
epochs = 1 # 데이터셋을 몇 번 반복해서 학습할지 (연습용으로 1회만 함)
batch_size = 100 # 한번에 학습할 이미지 개수
drop_out = 0.3 # (사용 안 했지만) 과적합 방지를 위한 드롭아웃 비율
CIFAR-10 데이터셋 불러오기
# 이미지 데이터 전처리: 숫자로 바꾸고 정규화
transform = transforms.Compose([
transforms.ToTensor(), # 이미지를 숫자(Tensor)로 변환
transforms.Normalize((0.5, 0.5, 0.5), # 평균값을 0.5로 맞추고
(0.5, 0.5, 0.5)) # 표준편차도 0.5로 맞춤 (이미지 명도 균형 맞추기)
])
# CIFAR-10 학습용 데이터 다운로드
training_data = datasets.CIFAR10(root='CIFAR10_data/',
train=True,
transform=transform,
download=True)
# CIFAR-10 테스트용 데이터 다운로드
test_data = datasets.CIFAR10(root='CIFAR10_data/',
train=False,
transform=transform,
download=True)
# 데이터를 한 번에 batch_size만큼 불러오기 위한 설정
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)
CNN(합성곱 신경망) 모델 만들기
# CNN 모델 클래스 정의
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
# 첫 번째 합성곱 층: 입력 채널 3개(RGB), 출력 채널 32개
self.layer1 = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), # 3x32x32 → 32x32x32
nn.ReLU(), # 비선형 함수 (복잡한 패턴 학습에 도움)
nn.MaxPool2d(kernel_size=2, stride=2) # 32x32x32 → 32x16x16 (절반 크기로 줄이기)
)
# 두 번째 합성곱 층: 입력 채널 32개 → 출력 채널 64개
self.layer2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), # 32x16x16 → 64x16x16
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2) # 64x16x16 → 64x8x8
)
# 마지막 단계: 완전연결층 (8x8x64 → 10 클래스)
self.fc = nn.Linear(8*8*64, 10, bias=True) # CIFAR-10은 총 10개 클래스 (비행기, 개구리 등)
nn.init.xavier_uniform_(self.fc.weight) # 가중치 초기화 (학습 잘되게 하기 위함)
def forward(self, x):
out = self.layer1(x) # layer1 통과
out = self.layer2(out) # layer2 통과
out = out.view(out.size(0), -1) # 데이터를 일렬로 펼치기 (벡터화)
out = self.fc(out) # 완전연결층 통과
return out
모델 학습하기
# 모델 인스턴스 만들고, GPU로 옮기기
model = CNN().to(device)
# 손실 함수: 정답과 예측 값의 차이를 계산
criterion = nn.CrossEntropyLoss().to(device)
# 최적화 도구: Adam (학습을 자동으로 도와줌)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# 학습 시작
total_batch = len(train_dataloader)
print('학습 시작! 잠시 기다려 주세요...')
for epoch in range(epochs):
avg_loss = 0 # 평균 손실값 저장할 변수
for X, Y in train_dataloader:
X = X.to(device)
Y = Y.to(device)
optimizer.zero_grad() # 이전 계산 기록 초기화
hypothesis = model(X) # 예측값 계산
cost = criterion(hypothesis, Y) # 실제 정답과 비교한 손실값
cost.backward() # 역전파로 기울기 계산
optimizer.step() # 가중치 업데이트
avg_loss += cost / total_batch
print(f'[Epoch {epoch+1}] 평균 손실값: {avg_loss:.6f}')
print('학습 완료!')
테스트 데이터로 정확도 확인하기
# 모델을 평가 모드로 전환하고, 테스트 정확도 계산
with torch.no_grad(): # 평가할 때는 기울기 계산 안 함 (속도↑)
X_test = test_data.data # numpy 형태의 이미지
X_test = torch.tensor(X_test).permute(0, 3, 1, 2).float() / 255.0 # 채널 순서 맞추기 + 정규화
X_test = transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))(X_test) # 테스트 이미지도 정규화
X_test = X_test.to(device)
Y_test = torch.tensor(test_data.targets).to(device)
prediction = model(X_test) # 예측 수행
correct_prediction = torch.argmax(prediction, 1) == Y_test # 예측 vs 실제 비교
accuracy = correct_prediction.float().mean() # 평균 정확도 계산
print('테스트 정확도(Accuracy): {:.2f}%'.format(accuracy.item() * 100))
정리
- CIFAR-10은 10개의 카테고리를 가진 32x32 크기의 RGB 이미지 데이터셋
- CNN 모델은 이 이미지들을 보고 각 이미지가 어떤 것인지 분류하게 학습됨
- 모델 구조는:
- 합성곱 → 활성화(ReLU) → 맥스풀링 두 번 반복
- 완전 연결층(Fully Connected Layer) 로 결과 출력
MNIST 와 CIFAR-10 차이점
이미지 크기 | 28x28 | 32x32 | CIFAR은 더 큼 |
채널 수 | 1 (흑백) | 3 (RGB) | Conv2D in_channels 변경 필요 |
마지막 fc 입력 크기 | 7*7*64 | 8*8*64 | MaxPool 후 크기 계산 결과 |
데이터셋 | datasets.MNIST | datasets.CIFAR10 | 과제 요구사항 |
정답 라벨 접근 방식 | test_data.test_labels | test_data.targets | CIFAR10은 .targets 사용 |
이미지 reshape 방식 | (N, 1, 28, 28) | (N, 3, 32, 32) | 채널 수 반영 |
728x90