MNIST Softmax Visualization

by allenlu2007



MNIST 加上 Softmax 分類器是教科書經典組合。直接用 GCP datalab 的 jupyter notebook 顯示。

In summary: imshow(original image), imshow(weights), plot(loss) 

Softmax Tutorial

First we import the needed libraries.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import scipy.sparse

Next we import MNIST data files. We use 500 training examples, and 100 test examples.

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)
batch = mnist.train.next_batch(500)
tb = mnist.train.next_batch(100)
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

Let’s take a look at one of the images in the set. Looks like a 4! Add colormap, cmap=’jet’ to use the colorbar

exampleNumber = 2 #Pick the example we want to visualize
example = batch[0][exampleNumber,:] #Then we load that example.
plt.imshow(np.reshape(example,[28,28]),cmap='jet') #Next we reshape it to 28x28 and display it.
<matplotlib.image.AxesImage at 0x7fa009767790>

Matplotlib.pyplot 的 imshow(), image show, 是非常重要的 function.  同樣 OpenCV, Matlab, Octave 都用一樣的 function, imshow(), 作為 show image 的基本 function.

= batch[1] x = batch[0] testY = tb[1] testX = tb[0]

Before we can get to training our model using the data, we will have to define a few functions that the training and testing process can use.

Here we define the loss function for softmax regression.

def getLoss(w,x,y,lam):
    m = x.shape[0] #First we get the number of training examples
    y_mat = oneHotIt(y) #Next we convert the integer class coding into a one-hot representation
    scores =,w) #Then we compute raw class scores given our input and current weights
    prob = softmax(scores) #Next we perform a softmax on these scores to get their probabilities
    loss = (-1 / m) * np.sum(y_mat * np.log(prob)) + (lam/2)*np.sum(w*w) #We then find the loss of the probabilities
    grad = (-1 / m) *,(y_mat - prob)) + lam*w #And compute the gradient for that loss
    return loss,grad

The below function converts integer class coding, where there is a unidimensional array of labels into a one-hot varient, where the array is size m (examples) x n (classes).

def oneHotIt(Y):
    m = Y.shape[0]
    #Y = Y[:,0]
    OHX = scipy.sparse.csr_matrix((np.ones(m), (Y, np.array(range(m)))))
    OHX = np.array(OHX.todense()).T
    return OHX

Here we perform the softmax transformation: This allows us to get probabilities for each class score that sum to 100%.

def softmax(z):
    z -= np.max(z)
    sm = (np.exp(z).T / np.sum(np.exp(z),axis=1)).T
    return sm

Here we determine the probabilities and predictions for each class when given a set of input data:

def getProbsAndPreds(someX):
    probs = softmax(,w))
    preds = np.argmax(probs,axis=1)
    return probs,preds

This is the main loop of the softmax regression.

Here we initialize our weights, regularization factor, number of iterations, and learning rate. We then loop over a computation of the loss and gradient, and application of gradient.

w = np.zeros([x.shape[1],len(np.unique(y))])
lam = 1
iterations = 1000
learningRate = 1e-5
losses = []
for i in range(0,iterations):
    loss,grad = getLoss(w,x,y,lam)
    w = w - (learningRate * grad)
print loss
[<matplotlib.lines.Line2D at 0x7fa0099b7090>]
def getAccuracy(someX,someY):
    prob,prede = getProbsAndPreds(someX)
    accuracy = sum(prede == someY)/(float(len(someY)))
    return accuracy
print 'Training Accuracy: ', getAccuracy(x,y)
print 'Test Accuracy: ', getAccuracy(testX,testY)
Training Accuracy:  0.902
Test Accuracy:  0.85

One of the benefits of a simple model like softmax is that we can visualize the weights for each of the classes, and see what it prefers. Here we look at the weights for the ‘3’ class.

classWeightsToVisualize = 3
<matplotlib.image.AxesImage at 0x7fa009695d50>


classWeightsToVisualize = 0
<matplotlib.image.AxesImage at 0x7fa0095d9d10>
classWeightsToVisualize = 1
<matplotlib.image.AxesImage at 0x7fa009517710>