katboost.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # conda create --name CAT
  2. # conda activate CAT
  3. # pip install catboost --user
  4. # pip install seaborn scikit-learn
  5. # pip freeze > katrequireements.txt
  6. # conda deactivate
  7. '''
  8. 📌 Кто такой catBoost? 🐈 CatBoost означает «категорическое» повышение.
  9. CatBoost — это библиотека градиентного бустинга, созданная 🌐ндексом.
  10. Прогнозы делаются на основе ансамбля слабых обучающих алгоритмов, а именно небрежных деревьев.
  11. Вот несколько преимущества использования этой библиотеки:
  12. ➖ позволяет использовать категориальные признаки без предварительной обработки
  13. ➖ дает отличные результаты с параметрами по умолчанию
  14. ➖ под капотом умеет обрабатывать пропущенные значения
  15. ➖ можно использовать и для регрессии, и для класссификации
  16. '''
  17. '''
  18. Она использует небрежные (oblivious) деревья решений, чтобы вырастить сбалансированное дерево.
  19. Одни и те же функции используются для создания левых и правых разделений (split) на каждом уровне дерева.
  20. В плане простоты использования и легкости входа для новичков, пожалуй является топ-1 библиотекой для табличных данных
  21. и вот почему:
  22. ⏩Принимает категориальные фичи сразу без всякой предварительной обработки.
  23. ⏩Чтобы перенести обучение с CPU на GPU достаточно поменять значение 1 параметра, без установки доп.пакетов или специальных версий, как в других библиотеках
  24. ⏩Даже с дефолтными параметрами выдает хорошую точность модели. Основные параметры не константные, а подбираются самой библиотекой, в зависимости от размера входных данных.
  25. ⏩Может принимать текстовые признаки, эмбеддинги, временные признаки.
  26. ⏩Без дополнительных манипуляций и оберток встраивается в стандартные пайплайны (например, sklearn).
  27. ⏩Идет в комплекте с "батарейками": feature_selection, object_selection, cross_validation, grid_search и пр.
  28. '''
  29. '''
  30. Бустинг – это ансамблевый метод машинного обучения, целью которого является объединение нескольких слабых моделей
  31. предсказания для создания одной сильной. Слабая модель – это такая, которая выполняет предсказания немного лучше,
  32. чем наугад, в то время как сильная модель обладает высокой предсказательной способностью. Цель бустинга – улучшить
  33. точность предсказаний.
  34. Бустинг работает путём последовательного добавления моделей в ансамбль. Каждая следующая модель строится таким образом,
  35. чтобы исправлять ошибки, сделанные предыдущими моделями. Это достигается путём фокусировки на наиболее проблемных данных,
  36. которые были неверно классифицированы или предсказаны ранее.
  37. Одной из основных фич бустинга является динамическое взвешивание обучающих данных. После каждого этапа обучения модели в
  38. ансамбле, данные, на которых были допущены ошибки, получают больший вес. Это означает, что последующие модели уделяют
  39. больше внимания именно этим трудным случаям.
  40. Когда используются решающие деревья, каждое последующее дерево строится с учетом ошибок, сделанных предыдущими деревьями.
  41. Новые деревья учатся на ошибках, улучшая общую точность ансамбля.
  42. Несмотря на свою мощь, бустинг может быть склонен к переобучению, особенно если в ансамбле слишком много моделей или они
  43. слишком сложные. Для контроля переобучения к примеру ранняя остановка (early stopping).
  44. '''
  45. import pandas as pd
  46. import matplotlib.pyplot as plt
  47. import seaborn as sns
  48. import numpy as np
  49. from sklearn.model_selection import train_test_split
  50. from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
  51. from catboost import CatBoostClassifier
  52. # Load the Titanic dataset
  53. # titanic = sns.load_dataset('titanic') # https://github.com/mwaskom/seaborn-data/blob/master/titanic.csv
  54. # target = 'survived'
  55. titanic = pd.read_csv('titanic.csv')
  56. print(titanic.head())
  57. target = 'survived'
  58. # preprocessing data
  59. # filling missing value in deck column with a new category: Unknown
  60. categories = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'Unknown']
  61. titanic['deck'] = pd.Categorical(
  62. titanic['deck'], categories=categories, ordered=True)
  63. titanic['deck'] = titanic['deck'].fillna('Unknown')
  64. # filling missing value in age column using mean imputation
  65. age_mean = titanic['age'].fillna(0).mean()
  66. titanic['age'] = titanic['age'].fillna(age_mean)
  67. # droping missing values in embark as there are only 2
  68. titanic = titanic.dropna()
  69. # droping alive column to make the problem more challenging
  70. titanic = titanic.drop('alive', axis=1)
  71. # Create the feature matrix (X) and target vector (y)
  72. X = titanic.drop(target, axis=1)
  73. y = titanic[target]
  74. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  75. # specifying categorical features
  76. categorical_features = ['sex', 'pclass', 'sibsp', 'parch', 'embarked',
  77. 'class', 'who', 'adult_male', 'embark_town', 'alone', 'deck']
  78. # create and train the CatBoostClassifier
  79. model = CatBoostClassifier(iterations=100, depth=8, learning_rate=0.1, cat_features=categorical_features,
  80. loss_function='Logloss', custom_metric=['AUC'], random_seed=42)
  81. # model.fit(X_train, y_train)
  82. # model.save_model('catboost_classification_titanic.model')
  83. model_name = CatBoostClassifier() # parameters not required.
  84. model_name.load_model('catboost_classification_titanic.model')
  85. # predicting accuracy
  86. y_pred = model_name.predict(X_test)
  87. # print(y_pred)
  88. X_test['predicted'] = y_pred
  89. print(X_test.head(11))
  90. # saving the dataframe
  91. X_test.to_csv('titanic-predicted.csv')
  92. accuracy = accuracy_score(y_test, y_pred)
  93. print(f"Accuracy: {accuracy:.2f}")
  94. # Plot the confusion matrix as a heatmap
  95. confusion = confusion_matrix(y_test, y_pred)
  96. plt.figure(figsize=(8, 6))
  97. sns.heatmap(confusion, annot=True, fmt='d', cmap='Blues', xticklabels=[
  98. 'Predicted Negative', 'Predicted Positive'], yticklabels=['Actual Negative', 'Actual Positive'])
  99. plt.xlabel('Predicted')
  100. plt.ylabel('Actual')
  101. plt.title('Confusion Matrix')
  102. plt.show()
  103. importances = model_name.get_feature_importance()
  104. feature_names = X.columns
  105. sorted_indices = np.argsort(importances)[::-1]
  106. plt.figure(figsize=(10, 6))
  107. plt.bar(range(len(feature_names)), importances[sorted_indices])
  108. plt.xticks(range(len(feature_names)), feature_names[sorted_indices], rotation=90)
  109. plt.title("Feature Importance")
  110. plt.show()
  111. # Print the classification report
  112. print("Classification Report:")
  113. print(classification_report(y_test, y_pred))
  114. '''
  115. Accuracy: 0.80
  116. Classification Report:
  117. precision recall f1-score support
  118. 0 0.82 0.87 0.84 109
  119. 1 0.77 0.70 0.73 69
  120. accuracy 0.80 178
  121. macro avg 0.80 0.78 0.79 178
  122. weighted avg 0.80 0.80 0.80 178
  123. '''