diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/ClusterUtils.java b/src/main/java/org/apache/commons/math4/ml/clustering/ClusterUtils.java
new file mode 100644
index 0000000000..e4ae9ce75f
--- /dev/null
+++ b/src/main/java/org/apache/commons/math4/ml/clustering/ClusterUtils.java
@@ -0,0 +1,124 @@
+package org.apache.commons.math4.ml.clustering;
+
+import org.apache.commons.math4.exception.ConvergenceException;
+import org.apache.commons.math4.exception.util.LocalizedFormats;
+import org.apache.commons.math4.ml.distance.DistanceMeasure;
+import org.apache.commons.math4.ml.distance.EuclideanDistance;
+import org.apache.commons.math4.stat.descriptive.moment.Variance;
+import org.apache.commons.rng.UniformRandomProvider;
+
+import java.util.Collection;
+import java.util.List;
+
+/**
+ * Common functions used in clustering
+ */
+public class ClusterUtils {
+    /**
+     * Use only for static
+     */
+    private ClusterUtils() {
+    }
+
+    public static final DistanceMeasure DEFAULT_MEASURE = new EuclideanDistance();
+
+    /**
+     * Predict which cluster is best for the point
+     *
+     * @param clusters cluster to predict into
+     * @param point    point to predict
+     * @param measure  distance measurer
+     * @param <T>      type of cluster point
+     * @return the cluster which has nearest center to the point
+     */
+    public static <T extends Clusterable> CentroidCluster<T> predict(List<CentroidCluster<T>> clusters, Clusterable point, DistanceMeasure measure) {
+        double minDistance = Double.POSITIVE_INFINITY;
+        CentroidCluster<T> nearestCluster = null;
+        for (CentroidCluster<T> cluster : clusters) {
+            double distance = measure.compute(point.getPoint(), cluster.getCenter().getPoint());
+            if (distance < minDistance) {
+                minDistance = distance;
+                nearestCluster = cluster;
+            }
+        }
+        return nearestCluster;
+    }
+
+    /**
+     * Predict which cluster is best for the point
+     *
+     * @param clusters cluster to predict into
+     * @param point    point to predict
+     * @param <T>      type of cluster point
+     * @return the cluster which has nearest center to the point
+     */
+    public static <T extends Clusterable> CentroidCluster<T> predict(List<CentroidCluster<T>> clusters, Clusterable point) {
+        return predict(clusters, point, DEFAULT_MEASURE);
+    }
+
+    /**
+     * Computes the centroid for a set of points.
+     *
+     * @param points    the set of points
+     * @param dimension the point dimension
+     * @return the computed centroid for the set of points
+     */
+    public static <T extends Clusterable> Clusterable centroidOf(final Collection<T> points, final int dimension) {
+        final double[] centroid = new double[dimension];
+        for (final T p : points) {
+            final double[] point = p.getPoint();
+            for (int i = 0; i < centroid.length; i++) {
+                centroid[i] += point[i];
+            }
+        }
+        for (int i = 0; i < centroid.length; i++) {
+            centroid[i] /= points.size();
+        }
+        return new DoublePoint(centroid);
+    }
+
+
+    /**
+     * Get a random point from the {@link Cluster} with the largest distance variance.
+     *
+     * @param clusters the {@link Cluster}s to search
+     * @param measure  DistanceMeasure
+     * @param random   Random generator
+     * @return a random point from the selected cluster
+     * @throws ConvergenceException if clusters are all empty
+     */
+    public static <T extends Clusterable> T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters,
+                                                                               final DistanceMeasure measure,
+                                                                               final UniformRandomProvider random)
+            throws ConvergenceException {
+        double maxVariance = Double.NEGATIVE_INFINITY;
+        Cluster<T> selected = null;
+        for (final CentroidCluster<T> cluster : clusters) {
+            if (!cluster.getPoints().isEmpty()) {
+                // compute the distance variance of the current cluster
+                final Clusterable center = cluster.getCenter();
+                final Variance stat = new Variance();
+                for (final T point : cluster.getPoints()) {
+                    stat.increment(measure.compute(point.getPoint(), center.getPoint()));
+                }
+                final double variance = stat.getResult();
+
+                // select the cluster with the largest variance
+                if (variance > maxVariance) {
+                    maxVariance = variance;
+                    selected = cluster;
+                }
+
+            }
+        }
+
+        // did we find at least one non-empty cluster ?
+        if (selected == null) {
+            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
+        }
+
+        // extract a random point from the cluster
+        final List<T> selectedPoints = selected.getPoints();
+        return selectedPoints.remove(random.nextInt(selectedPoints.size()));
+    }
+}
diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java
index 74699ffb07..48208b432c 100644
--- a/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java
+++ b/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java
@@ -19,13 +19,14 @@
 
 import java.util.ArrayList;
 import java.util.Collection;
-import java.util.Collections;
 import java.util.List;
 
 import org.apache.commons.math4.exception.ConvergenceException;
 import org.apache.commons.math4.exception.MathIllegalArgumentException;
 import org.apache.commons.math4.exception.NumberIsTooSmallException;
 import org.apache.commons.math4.exception.util.LocalizedFormats;
+import org.apache.commons.math4.ml.clustering.initialization.CentroidInitializer;
+import org.apache.commons.math4.ml.clustering.initialization.KMeansPlusPlusCentroidInitializer;
 import org.apache.commons.math4.ml.distance.DistanceMeasure;
 import org.apache.commons.math4.ml.distance.EuclideanDistance;
 import org.apache.commons.rng.simple.RandomSource;
@@ -35,42 +36,67 @@
 
 /**
  * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
+ *
  * @param <T> type of the points to cluster
  * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
  * @since 3.2
  */
 public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
 
-    /** Strategies to use for replacing an empty cluster. */
+    /**
+     * Strategies to use for replacing an empty cluster.
+     */
     public enum EmptyClusterStrategy {
 
-        /** Split the cluster with largest distance variance. */
+        /**
+         * Split the cluster with largest distance variance.
+         */
         LARGEST_VARIANCE,
 
-        /** Split the cluster with largest number of points. */
+        /**
+         * Split the cluster with largest number of points.
+         */
         LARGEST_POINTS_NUMBER,
 
-        /** Create a cluster around the point farthest from its centroid. */
+        /**
+         * Create a cluster around the point farthest from its centroid.
+         */
         FARTHEST_POINT,
 
-        /** Generate an error. */
+        /**
+         * Generate an error.
+         */
         ERROR
 
     }
 
-    /** The number of clusters. */
+    /**
+     * The number of clusters.
+     */
     private final int k;
 
-    /** The maximum number of iterations. */
+    /**
+     * The maximum number of iterations.
+     */
     private final int maxIterations;
 
-    /** Random generator for choosing initial centers. */
+    /**
+     * Random generator for choosing initial centers.
+     */
     private final UniformRandomProvider random;
 
-    /** Selected strategy for empty clusters. */
+    /**
+     * Selected strategy for empty clusters.
+     */
     private final EmptyClusterStrategy emptyStrategy;
 
-    /** Build a clusterer.
+    /**
+     * Centroid initial algorithm
+     */
+    private final CentroidInitializer centroidInitializer;
+
+    /**
+     * Build a clusterer.
      * <p>
      * The default strategy for handling empty clusters that may appear during
      * algorithm iterations is to split the cluster with largest distance variance.
@@ -83,45 +109,48 @@ public KMeansPlusPlusClusterer(final int k) {
         this(k, -1);
     }
 
-    /** Build a clusterer.
+    /**
+     * Build a clusterer.
      * <p>
      * The default strategy for handling empty clusters that may appear during
      * algorithm iterations is to split the cluster with largest distance variance.
      * <p>
      * The euclidean distance will be used as default distance measure.
      *
-     * @param k the number of clusters to split the data into
+     * @param k             the number of clusters to split the data into
      * @param maxIterations the maximum number of iterations to run the algorithm for.
-     *   If negative, no maximum will be used.
+     *                      If negative, no maximum will be used.
      */
     public KMeansPlusPlusClusterer(final int k, final int maxIterations) {
         this(k, maxIterations, new EuclideanDistance());
     }
 
-    /** Build a clusterer.
+    /**
+     * Build a clusterer.
      * <p>
      * The default strategy for handling empty clusters that may appear during
      * algorithm iterations is to split the cluster with largest distance variance.
      *
-     * @param k the number of clusters to split the data into
+     * @param k             the number of clusters to split the data into
      * @param maxIterations the maximum number of iterations to run the algorithm for.
-     *   If negative, no maximum will be used.
-     * @param measure the distance measure to use
+     *                      If negative, no maximum will be used.
+     * @param measure       the distance measure to use
      */
     public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
         this(k, maxIterations, measure, RandomSource.create(RandomSource.MT_64));
     }
 
-    /** Build a clusterer.
+    /**
+     * Build a clusterer.
      * <p>
      * The default strategy for handling empty clusters that may appear during
      * algorithm iterations is to split the cluster with largest distance variance.
      *
-     * @param k the number of clusters to split the data into
+     * @param k             the number of clusters to split the data into
      * @param maxIterations the maximum number of iterations to run the algorithm for.
-     *   If negative, no maximum will be used.
-     * @param measure the distance measure to use
-     * @param random random generator to use for choosing initial centers
+     *                      If negative, no maximum will be used.
+     * @param measure       the distance measure to use
+     * @param random        random generator to use for choosing initial centers
      */
     public KMeansPlusPlusClusterer(final int k, final int maxIterations,
                                    final DistanceMeasure measure,
@@ -129,29 +158,33 @@ public KMeansPlusPlusClusterer(final int k, final int maxIterations,
         this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE);
     }
 
-    /** Build a clusterer.
+    /**
+     * Build a clusterer.
      *
-     * @param k the number of clusters to split the data into
+     * @param k             the number of clusters to split the data into
      * @param maxIterations the maximum number of iterations to run the algorithm for.
-     *   If negative, no maximum will be used.
-     * @param measure the distance measure to use
-     * @param random random generator to use for choosing initial centers
+     *                      If negative, no maximum will be used.
+     * @param measure       the distance measure to use
+     * @param random        random generator to use for choosing initial centers
      * @param emptyStrategy strategy to use for handling empty clusters that
-     * may appear during algorithm iterations
+     *                      may appear during algorithm iterations
      */
     public KMeansPlusPlusClusterer(final int k, final int maxIterations,
                                    final DistanceMeasure measure,
                                    final UniformRandomProvider random,
                                    final EmptyClusterStrategy emptyStrategy) {
         super(measure);
-        this.k             = k;
+        this.k = k;
         this.maxIterations = maxIterations;
-        this.random        = random;
+        this.random = random;
         this.emptyStrategy = emptyStrategy;
+        // It is a Common KMeans algorithm if centroidInitializer is not KMeansPlusPlus algorithm.
+        this.centroidInitializer = new KMeansPlusPlusCentroidInitializer(measure, random);
     }
 
     /**
      * Return the number of clusters this instance will use.
+     *
      * @return the number of clusters
      */
     public int getK() {
@@ -160,6 +193,7 @@ public int getK() {
 
     /**
      * Returns the maximum number of iterations this instance will use.
+     *
      * @return the maximum number of iterations, or -1 if no maximum is set
      */
     public int getMaxIterations() {
@@ -168,6 +202,7 @@ public int getMaxIterations() {
 
     /**
      * Returns the random generator this instance will use.
+     *
      * @return the random generator
      */
     public UniformRandomProvider getRandomGenerator() {
@@ -176,6 +211,7 @@ public UniformRandomProvider getRandomGenerator() {
 
     /**
      * Returns the {@link EmptyClusterStrategy} used by this instance.
+     *
      * @return the {@link EmptyClusterStrategy}
      */
     public EmptyClusterStrategy getEmptyClusterStrategy() {
@@ -188,13 +224,13 @@ public EmptyClusterStrategy getEmptyClusterStrategy() {
      * @param points the points to cluster
      * @return a list of clusters containing the points
      * @throws MathIllegalArgumentException if the data points are null or the number
-     *     of clusters is larger than the number of data points
-     * @throws ConvergenceException if an empty cluster is encountered and the
-     * {@link #emptyStrategy} is set to {@code ERROR}
+     *                                      of clusters is larger than the number of data points
+     * @throws ConvergenceException         if an empty cluster is encountered and the
+     *                                      {@link #emptyStrategy} is set to {@code ERROR}
      */
     @Override
     public List<CentroidCluster<T>> cluster(final Collection<T> points)
-        throws MathIllegalArgumentException, ConvergenceException {
+            throws MathIllegalArgumentException, ConvergenceException {
 
         // sanity checks
         MathUtils.checkNotNull(points);
@@ -205,7 +241,7 @@ public List<CentroidCluster<T>> cluster(final Collection<T> points)
         }
 
         // create the initial clusters
-        List<CentroidCluster<T>> clusters = chooseInitialCenters(points);
+        List<CentroidCluster<T>> clusters = centroidInitializer.selectCentroids(points, k);
 
         // create an array containing the latest assignment of a point to a cluster
         // no need to initialize the array, as it will be filled with the first assignment
@@ -221,21 +257,21 @@ public List<CentroidCluster<T>> cluster(final Collection<T> points)
                 final Clusterable newCenter;
                 if (cluster.getPoints().isEmpty()) {
                     switch (emptyStrategy) {
-                        case LARGEST_VARIANCE :
+                        case LARGEST_VARIANCE:
                             newCenter = getPointFromLargestVarianceCluster(clusters);
                             break;
-                        case LARGEST_POINTS_NUMBER :
+                        case LARGEST_POINTS_NUMBER:
                             newCenter = getPointFromLargestNumberCluster(clusters);
                             break;
-                        case FARTHEST_POINT :
+                        case FARTHEST_POINT:
                             newCenter = getFarthestPoint(clusters);
                             break;
-                        default :
+                        default:
                             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
                     }
                     emptyCluster = true;
                 } else {
-                    newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length);
+                    newCenter = ClusterUtils.centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length);
                 }
                 newClusters.add(new CentroidCluster<T>(newCenter));
             }
@@ -254,8 +290,8 @@ public List<CentroidCluster<T>> cluster(final Collection<T> points)
     /**
      * Adds the given points to the closest {@link Cluster}.
      *
-     * @param clusters the {@link Cluster}s to add the points to
-     * @param points the points to add to the given {@link Cluster}s
+     * @param clusters    the {@link Cluster}s to add the points to
+     * @param points      the points to add to the given {@link Cluster}s
      * @param assignments points assignments to clusters
      * @return the number of points assigned to different clusters as the iteration before
      */
@@ -278,131 +314,6 @@ private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
         return assignedDifferently;
     }
 
-    /**
-     * Use K-means++ to choose the initial centers.
-     *
-     * @param points the points to choose the initial centers from
-     * @return the initial centers
-     */
-    private List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {
-
-        // Convert to list for indexed access. Make it unmodifiable, since removal of items
-        // would screw up the logic of this method.
-        final List<T> pointList = Collections.unmodifiableList(new ArrayList<> (points));
-
-        // The number of points in the list.
-        final int numPoints = pointList.size();
-
-        // Set the corresponding element in this array to indicate when
-        // elements of pointList are no longer available.
-        final boolean[] taken = new boolean[numPoints];
-
-        // The resulting list of initial centers.
-        final List<CentroidCluster<T>> resultSet = new ArrayList<>();
-
-        // Choose one center uniformly at random from among the data points.
-        final int firstPointIndex = random.nextInt(numPoints);
-
-        final T firstPoint = pointList.get(firstPointIndex);
-
-        resultSet.add(new CentroidCluster<T>(firstPoint));
-
-        // Must mark it as taken
-        taken[firstPointIndex] = true;
-
-        // To keep track of the minimum distance squared of elements of
-        // pointList to elements of resultSet.
-        final double[] minDistSquared = new double[numPoints];
-
-        // Initialize the elements.  Since the only point in resultSet is firstPoint,
-        // this is very easy.
-        for (int i = 0; i < numPoints; i++) {
-            if (i != firstPointIndex) { // That point isn't considered
-                double d = distance(firstPoint, pointList.get(i));
-                minDistSquared[i] = d*d;
-            }
-        }
-
-        while (resultSet.size() < k) {
-
-            // Sum up the squared distances for the points in pointList not
-            // already taken.
-            double distSqSum = 0.0;
-
-            for (int i = 0; i < numPoints; i++) {
-                if (!taken[i]) {
-                    distSqSum += minDistSquared[i];
-                }
-            }
-
-            // Add one new data point as a center. Each point x is chosen with
-            // probability proportional to D(x)2
-            final double r = random.nextDouble() * distSqSum;
-
-            // The index of the next point to be added to the resultSet.
-            int nextPointIndex = -1;
-
-            // Sum through the squared min distances again, stopping when
-            // sum >= r.
-            double sum = 0.0;
-            for (int i = 0; i < numPoints; i++) {
-                if (!taken[i]) {
-                    sum += minDistSquared[i];
-                    if (sum >= r) {
-                        nextPointIndex = i;
-                        break;
-                    }
-                }
-            }
-
-            // If it's not set to >= 0, the point wasn't found in the previous
-            // for loop, probably because distances are extremely small.  Just pick
-            // the last available point.
-            if (nextPointIndex == -1) {
-                for (int i = numPoints - 1; i >= 0; i--) {
-                    if (!taken[i]) {
-                        nextPointIndex = i;
-                        break;
-                    }
-                }
-            }
-
-            // We found one.
-            if (nextPointIndex >= 0) {
-
-                final T p = pointList.get(nextPointIndex);
-
-                resultSet.add(new CentroidCluster<T> (p));
-
-                // Mark it as taken.
-                taken[nextPointIndex] = true;
-
-                if (resultSet.size() < k) {
-                    // Now update elements of minDistSquared.  We only have to compute
-                    // the distance to the new center to do this.
-                    for (int j = 0; j < numPoints; j++) {
-                        // Only have to worry about the points still not taken.
-                        if (!taken[j]) {
-                            double d = distance(p, pointList.get(j));
-                            double d2 = d * d;
-                            if (d2 < minDistSquared[j]) {
-                                minDistSquared[j] = d2;
-                            }
-                        }
-                    }
-                }
-
-            } else {
-                // None found --
-                // Break from the while loop to prevent
-                // an infinite loop.
-                break;
-            }
-        }
-
-        return resultSet;
-    }
-
     /**
      * Get a random point from the {@link Cluster} with the largest distance variance.
      *
@@ -502,9 +413,9 @@ private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) throws
             for (int i = 0; i < points.size(); ++i) {
                 final double distance = distance(points.get(i), center);
                 if (distance > maxDistance) {
-                    maxDistance     = distance;
+                    maxDistance = distance;
                     selectedCluster = cluster;
-                    selectedPoint   = i;
+                    selectedPoint = i;
                 }
             }
 
@@ -523,7 +434,7 @@ private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) throws
      * Returns the nearest {@link Cluster} to the given point
      *
      * @param clusters the {@link Cluster}s to search
-     * @param point the point to find the nearest {@link Cluster} for
+     * @param point    the point to find the nearest {@link Cluster} for
      * @return the index of the nearest {@link Cluster} to the given point
      */
     private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) {
@@ -540,26 +451,4 @@ private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, fin
         }
         return minCluster;
     }
-
-    /**
-     * Computes the centroid for a set of points.
-     *
-     * @param points the set of points
-     * @param dimension the point dimension
-     * @return the computed centroid for the set of points
-     */
-    private Clusterable centroidOf(final Collection<T> points, final int dimension) {
-        final double[] centroid = new double[dimension];
-        for (final T p : points) {
-            final double[] point = p.getPoint();
-            for (int i = 0; i < centroid.length; i++) {
-                centroid[i] += point[i];
-            }
-        }
-        for (int i = 0; i < centroid.length; i++) {
-            centroid[i] /= points.size();
-        }
-        return new DoublePoint(centroid);
-    }
-
 }
diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java b/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java
new file mode 100644
index 0000000000..9c848d61c2
--- /dev/null
+++ b/src/main/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.java
@@ -0,0 +1,297 @@
+package org.apache.commons.math4.ml.clustering;
+
+import org.apache.commons.math4.exception.ConvergenceException;
+import org.apache.commons.math4.exception.MathIllegalArgumentException;
+import org.apache.commons.math4.exception.NumberIsTooSmallException;
+import org.apache.commons.math4.ml.clustering.initialization.CentroidInitializer;
+import org.apache.commons.math4.ml.clustering.initialization.KMeansPlusPlusCentroidInitializer;
+import org.apache.commons.math4.ml.distance.DistanceMeasure;
+import org.apache.commons.math4.util.MathUtils;
+import org.apache.commons.math4.util.Pair;
+import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.sampling.ListSampler;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+/**
+ * A very fast clustering algorithm base on KMeans(Refer to Python sklearn.cluster.MiniBatchKMeans)
+ * Use a partial points in initialize cluster centers, and mini batch in iterations.
+ * It finish in few seconds when clustering millions of data, and has few differences between KMeans.
+ * See https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf
+ *
+ * @param <T> Type of the points to cluster
+ */
+public class MiniBatchKMeansClusterer<T extends Clusterable> extends Clusterer<T> {
+    /**
+     * The number of clusters.
+     */
+    private final int k;
+
+    /**
+     * The maximum number of iterations.
+     */
+    private final int maxIterations;
+
+    /**
+     * Batch data size in iteration.
+     */
+    private final int batchSize;
+    /**
+     * Iteration count of initialize the centers.
+     */
+    private final int initIterations;
+    /**
+     * Data size of batch to initialize the centers, default 3*k
+     */
+    private final int initBatchSize;
+    /**
+     * Max iterate times when no improvement on step iterations.
+     */
+    private final int maxNoImprovementTimes;
+    /**
+     * Random generator for choosing initial centers.
+     */
+    private final UniformRandomProvider random;
+
+    /**
+     * Centroid initial algorithm
+     */
+    private final CentroidInitializer centroidInitializer;
+
+
+    /**
+     * Build a clusterer.
+     *
+     * @param k                     the number of clusters to split the data into
+     * @param maxIterations         the maximum number of iterations to run the algorithm for.
+     *                              If negative, no maximum will be used.
+     * @param batchSize             the mini batch size for training iterations.
+     * @param initIterations        the iterations to find out the best clusters centers.
+     * @param initBatchSize         the mini batch size to initial the clusters centers.
+     * @param maxNoImprovementTimes the max iterations times when the square distance has no improvement.
+     * @param measure               the distance measure to use
+     * @param random                random generator to use for choosing initial centers
+     *                              may appear during algorithm iterations
+     * @param centroidInitializer   the centroid initializer algorithm
+     */
+    public MiniBatchKMeansClusterer(final int k, int maxIterations, final int batchSize, final int initIterations,
+                                    final int initBatchSize, final int maxNoImprovementTimes,
+                                    final DistanceMeasure measure, final UniformRandomProvider random,
+                                    final CentroidInitializer centroidInitializer) {
+        super(measure);
+        this.k = k;
+        this.maxIterations = maxIterations > 0 ? maxIterations : 100;
+        this.batchSize = batchSize;
+        this.initIterations = initIterations;
+        this.initBatchSize = initBatchSize;
+        this.maxNoImprovementTimes = maxNoImprovementTimes;
+        this.random = random;
+        this.centroidInitializer = centroidInitializer;
+    }
+
+    /**
+     * Build a clusterer
+     *
+     * @param k             the number of clusters to split the data into
+     * @param maxIterations the maximum number of iterations to run the algorithm for.
+     *                      If negative, no maximum will be used.
+     * @param measure       the distance measure to use
+     * @param random        random generator to use for choosing initial centers
+     *                      may appear during algorithm iterations
+     */
+    public MiniBatchKMeansClusterer(int k, int maxIterations, DistanceMeasure measure, UniformRandomProvider random) {
+        this(k, maxIterations, 100, 3, 100 * 3, 10,
+                measure, random, new KMeansPlusPlusCentroidInitializer(measure, random));
+    }
+
+    /**
+     * Runs the MiniBatch K-means clustering algorithm.
+     *
+     * @param points the points to cluster
+     * @return a list of clusters containing the points
+     * @throws MathIllegalArgumentException if the data points are null or the number
+     *                                      of clusters is larger than the number of data points
+     */
+    @Override
+    public List<CentroidCluster<T>> cluster(Collection<T> points) throws MathIllegalArgumentException, ConvergenceException {
+        // sanity checks
+        MathUtils.checkNotNull(points);
+
+        // number of clusters has to be smaller or equal the number of data points
+        if (points.size() < k) {
+            throw new NumberIsTooSmallException(points.size(), k, false);
+        }
+
+        int pointSize = points.size();
+        int batchSize = this.batchSize;
+        int batchCount = pointSize / batchSize + ((pointSize % batchSize > 0) ? 1 : 0);
+        int maxIterations = (this.maxIterations <= 0) ? Integer.MAX_VALUE : (this.maxIterations * batchCount);
+        MiniBatchImprovementEvaluator evaluator = new MiniBatchImprovementEvaluator();
+        List<CentroidCluster<T>> clusters = initialCenters(points);
+        for (int i = 0; i < maxIterations; i++) {
+            //Clear points in clusters
+            clearClustersPoints(clusters);
+            //Random sampling a mini batch of points.
+            List<T> batchPoints = randomMiniBatch(points, batchSize);
+            // Processing the mini batch training step
+            Pair<Double, List<CentroidCluster<T>>> pair = step(batchPoints, clusters);
+            double squareDistance = pair.getFirst();
+            clusters = pair.getSecond();
+            // Evaluate the training can finished early.
+            if (evaluator.convergence(squareDistance, pointSize)) break;
+        }
+        clearClustersPoints(clusters);
+        //Add every mini batch points to their nearest cluster.
+        for (T point : points) {
+            addToNearestCentroidCluster(point, clusters);
+        }
+        return clusters;
+    }
+
+    /**
+     * clear clustered points
+     *
+     * @param clusters The clusters to clear
+     */
+    private void clearClustersPoints(List<CentroidCluster<T>> clusters) {
+        for (CentroidCluster<T> cluster : clusters) {
+            cluster.getPoints().clear();
+        }
+    }
+
+    /**
+     * Mini batch iteration step
+     *
+     * @param batchPoints The mini batch points.
+     * @param clusters    The cluster centers.
+     * @return Square distance of all the batch points to the nearest center, and newly clusters.
+     */
+    private Pair<Double, List<CentroidCluster<T>>> step(
+            List<T> batchPoints,
+            List<CentroidCluster<T>> clusters) {
+        //Add every mini batch points to their nearest cluster.
+        for (T point : batchPoints) {
+            addToNearestCentroidCluster(point, clusters);
+        }
+        List<CentroidCluster<T>> newClusters = new ArrayList<CentroidCluster<T>>(clusters.size());
+        //Refresh then cluster centroid.
+        for (CentroidCluster<T> cluster : clusters) {
+            Clusterable newCenter;
+            if (cluster.getPoints().isEmpty()) {
+                newCenter = new DoublePoint(ClusterUtils.getPointFromLargestVarianceCluster(clusters, this.getDistanceMeasure(), random).getPoint());
+            } else {
+                newCenter = ClusterUtils.centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length);
+            }
+            newClusters.add(new CentroidCluster<T>(newCenter));
+        }
+        // Add every mini batch points to their nearest cluster again.
+        double squareDistance = 0.0;
+        for (T point : batchPoints) {
+            double d = addToNearestCentroidCluster(point, newClusters);
+            squareDistance += d * d;
+        }
+        return new Pair<Double, List<CentroidCluster<T>>>(squareDistance, newClusters);
+    }
+
+    /**
+     * Get a mini batch of points
+     *
+     * @param points    all the points
+     * @param batchSize the mini batch size
+     * @return mini batch of all the points
+     */
+    private List<T> randomMiniBatch(Collection<T> points, int batchSize) {
+        ArrayList<T> list = new ArrayList<T>(points);
+        ListSampler.shuffle(random, list);
+        return list.subList(0, batchSize);
+    }
+
+    /**
+     * Initial cluster centers with multiply iterations, find out the best.
+     *
+     * @param points Points use to initial the cluster centers.
+     * @return Clusters with center
+     */
+    private List<CentroidCluster<T>> initialCenters(Collection<T> points) {
+        List<T> validPoints = initBatchSize < points.size() ? randomMiniBatch(points, initBatchSize) : new ArrayList<T>(points);
+        double nearestSquareDistance = Double.POSITIVE_INFINITY;
+        List<CentroidCluster<T>> bestCenters = null;
+        for (int i = 0; i < initIterations; i++) {
+            List<T> initialPoints = (initBatchSize < points.size()) ? randomMiniBatch(points, initBatchSize) : new ArrayList<T>(points);
+            List<CentroidCluster<T>> clusters = centroidInitializer.selectCentroids(initialPoints, k);
+            Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters);
+            double squareDistance = pair.getFirst();
+            List<CentroidCluster<T>> newClusters = pair.getSecond();
+            //Find out a best centers that has the nearest total square distance.
+            if (squareDistance < nearestSquareDistance) {
+                nearestSquareDistance = squareDistance;
+                bestCenters = newClusters;
+            }
+        }
+        return bestCenters;
+    }
+
+    /**
+     * Add a point to the cluster which the nearest center belong to.
+     *
+     * @param point    The point to add.
+     * @param clusters The clusters to add to.
+     * @return The distance to nearest center.
+     */
+    private double addToNearestCentroidCluster(T point, List<CentroidCluster<T>> clusters) {
+        double minDistance = Double.POSITIVE_INFINITY;
+        CentroidCluster<T> nearestCentroidCluster = null;
+        for (CentroidCluster<T> centroidCluster : clusters) {
+            double distance = distance(point, centroidCluster.getCenter());
+            if (distance < minDistance) {
+                minDistance = distance;
+                nearestCentroidCluster = centroidCluster;
+            }
+        }
+        assert nearestCentroidCluster != null;
+        nearestCentroidCluster.addPoint(point);
+        return minDistance;
+    }
+
+    /**
+     * The Evaluator to evaluate whether the iteration should finish where square has no improvement for appointed times.
+     */
+    class MiniBatchImprovementEvaluator {
+        private Double ewaInertia = null;
+        private double ewaInertiaMin = Double.POSITIVE_INFINITY;
+        private int noImprovementTimes = 0;
+
+        /**
+         * Evaluate whether the iteration should finish where square has no improvement for appointed times
+         *
+         * @param squareDistance the total square distance of the mini batch points to their nearest center.
+         * @param pointSize      size of the the data points.
+         * @return true if no improvement for appointed times, otherwise false
+         */
+        public boolean convergence(double squareDistance, int pointSize) {
+            double batchInertia = squareDistance / batchSize;
+            if (ewaInertia == null) {
+                ewaInertia = batchInertia;
+            } else {
+                // Refer to sklearn, pointSize+1 maybe intent to avoid the div/0 error,
+                // but java double does not have a div/0 error
+                double alpha = batchSize * 2.0 / (pointSize + 1);
+                alpha = Math.min(alpha, 1.0);
+                ewaInertia = ewaInertia * (1 - alpha) + batchInertia * alpha;
+            }
+
+            // Improved
+            if (ewaInertia < ewaInertiaMin) {
+                noImprovementTimes = 0;
+                ewaInertiaMin = ewaInertia;
+            } else {
+                // No improvement
+                noImprovementTimes++;
+            }
+            // Has no improvement continuous for many times
+            return noImprovementTimes >= maxNoImprovementTimes;
+        }
+    }
+}
diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/evaluation/ClusterEvaluator.java b/src/main/java/org/apache/commons/math4/ml/clustering/evaluation/ClusterEvaluator.java
index 5f364c040e..314b6f8aac 100644
--- a/src/main/java/org/apache/commons/math4/ml/clustering/evaluation/ClusterEvaluator.java
+++ b/src/main/java/org/apache/commons/math4/ml/clustering/evaluation/ClusterEvaluator.java
@@ -19,10 +19,7 @@
 
 import java.util.List;
 
-import org.apache.commons.math4.ml.clustering.CentroidCluster;
-import org.apache.commons.math4.ml.clustering.Cluster;
-import org.apache.commons.math4.ml.clustering.Clusterable;
-import org.apache.commons.math4.ml.clustering.DoublePoint;
+import org.apache.commons.math4.ml.clustering.*;
 import org.apache.commons.math4.ml.distance.DistanceMeasure;
 import org.apache.commons.math4.ml.distance.EuclideanDistance;
 
@@ -106,17 +103,7 @@ protected Clusterable centroidOf(final Cluster<T> cluster) {
         }
 
         final int dimension = points.get(0).getPoint().length;
-        final double[] centroid = new double[dimension];
-        for (final T p : points) {
-            final double[] point = p.getPoint();
-            for (int i = 0; i < centroid.length; i++) {
-                centroid[i] += point[i];
-            }
-        }
-        for (int i = 0; i < centroid.length; i++) {
-            centroid[i] /= points.size();
-        }
-        return new DoublePoint(centroid);
+        return ClusterUtils.centroidOf(points,dimension);
     }
 
 }
diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java
new file mode 100644
index 0000000000..4adb67406c
--- /dev/null
+++ b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java
@@ -0,0 +1,22 @@
+package org.apache.commons.math4.ml.clustering.initialization;
+
+import org.apache.commons.math4.ml.clustering.CentroidCluster;
+import org.apache.commons.math4.ml.clustering.Clusterable;
+
+import java.util.Collection;
+import java.util.List;
+
+/**
+ * Interface abstract the algorithm for clusterer to choose the initial centers.
+ */
+public interface CentroidInitializer {
+
+    /**
+     * Choose the initial centers.
+     *
+     * @param points the points to choose the initial centers from
+     * @param k      The number of clusters
+     * @return the initial centers
+     */
+    <T extends Clusterable> List<CentroidCluster<T>> selectCentroids(final Collection<T> points, final int k);
+}
diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java
new file mode 100644
index 0000000000..bc94987979
--- /dev/null
+++ b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java
@@ -0,0 +1,169 @@
+package org.apache.commons.math4.ml.clustering.initialization;
+
+import org.apache.commons.math4.ml.clustering.CentroidCluster;
+import org.apache.commons.math4.ml.clustering.Clusterable;
+import org.apache.commons.math4.ml.distance.DistanceMeasure;
+import org.apache.commons.rng.UniformRandomProvider;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * Use K-means++ to choose the initial centers.
+ *
+ * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
+ */
+public class KMeansPlusPlusCentroidInitializer implements CentroidInitializer {
+    private final DistanceMeasure measure;
+    private final UniformRandomProvider random;
+
+    /**
+     * Build a K-means++ CentroidInitializer
+     * @param measure the distance measure to use
+     * @param random the random to use.
+     */
+    public KMeansPlusPlusCentroidInitializer(final DistanceMeasure measure, final UniformRandomProvider random) {
+        this.measure = measure;
+        this.random = random;
+    }
+
+    /**
+     * Use K-means++ to choose the initial centers.
+     *
+     * @param points the points to choose the initial centers from
+     * @param k      The number of clusters
+     * @return the initial centers
+     */
+    @Override
+    public <T extends Clusterable> List<CentroidCluster<T>> selectCentroids(final Collection<T> points, final int k) {
+        // Convert to list for indexed access. Make it unmodifiable, since removal of items
+        // would screw up the logic of this method.
+        final List<T> pointList = Collections.unmodifiableList(new ArrayList<>(points));
+
+        // The number of points in the list.
+        final int numPoints = pointList.size();
+
+        // Set the corresponding element in this array to indicate when
+        // elements of pointList are no longer available.
+        final boolean[] taken = new boolean[numPoints];
+
+        // The resulting list of initial centers.
+        final List<CentroidCluster<T>> resultSet = new ArrayList<>();
+
+        // Choose one center uniformly at random from among the data points.
+        final int firstPointIndex = random.nextInt(numPoints);
+
+        final T firstPoint = pointList.get(firstPointIndex);
+
+        resultSet.add(new CentroidCluster<T>(firstPoint));
+
+        // Must mark it as taken
+        taken[firstPointIndex] = true;
+
+        // To keep track of the minimum distance squared of elements of
+        // pointList to elements of resultSet.
+        final double[] minDistSquared = new double[numPoints];
+
+        // Initialize the elements.  Since the only point in resultSet is firstPoint,
+        // this is very easy.
+        for (int i = 0; i < numPoints; i++) {
+            if (i != firstPointIndex) { // That point isn't considered
+                double d = distance(firstPoint, pointList.get(i));
+                minDistSquared[i] = d * d;
+            }
+        }
+
+        while (resultSet.size() < k) {
+
+            // Sum up the squared distances for the points in pointList not
+            // already taken.
+            double distSqSum = 0.0;
+
+            for (int i = 0; i < numPoints; i++) {
+                if (!taken[i]) {
+                    distSqSum += minDistSquared[i];
+                }
+            }
+
+            // Add one new data point as a center. Each point x is chosen with
+            // probability proportional to D(x)2
+            final double r = random.nextDouble() * distSqSum;
+
+            // The index of the next point to be added to the resultSet.
+            int nextPointIndex = -1;
+
+            // Sum through the squared min distances again, stopping when
+            // sum >= r.
+            double sum = 0.0;
+            for (int i = 0; i < numPoints; i++) {
+                if (!taken[i]) {
+                    sum += minDistSquared[i];
+                    if (sum >= r) {
+                        nextPointIndex = i;
+                        break;
+                    }
+                }
+            }
+
+            // If it's not set to >= 0, the point wasn't found in the previous
+            // for loop, probably because distances are extremely small.  Just pick
+            // the last available point.
+            if (nextPointIndex == -1) {
+                for (int i = numPoints - 1; i >= 0; i--) {
+                    if (!taken[i]) {
+                        nextPointIndex = i;
+                        break;
+                    }
+                }
+            }
+
+            // We found one.
+            if (nextPointIndex >= 0) {
+
+                final T p = pointList.get(nextPointIndex);
+
+                resultSet.add(new CentroidCluster<T>(p));
+
+                // Mark it as taken.
+                taken[nextPointIndex] = true;
+
+                if (resultSet.size() < k) {
+                    // Now update elements of minDistSquared.  We only have to compute
+                    // the distance to the new center to do this.
+                    for (int j = 0; j < numPoints; j++) {
+                        // Only have to worry about the points still not taken.
+                        if (!taken[j]) {
+                            double d = distance(p, pointList.get(j));
+                            double d2 = d * d;
+                            if (d2 < minDistSquared[j]) {
+                                minDistSquared[j] = d2;
+                            }
+                        }
+                    }
+                }
+
+            } else {
+                // None found --
+                // Break from the while loop to prevent
+                // an infinite loop.
+                break;
+            }
+        }
+
+        return resultSet;
+    }
+
+    /**
+     * Calculates the distance between two {@link Clusterable} instances
+     * with the configured {@link DistanceMeasure}.
+     *
+     * @param p1 the first clusterable
+     * @param p2 the second clusterable
+     * @return the distance between the two clusterables
+     */
+    protected double distance(final Clusterable p1, final Clusterable p2) {
+        return measure.compute(p1.getPoint(), p2.getPoint());
+    }
+}
diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java
new file mode 100644
index 0000000000..723876711b
--- /dev/null
+++ b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java
@@ -0,0 +1,44 @@
+package org.apache.commons.math4.ml.clustering.initialization;
+
+import org.apache.commons.math4.ml.clustering.CentroidCluster;
+import org.apache.commons.math4.ml.clustering.Clusterable;
+import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.sampling.ListSampler;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+/**
+ * Random choose the initial centers.
+ */
+public class RandomCentroidInitializer implements CentroidInitializer {
+    private final UniformRandomProvider random;
+
+    /**
+     * Build a random RandomCentroidInitializer
+     *
+     * @param random the random to use.
+     */
+    public RandomCentroidInitializer(final UniformRandomProvider random) {
+        this.random = random;
+    }
+
+    /**
+     * Random choose the initial centers.
+     *
+     * @param points the points to choose the initial centers from
+     * @param k      The number of clusters
+     * @return the initial centers
+     */
+    @Override
+    public <T extends Clusterable> List<CentroidCluster<T>> selectCentroids(Collection<T> points, int k) {
+        ArrayList<T> list = new ArrayList<T>(points);
+        ListSampler.shuffle(random, list);
+        List<CentroidCluster<T>> result = new ArrayList<>(k);
+        for (int i = 0; i < k; i++) {
+            result.add(new CentroidCluster<>(list.get(i)));
+        }
+        return result;
+    }
+}
diff --git a/src/test/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClustererTest.java b/src/test/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClustererTest.java
new file mode 100644
index 0000000000..980a62cec0
--- /dev/null
+++ b/src/test/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClustererTest.java
@@ -0,0 +1,76 @@
+package org.apache.commons.math4.ml.clustering;
+
+import org.apache.commons.math4.ml.clustering.evaluation.ClusterEvaluator;
+import org.apache.commons.math4.ml.clustering.evaluation.SumOfClusterVariances;
+import org.apache.commons.math4.ml.distance.DistanceMeasure;
+import org.apache.commons.math4.ml.distance.EuclideanDistance;
+import org.apache.commons.rng.simple.RandomSource;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+
+public class MiniBatchKMeansClustererTest {
+    private DistanceMeasure measure = new EuclideanDistance();
+
+    /**
+     * Compare the result to KMeansPlusPlusClusterer
+     */
+    @Test
+    public void testCompareToKMeans() {
+        //Generate 4 cluster
+        int randomSeed = 0;
+        List<DoublePoint> data = generateCircles(randomSeed);
+        KMeansPlusPlusClusterer<DoublePoint> kMeans = new KMeansPlusPlusClusterer<>(4, -1, measure,
+                RandomSource.create(RandomSource.MT_64, randomSeed));
+        MiniBatchKMeansClusterer<DoublePoint> miniBatchKMeans = new MiniBatchKMeansClusterer<>(4, -1,
+                measure, RandomSource.create(RandomSource.MT_64, randomSeed));
+        for (int i = 0; i < 100; i++) {
+            List<CentroidCluster<DoublePoint>> kMeansClusters = kMeans.cluster(data);
+            List<CentroidCluster<DoublePoint>> miniBatchKMeansClusters = miniBatchKMeans.cluster(data);
+            Assert.assertEquals(4, kMeansClusters.size());
+            Assert.assertEquals(kMeansClusters.size(), miniBatchKMeansClusters.size());
+            int totalDiffCount = 0;
+            for (CentroidCluster<DoublePoint> kMeanCluster : kMeansClusters) {
+                CentroidCluster<DoublePoint> miniBatchCluster = ClusterUtils.predict(miniBatchKMeansClusters, kMeanCluster.getCenter());
+                totalDiffCount += Math.abs(kMeanCluster.getPoints().size() - miniBatchCluster.getPoints().size());
+            }
+            ClusterEvaluator<DoublePoint> clusterEvaluator = new SumOfClusterVariances<>(measure);
+            double kMeansScore = clusterEvaluator.score(kMeansClusters);
+            double miniBatchKMeansScore = clusterEvaluator.score(miniBatchKMeansClusters);
+            double diffPointsRatio = totalDiffCount * 1.0 / data.size();
+            double scoreDiffRatio = (miniBatchKMeansScore - kMeansScore) /
+                    kMeansScore;
+            // MiniBatchKMeansClusterer has few score differences between KMeansClusterer
+            Assert.assertTrue(String.format("Different score ratio %f%%!, diff points ratio: %f%%\"", scoreDiffRatio * 100, diffPointsRatio * 100),
+                    scoreDiffRatio < 0.1);
+        }
+    }
+
+    private List<DoublePoint> generateCircles(int randomSeed) {
+        List<DoublePoint> data = new ArrayList<>();
+        Random random = new Random(randomSeed);
+        data.addAll(generateCircle(250, new double[]{-1.0, -1.0}, 1.0, random));
+        data.addAll(generateCircle(260, new double[]{0.0, 0.0}, 0.7, random));
+        data.addAll(generateCircle(270, new double[]{1.0, 1.0}, 0.7, random));
+        data.addAll(generateCircle(280, new double[]{2.0, 2.0}, 0.7, random));
+        return data;
+    }
+
+    List<DoublePoint> generateCircle(int count, double[] center, double radius, Random random) {
+        double x0 = center[0];
+        double y0 = center[1];
+        ArrayList<DoublePoint> list = new ArrayList<DoublePoint>(count);
+        for (int i = 0; i < count; i++) {
+            double ao = random.nextDouble() * 720 - 360;
+            double r = random.nextDouble() * radius * 2 - radius;
+            double x1 = x0 + r * Math.cos(ao * Math.PI / 180);
+            double y1 = y0 + r * Math.sin(ao * Math.PI / 180);
+            list.add(new DoublePoint(new double[]{x1, y1}));
+        }
+        return list;
+    }
+
+}