Active February 14, 2022    /    Viewed 479    /    Comments 0    /    Edit

Examples of how to create a confusion matrix and infer the true positive, true negative, false positive and false negative values using scikit learn in python ?

### Create a confusion matrix with scikit-learn

To create a confusion matrix a solution is to use scikit-learn:

````from sklearn.metrics import confusion_matrix`

`y_true = [1,1,0,0,1]`
`y_pred = [1,1,1,0,1]`

`cm = confusion_matrix(y_true, y_pred, labels=[0, 1])`

`print(cm)`
```

returns

````[[1 1]`
` [0 3]]`
```

### Get tn, fp, fn, tp for a binary classification

````tn, fp, fn, tp = confusion_matrix(list(y_true), list(y_pred), labels=[0, 1]).ravel()`

`print('True Positive', tp)`
`print('True Negative', tn)`
`print('False Positive', fp)`
`print('False Negative', fn)`
```

gives

````True Positive 3`
`True Negative 1`
`False Positive 1`
`False Negative 0`
```

#### Calculate the accuracy score

````from sklearn.metrics import accuracy_score`

`accuracy_score(y_true, y_pred)`
```

gives

````0.8`
```

same as doing

````acc = (tp+tn) / (tp+tn+fn+fp)`

`print(acc)`
```

gives

````0.8`
```

#### Calculate the tp, tn, fp and fn rates

````tot = cm.sum()`
```

same as

````tot = tn+tp+fp+fn`
```

returns here

````5`
```

and then

````print('True Positive Rate', tp/tot)`
`print('True Negative Rate', tn/tot)`
`print('False Positive Rate', fp/tot)`
`print('False Negative Rate', fn/tot)`
```

returns

````True Positive Rate 0.6`
`True Negative Rate 0.2`
`False Positive Rate 0.2`
`False Negative Rate 0.0`
```

### Get tn, fp, fn, tp with more than two categories

````from sklearn.metrics import confusion_matrix`

`y_true = [1,1,0,0,1,1,2,2]`
`y_pred = [1,1,1,0,2,1,2,2]`

`cm = confusion_matrix(y_true, y_pred, labels=[0, 1])`
`cm = confusion_matrix(y_true, y_pred)`

`print(cm)`
```

gives

````[[1 1 0]`
` [0 3 1]`
` [0 0 2]]`
```

and

````fp = cm.sum(axis=0) - np.diag(cm)  `
`fn = cm.sum(axis=1) - np.diag(cm)`
`tp = np.diag(cm)`
`tn = cm.sum() - (fp + fn + tp)`

`print(fp,fn,tp,tn)`
```

returns

````[0 1 1] [1 1 0] [1 3 2] [6 3 5]`
```

Get tp, tn,fp, fn for a given category

````idx = 0`
`print(fp[idx], fn[idx], tp[idx], tn[idx])`
```

returns

````0 1 1 6`

`print(fp[idx], fn[idx], tp[idx], tn[idx])`
```

Category 1

````idx = 1`
```

returns

````1 1 3 3`
```

Category 2

````idx = 2`
```

returns

````1 0 2 5`
```

#### Calculate the accuracy score

````from sklearn.metrics import accuracy_score`

`accuracy_score(y_true, y_pred)`
```

gives

````0.75`
```

same as

````acc = np.diag(cm).sum() / cm.sum()`
```

### Plot a confusion matrix with matplotlib and seaborn

````import numpy as np`
`import matplotlib.pyplot as plt`
`import seaborn as sn`
`import pandas as pd`

`import seaborn as sns`
`import math`

`from mpl_toolkits.axes_grid1 import make_axes_locatable`

`import matplotlib as mpl`

`mpl.style.use('seaborn')`

`df_cm = pd.DataFrame(cm, `
`    index = [i for i in range(cm.shape[0])],`
`    columns = [i for i in range(cm.shape[1])])`

`fig = plt.figure()`

`plt.clf()`

`ax = fig.add_subplot(111)`
`ax.set_aspect(1)`

`cmap = sns.cubehelix_palette(light=1, as_cmap=True)`

`res = sn.heatmap(df_cm, annot=True, fmt='.2f', cmap=cmap)`

`res.invert_yaxis()`

`#plt.yticks([0.5,1.5,2.5], [ '0', '1', '2'],va='center')`

`plt.title('Confusion Matrix')`

`plt.savefig('confusion_matrix_1.png', dpi=100, bbox_inches='tight' )`

`plt.show()`
```

#### Normalize the confusion matrix

````sum = cm.sum()`
`cm = cm * 100.0 / ( 1.0 * sum )`
```

and replace

````res = sn.heatmap(df_cm, annot=True, fmt='.2f', cmap=cmap)`
```

by

````res = sn.heatmap(df_cm, annot=True, vmin=0.0, vmax=100.0, fmt='.2f', cmap=cmap)`
```

### References

##### Daidalos

Hi, I am Ben.

I have developed this web site from scratch with Django to share with everyone my notes. If you have any ideas or suggestions to improve the site, let me know ! (you can contact me using the form in the welcome page). Thanks!

Did you find this content useful ?, If so, please consider donating a tip to the author(s). MoonBooks.org is visited by millions of people each year and it will help us to maintain our servers and create new contents.

Amount