A fast and efficient k-means implementation for PyTorch, with support for GPU and CPU.
Welcome to fastkmeans
! This is an extremely tiny library, meant to be slotted-in anywhere you need "fast-enough" PyTorch native k-means clustering. It's compatible with any PyTorch-compatible CPU or GPU, matching or outperforming faiss
by ~4-5× on a single GPU, and is without install woes, relying on just two dependencies you already have installed: torch
and numpy
.
[uv] pip install fastkmeans
... and that's all you need to do! FastKMeans
is now ready to use.
There's very, very little to this library. It provides a single interface, FastKMeans
, which you can use by importing it from fastkmeans
. This interface has been designed to be a slot-in replacement for existing FAISS implementations, while being mostly sklearn-compliant as well. Effectively, this means that three methods are exposed:
- train(): mimics the FAISS API, training the model on a dataset.
- fit(): mimics the sklearn API, training the model on a dataset.
- predict(): mimics the sklearn API, use the trained clusters to predict where new points belong.
- fit_predict(): mimics the sklearn API, chaining the two calls above.
Whenever possible, the library attempts to mimic the FAISS API, albeit with a bit more flexibility. We encourage you to check out the API docstring to see what the arguments are, as they are straightforward.
The default behaviour of the library mostly follows faiss's, including downsampling data to a maximum of 256 points per centroid to speed up calculations, which can be freely modified and/or disabled. The only major difference is that, by default, the library does adopt an early stopping mechanism based on a tol
parameter, which stops the algorithm when the centroids don't move more than tol
between iterations. This is unlike faiss', whose default behaviour is to run for a fixed number of iterations no matter what -- you can restore this behaviour by setting tol
to -1.
The algorith is implemented with a double-chunking logics, where both the data points and the centroids are split into moderately-sized chunks, avoiding the risks of OOMs. The defaults allow you to cluster 26_214_400 128-dimensional points into 262_144 clusters with ~11GB memory usage (including storing the data in fp32). You can check out the available arguments here to see how to tweak these. As a rule of thumb, increasing chunk sizes will speed up computations, at the cost of memory usage, and decreasing it will have the reverse effect.
Note: See the Triton section below for triton specific chunking details.
The main motivation behind fastkmeans
is having a considerably easier way to package late-interaction models, such as ColBERT, in both its Stanford implementation, its PyLate implentation, and for the RAGatouille high-level API. The use of clustering for ColBERT is perhaps somewhat peculiar, as it relies on large numbers of clusters for relatively few data points (~100 per cluster centre). This has been a major problem in getting ColBERT to be more usable, as the existing alternatives, while great in their own merit, have flaws for this particular use:
faiss
is highly-optimized and is the original library used by the ColBERT authors and most implementations nowadays. However, it is rather difficult to install as there are no "perfect" PyPi wheels, with many segfault issues reported, as the official install is only supported via conda or from source. It can also be finnicky, causing issues with PyTorch if not installed via conda too. Finally, and this is a major problem for maintaining libraries such as ragatouille: it requires different packages and different install methods for itscpu
andgpu
variant, meaning additional friction for users and the inability to provide a nice default.fast-pytorch-kmeans
is a great library which provides lightning fast kmeans implementation in PyTorch. However, it relies on highly vectorized operations which are exceedingly memory hungry, and consistently OOMs on consumer hardware when trying to index even a moderate number of colbert documents (or produces suboptimal clusters with minibatching).scikit-learn
, while being the ML giant whose shoulders we all stand on, only supports CPU. This becomes unusably slow when indexing larger volumes of documents, especially as there'll (almost) always be a GPU available in these situations.
There are some libraries (such as NVidia's own implementations), but they again require more dependencies than we'd like for nimble packaging, and/or are less flexible in terms of hardware.
- On a few toy datasets & MNIST,
fastkmeans
reaches roughly the same NMI asfaiss
andscikit-learn
, indicating that it creates at least somewhat coherent clusters. However, it's not extensively tested, especially in non-ColBERT uses, so your mileage may vary. - The "chunking" defaults to avoid OOMs is rather simplistic, and you might need to tweak the numbers depending on your hardware and dataset size.
- The library currently assumes you'll be using it either on a CPU or a single GPU. Multiple GPUs don't currently provide a major speed-up, this might change in the future, though we expect users wanting to index 10M+ documents to likely have the more robust
faiss
available on their machine anyway.
fastkmeans
's Triton kmeans kernel is ~4-5 times faster than single-GPU faiss
or fastkmeans
's PyTorch backend. On a modern GPU (Ampere or newer), the Triton backend is enabled by default.
While the Triton kernel uses significantly less memory than the PyTorch implementation, increasing the chunk size above 512K can result in slower performance.
Below is fastkmeans
benchmarked against faiss
on a single RTX 4090 GPU, with 128-dimensional data points at various data scales that will commonly be used in ColBERT-style indexing (8192, 16384, 32768, 65536, and 131072 clusters, each with w/ cluster_size*100 data points).
To benchmark fastkmeans
against faiss
on your own machine, install faiss and PyTorch 2.5 via the bench_env.yaml
Conda environment:
conda env create -f bench_env.yaml
conda activate fastkmeans
pip install fastkmeans
Then, run the benchmark script:
CUDA_VISIBLE_DEVICES=0 python speedbench.py --do-faiss --do-fastkmeans --do-fastkmeans-triton --do-evals
If you use fastmeans and want to/need to cite it in your work, please feel free to use the citation below:
@misc{fastkmeans2025,
author = {Benjamin Clavié and Benjamin Warner},
title = {fastkmeans: Accelerated KMeans Clustering in PyTorch and Triton},
year = {2025},
howpublished = {\url{https://github.com/AnswerDotAI/fastkmeans/}}
}