Reference: https://gist.github.com/awjuliani/5ce098b4b76244b7a9e3

MNIST 加上 Softmax 分類器是教科書經典組合。直接用 GCP datalab 的 jupyter notebook 顯示。

**In summary: imshow(original image), imshow(weights), plot(loss) **

# Softmax Tutorial

First we import the needed libraries.

Next we import MNIST data files. We use 500 training examples, and 100 test examples.

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

Let’s take a look at one of the images in the set. Looks like a 4! Add colormap, cmap=’jet’ to use the colorbar

<matplotlib.image.AxesImage at 0x7fa009767790>

Matplotlib.pyplot 的 imshow(), image show, 是非常重要的 function. 同樣 OpenCV, Matlab, Octave 都用一樣的 function, imshow(), 作為 show image 的基本 function.

### Before we can get to training our model using the data, we will have to define a few functions that the training and testing process can use.

Here we define the loss function for softmax regression.

The below function converts integer class coding, where there is a unidimensional array of labels into a one-hot varient, where the array is size m (examples) x n (classes).

Here we perform the softmax transformation: This allows us to get probabilities for each class score that sum to 100%.

Here we determine the probabilities and predictions for each class when given a set of input data:

### This is the main loop of the softmax regression.

Here we initialize our weights, regularization factor, number of iterations, and learning rate. We then loop over a computation of the loss and gradient, and application of gradient.

[<matplotlib.lines.Line2D at 0x7fa0099b7090>]

Training Accuracy: 0.902
Test Accuracy: 0.85

### One of the benefits of a simple model like softmax is that we can visualize the weights for each of the classes, and see what it prefers. Here we look at the weights for the ‘3’ class.

<matplotlib.image.AxesImage at 0x7fa009695d50>

<matplotlib.image.AxesImage at 0x7fa0095d9d10>

<matplotlib.image.AxesImage at 0x7fa009517710>

### Like this:

Like Loading...

*Related*