Confusion matrix can now be drawn
This commit is contained in:
Binary file not shown.
|
After Width: | Height: | Size: 20 KiB |
@@ -41,6 +41,25 @@ def main():
|
||||
|
||||
# evaluation
|
||||
evaluate_model(model, X, y_pred, y_test, le)
|
||||
|
||||
draw_confusion_matrix(y_test, y_pred, le)
|
||||
|
||||
def draw_confusion_matrix(y_test, y_pred, le):
|
||||
y_test_decoded = le.inverse_transform(y_test)
|
||||
y_pred_decoded = le.inverse_transform(y_pred)
|
||||
|
||||
cm = confusion_matrix(y_test_decoded, y_pred_decoded, labels=le.classes_)
|
||||
|
||||
# Plot
|
||||
plt.figure(figsize=(6,5))
|
||||
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=le.classes_,
|
||||
yticklabels=le.classes_)
|
||||
plt.xlabel("Predicted")
|
||||
plt.ylabel("Actual")
|
||||
plt.title("Confusion Matrix")
|
||||
plt.tight_layout()
|
||||
plt.savefig("images/confusion_matrix.png", dpi=300) # Save for README
|
||||
plt.show()
|
||||
|
||||
def predict_target(model, X_test):
|
||||
y_pred = model.predict(X_test)
|
||||
|
||||
Reference in New Issue
Block a user