Skip to main content

Gaussian Mixture Model

A Gaussian Mixture Model (GMM) is a probabilistic model used for clustering and density estimation. It assumes that data is generated from a mixture of several Gaussian distributions, each representing a cluster within the dataset. Unlike K-means, which assigns data points to the nearest cluster centroid deterministically, GMM considers each data point as belonging to each cluster with a certain probability, allowing for soft clustering.

GMM is ideal when:

  • Clusters have elliptical shapes or different spreads: GMM captures varying shapes and densities, unlike K-means, which assumes clusters are spherical.
  • Soft clustering is preferred: If you want to know the probability of a data point belonging to each cluster (not a hard assignment).
  • Data has overlapping clusters: GMM allows a point to belong partially to multiple clusters, which is helpful when clusters have significant overlap.

Applications of GMM

  1. Image Segmentation: Used to segment images into regions, where each region can be represented by a Gaussian distribution in color or intensity space.
  2. Speech Recognition: Models sound waves or frequencies where different Gaussian distributions represent different phonemes or sounds.
  3. Anomaly Detection: Helps in identifying anomalies by learning the normal distribution of data and detecting points with low probability under the model.
  4. Finance: Modeling returns or risk in portfolios where different Gaussian distributions represent different market conditions.
  5. Customer Segmentation: Clusters customers into segments based on purchasing behavior, allowing overlapping segments where customers belong to multiple clusters.

Step 1: Define Model Parameters

Each Gaussian in the mixture model is described by three parameters:

  • Mean (μ): The center of the Gaussian distribution.
  • Covariance (Σ): The spread of the distribution. This parameter enables GMM to capture different shapes (spherical, elliptical).
  • Mixing Coefficient (π): The weight of each Gaussian component in the mixture. This represents the fraction of data points in each Gaussian and sums to 1.

Let’s say we have K clusters. Then each Gaussian cluster will have:

  • A mean vector μₖ.
  • A covariance matrix Σₖ.
  • A mixing coefficient πₖ, where k=1Kπk=1.

Step 2: Initialize Parameters

  1. Choose K, the number of Gaussian components (clusters).
  2. Initialize the mean (μ), covariance (Σ), and mixing coefficients (π) randomly or using some heuristic like K-means clustering.

Step 3: Expectation-Maximization (EM) Algorithm

The core of GMM is the Expectation-Maximization (EM) algorithm, which iteratively adjusts the model parameters to maximize the likelihood of the data. The algorithm has two main steps:

Expectation (E) Step

  1. Calculate Responsibilities: For each data point, compute the responsibility of each Gaussian component. The responsibility rikr_{ik} represents the probability that data point i belongs to Gaussian k.
    • Using Bayes’ theorem, the responsibility is calculated as: rik=πkN(xiμk,Σk)j=1KπjN(xiμj,Σj)r_{ik} = \frac{\pi_k \cdot \mathcal{N}(x_i | \mu_k, \Sigma_k)}{\sum_{j=1}^{K} \pi_j \cdot \mathcal{N}(x_i | \mu_j, \Sigma_j)}
    • Here, N(xiμk,Σk)\mathcal{N}(x_i | \mu_k, \Sigma_k) represents the probability density function (PDF) of the Gaussian with mean μₖ and covariance Σₖ evaluated at point xᵢ.

Maximization (M) Step

  1. Update Parameters: After computing the responsibilities, update the parameters of each Gaussian component to maximize the likelihood function.
    • Update Mean μk\mu_k: μk=i=1Nrikxii=1Nrik\mu_k = \frac{\sum_{i=1}^{N} r_{ik} x_i}{\sum_{i=1}^{N} r_{ik}}
    • Update Covariance Σk\Sigma_k: Σk=i=1Nrik(xiμk)(xiμk)Ti=1Nrik\Sigma_k = \frac{\sum_{i=1}^{N} r_{ik} (x_i - \mu_k)(x_i - \mu_k)^T}{\sum_{i=1}^{N} r_{ik}}
    • Update Mixing Coefficient πk\pi_k: πk=1Ni=1Nrik\pi_k = \frac{1}{N} \sum_{i=1}^{N} r_{ik}
    • These updated parameters maximize the likelihood of observing the data given the GMM.

Repeat: Alternate between the E-step and the M-step until the algorithm converges. Convergence is typically defined as a minimal change in the log-likelihood or parameter values between iterations.

Step 4: Cluster Assignment

Once the EM algorithm converges:

  • Hard Clustering: Assign each data point to the cluster with the highest responsibility (i.e., highest probability).
  • Soft Clustering: Retain the responsibility values for each data point, giving a probability distribution over clusters.

Step 5: Evaluate and Interpret

  • Centroids and Covariance Matrices: The learned means and covariances of each Gaussian component describe the centers and shapes of the clusters.
  • Mixing Coefficients: The π values give insight into the relative sizes of each cluster.

Implementation

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_wine
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import StandardScaler

# Step 1: Load and Preprocess the Data
# Load the Wine dataset
wine = load_wine()
data = pd.DataFrame(wine.data, columns=wine.feature_names)

# Standardize the features
scaler = StandardScaler()
scaled_data = scaler.fit_transform(data)

# Step 2: Initialize and Fit the GMM Model
# Set the number of clusters (e.g., 3)
k = 3
gmm = GaussianMixture(n_components=k, random_state=42)
gmm.fit(scaled_data)

# Step 3: Predict Clusters
# Get cluster labels for each data point
clusters = gmm.predict(scaled_data)

# Add cluster labels to the data
data['Cluster'] = clusters

# Step 4: Analyze and Interpret Results
# Print GMM parameters
print("Means of each Gaussian component:\n", gmm.means_)
print("\nCovariances of each Gaussian component:\n", gmm.covariances_)
print("\nMixing coefficients:\n", gmm.weights_)

# Calculate the responsibility for each data point (soft clustering)
responsibilities = gmm.predict_proba(scaled_data)
print("\nResponsibilities (first 5 data points):\n", responsibilities[:5])

# Step 5: Visualize the Clustering
# Plot the clusters on the first two principal components for better visualization
from sklearn.decomposition import PCA

# Reduce data to 2D using PCA for visualization
pca = PCA(n_components=2)
reduced_data = pca.fit_transform(scaled_data)

# Scatter plot for clusters
plt.figure(figsize=(10, 6))
plt.scatter(reduced_data[:, 0], reduced_data[:, 1], c=clusters, cmap='viridis', marker='o', label='Data Points')
plt.scatter(pca.transform(gmm.means_)[:, 0], pca.transform(gmm.means_)[:, 1], c='red', marker='X', s=100, label='Centroids')
plt.title("Gaussian Mixture Model Clustering on Wine Dataset")
plt.xlabel("Principal Component 1")
plt.ylabel("Principal Component 2")
plt.colorbar(label='Cluster')
plt.legend()
plt.show()

# Step 6: Detailed Cluster Analysis
# Calculate the mean values for each feature per cluster to understand cluster characteristics
cluster_analysis = data.groupby('Cluster').mean()
print("\nMean values of each feature per cluster:")
print(cluster_analysis)

Output:
Means of each Gaussian component: [[-0.92712973 -0.39128153 -0.49499088 0.17039636 -0.48628894 -0.07608854 0.0181207 -0.02946784 0.05791398 -0.90330802 0.45780093 0.2673636 -0.75592434] [ 0.16553389 0.8713558 0.18913283 0.52649363 -0.0735832 -0.97921403 -1.21592372 0.72617188 -0.77902384 0.94391583 -1.16490901 -1.29327167 -0.40596983] [ 0.83076818 -0.30532723 0.3610863 -0.60846529 0.56702326 0.8806841 0.97622314 -0.56358591 0.57719366 0.16891615 0.47624476 0.77976728 1.1200574 ]] Covariances of each Gaussian component: [[[ 4.27634112e-01 -4.10778719e-02 -1.36593911e-01 -3.81250002e-02 -4.67143883e-02 -7.08822151e-03 2.72471658e-02 -6.16676004e-02 -1.31845762e-01 7.16052510e-02 3.27294986e-03 -2.54707450e-02 -8.07204131e-03] [-4.10778719e-02 8.19020110e-01 1.70266815e-01 2.35891028e-01 -1.73807869e-02 8.19910379e-02 1.00776988e-01 6.32043410e-02 2.48999610e-01 -9.66779397e-02 -2.97420565e-01 1.44330957e-01 -8.99929865e-02] [-1.36593911e-01 1.70266815e-01 1.18126803e+00 6.36543227e-01 2.95957643e-02 -8.13972817e-03 7.53465468e-02 3.62213623e-01 -1.38124742e-02 -3.16566813e-02 -4.55374685e-02 2.02300749e-02 -1.08251902e-03] [-3.81250002e-02 2.35891028e-01 6.36543227e-01 8.12634055e-01 -1.90376572e-01 -2.40850674e-02 6.67264060e-02 2.25907440e-01 4.45098154e-02 -8.08742702e-02 -1.22219993e-01 1.64645454e-01 -7.35905054e-02] [-4.67143883e-02 -1.73807869e-02 2.95957643e-02 -1.90376572e-01 9.29771113e-01 -9.30930731e-02 -1.46406373e-01 -1.72427132e-01 1.48995111e-01 -4.09827256e-03 4.39178505e-02 -1.45975125e-01 1.51189489e-01] [-7.08822151e-03 8.19910379e-02 -8.13972817e-03 -2.40850674e-02 -9.30930731e-02 7.13946402e-01 4.01103365e-01 -3.41780228e-01 3.08860499e-01 5.26397115e-02 -2.37369445e-02 2.27621401e-01 -4.21493937e-02] [ 2.72471658e-02 1.00776988e-01 7.53465468e-02 6.67264060e-02 -1.46406373e-01 4.01103365e-01 3.68689593e-01 -1.67521506e-01 3.31856442e-01 7.33738047e-02 -3.93001993e-02 2.02269983e-01 -7.13949014e-02] [-6.16676004e-02 6.32043410e-02 3.62213623e-01 2.25907440e-01 -1.72427132e-01 -3.41780228e-01 -1.67521506e-01 9.42824067e-01 -2.97087113e-01 -5.59461295e-02 4.93669750e-02 -2.32835503e-01 -3.54783725e-02] [-1.31845762e-01 2.48999610e-01 -1.38124742e-02 4.45098154e-02 1.48995111e-01 3.08860499e-01 3.31856442e-01 -2.97087113e-01 9.94229070e-01 -4.90573539e-04 -1.21846805e-01 2.34528954e-01 -1.38643461e-02] [ 7.16052510e-02 -9.66779397e-02 -3.16566813e-02 -8.08742702e-02 -4.09827256e-03 5.26397115e-02 7.33738047e-02 -5.59461295e-02 -4.90573539e-04 1.15305368e-01 2.01006135e-02 -1.80630816e-02 2.94995370e-02] [ 3.27294986e-03 -2.97420565e-01 -4.55374685e-02 -1.22219993e-01 4.39178505e-02 -2.37369445e-02 -3.93001993e-02 4.93669750e-02 -1.21846805e-01 2.01006135e-02 7.77206490e-01 -8.77016718e-02 6.57581031e-03] [-2.54707450e-02 1.44330957e-01 2.02300749e-02 1.64645454e-01 -1.45975125e-01 2.27621401e-01 2.02269983e-01 -2.32835503e-01 2.34528954e-01 -1.80630816e-02 -8.77016718e-02 4.12759741e-01 -7.03422101e-02] [-8.07204131e-03 -8.99929865e-02 -1.08251902e-03 -7.35905054e-02 1.51189489e-01 -4.21493937e-02 -7.13949014e-02 -3.54783725e-02 -1.38643461e-02 2.94995370e-02 6.57581031e-03 -7.03422101e-02 2.03753192e-01]] [[ 4.07725583e-01 8.47920476e-02 1.32927763e-01 1.15406213e-01 -2.62635177e-02 6.69198576e-02 4.55913043e-04 1.35886523e-02 1.77261259e-01 2.36797158e-01 -2.80457419e-02 2.62720548e-02 -5.35875739e-03] [ 8.47920476e-02 9.45736551e-01 5.57772754e-02 9.94132797e-02 -1.34723858e-01 -1.02213269e-01 -8.76117969e-02 1.18308884e-01 -1.05414255e-01 -1.36081244e-01 -1.83854903e-03 1.27395771e-02 1.54061327e-02] [ 1.32927763e-01 5.57772754e-02 5.16109003e-01 4.05408769e-01 1.46019711e-01 1.54645920e-01 1.33830315e-02 -3.67510882e-02 1.24078032e-01 1.56775992e-01 1.21217107e-02 3.64740570e-02 9.33052369e-03] [ 1.15406213e-01 9.94132797e-02 4.05408769e-01 5.09935659e-01 1.06680880e-01 1.18056202e-01 2.59313422e-02 -1.44497293e-02 1.64909537e-01 1.60678540e-01 -2.26285428e-02 -8.71132867e-04 1.33424851e-02] [-2.62635177e-02 -1.34723858e-01 1.46019711e-01 1.06680880e-01 5.91107153e-01 -1.22751540e-02 9.28900354e-02 -3.76995221e-01 7.41701963e-02 1.34767206e-01 -1.08016898e-02 -8.41703268e-02 7.46498478e-02] [ 6.69198576e-02 -1.02213269e-01 1.54645920e-01 1.18056202e-01 -1.22751540e-02 3.07908135e-01 4.22255816e-02 1.78172181e-01 2.24580129e-01 1.73652439e-01 6.00513271e-03 3.78071088e-02 3.93587214e-03] [ 4.55913043e-04 -8.76117969e-02 1.33830315e-02 2.59313422e-02 9.28900354e-02 4.22255816e-02 1.03511962e-01 -1.46143333e-01 7.64616148e-02 6.07511489e-02 -1.37593624e-02 -3.45913503e-02 -4.32903557e-02] [ 1.35886523e-02 1.18308884e-01 -3.67510882e-02 -1.44497293e-02 -3.76995221e-01 1.78172181e-01 -1.46143333e-01 9.64694573e-01 1.24266788e-01 -6.89933945e-03 1.03980775e-01 1.16642545e-01 5.49040689e-02] [ 1.77261259e-01 -1.05414255e-01 1.24078032e-01 1.64909537e-01 7.41701963e-02 2.24580129e-01 7.64616148e-02 1.24266788e-01 5.09472506e-01 4.64809486e-01 -1.53197296e-01 -2.90585235e-02 6.01180730e-02] [ 2.36797158e-01 -1.36081244e-01 1.56775992e-01 1.60678540e-01 1.34767206e-01 1.73652439e-01 6.07511489e-02 -6.89933945e-03 4.64809486e-01 1.00691187e+00 -2.96445693e-01 -6.30730565e-02 8.04223059e-02] [-2.80457419e-02 -1.83854903e-03 1.21217107e-02 -2.26285428e-02 -1.08016898e-02 6.00513271e-03 -1.37593624e-02 1.03980775e-01 -1.53197296e-01 -2.96445693e-01 2.73836750e-01 6.78498633e-02 -1.89815336e-02] [ 2.62720548e-02 1.27395771e-02 3.64740570e-02 -8.71132867e-04 -8.41703268e-02 3.78071088e-02 -3.45913503e-02 1.16642545e-01 -2.90585235e-02 -6.30730565e-02 6.78498633e-02 1.47460976e-01 1.49620594e-02] [-5.35875739e-03 1.54061327e-02 9.33052369e-03 1.33424851e-02 7.46498478e-02 3.93587214e-03 -4.32903557e-02 5.49040689e-02 6.01180730e-02 8.04223059e-02 -1.89815336e-02 1.49620594e-02 1.44254510e-01]] [[ 4.72798085e-01 9.86720162e-04 -1.65172509e-01 -3.43355678e-01 -8.61489973e-02 1.07109995e-01 1.98767524e-02 -5.15409149e-02 6.53541127e-02 1.53481896e-01 1.88127055e-02 -2.27775919e-03 2.48296883e-01] [ 9.86720162e-04 3.58865067e-01 1.96928752e-02 6.33147179e-03 -4.46668256e-04 -2.44104125e-02 -3.64947525e-02 -2.65479258e-02 -4.93123915e-02 -6.55359524e-02 -1.31350287e-01 5.37911669e-02 -1.44850417e-01] [-1.65172509e-01 1.96928752e-02 7.92092644e-01 4.99224578e-01 2.25240788e-01 4.51913315e-02 8.47651170e-02 2.61211314e-01 -1.17498526e-01 -2.78340300e-02 6.84598730e-02 1.67967847e-02 -9.83659154e-02] [-3.43355678e-01 6.33147179e-03 4.99224578e-01 9.68994224e-01 3.35796108e-01 -1.87843068e-02 1.88124771e-02 1.61112045e-01 -6.92344006e-02 -1.37150577e-01 7.26405789e-02 2.16304338e-02 -2.03907011e-01] [-8.61489973e-02 -4.46668256e-04 2.25240788e-01 3.35796108e-01 8.35431744e-01 1.23479627e-01 1.90362352e-02 1.03444787e-01 1.15532321e-01 -3.07842769e-02 1.33625117e-02 2.53007935e-02 -1.36099231e-01] [ 1.07109995e-01 -2.44104125e-02 4.51913315e-02 -1.87843068e-02 1.23479627e-01 2.98163098e-01 1.87842270e-01 -7.45910094e-04 1.18140894e-01 1.81348963e-01 -5.62551881e-02 3.21609668e-02 9.39060170e-02] [ 1.98767524e-02 -3.64947525e-02 8.47651170e-02 1.88124771e-02 1.90362352e-02 1.87842270e-01 2.28298542e-01 2.88475272e-02 1.18869928e-01 1.72164659e-01 -2.58937363e-02 1.61179519e-02 4.03811115e-02] [-5.15409149e-02 -2.65479258e-02 2.61211314e-01 1.61112045e-01 1.03444787e-01 -7.45910094e-04 2.88475272e-02 3.38388877e-01 -4.85614839e-02 -3.17722991e-02 8.69168607e-02 -7.61820105e-02 -4.94010938e-02] [ 6.53541127e-02 -4.93123915e-02 -1.17498526e-01 -6.92344006e-02 1.15532321e-01 1.18140894e-01 1.18869928e-01 -4.85614839e-02 5.74148186e-01 1.06701031e-01 5.25860603e-02 -2.65475212e-02 5.07417690e-02] [ 1.53481896e-01 -6.55359524e-02 -2.78340300e-02 -1.37150577e-01 -3.07842769e-02 1.81348963e-01 1.72164659e-01 -3.17722991e-02 1.06701031e-01 3.08269198e-01 -2.06032068e-02 -3.71822681e-02 2.20446190e-01] [ 1.88127055e-02 -1.31350287e-01 6.84598730e-02 7.26405789e-02 1.33625117e-02 -5.62551881e-02 -2.58937363e-02 8.69168607e-02 5.25860603e-02 -2.06032068e-02 2.70677442e-01 -7.84320683e-02 1.24508751e-01] [-2.27775919e-03 5.37911669e-02 1.67967847e-02 2.16304338e-02 2.53007935e-02 3.21609668e-02 1.61179519e-02 -7.61820105e-02 -2.65475212e-02 -3.71822681e-02 -7.84320683e-02 2.58384188e-01 -1.36690624e-01] [ 2.48296883e-01 -1.44850417e-01 -9.83659154e-02 -2.03907011e-01 -1.36099231e-01 9.39060170e-02 4.03811115e-02 -4.94010938e-02 5.07417690e-02 2.20446190e-01 1.24508751e-01 -1.36690624e-01 5.45241977e-01]]] Mixing coefficients: [0.3643262 0.28609457 0.34957923] Responsibilities (first 5 data points): [[1.36788523e-13 4.07977194e-51 1.00000000e+00] [1.80955191e-09 2.12970309e-45 9.99999998e-01] [1.77819585e-09 4.31184715e-56 9.99999998e-01] [1.83249181e-22 2.35180356e-87 1.00000000e+00] [3.94006740e-04 1.24671913e-23 9.99605993e-01]]


Mean values of each feature per cluster: alcohol malic_acid ash alcalinity_of_ash magnesium \ Cluster 0 12.250923 1.897385 2.231231 20.063077 92.738462 1 13.134118 3.307255 2.417647 21.241176 98.666667 2 13.676774 1.997903 2.466290 17.462903 107.967742 total_phenols flavanoids nonflavanoid_phenols proanthocyanins \ Cluster 0 2.247692 2.050000 0.357692 1.624154 1 1.683922 0.818824 0.451961 1.145882 2 2.847581 3.003226 0.292097 1.922097 color_intensity hue od280/od315_of_diluted_wines proline Cluster 0 2.973077 1.062708 2.803385 510.169231 1 7.234706 0.691961 1.696667 619.058824 2 5.453548 1.065484 3.163387 1100.225806

Comments

Popular posts from this blog

Logistic Regression

Logistic regression is a statistical method used for binary classification problems. It's particularly useful when you need to predict the probability of a binary outcome based on one or more predictor variables. Here's a breakdown: What is Logistic Regression? Purpose : It models the probability of a binary outcome (e.g., yes/no, success/failure) using a logistic function (sigmoid function). Function : The logistic function maps predicted values (which are in a range from negative infinity to positive infinity) to a probability range between 0 and 1. Formula : The model is typically expressed as: P ( Y = 1 ∣ X ) = 1 1 + e − ( β 0 + β 1 X ) P(Y = 1 | X) = \frac{1}{1 + e^{-(\beta_0 + \beta_1 X)}} P ( Y = 1∣ X ) = 1 + e − ( β 0 ​ + β 1 ​ X ) 1 ​ Where P ( Y = 1 ∣ X ) P(Y = 1 | X) P ( Y = 1∣ X ) is the probability of the outcome being 1 given predictor X X X , and β 0 \beta_0 β 0 ​ and β 1 \beta_1 β 1 ​ are coefficients estimated during model training. When to Apply Logistic R...

Linear Regression using Ordinary Least Square method

Ordinary Least Square Method Download Dataset Step 1: Import the necessary libraries import numpy as np import pandas as pd import matplotlib.pyplot as plt Step 2: Load the CSV Data # Load the dataset data = pd.read_csv('house_data.csv') # Extract the features (X) and target variable (y) X = data['Size'].values y = data['Price'].values # Reshape X to be a 2D array X = X.reshape(-1, 1) # Add a column of ones to X for the intercept X_b = np.c_[np.ones((X.shape[0], 1)), X] Step 3: Add a Column of Ones to X for the Intercept # Add a column of ones to X for the intercept X_b = np.c_[np.ones((X.shape[0], 1)), X] Step 4: Implement the OLS Method # Calculate the OLS estimate of theta (the coefficients) theta_best = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y) Step 5: Make Predictions # Make predictions y_pred = X_b.dot(theta_best) Step 6: Visualize the Results # Plot the data and the regression line plt.scatter(X, y, color='blue', label='Data') plt.pl...

Quadratic Regression

  Quadratic regression is a statistical method used to model a relationship between variables with a parabolic best-fit curve, rather than a straight line. It's ideal when the data relationship appears curvilinear. The goal is to fit a quadratic equation   y=ax^2+bx+c y = a ⁢ x 2 + b ⁢ x + c to the observed data, providing a nuanced model of the relationship. Contrary to historical or biological connotations, "regression" in this mathematical context refers to advancing our understanding of complex relationships among variables, particularly when data follows a curvilinear pattern. Working with quadratic regression These calculations can become quite complex and tedious. We have just gone over a few very detailed formulas, but the truth is that we can handle these calculations with a graphing calculator. This saves us from having to go through so many steps -- but we still must understand the core concepts at play. Let's try a practice problem that includes quadratic ...