Confusion matrix can now be drawn
This commit is contained in:
Binary file not shown.
|
After Width: | Height: | Size: 20 KiB |
@@ -41,6 +41,25 @@ def main():
|
|||||||
|
|
||||||
# evaluation
|
# evaluation
|
||||||
evaluate_model(model, X, y_pred, y_test, le)
|
evaluate_model(model, X, y_pred, y_test, le)
|
||||||
|
|
||||||
|
draw_confusion_matrix(y_test, y_pred, le)
|
||||||
|
|
||||||
|
def draw_confusion_matrix(y_test, y_pred, le):
|
||||||
|
y_test_decoded = le.inverse_transform(y_test)
|
||||||
|
y_pred_decoded = le.inverse_transform(y_pred)
|
||||||
|
|
||||||
|
cm = confusion_matrix(y_test_decoded, y_pred_decoded, labels=le.classes_)
|
||||||
|
|
||||||
|
# Plot
|
||||||
|
plt.figure(figsize=(6,5))
|
||||||
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=le.classes_,
|
||||||
|
yticklabels=le.classes_)
|
||||||
|
plt.xlabel("Predicted")
|
||||||
|
plt.ylabel("Actual")
|
||||||
|
plt.title("Confusion Matrix")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig("images/confusion_matrix.png", dpi=300) # Save for README
|
||||||
|
plt.show()
|
||||||
|
|
||||||
def predict_target(model, X_test):
|
def predict_target(model, X_test):
|
||||||
y_pred = model.predict(X_test)
|
y_pred = model.predict(X_test)
|
||||||
|
|||||||
Reference in New Issue
Block a user