Understanding estimators
So, what exactly is an estimator anyway? The concept of estimators lies at the heart of scikit-learn. Estimators are objects (in the sense of Python’s Object-Oriented Programming (OOP)) that implement algorithms for learning from data and are consistent across the entire library. Every estimator in scikit-learn, whether a model or a transformer, follows a simple and intuitive interface. The two most essential methods of any estimator are fit() and predict(), both of which were mentioned previously. The fit() method trains the model by learning from data, while predict() is used to make predictions on new data based on the trained model. This is the raison d’être of ML.
For example, in one of the simplest—yet often most powerful—ML models, LinearRegression(), calling fit() with training data allows the model to learn the optimal coefficients for predicting outcomes. Afterward, predict() can be used on new data to generate predictions:
from sklearn.linear_model import LinearRegression import numpy as np # Example data X = np.array([[1], [2], [3], [4], [5]]) # Feature matrix y = np.array([1, 2, 3, 3.5, 5]) # Target values # Create and fit the model model = LinearRegression() model.fit(X, y) # Predict values for new data X_new = np.array([[6], [7]]) predictions = model.predict(X_new) print(predictions) # Output: [5.75, 6.7]
The library also provides a nice shortcut method, fit_predict(), that combines these operations into a single API call—a very useful tool! Now, there is a reason why scikit-learn has both the fit() and predict() methods separate, as well as fit_predict(). Typically, the fit_predict() method is applied when you want to obtain predictions within the same dataset the model was trained on. This is often the case in unsupervised learning. An example of this can be seen here regarding KMeans, where our data does not contain a target variable we are trying to predict in the training data. In supervised learning scenarios where we do have a target, the fit() method would be applied to the training data, and the predict() method would be applied to our holdout dataset.
This is not to say you can’t use fit_predict() in unsupervised learning scenarios. Datasets can still be split into training, validation, and testing sets:
# Fit_predict is not used in LinearRegression, # but as an example for clustering: from sklearn.cluster import KMeans # Example data X = np.array([[1], [2], [3], [4], [5]]) # KMeans Clustering example kmeans = KMeans(n_clusters=2) labels = kmeans.fit_predict(X) print(labels) # Output: [0,0,0,1,1]
scikit-learn’s design ensures that whether you are working with simple linear regression or more complex algorithms such as random forests, the pattern remains the same, promoting consistency and ease of use.
Throughout this book, we will explore various estimators, including LinearRegression() (Chapter 5), DecisionTreeClassifier() (Chapter 8), and KNeighborsClassifier() (Chapter 4), while demonstrating how to use them to train models, evaluate performance, and make predictions, all using the familiar fit() and predict() structure.