Assignment 9
과제 9
과제 | 9 |
---|---|
점수 | 0 |
제출 | False |
파일 | 다운로드 |
MNIST dataset은 0 부터 9까지의 손글씨가 쓰여진 70000장의 이미지들로 이루어져 있습니다.
MNIST examples: https://en.wikipedia.org/wiki/MNIST_database
각 이미지는 \(28 \times 28\)의 크기를 가지며, 정답숫자와 함께 제공됩니다. 우리는 이 과제에서, best approximation을 이용한 MNIST의 손글씨 이미지 분류 작업을 해보려 합니다.
1. MNIST dataset 불러오기
문서에 첨부된 파일을 다운받은 후 압축을 해제하세요. 총 네개의 mat
파일을 얻어야 합니다.
파일명 | 설명 |
---|---|
train_image.mat |
60000장의 손글씨 이미지들 입니다. train_image의 이미지들을 이용하여 best approximate solution을 구하려 합니다. |
train_target.mat |
train_image의 각 이미지에 대한 정답 숫자들을 순서대로 포함하고 있습니다. |
test_image.mat |
10000장의 손글씨 이미지들 입니다. train_image를 통해 얻어낸 solution을 검증하기 위해 사용합니다. |
test_target.mat |
test_image의 정답 숫자들 입니다. |
매트랩에 다음 명령어들을 순서대로 입력하여 작업 공간
에 데이터를 불러옵니다. 또한, 각 데이터의 형을 double
로 지정합니다.
load('training_image.mat'); train_x = double(train_x);
load('training_target.mat'); train_y = double(train_y);
load('test_image.mat'); test_x = double(test_x);
load('test_target.mat'); test_y = double(test_y);
작업 공간
에 4개의 변수가 생성되었는지 확인하세요.
Name Size Bytes Class
test_x 28x28x10000 62720000 double
test_y 1x10000 80000 double
train_x 28x28x60000 376320000 double
train_y 1x60000 480000 double
2. Normal equation 구성 및 풀이 [0과 1 분류]
우선은 0과 1 단 두개의 숫자만 구분하도록 하는 solution을 구해보겠습니다. 정답 숫자가 0 또는 1인 데이터만 추출하여 변수를 재구성합니다.
% find indices whose value is either 0 or 1.
train_idx = [find(train_y == 0), find(train_y == 1)];
test_idx = [find(test_y == 0), find(test_y == 1)];
% Leave only data that are 0 or 1.
train_x = train_x(:, :, train_idx); train_y = train_y(:, train_idx);
test_x = test_x(:, :, test_idx); test_y = test_y(:, test_idx);
각 행이 하나의 이미지로 이루어진 데이터 행렬 \(M\)을 구성하기 위해 train_x
에 reshape
명령어를 사용합니다.
M = reshape(train_x, 28^2, [])'; % [] means that 'matche dimension automatically'.
y = train_y';
(거의 자명하게도) \(Mx=y\)를 만족하는 \(y\)는 존재하지 않습니다. 따라서 우리는 best approximate solution을 구하려고 합니다.
-
간단하게, 매트랩의
\
구문을 사용하여 solution을 구할 수 있습니다.\
는 자동으로 가장 적합한 solution을 구해줄 것 입니다.x = M \ y; % the most simple way.
하지만, 이 경우
\
는 상당히 느리게 작동할 뿐 아니라, 우리 수준에서는 어떻게 solution을 구해주는지 이해하기 힘듭니다. -
Normal equation을 직접 구성하여 solution을 구할 수도 있습니다. 이 때, linear system의 해를 구해주는
linsolve
명령어를 사용합니다.
(이 경우,linsolve
는 \(LDU\)-분해를 이용하여 해를 구합니다. 또한,\
를 사용해도 역시 같은 알고리즘을 이용합니다.)A = M' * M; b = M' * y; x = linsolve(A, b);
만약, 여러분이 위 코드를 그대로 사용하여 solution을 구했다면, 구해진 solution은 매우 잘못된 값을 포함하고 있을 확률이 높습니다. 이 것은 \(A\)의 원소가 대체적으로 0에 가까우며, 아주 적은량의 원소만 매우 크게 나타나서 수치적으로 불안정하기 때문입니다. 수치적 안정성을 높이는 방법은 여러가지가 있을 수 있지만, 여기서는 가장 간단하게 \(A\)에 작은 noise(perturbation)를 더하는 방법으로 해결하겠습니다. 물론, 이 방법으로 구해지는 solution은 실제값과는 조금 다릅니다. 하지만 여기서는 작은 오차는 무시하도록 합시다.
(이 부분은 이해하지 않고 넘어가도 됩니다.)A = M' * M + rand(size(M, 2)) * 0.001; % add random perturbation. b = M' * y; x = linsolve(A, b);
1과 2의 소요시간을 각각 확인해보죠.
%% method 1
tic; x = M \ y; toc;
%% method 2
tic;
A = M' * M + rand(size(M, 2)) * 0.001;
b = M' * y;
x = inv(A) * b;
toc;
3. Solution 검증
앞서 불러온 0과 1로만 재구성 된 test
데이터들을 이용하여 데이터 행렬 \(T\)와 \(z\)를 만듭니다.
T = reshape(test_x, 28^2, [])';
z = test_y';
앞서 구한 solution \(x\)를 이용하여 예측값 \(w\)를 구한 뒤, 반올림하여 \(z\)의 값과 비교합니다.
w = T * x; % compute the prediction by the best approximate solution `x`
acc = sum(round(w) == z) / length(w); % compute the accuracy: (# of correct answers) / (# of data points)
fprintf('Accuracy = %.2f %% \n', acc * 100)
우리가 만들어낸 모델(solution \(x\))은 98%가 넘는 정확도를 가지는, 꽤 정확한 분류모델인 걸 알 수 있습니다.
4. 더 많은 숫자 분류
이제 0과 1뿐만 아니라 더 많은 숫자들을 분류해 보겠습니다. 먼저 분류하고 싶은 숫자들을 지정한 뒤,
num_for_use = [0, 1, 2, 3];
for
문을 활용하여 분류하고 싶은 숫자들만 골라내겠습니다. 구성 방식은 위와 완전히 같습니다.
train_idx = []; % create empty idcies
test_idx = [];
for i = num_for_use
train_idx = [train_idx, find(train_y == i)];
test_idx = [test_idx, find(test_y == i)];
end
이제, 나머지 부분은 앞선 예제와 같습니다. 0, 1, 2, 3 네 개의 숫자를 분류하는 모델의 정확도는 70% 정도입니다.
연습문제
-
여러개의 숫자를 분류하는 모델을 만들어보세요.
-
위의 내용을 그대로 따라서 0과 9를 분류하는 모델을 만들고 검증 해보세요. 실제로는 0과 1을 분류하는 모델과 똑같이, 98% 정도의 정확도가 나와야 합니다. 여러분이 얻은 결과는 어떤가요?
결과가 예상과 다르다면, 어느 부분에서 달라졌는지 생각해보고 수정해보세요.
Assignment 9
Assignment | 9 |
---|---|
Score | 0 |
Submission | False |
Files | Download |
MNIST dataset consists of 70K images of handwritten numbers 0 to 9.
MNIST examples: https://en.wikipedia.org/wiki/MNIST_database
Each image has a size of \(28 \times 28\) and comes with a correct number. In this assignment, we try to classify handwritten images of MNIST using the best approximation.
1. Loading MNIST dataset
After downloading the file attached to the document, extract it. There should be four mat
files.
File name | Description |
---|---|
train_image.mat |
It’s 60K handwritten images. We are trying to get the best approximate solution using the images in train_image. |
train_target.mat |
Contains the correct numbers for each image in train_image in order. |
test_image.mat |
10K handwritten images. Will be used to validate solutions obtained through train_image. |
test_target.mat |
Correct numbers for test_image. |
Enter the following commands in matlab in order to load data into the Workspace
. Also, specify the type of each data as double
.
load('training_image.mat'); train_x = double(train_x);
load('training_target.mat'); train_y = double(train_y);
load('test_image.mat'); test_x = double(test_x);
load('test_target.mat'); test_y = double(test_y);
Make sure 4 variables are created in the Workspace
.
Name Size Bytes Class
test_x 28x28x10000 62720000 double
test_y 1x10000 80000 double
train_x 28x28x60000 376320000 double
train_y 1x60000 480000 double
2. Composing and solving Normal equation [Binary classification]
First, let’s find a solution to distinguish only two numbers 0 and 1. Reconstruct the variable by extracting only the data with correct answer numbers 0 or 1.
% find indices whose value is either 0 or 1.
train_idx = [find(train_y == 0), find(train_y == 1)];
test_idx = [find(test_y == 0), find(test_y == 1)];
% Leave only data that are 0 or 1.
train_x = train_x(:, :, train_idx); train_y = train_y(:, train_idx);
test_x = test_x(:, :, test_idx); test_y = test_y(:, test_idx);
Use the reshape
command on train_x
to construct a data matrix \(M\) where each row is an image.
M = reshape(train_x, 28^2, [])'; % [] means that 'matche dimension automatically'.
y = train_y';
(Almost trivially) There are no \(y\) that satisfies \(Mx=y\). So we are goint to try to find the best approximate solution.
-
Simply, we can use
\
syntax to get the solution.\
will automatically find the most suitable solution.x = M \ y; % the most simple way.
However, in this case
\
works quite slowly, and it is difficult to understand how to find a solution at our level. -
We can also find the solution by directly constructing the normal equation. In this case, the
linsolve
command is used to find the solution of the linear system.
(In this case,linsolve
uses \(LDU\)-decomposition to find the solution. Also, the syntax\
uses the same algorithm.)A = M' * M; b = M' * y; x = linsolve(A, b);
If you find a solution using the above code, there is a high probability that the obtained solution contains very wrong values. This is because \(A\) is numerically unstable since the elements of \(A\) are generally close to zero, and only very small amounts of elements appear very large. There are many ways to make the \(A\) numerical stable, but here we will use the simplest way which is adding a small noise (perturbation) to \(A\). Of course, the solution obtained in this way is slightly different from the actual value. But let’s ignore small errors here.
(You don’t have to try to understand this part.)A = M' * M + rand(size(M, 2)) * 0.001; % add random perturbation. b = M' * y; x = linsolve(A, b);
Let’s compare the time required for 1 and 2 respectively.
%% method 1
tic; x = M \ y; toc;
%% method 2
tic;
A = M' * M + rand(size(M, 2)) * 0.001;
b = M' * y;
x = inv(A) * b;
toc;
3. Solution validation
Create the data matrices \(T\) and \(z\) using the test
data reconstructed only with 0 and 1, which were loaded earlier.
T = reshape(test_x, 28^2, [])';
z = test_y';
Calculate the predicted value \(w\) using the solution \(x\) obtained earlier, round it and compare it with the value of \(z\).
w = T * x; % compute the prediction by the best approximate solution `x`
acc = sum(round(w) == z) / length(w); % compute the accuracy: (# of correct answers) / (# of data points)
fprintf('Accuracy = %.2f %% \n', acc * 100)
We can see that the classification model(solution \(x\)) is quite accurate with 98% accuary.
4. More classes of numbers
Now let’s classify more numbers. First, specify the numbers we want to classify,
num_for_use = [0, 1, 2, 3];
and then use the `for’-loop to select only the numbers we want to classify. The way is exactly the same as above.
train_idx = []; % create empty idcies
test_idx = [];
for i = num_for_use
train_idx = [train_idx, find(train_y == i)];
test_idx = [test_idx, find(test_y == i)];
end
Now, the rest is the same as in the previous example. The model that classifies the four numbers 0, 1, 2, 3 is about 70% accurate.
Exercises
-
Make a model to classify multiple numbers.
-
Create and validate a model that classifies 0 and 9 by following the above. In practice, the accuracy should be similar to a model that classifies 0’s and 1’s (about 98%). What results did you get?
If the result is not what you expected, think about where it changed and try to correct it.
Leave a comment