Skip to content

oli-clive-griffin/crosscode

Repository files navigation

Crosscode

A library for training crosscoders, and by extension, (cross-layer, skip) transcoders, SAEs, and other sparse coding models.

Examples:

  • Training a multi-layer / multi-model crosscoder
    • With Vanilla L1 loss as in the original crosscoder paper: here
    • With TopK/BatchTopK/GroupMax: here
    • With JumpReLU according to Anthropic's January 2025 update: here
    • According to Anthropic's February 2025 model diffing update
      • With JumpReLU (as in the paper): here
      • With TopK/BatchTopK/GroupMax: here
    • (Any of these can also be an SAE by running on one LLM only and setting hookpoints to a single layer)
  • Training a cross-layer (skip) transcoder
    • With L1 loss: here
    • With TopK/BatchTopK/GroupMax: here

Key terms:

  • crosscoding dimensions: the dimensions over which the crosscoder is applied.
    • e.g. in a cross-layer crosscoder, the crosscoding dimensions are (layers,)
    • e.g. in a cross-model crosscoder, the crosscoding dimensions are (models,)
    • e.g. in a cross-model, cross-layer crosscoder, the crosscoding dimensions are (models, layers)
  • hookpoints: the hookpoints at which activations are harvested.
  • latents: the hidden activations of the crosscoder/transcoder.
  • topk-style: blanket term for TopK, BatchTopK, and GroupMax activations. Lumped together as they are all trained the same way.
  • Jan Update: the "January 2025 update" describing a specific jumprelu loss.
  • Feb Update: the "February 2025 model diffing update" describing a technique for improving model-diffing crosscoder training with shared latents.

Conventions:

Models, Trainers, Loss Functions.

All sparse coding models are abstracted over activation functions, and losses are handled by trainers. This is nice because different training schemes are usually a combination of (activation function, loss function, some hyperparameters, some scheduling) and this way we put all of that in the trainer class in a type-safe way.

Tensor Shapes

This library makes extensive use of "shape suffixes" and einops to attempt to make the quite complex and varying tensor shapes a bit more manageable. The most common shapes are:

  • B: Batch
  • M: Models
  • P: hookPoints (for example, different layers of the residual stream)
  • L: Latents (aka "features")
  • D: model Dimension (aka d_model)
  • X: an arbitrary number of crosscoding dimensions (usually 0, 1, or 2 of them, such as (n_models, n_layers))

Shape suffixes should be interpretted in PascalCase, where lowercase denotes a more specific version of a shape. For example, here we have _W_enc_XiDiL which means shape ("Input Xrosscoding dims, Input D_model, Latents).

Dataloading

We currently only have one dataloader type, and handle the reshaping of activations for a given model / training scheme in the trainer classes. Once again trying to keep most of the complexity in the same place.

We currently harvest activations from the LLM(s) at training time. You can cache activations to disk to avoid re-harvesting them in subsequent runs. This, however is probably the least-developed part of the library.

Structure

The library is structured roughly as follows:

  • BaseCrosscoder: Generic base class for all crosscoding models. It's allowed to have different input and output crosscoding dimensions and d_models, and meant to be subclassed in a way that concretifies the dimensions.
    • For example, CrossLayerTranscoder is a subclass of BaseCrosscoder that concretifies the input crosscoding dimensions to be (), and the output dimensions to be (n_layers,).
  • ModelHookpointAcausalCrosscoder: An acausal crosscoder that can be applied across multiple models / layers.
    • with n_layers = 1 and n_models=2, it's a normal model diffing crosscoder.
    • with n_layers > 1 and n_models=1, it's a normal cross-layer acausal transcoder.
    • with n_layers > 1 and n_models > 1, it's a cross-layer, cross-model, acausal crosscoder (???).
  • CrossLayerTranscoder: A cross-layer acausal transcoder

A collection of activation functions that can be used with the model classes.

A collection of InitStrategys for initializing crosscoder weights.

  • InitStrategy: A base class for all initialization strategies.
  • AnthropicTransposeInit: Initializes the weights of a ModelHookpointAcausalCrosscoder using the "Anthropic Transpose" method.
  • IdenticalLatentsInit: Initializes the weights of a ModelHookpointAcausalCrosscoder such that the first n_shared_latents are identical for all models.
  • JanUpdateInit: Initializes the weights of a ModelHookpointAcausalCrosscoder using the method described in the "January 2025 update" paper.
  • Theres's some other random initialization strategies in here that are more speculative.

(The trainers make extensive use of Inheritance which I really don't like. I might refactor this to use composition instead)

data loading via harvesting LLM activations on text.

  • ActivationsDataloader: Dataloader for activations. Supports harvesting for multiple models and multiple hookpoints.
    • TokenSequenceLoader: Used by ActivationsDataloader to load sequences of tokens from huggingface and chunk them into batches for activations harvesting. Can shuffle across sequences.
    • ActivationHarvester: Used by ActivationsDataloader to harvest LLM activations on those sequences.
      • ActivationCache: Used by ActivationsHarvester to cache activations to disk (if enabled).

About

A library for training crosscoders

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •