A library for training crosscoders, and by extension, (cross-layer, skip) transcoders, SAEs, and other sparse coding models.
- Training a multi-layer / multi-model crosscoder
- With Vanilla L1 loss as in the original crosscoder paper: here
- With TopK/BatchTopK/GroupMax: here
- With JumpReLU according to Anthropic's January 2025 update: here
- According to Anthropic's February 2025 model diffing update
- (Any of these can also be an SAE by running on one LLM only and setting
hookpoints
to a single layer)
- Training a cross-layer (skip) transcoder
- crosscoding dimensions: the dimensions over which the crosscoder is applied.
- e.g. in a cross-layer crosscoder, the crosscoding dimensions are
(layers,)
- e.g. in a cross-model crosscoder, the crosscoding dimensions are
(models,)
- e.g. in a cross-model, cross-layer crosscoder, the crosscoding dimensions are
(models, layers)
- e.g. in a cross-layer crosscoder, the crosscoding dimensions are
- hookpoints: the hookpoints at which activations are harvested.
- latents: the hidden activations of the crosscoder/transcoder.
- topk-style: blanket term for TopK, BatchTopK, and GroupMax activations. Lumped together as they are all trained the same way.
- Jan Update: the "January 2025 update" describing a specific jumprelu loss.
- Feb Update: the "February 2025 model diffing update" describing a technique for improving model-diffing crosscoder training with shared latents.
All sparse coding models are abstracted over activation functions, and losses are handled by trainers. This is nice because different training schemes are usually a combination of (activation function, loss function, some hyperparameters, some scheduling) and this way we put all of that in the trainer class in a type-safe way.
This library makes extensive use of "shape suffixes" and einops to attempt to make the quite complex and varying tensor shapes a bit more manageable. The most common shapes are:
B
: BatchM
: ModelsP
: hookPoints (for example, different layers of the residual stream)L
: Latents (aka "features")D
: model Dimension (akad_model
)X
: an arbitrary number of crosscoding dimensions (usually 0, 1, or 2 of them, such as(n_models, n_layers)
)
Shape suffixes should be interpretted in PascalCase, where lowercase denotes a more specific version of a shape. For example, here we have _W_enc_XiDiL
which means shape ("Input Xrosscoding dims, Input D_model, Latents).
We currently only have one dataloader type, and handle the reshaping of activations for a given model / training scheme in the trainer classes. Once again trying to keep most of the complexity in the same place.
We currently harvest activations from the LLM(s) at training time. You can cache activations to disk to avoid re-harvesting them in subsequent runs. This, however is probably the least-developed part of the library.
The library is structured roughly as follows:
Models (crosscode.models
)
BaseCrosscoder
: Generic base class for all crosscoding models. It's allowed to have different input and output crosscoding dimensions andd_model
s, and meant to be subclassed in a way that concretifies the dimensions.- For example,
CrossLayerTranscoder
is a subclass of BaseCrosscoder that concretifies the input crosscoding dimensions to be(),
and the output dimensions to be(n_layers,)
.
- For example,
ModelHookpointAcausalCrosscoder
: An acausal crosscoder that can be applied across multiple models / layers.- with
n_layers = 1
andn_models=2
, it's a normal model diffing crosscoder. - with
n_layers > 1
andn_models=1
, it's a normal cross-layer acausal transcoder. - with
n_layers > 1
andn_models > 1
, it's a cross-layer, cross-model, acausal crosscoder (???).
- with
CrossLayerTranscoder
: A cross-layer acausal transcoderCompoundCrossLayerTranscoder
: A wrapper around a list of CrossLayerTranscoder that applies them in parallel, as described in the "Circuit Tracing" paper.
Activations (crosscode.models.activations
)
A collection of activation functions that can be used with the model classes.
Initialization (crosscode.models.initialization
)
A collection of InitStrategy
s for initializing crosscoder weights.
InitStrategy
: A base class for all initialization strategies.AnthropicTransposeInit
: Initializes the weights of a ModelHookpointAcausalCrosscoder using the "Anthropic Transpose" method.IdenticalLatentsInit
: Initializes the weights of a ModelHookpointAcausalCrosscoder such that the firstn_shared_latents
are identical for all models.JanUpdateInit
: Initializes the weights of a ModelHookpointAcausalCrosscoder using the method described in the "January 2025 update" paper.- Theres's some other random initialization strategies in here that are more speculative.
Trainers (crosscode.trainers
)
(The trainers make extensive use of Inheritance which I really don't like. I might refactor this to use composition instead)
BaseTrainer
: Training boilerplate. Gradient accumulation, logging, optimizer, lr scheduler, etc.BaseCrossLayerTranscoderTrainer
: A trainer for CrossLayerTranscoder models, handles dataloading, splitting activations into input and output layers.L1CrossLayerTranscoderTrainer
: Trains a CrossLayerTranscoder with L1 loss.TopkCrossLayerTranscoderTrainer
: Trains a CrossLayerTranscoder with TopK loss.
BaseModelHookpointAcausalTrainer
: A trainer for ModelHookpointAcausalCrosscoder models.TopkStyleAcausalCrosscoderTrainer
: A trainer for TopK style models.BaseFebUpdateDiffingTrainer
: an extension of BaseModelHookpointAcausalTrainer that implements shared latents as in the "February 2025 model diffing update"
Data (crosscode.data
)
data loading via harvesting LLM activations on text.
ActivationsDataloader
: Dataloader for activations. Supports harvesting for multiple models and multiple hookpoints.TokenSequenceLoader
: Used by ActivationsDataloader to load sequences of tokens from huggingface and chunk them into batches for activations harvesting. Can shuffle across sequences.ActivationHarvester
: Used by ActivationsDataloader to harvest LLM activations on those sequences.ActivationCache
: Used by ActivationsHarvester to cache activations to disk (if enabled).