머신러닝/딥러닝 공부

분류 (4) - 다중 분류 (Multiclass Classification) 코드 구현 본문

AI 공부/Machine Learning

분류 (4) - 다중 분류 (Multiclass Classification) 코드 구현

호사린가마데라닌 2021. 10. 29. 15:16

다중 분류 모델을 구현하기 위해 이번에는 Multiclass라는 클래스를 만들었습니다. 다른 모델들의 클래스와 코드가 거의 비슷하지만 가중치가 행렬 형태이기 때문에 행렬 연산이 들어갑니다.


저는 행렬의 연산을 계산하기 위해 np.dot() 함수를 사용하였습니다. np.dot() 함수는 인자로 들어온 두 행렬이 모두 1차원이라면 np.sum처럼 점곱(inner product)를 계산해서 리턴해주고, 2차원 행렬이 섞여있을 경우 행렬의 곱을 계산하여 리턴해줍니다.

 

import numpy as np

a=np.array([1,2,3])
b=np.array([2,4,6])

c=np.dot(a,b)
print(c)

np.dot()의 점곱 기능

 

 

import numpy as np

a=np.array([[1,2],[3,4]])
b=np.array([[-2],[1]])

c=np.dot(a,b)
print(c)

np.dot()의 행렬곱 기능

 


따라서 제 코드에서 np.dot()함수가 등장하면 행렬의 곱셈을 계산하는 과정이라고 이해하시면 됩니다.

제 코드에서 n은 x의 특성(feature)의 개수이고 m은 클래스의 개수입니다.

이제 코드를 보겠습니다.

 

 

 

 

그림의 forward(x)와 softmax(z) 함수, 그리고 loss(x,y) 함수를 구현해보겠습니다.

 

class Multiclass:
  def __init__(self,learning_rate=0.01):
    self.w=None #가중치 행렬
    self.b=None #바이어스 배열
    self.lr=learning_rate #학습률 
    self.losses=[] #매 에포크마다 손실 저장할 리스트
    self.weight_history=[] #매 에포크마다 가중치 저장할 리스트
    self.bias_history=[] #매 에포크마다 바이어스 저장할 리스트

  def forward(self,x):
    z=np.dot(x,self.w)+self.b
    z=np.clip(z,-100,None) #NaN 방지
    return z

  def softmax(self,z):
    exp_z=np.exp(z)
    a=exp_z/np.sum(exp_z)
    return a

  def loss(self,x,y):
    z=self.forward(x)
    a=self.softmax(z)
    return -np.sum(y*np.log(a)) #손실 계산 후 리턴

 

 

forward()와 softmax()는 처음 모델이 받은 입력값(x)를 통과시켜 z와 a값을 계산하여 줍니다. 시그모이드 함수와 마찬가지로 softmax함수에서 z값이 너무 작아 exp의 값이 0이 되는 상황을 방지하기 위해 numpy의 clip() 함수를 사용했습니다. 손실을 계산하기 위해 loss함수도 구현해주었습니다.

 

 

한 가지 신경 써야 할 점은 현재 입력값(x)은 1 x n 행렬, 가중치는 n x m 행렬이라는 것입니다. 행렬의 연산으로 인해 forward()함수를 통과하여 만들어진 z는 (1 x n) x (n x m) = (1 x m) 행렬입니다. softmax에서 행렬끼리의 연산으로 인해 행렬의 형태가 바뀌는 상황이 발생하지 않으므로 여전히 (1 x m) 행렬입니다. 

 

 

loss 함수는 크로스 엔트로피를 사용하기 때문에 그대로 구현해주었습니다.

 

 

다음은 back propagation 알고리즘을 이용하여 가중치와 바이어스를 업데이트 해주기 위해 가중치와 바이어스의 기울기를 구해주는  gradient() 함수입니다.

 

 

 

  def gradient(self,x,y):
    z=self.forward(x) #선형방정식을 통한 z값 산출
    a=self.softmax(z) #z값을 softmax에 통과시켜 a값 산출

    w_grad=-np.dot(x.reshape(-1,1),(y-a).reshape(1,-1)) #가중치의 기울기
    b_grad=-(y-a) #바이어스의 기울기

    return w_grad,b_grad

 

 

바이어스의 경우 1 x m 행렬(정확히는 len(b)=m인 리스트)이기 때문에 위의 코드와 같이 계산해주어 업데이트를 해주면 됩니다. 다만 가중치의 경우 n x m 행렬이고 입력값으로 받은 x는 1 x n 행렬이므로 그대로 np.dot()을 해주면 안 됩니다. 

 

 

x를 (n x 1) 행렬로 바꿔주고, (1 x m) 행렬인 (y-a)와 np.dot()을 해주어야 원하는 행렬인 (n x m) 의 가중치 기울기들을 원소로 갖는 행렬을 얻을 수가 있습니다. 이 때문에 위의 코드에서 reshape() 함수를 이용해서 행렬의 모양을 바꿔주었습니다.

 

 

마지막으로 모델을 훈련시키는 코드입니다.

 def fit(self,x_data,y_data,epochs=40):
    #가중치와 바이어스 초기화
    self.w=np.random.normal(0,1,(x_data.shape[1],y_data.shape[1])) #표준정규분포로 초기화 
    self.b=np.zeros(y_data.shape[1]) #0으로 초기화

    #에포크
    for epoch in range(epochs):
      l=0 #손실값을 계산할 변수
      w_grad=np.zeros((x_data.shape[1],y_data.shape[1]))
      b_grad=np.zeros(y_data.shape[1])
      
      for x,y in zip(x_data,y_data):
        l+=self.loss(x,y) #매 에포크마다 손실값 계산
        
        w_i,b_i=self.gradient(x,y) #가중치와 바이어스의 기울기를 계산
		
        w_grad+=w_i #가중치의 기울기 누적
        b_grad+=b_i #바이어스의 기울기 누적 
        
      self.w-=self.lr*(w_grad/len(y_data) #가중치 업데이트
      self.b-=self.lr*(b_grad/len(y_data) #바이어스 업데이트

      self.losses.append(l/len(y_data)) #손실 저장
      self.weight_history.append(self.w) #가중치 저장
      self.bias_history.append(self.b) #바이어스 저장
      
      print(f'epoch({epoch+1}) ===> loss : {l/len(y_data):.5f}')

 

 

매 에포크마다 손실을 계산하여 저장하고, 연산한 값을 이용하여 가중치와 바이어스를 업데이트 해주고 있습니다. 

 

 

이 모델을 이용하여 MNIST 손글씨 데이터셋과 MNIST Fashion 데이터셋을 학습시켜 정확도를 판단해 보겠습니다. 그전에 주어진 모델의 성능을 평가하기 위해 데이터셋을 전처리(Preprocessing)할 필요가 있습니다. 또한 머신러닝에서 모델을 학습할 때 보통 주어진 데이터셋을 훈련 세트(Trainig dataset), 검증 세트(Validation dataset) 그리고 테스트 세트(Test set)로 나누어 학습시킵니다. 이는 모델의 과소 적합(Under-fitting) 및 과대 적합(Over-fitting) 여부를 확인하기 위해서가 가장 주된 이유인데, 이 부분은 머신러닝에서 가장 주의해야하고 또 가장 중요한 부분이기 때문에 다음 포스팅에서 좀 더 자세하게 정리할 생각입니다. 

 

 

이 모델을 이용하여 MNIST 손글씨 데이터 세트를 학습시킵니다.

 

https://yhyun225.tistory.com/19

 

분류(5) - MNIST 숫자 손글씨 분류 모델( 다중 선형 분류 )

1) 다중 분류 모델 코드 이전 포스팅에서 수정했던 다중 분류 코드를 가져오겠습니다. https://yhyun225.tistory.com/15?category=964332 분류 (4) - 다중 분류 (Multiclass Classification) 코드 구현 다중 분류..

yhyun225.tistory.com

 

https://yhyun225.tistory.com/16

Comments