kat-taxis.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. # https://github.com/mwaskom/seaborn-data/tree/master
  2. # https://github.com/mwaskom/seaborn-data/blob/master/taxis.csv
  3. # conda activate CAT
  4. import pandas as pd
  5. import matplotlib.pyplot as plt
  6. import seaborn as sns
  7. import numpy as np
  8. from sklearn.model_selection import train_test_split
  9. from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
  10. from catboost import CatBoostClassifier
  11. # Load the dataset
  12. dataframe = pd.read_csv('taxis.csv')
  13. print(dataframe.head())
  14. print(dataframe.shape)
  15. print(dataframe.info())
  16. '''
  17. pickup dropoff passengers distance fare tip tolls total color payment pickup_zone dropoff_zone pickup_borough dropoff_borough
  18. 0 2019-03-23 20:21:09 2019-03-23 20:27:24 1 1.60 7.0 2.15 0.0 12.95 yellow credit card Lenox Hill West UN/Turtle Bay South Manhattan Manhattan
  19. 1 2019-03-04 16:11:55 2019-03-04 16:19:00 1 0.79 5.0 0.00 0.0 9.30 yellow cash Upper West Side South Upper West Side South Manhattan Manhattan
  20. 2 2019-03-27 17:53:01 2019-03-27 18:00:25 1 1.37 7.5 2.36 0.0 14.16 yellow credit card Alphabet City West Village Manhattan Manhattan
  21. 3 2019-03-10 01:23:59 2019-03-10 01:49:51 1 7.70 27.0 6.15 0.0 36.95 yellow credit card Hudson Sq Yorkville West Manhattan Manhattan
  22. 4 2019-03-30 13:27:42 2019-03-30 13:37:14 3 2.16 9.0 1.10 0.0 13.40 yellow credit card Midtown East Yorkville West Manhattan Manhattan
  23. (6433, 14)
  24. <class 'pandas.core.frame.DataFrame'>
  25. RangeIndex: 6433 entries, 0 to 6432
  26. Data columns (total 14 columns):
  27. # Column Non-Null Count Dtype
  28. --- ------ -------------- -----
  29. 0 pickup 6433 non-null object
  30. 1 dropoff 6433 non-null object
  31. 2 passengers 6433 non-null int64
  32. 3 distance 6433 non-null float64
  33. 4 fare 6433 non-null float64
  34. 5 tip 6433 non-null float64
  35. 6 tolls 6433 non-null float64
  36. 7 total 6433 non-null float64
  37. 8 color 6433 non-null object
  38. 9 payment 6389 non-null object
  39. 10 pickup_zone 6407 non-null object
  40. 11 dropoff_zone 6388 non-null object
  41. 12 pickup_borough 6407 non-null object
  42. 13 dropoff_borough 6388 non-null object
  43. dtypes: float64(5), int64(1), object(8)
  44. memory usage: 703.7+ KB
  45. '''
  46. target = 'payment'
  47. # preprocessing data: cat_features must be integer or string, real number values and NaN values should be converted to string.
  48. dataframe = dataframe.dropna()
  49. # dataframe = dataframe.drop(['pickup', 'dropoff'], axis=1)
  50. dataframe['dropoff'] = dataframe['dropoff'].astype(str)
  51. dataframe['pickup'] = dataframe['pickup'].astype(str)
  52. # Create the feature matrix (X) and target vector (y)
  53. X = dataframe.drop(target, axis=1)
  54. y = dataframe[target]
  55. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  56. # specifying categorical features
  57. categorical_features = ['pickup', 'dropoff','color', 'pickup_zone', 'dropoff_zone', 'pickup_borough',
  58. 'dropoff_borough']
  59. # create and train the CatBoostClassifier
  60. model = CatBoostClassifier(iterations=100, depth=8, learning_rate=0.1, cat_features=categorical_features,
  61. loss_function='Logloss', custom_metric=['AUC'], random_seed=42)
  62. model.fit(X_train, y_train)
  63. '''
  64. 0: learn: 0.6233187 total: 176ms remaining: 17.5s
  65. 1: learn: 0.5582761 total: 182ms remaining: 8.92s
  66. 2: learn: 0.5070269 total: 196ms remaining: 6.33s
  67. 3: learn: 0.4623906 total: 209ms remaining: 5.03s
  68. 4: learn: 0.4268648 total: 224ms remaining: 4.26s
  69. 5: learn: 0.3923146 total: 236ms remaining: 3.69s
  70. 6: learn: 0.3627496 total: 244ms remaining: 3.24s
  71. 7: learn: 0.3379926 total: 256ms remaining: 2.95s
  72. 8: learn: 0.3163899 total: 271ms remaining: 2.75s
  73. 9: learn: 0.2968414 total: 279ms remaining: 2.51s
  74. 10: learn: 0.2796024 total: 291ms remaining: 2.36s
  75. 11: learn: 0.2645461 total: 305ms remaining: 2.24s
  76. 12: learn: 0.2509147 total: 319ms remaining: 2.13s
  77. 13: learn: 0.2390644 total: 332ms remaining: 2.04s
  78. 14: learn: 0.2283068 total: 346ms remaining: 1.96s
  79. 15: learn: 0.2187057 total: 360ms remaining: 1.89s
  80. 16: learn: 0.2098948 total: 374ms remaining: 1.82s
  81. 17: learn: 0.2025366 total: 388ms remaining: 1.76s
  82. 18: learn: 0.1956326 total: 401ms remaining: 1.71s
  83. 19: learn: 0.1889538 total: 414ms remaining: 1.66s
  84. 20: learn: 0.1831394 total: 428ms remaining: 1.61s
  85. 21: learn: 0.1776855 total: 441ms remaining: 1.56s
  86. 22: learn: 0.1730124 total: 455ms remaining: 1.52s
  87. 23: learn: 0.1686998 total: 469ms remaining: 1.48s
  88. 24: learn: 0.1649280 total: 482ms remaining: 1.45s
  89. 25: learn: 0.1614438 total: 497ms remaining: 1.42s
  90. 26: learn: 0.1581258 total: 518ms remaining: 1.4s
  91. 27: learn: 0.1551839 total: 533ms remaining: 1.37s
  92. 28: learn: 0.1524781 total: 546ms remaining: 1.34s
  93. 29: learn: 0.1499819 total: 560ms remaining: 1.31s
  94. 30: learn: 0.1476731 total: 575ms remaining: 1.28s
  95. 31: learn: 0.1458057 total: 588ms remaining: 1.25s
  96. 32: learn: 0.1439919 total: 602ms remaining: 1.22s
  97. 33: learn: 0.1420310 total: 616ms remaining: 1.2s
  98. 34: learn: 0.1404973 total: 630ms remaining: 1.17s
  99. 35: learn: 0.1389005 total: 644ms remaining: 1.14s
  100. 36: learn: 0.1371300 total: 658ms remaining: 1.12s
  101. 37: learn: 0.1357909 total: 672ms remaining: 1.09s
  102. 38: learn: 0.1345599 total: 685ms remaining: 1.07s
  103. 39: learn: 0.1338893 total: 699ms remaining: 1.05s
  104. 40: learn: 0.1328265 total: 712ms remaining: 1.02s
  105. 41: learn: 0.1316869 total: 726ms remaining: 1s
  106. 42: learn: 0.1307625 total: 740ms remaining: 981ms
  107. 43: learn: 0.1296670 total: 754ms remaining: 960ms
  108. 44: learn: 0.1287822 total: 763ms remaining: 932ms
  109. 45: learn: 0.1279533 total: 776ms remaining: 911ms
  110. 46: learn: 0.1271221 total: 790ms remaining: 891ms
  111. 47: learn: 0.1263117 total: 804ms remaining: 871ms
  112. 48: learn: 0.1259618 total: 817ms remaining: 850ms
  113. 49: learn: 0.1254444 total: 831ms remaining: 831ms
  114. 50: learn: 0.1246826 total: 845ms remaining: 812ms
  115. 51: learn: 0.1243692 total: 859ms remaining: 793ms
  116. 52: learn: 0.1237797 total: 873ms remaining: 774ms
  117. 53: learn: 0.1232281 total: 887ms remaining: 756ms
  118. 54: learn: 0.1225791 total: 901ms remaining: 737ms
  119. 55: learn: 0.1219648 total: 914ms remaining: 719ms
  120. 56: learn: 0.1219299 total: 919ms remaining: 693ms
  121. 57: learn: 0.1215192 total: 932ms remaining: 675ms
  122. 58: learn: 0.1213138 total: 946ms remaining: 657ms
  123. 59: learn: 0.1212711 total: 962ms remaining: 641ms
  124. 60: learn: 0.1207654 total: 977ms remaining: 624ms
  125. 61: learn: 0.1206010 total: 991ms remaining: 607ms
  126. 62: learn: 0.1200860 total: 1s remaining: 591ms
  127. 63: learn: 0.1199512 total: 1.02s remaining: 574ms
  128. 64: learn: 0.1194240 total: 1.03s remaining: 558ms
  129. 65: learn: 0.1189549 total: 1.05s remaining: 541ms
  130. 66: learn: 0.1188439 total: 1.06s remaining: 522ms
  131. 67: learn: 0.1187920 total: 1.06s remaining: 501ms
  132. 68: learn: 0.1187825 total: 1.07s remaining: 480ms
  133. 69: learn: 0.1185256 total: 1.08s remaining: 464ms
  134. 70: learn: 0.1185190 total: 1.09s remaining: 445ms
  135. 71: learn: 0.1181260 total: 1.1s remaining: 428ms
  136. 72: learn: 0.1178732 total: 1.12s remaining: 413ms
  137. 73: learn: 0.1175964 total: 1.13s remaining: 397ms
  138. 74: learn: 0.1172277 total: 1.14s remaining: 381ms
  139. 75: learn: 0.1168594 total: 1.16s remaining: 366ms
  140. 76: learn: 0.1166237 total: 1.17s remaining: 350ms
  141. 77: learn: 0.1163408 total: 1.18s remaining: 334ms
  142. 78: learn: 0.1160788 total: 1.2s remaining: 318ms
  143. 79: learn: 0.1158453 total: 1.21s remaining: 303ms
  144. 80: learn: 0.1155340 total: 1.23s remaining: 288ms
  145. 81: learn: 0.1152641 total: 1.24s remaining: 272ms
  146. 82: learn: 0.1150903 total: 1.25s remaining: 257ms
  147. 83: learn: 0.1147869 total: 1.27s remaining: 242ms
  148. 84: learn: 0.1145184 total: 1.28s remaining: 226ms
  149. 85: learn: 0.1142394 total: 1.29s remaining: 211ms
  150. 86: learn: 0.1139451 total: 1.31s remaining: 196ms
  151. 87: learn: 0.1136813 total: 1.32s remaining: 180ms
  152. 88: learn: 0.1134232 total: 1.34s remaining: 165ms
  153. 89: learn: 0.1133195 total: 1.35s remaining: 150ms
  154. 90: learn: 0.1130514 total: 1.36s remaining: 135ms
  155. 91: learn: 0.1129326 total: 1.38s remaining: 120ms
  156. 92: learn: 0.1127142 total: 1.39s remaining: 105ms
  157. 93: learn: 0.1127105 total: 1.4s remaining: 89.1ms
  158. 94: learn: 0.1125976 total: 1.41s remaining: 74.2ms
  159. 95: learn: 0.1123823 total: 1.42s remaining: 59.3ms
  160. 96: learn: 0.1123746 total: 1.43s remaining: 44.2ms
  161. 97: learn: 0.1123363 total: 1.44s remaining: 29.4ms
  162. 98: learn: 0.1122176 total: 1.45s remaining: 14.7ms
  163. 99: learn: 0.1120782 total: 1.47s remaining: 0us
  164. '''
  165. model.save_model('catboost_classification_taxis.model')
  166. model_name = CatBoostClassifier() # parameters not required.
  167. model_name.load_model('catboost_classification_taxis.model')
  168. # predicting accuracy
  169. y_pred = model_name.predict(X_test)
  170. # print(y_pred)
  171. X_test['predicted'] = y_pred
  172. print(X_test.head(11))
  173. # saving the dataframe
  174. X_test.to_csv('taxis-predicted.csv')
  175. accuracy = accuracy_score(y_test, y_pred)
  176. print(f"Accuracy: {accuracy:.2f}")
  177. # Plot the confusion matrix as a heatmap
  178. confusion = confusion_matrix(y_test, y_pred)
  179. plt.figure(figsize=(8, 6))
  180. sns.heatmap(confusion, annot=True, fmt='d', cmap='Blues', xticklabels=[
  181. 'Predicted Negative', 'Predicted Positive'], yticklabels=['Actual Negative', 'Actual Positive'])
  182. plt.xlabel('Predicted')
  183. plt.ylabel('Actual')
  184. plt.title('Confusion Matrix for taxis')
  185. plt.show()
  186. importances = model_name.get_feature_importance()
  187. feature_names = X.columns
  188. sorted_indices = np.argsort(importances)[::-1]
  189. plt.figure(figsize=(10, 6))
  190. plt.bar(range(len(feature_names)), importances[sorted_indices])
  191. plt.xticks(range(len(feature_names)), feature_names[sorted_indices], rotation=90)
  192. plt.title("Feature Importance for taxis")
  193. plt.show()
  194. # Print the classification report
  195. print("Classification Report for taxis:")
  196. print(classification_report(y_test, y_pred))
  197. '''
  198. pickup dropoff passengers distance fare tip ... color pickup_zone dropoff_zone pickup_borough dropoff_borough predicted
  199. 742 2019-03-05 09:52:36 2019-03-05 10:03:47 1 1.32 8.50 0.00 ... yellow Greenwich Village North Murray Hill Manhattan Manhattan cash
  200. 4824 2019-03-17 13:16:13 2019-03-17 13:40:32 1 2.90 17.00 0.00 ... yellow Clinton East Lenox Hill East Manhattan Manhattan cash
  201. 3108 2019-03-14 01:33:26 2019-03-14 01:45:58 1 4.17 14.50 3.66 ... yellow Lower East Side Midtown Center Manhattan Manhattan credit card
  202. 4985 2019-03-05 11:41:34 2019-03-05 12:12:45 1 4.03 21.50 2.50 ... yellow Clinton West Little Italy/NoLiTa Manhattan Manhattan credit card
  203. 219 2019-03-08 18:08:03 2019-03-08 18:15:42 1 1.40 7.00 2.25 ... yellow Kips Bay Alphabet City Manhattan Manhattan credit card
  204. 4154 2019-03-11 18:57:33 2019-03-11 19:06:21 1 1.10 7.50 2.00 ... yellow Midtown East Penn Station/Madison Sq West Manhattan Manhattan credit card
  205. 2280 2019-03-17 12:10:05 2019-03-17 12:15:02 1 1.00 6.00 0.70 ... yellow Lenox Hill East Yorkville East Manhattan Manhattan credit card
  206. 5456 2019-03-12 21:11:03 2019-03-12 21:41:36 1 15.78 42.82 0.00 ... green Maspeth Marble Hill Queens Manhattan credit card
  207. 1684 2019-03-28 19:53:03 2019-03-28 20:04:10 6 2.39 10.50 2.96 ... yellow Midtown Center Long Island City/Hunters Point Manhattan Queens credit card
  208. 4889 2019-03-28 12:24:30 2019-03-28 12:28:44 2 0.95 5.00 1.66 ... yellow Upper West Side North Manhattan Valley Manhattan Manhattan credit card
  209. 5395 2019-03-02 19:36:34 2019-03-02 20:05:06 1 3.37 18.50 3.00 ... yellow Upper East Side South Union Sq Manhattan Manhattan credit card
  210. [11 rows x 14 columns]
  211. Accuracy: 0.97
  212. Classification Report for taxis:
  213. precision recall f1-score support
  214. cash 0.92 0.98 0.95 365
  215. credit card 0.99 0.97 0.98 904
  216. accuracy 0.97 1269
  217. macro avg 0.95 0.97 0.96 1269
  218. weighted avg 0.97 0.97 0.97 1269
  219. '''