diff --git a/images/confusion_matrix.png b/images/confusion_matrix.png new file mode 100644 index 0000000..f2fe755 Binary files /dev/null and b/images/confusion_matrix.png differ diff --git a/main.py b/main.py index 9dc5c06..9f6ea01 100644 --- a/main.py +++ b/main.py @@ -41,6 +41,25 @@ def main(): # evaluation evaluate_model(model, X, y_pred, y_test, le) + + draw_confusion_matrix(y_test, y_pred, le) + +def draw_confusion_matrix(y_test, y_pred, le): + y_test_decoded = le.inverse_transform(y_test) + y_pred_decoded = le.inverse_transform(y_pred) + + cm = confusion_matrix(y_test_decoded, y_pred_decoded, labels=le.classes_) + + # Plot + plt.figure(figsize=(6,5)) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=le.classes_, + yticklabels=le.classes_) + plt.xlabel("Predicted") + plt.ylabel("Actual") + plt.title("Confusion Matrix") + plt.tight_layout() + plt.savefig("images/confusion_matrix.png", dpi=300) # Save for README + plt.show() def predict_target(model, X_test): y_pred = model.predict(X_test)