diff --git a/benchmarks/microbenchmarks/README.md b/benchmarks/microbenchmarks/README.md index a95dc53755..d65b295645 100644 --- a/benchmarks/microbenchmarks/README.md +++ b/benchmarks/microbenchmarks/README.md @@ -63,7 +63,15 @@ Currently, quantization string is in same format as the one being passed in llam ### Model Types - `linear`: Simple linear layer -- `ln_linear_sigmoid`: LayerNorm + Linear + Sigmoid +- `ln_linear_`: LayerNorm + Linear + Activation, where activation can be: + - `ln_linear_sigmoid`: LayerNorm + Linear + Sigmoid + - `ln_linear_relu`: LayerNorm + Linear + ReLU + - `ln_linear_leakyrelu`: LayerNorm + Linear + LeakyReLU + - `ln_linear_relu6`: LayerNorm + Linear + ReLU6 + - `ln_linear_gelu`: LayerNorm + Linear + GELU + - `ln_linear_silu`: LayerNorm + Linear + SiLU + - `ln_linear_hardswish`: LayerNorm + Linear + Hardswish +- `transformer_block`: Transformer block with self-attention and MLP ### Device Options - `cuda`: NVIDIA GPU @@ -71,6 +79,58 @@ Currently, quantization string is in same format as the one being passed in llam - `mps`: Apple Silicon GPU - `cpu`: CPU fallback +### Shape Generation Options +- `custom`: Manually specify shapes as a list of [m, k, n] dimensions + ```yaml + matrix_shapes: + - name: "custom" + shapes: [ + [1024, 1024, 1024], # [m, k, n] + [2048, 4096, 1024] + ] + ``` + +- `llama`: Use LLaMa 2 70B single-node weight shapes (assumes fused attn.wqkv and ffn.w13) + - Generates shapes for: "attn.wqkv", "attn.w0", "ffn.w13", "ffn.w2" + ```yaml + matrix_shapes: + - name: "llama" + ``` + +- `pow2`: Generate shapes with dimensions that are powers of 2 + - Parameters: + - `min_power`: Minimum power of 2 (default: 10, which is 1024) + - `max_power`: Maximum power of 2 (default: 14, which is 16,384) + ```yaml + matrix_shapes: + - name: "pow2" + min_power: 10 # 2^10 = 1024 + max_power: 12 # 2^12 = 4096 + ``` + +- `pow2_extended`: Generate shapes with dimensions that are powers of 2 and powers of 2 + half + - Parameters: + - `min_power`: Minimum power of 2 (default: 10, which is 1024) + - `max_power`: Maximum power of 2 (default: 14, which is 16,384) + ```yaml + matrix_shapes: + - name: "pow2_extended" + min_power: 10 # Generates: 1024, 1536, 2048, 3072, etc. + max_power: 11 + ``` + +- `sweep`: Generate a sweep of shapes with different powers of 2 for M, K, N dimensions + - Parameters: + - `min_power`: Minimum power of 2 (default: 8, which is 256) + - `max_power`: Maximum power of 2 (default: 15, which is 32,768) + - Note: This generates all combinations of M, K, N dimensions, which can be a large number of shapes + ```yaml + matrix_shapes: + - name: "sweep" + min_power: 8 # 2^8 = 256 + max_power: 9 # 2^9 = 512 + ``` + ## Output Results are saved to a CSV file in the specified output directory diff --git a/benchmarks/microbenchmarks/benchmark_inference.py b/benchmarks/microbenchmarks/benchmark_inference.py index 3af0ceb57b..b7a8e8d7c4 100644 --- a/benchmarks/microbenchmarks/benchmark_inference.py +++ b/benchmarks/microbenchmarks/benchmark_inference.py @@ -22,12 +22,14 @@ BenchmarkConfig, BenchmarkResult, clean_caches, - create_model_and_input, model_inference_time_in_ms, string_to_config, ) from torchao.quantization import quantize_ from torchao.sparsity.sparse_api import sparsify_ +from torchao.testing.model_architectures import ( + create_model_and_input_data, +) def run(config: BenchmarkConfig) -> BenchmarkResult: @@ -38,7 +40,7 @@ def run(config: BenchmarkConfig) -> BenchmarkResult: # Create output directory if it doesn't exist Path(config.output_dir).mkdir(parents=True, exist_ok=True) - base_model, input_data = create_model_and_input( + base_model, input_data = create_model_and_input_data( config.model_type, config.m, config.k, diff --git a/benchmarks/microbenchmarks/benchmark_runner.py b/benchmarks/microbenchmarks/benchmark_runner.py index e38fc93819..fbd7f08388 100644 --- a/benchmarks/microbenchmarks/benchmark_runner.py +++ b/benchmarks/microbenchmarks/benchmark_runner.py @@ -48,9 +48,50 @@ def get_shapes_for_config( name = shape_config["name"] if name == "custom": shapes.extend([(name, shape) for shape in shape_config["shapes"]]) + elif name == "llama": + # LLaMa 2 70B single-node weight shapes + # assumes fused attn.wqkv and ffn.w13 + bsz, seq_len = 4, 4096 + M = bsz * seq_len + llama_shapes = { + "attn.wqkv": (M, 8192, 1280), + "attn.w0": (M, 1024, 8192), + "ffn.w13": (M, 8192, 7168), + "ffn.w2": (M, 3584, 8192), + } + shapes.extend([(f"{name}_{k}", v) for k, v in llama_shapes.items()]) + elif name == "pow2": + # Generate shapes with dimensions that are powers of 2 + min_power_of_2 = shape_config.get("min_power", 10) # 1024 + max_power_of_2 = shape_config.get("max_power", 14) # 16,384 + for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)): + val = 2**power_of_2 + shapes.append((f"{name}_{idx}", [val, val, val])) + elif name == "pow2_extended": + # Generate shapes with dimensions that are powers of 2 and powers of 2 + half + min_power_of_2 = shape_config.get("min_power", 10) # 1024 + max_power_of_2 = shape_config.get("max_power", 14) # 16,384 + for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)): + val1 = 2**power_of_2 + val2 = 2**power_of_2 + 2 ** (power_of_2 - 1) + shapes.append((f"{name}_{idx * 2}", [val1, val1, val1])) + shapes.append((f"{name}_{idx * 2 + 1}", [val2, val2, val2])) + elif name == "sweep": + # Generate a sweep of shapes with different powers of 2 for M, K, N + min_p2 = shape_config.get("min_power", 8) # 256 + max_p2 = shape_config.get("max_power", 15) # 32,768 + counter = 0 + for M_p2 in range(min_p2, max_p2 + 1): + M = 2**M_p2 + for K_p2 in range(min_p2, max_p2 + 1): + K = 2**K_p2 + for N_p2 in range(min_p2, max_p2 + 1): + N = 2**N_p2 + shapes.append((f"{name}_{counter}", [M, K, N])) + counter += 1 else: raise NotImplementedError( - f"Shape config {name} not supported. Currently only supports custom shapes." + f"Shape config {name} not supported. Supported options: custom, llama, pow2, pow2_extended, sweep." ) return shapes diff --git a/benchmarks/microbenchmarks/test/benchmark_config.yml b/benchmarks/microbenchmarks/test/benchmark_config.yml index 5ea3f5d642..2fc0433c36 100644 --- a/benchmarks/microbenchmarks/test/benchmark_config.yml +++ b/benchmarks/microbenchmarks/test/benchmark_config.yml @@ -26,3 +26,48 @@ model_params: device: "cuda" model_type: "linear" enable_profiler: true # Enable profiling for this model + + - name: "ln_linear_sigmoid_cuda" + matrix_shapes: + - name: "custom" + shapes: [ + [2048, 4096, 1024], + ] + high_precision_dtype: "torch.bfloat16" + use_torch_compile: true + torch_compile_mode: "max-autotune" + device: "cuda" + model_type: "ln_linear_sigmoid" + enable_profiler: true + + - name: "bf16_transformer_block" + matrix_shapes: + - name: "custom" + shapes: [ + [2048, 4096, 1024], # For transformer_block, k is the hidden dimension + ] + high_precision_dtype: "torch.bfloat16" + use_torch_compile: true + torch_compile_mode: "max-autotune" + device: "cuda" + model_type: "transformer_block" # TODO: Add a custom model (Figure out how to do this, maybe pass a .py file with model definition) + enable_profiler: true + + - name: "large_bf16_ln_linear" + matrix_shapes: + - name: "llama" # Example of using LLaMa shapes + - name: "pow2" # Example of using power of 2 shapes + min_power: 10 # 1024 + max_power: 12 # 4096 + - name: "pow2_extended" # Example of using extended power of 2 shapes + min_power: 10 # 1024 + max_power: 11 # 2048 + - name: "sweep" # Example of using sweep shapes (commented out as it generates many shapes) + min_power: 8 # 256 + max_power: 9 # 512 + high_precision_dtype: "torch.bfloat16" + use_torch_compile: true + torch_compile_mode: "max-autotune" + device: "cuda" + model_type: "linear" + enable_profiler: true # Enable profiling for this model diff --git a/benchmarks/microbenchmarks/test/test_benchmark_profiler.py b/benchmarks/microbenchmarks/test/test_benchmark_profiler.py index 0e398b4899..e3971b5986 100644 --- a/benchmarks/microbenchmarks/test/test_benchmark_profiler.py +++ b/benchmarks/microbenchmarks/test/test_benchmark_profiler.py @@ -15,8 +15,8 @@ ) from benchmarks.microbenchmarks.utils import ( BenchmarkConfig, - ToyLinearModel, ) +from torchao.testing.model_architectures import ToyLinearModel class TestBenchmarkProfiler(unittest.TestCase): diff --git a/benchmarks/microbenchmarks/test/test_benchmark_runner.py b/benchmarks/microbenchmarks/test/test_benchmark_runner.py index a8683a1de8..7f93213a22 100644 --- a/benchmarks/microbenchmarks/test/test_benchmark_runner.py +++ b/benchmarks/microbenchmarks/test/test_benchmark_runner.py @@ -57,12 +57,72 @@ def tearDown(self): shutil.rmtree(self.temp_dir) def test_get_shapes_for_config(self): + # Test custom shapes shapes = get_shapes_for_config( self.test_config["model_params"][0]["matrix_shapes"] ) self.assertEqual(len(shapes), 1) self.assertEqual(shapes[0], ("custom", [1024, 1024, 1024])) + # Test llama shapes + llama_shapes = get_shapes_for_config([{"name": "llama"}]) + self.assertEqual(len(llama_shapes), 4) # 4 LLaMa shapes + self.assertTrue( + any(name.startswith("llama_attn.wqkv") for name, _ in llama_shapes) + ) + self.assertTrue( + any(name.startswith("llama_attn.w0") for name, _ in llama_shapes) + ) + self.assertTrue( + any(name.startswith("llama_ffn.w13") for name, _ in llama_shapes) + ) + self.assertTrue( + any(name.startswith("llama_ffn.w2") for name, _ in llama_shapes) + ) + + # Test pow2 shapes + pow2_shapes = get_shapes_for_config( + [{"name": "pow2", "min_power": 10, "max_power": 12}] + ) + self.assertEqual(len(pow2_shapes), 3) # 3 powers of 2 (10, 11, 12) + self.assertEqual(pow2_shapes[0], ("pow2_0", [1024, 1024, 1024])) # 2^10 + self.assertEqual(pow2_shapes[1], ("pow2_1", [2048, 2048, 2048])) # 2^11 + self.assertEqual(pow2_shapes[2], ("pow2_2", [4096, 4096, 4096])) # 2^12 + + # Test pow2_extended shapes + pow2_extended_shapes = get_shapes_for_config( + [{"name": "pow2_extended", "min_power": 10, "max_power": 11}] + ) + self.assertEqual( + len(pow2_extended_shapes), 4 + ) # 2 powers of 2, each with 2 variants + self.assertEqual( + pow2_extended_shapes[0], ("pow2_extended_0", [1024, 1024, 1024]) + ) # 2^10 + self.assertEqual( + pow2_extended_shapes[1], ("pow2_extended_1", [1536, 1536, 1536]) + ) # 2^10 + 2^9 + self.assertEqual( + pow2_extended_shapes[2], ("pow2_extended_2", [2048, 2048, 2048]) + ) # 2^11 + self.assertEqual( + pow2_extended_shapes[3], ("pow2_extended_3", [3072, 3072, 3072]) + ) # 2^11 + 2^10 + + # Test sweep shapes (limited to a small range for testing) + sweep_shapes = get_shapes_for_config( + [{"name": "sweep", "min_power": 8, "max_power": 9}] + ) + # For min_power=8, max_power=9, we should have 8 shapes (2^3 = 8 combinations) + self.assertEqual(len(sweep_shapes), 8) + # Check that all shapes have the expected format + for name, shape in sweep_shapes: + self.assertTrue(name.startswith("sweep_")) + self.assertEqual(len(shape), 3) # [M, K, N] + # Check that all dimensions are powers of 2 between 2^8 and 2^9 + for dim in shape: + self.assertTrue(dim in [256, 512]) # 2^8, 2^9 + def test_get_param_combinations(self): model_param = self.test_config["model_params"][0] shapes, params = get_param_combinations(model_param) diff --git a/benchmarks/microbenchmarks/test/test_utils.py b/benchmarks/microbenchmarks/test/test_utils.py index 14f226bd7e..bb721e9e03 100644 --- a/benchmarks/microbenchmarks/test/test_utils.py +++ b/benchmarks/microbenchmarks/test/test_utils.py @@ -16,15 +16,17 @@ BlockSparseWeightConfig, Float8DynamicActivationFloat8SemiSparseWeightConfig, Int4WeightOnlyConfig, - LNLinearSigmoid, SemiSparseWeightConfig, - ToyLinearModel, clean_caches, - create_model_and_input, generate_results_csv, get_default_device, string_to_config, ) +from torchao.testing.model_architectures import ( + LNLinearActivationModel, + ToyLinearModel, + create_model_and_input_data, +) class TestUtils(unittest.TestCase): @@ -153,7 +155,7 @@ def test_toy_linear_model(self): self.assertEqual(out.dtype, torch.float32) def test_ln_linear_sigmoid(self): - model = LNLinearSigmoid(fc_dim1=64, fc_dim2=32, dtype=torch.float32) + model = LNLinearActivationModel(fc_dim1=64, fc_dim2=32, dtype=torch.float32) x = torch.randn(16, 64) out = model(x) self.assertEqual(out.shape, (16, 32)) @@ -162,9 +164,9 @@ def test_ln_linear_sigmoid(self): torch.all((out >= 0) & (out <= 1)) ) # Check sigmoid output range - def test_create_model_and_input(self): + def test_create_model_and_input_data(self): m, k, n = 16, 64, 32 - model, input_data = create_model_and_input( + model, input_data = create_model_and_input_data( model_type="linear", m=m, k=k, @@ -175,7 +177,7 @@ def test_create_model_and_input(self): self.assertIsInstance(model, ToyLinearModel) self.assertEqual(input_data.shape, (m, k)) - model, input_data = create_model_and_input( + model, input_data = create_model_and_input_data( model_type="ln_linear_sigmoid", m=m, k=k, @@ -183,7 +185,7 @@ def test_create_model_and_input(self): high_precision_dtype=torch.float32, device="cpu", ) - self.assertIsInstance(model, LNLinearSigmoid) + self.assertIsInstance(model, LNLinearActivationModel) self.assertEqual(input_data.shape, (m, k)) def test_generate_results_csv(self): diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 2fef1317fc..3907abfa89 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -137,30 +137,6 @@ def to_dict(self) -> Dict[str, Any]: return result_dict -class ToyLinearModel(torch.nn.Module): - def __init__(self, k=64, n=32, dtype=torch.bfloat16): - super().__init__() - self.linear1 = torch.nn.Linear(k, n, bias=False).to(dtype) - - def forward(self, x): - x = self.linear1(x) - return x - - -class LNLinearSigmoid(torch.nn.Module): - def __init__(self, fc_dim1, fc_dim2, dtype=torch.bfloat16): - super().__init__() - self.ln = torch.nn.LayerNorm(fc_dim1, elementwise_affine=False) - self.fc = torch.nn.Linear(fc_dim1, fc_dim2, bias=False).to(dtype) - self.sigmoid = torch.nn.Sigmoid() - - def forward(self, x): - x = self.ln(x) - x = self.fc(x) - x = self.sigmoid(x) - return x - - def string_to_config( quantization: Optional[str], sparsity: Optional[str], **kwargs ) -> AOBaseConfig: @@ -337,34 +313,6 @@ def model_inference_time_in_ms(model, input_data): return res * 1e6 -def create_model_and_input( - model_type: str, - m: int, - k: int, - n: int, - high_precision_dtype: torch.dtype = torch.bfloat16, - device: str = get_default_device(), -): - """Create a model and input data for benchmarking. - - Args: - model_type (str): type of the model to be created - batch_size (int): batch size of the input data - device (str): device to run the model on - high_precision_dtype (torch.dtype): data type of the model - m, k, n (int): dimensions of the model and input data - """ - if model_type == "linear": - model = ToyLinearModel(k, n, high_precision_dtype).to(device) - input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) - elif model_type == "ln_linear_sigmoid": - model = LNLinearSigmoid(k, n, high_precision_dtype).to(device) - input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) - else: - raise ValueError(f"Unknown model type: {model_type}") - return model, input_data - - def clean_caches(): import gc diff --git a/test/test_model_architecture.py b/test/test_model_architecture.py new file mode 100644 index 0000000000..973939a56a --- /dev/null +++ b/test/test_model_architecture.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from parameterized import parameterized + +from torchao.testing.model_architectures import create_model_and_input_data +from torchao.utils import get_available_devices + + +class TestModels(unittest.TestCase): + @parameterized.expand([(device,) for device in get_available_devices()]) + def test_toy_linear_model(self, device): + # Skip if device is not available + if device == "cuda" and not torch.cuda.is_available(): + self.skipTest("CUDA not available") + + model, input_data = create_model_and_input_data( + "linear", 10, 64, 32, device=device + ) + output = model(input_data) + self.assertEqual(output.shape, (10, 32)) + + @parameterized.expand([(device,) for device in get_available_devices()]) + def test_ln_linear_activation_model(self, device): + # Skip if device is not available + if device == "cuda" and not torch.cuda.is_available(): + self.skipTest("CUDA not available") + + model, input_data = create_model_and_input_data( + "ln_linear_sigmoid", 10, 64, 32, device=device + ) + output = model(input_data) + self.assertEqual(output.shape, (10, 32)) + + @parameterized.expand([(device,) for device in get_available_devices()]) + def test_transformer_block(self, device): + # Skip if device is not available + if device == "cuda" and not torch.cuda.is_available(): + self.skipTest("CUDA not available") + + model, input_data = create_model_and_input_data( + "transformer_block", 10, 64, 32, device=device + ) + output = model(input_data) + self.assertEqual(output.shape, (10, 16, 64)) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/testing/model_architectures.py b/torchao/testing/model_architectures.py new file mode 100644 index 0000000000..f59a1271b1 --- /dev/null +++ b/torchao/testing/model_architectures.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import re + +import torch +import torch.nn as nn + + +# TODO: Refactor torchao and tests to use these models +class ToyLinearModel(torch.nn.Module): + def __init__(self, k=64, n=32, dtype=torch.bfloat16): + super().__init__() + self.linear1 = torch.nn.Linear(k, n, bias=False).to(dtype) + + def forward(self, x): + x = self.linear1(x) + return x + + +class LNLinearActivationModel(nn.Module): + def __init__(self, fc_dim1, fc_dim2, dtype=torch.bfloat16, activation="sigmoid"): + super().__init__() + + activation = activation.lower() + activation_map = { + "relu": nn.ReLU(), + "sigmoid": nn.Sigmoid(), + "leakyrelu": nn.LeakyReLU(), + "relu6": nn.ReLU6(), + "gelu": nn.GELU(), + "silu": nn.SiLU(), + "hardswish": nn.Hardswish(), + } + + if activation not in activation_map: + raise ValueError(f"Unsupported activation: {activation}") + + self.ln = nn.LayerNorm(fc_dim1, elementwise_affine=False) + self.fc = nn.Linear(fc_dim1, fc_dim2, bias=False).to(dtype=dtype) + self.activation = activation_map[activation] + + def forward(self, x): + x = self.ln(x) + x = self.fc(x) + return self.activation(x) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +class TransformerBlock(torch.nn.Module): + def __init__(self, hidden_dim, num_heads=8, mlp_ratio=4, dtype=torch.bfloat16): + super().__init__() + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads + + # Self-attention + self.qkv = torch.nn.Linear(hidden_dim, 3 * hidden_dim, bias=False).to(dtype) + self.proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False).to(dtype) + + # MLP + self.mlp_ratio = mlp_ratio + self.mlp_hidden_dim = int(hidden_dim * mlp_ratio) + self.mlp_fc1 = torch.nn.Linear(hidden_dim, self.mlp_hidden_dim, bias=False).to( + dtype + ) + self.mlp_fc2 = torch.nn.Linear(self.mlp_hidden_dim, hidden_dim, bias=False).to( + dtype + ) + + # Layer norms + self.norm1 = RMSNorm(hidden_dim).to(dtype) + self.norm2 = RMSNorm(hidden_dim).to(dtype) + + # Activation + self.activation = torch.nn.GELU() + + def forward(self, x): + batch_size, seq_len, _ = x.shape + + # Self-attention + residual = x + x = self.norm1(x) + + # Reshape qkv projection for better memory layout + qkv = self.qkv(x) # [batch_size, seq_len, 3 * hidden_dim] + qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim) + qkv = qkv.permute( + 2, 0, 3, 1, 4 + ) # [3, batch_size, num_heads, seq_len, head_dim] + q, k, v = qkv # Each has shape [batch_size, num_heads, seq_len, head_dim] + + # Scaled dot-product attention with proper reshaping + # Reshape for better memory layout and avoid broadcasting issues + q = q.reshape(batch_size * self.num_heads, seq_len, self.head_dim) + k = k.reshape(batch_size * self.num_heads, seq_len, self.head_dim) + v = v.reshape(batch_size * self.num_heads, seq_len, self.head_dim) + + # Compute attention scores + attn = (q @ k.transpose(-2, -1)) * (1.0 / (self.head_dim**0.5)) + attn = torch.softmax(attn, dim=-1) + + # Apply attention to values + x = attn @ v # [batch_size * num_heads, seq_len, head_dim] + + # Reshape back to original dimensions + x = x.reshape(batch_size, self.num_heads, seq_len, self.head_dim) + x = x.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_dim) + + # Project back to hidden dimension + x = self.proj(x) + x = residual + x + + # MLP + residual = x + x = self.norm2(x) + x = self.mlp_fc1(x) + x = self.activation(x) + x = self.mlp_fc2(x) + x = residual + x + + return x + + +def create_model_and_input_data( + model_type: str, + m: int, + k: int, + n: int, + high_precision_dtype: torch.dtype = torch.bfloat16, + device: str = "cuda", + activation: str = "relu", +): + """Create a model and input data for benchmarking. + + Args: + model_type (str): type of the model to be created + batch_size (int): batch size of the input data + device (str): device to run the model on + high_precision_dtype (torch.dtype): data type of the model + m, k, n (int): dimensions of the model and input data + """ + if model_type == "linear": + model = ToyLinearModel(k, n, high_precision_dtype).to(device) + input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) + elif "ln_linear" in model_type: + # Extract activation type from model_type string + match = re.search(r"ln_linear_?(\w+)?", model_type) + activation = match.group(1) if match and match.group(1) else "relu" + model = LNLinearActivationModel( + k, n, high_precision_dtype, activation=activation + ).to(device) + input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) + elif model_type == "transformer_block": + # For transformer block, k is the hidden dimension + model = TransformerBlock( + k, num_heads=8, mlp_ratio=4, dtype=high_precision_dtype + ).to(device) + # Input shape for transformer is [batch_size, seq_len, hidden_dim] + input_data = torch.randn(m, 16, k, device=device, dtype=high_precision_dtype) + else: + raise ValueError(f"Unknown model type: {model_type}") + return model, input_data