diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 7503ad73c..1132380be 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -1555,9 +1555,9 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] =
 
 def optimizer_update_32bit(
     optimizer_name: str,
-    g: Tensor,
-    p: Tensor,
-    state1: Tensor,
+    g: torch.Tensor,
+    p: torch.Tensor,
+    state1: torch.Tensor,
     beta1: float,
     eps: float,
     step: int,
@@ -1571,6 +1571,7 @@ def optimizer_update_32bit(
     unorm_vec: Optional[torch.Tensor] = None,
     max_unorm: float = 0.0,
     skip_zeros=False,
+    return_updates: Optional[torch.Tensor] = None,
 ) -> None:
     """
     Performs an inplace optimizer update with one or two optimizer states.
@@ -1613,6 +1614,8 @@ def optimizer_update_32bit(
         The maximum update norm relative to the weight norm.
     skip_zeros : bool
         Whether to skip zero-valued gradients or not (default: False).
+    return_updates: Optional[torch.Tensor]
+        When provided, updates are written to this tensor and not applied directly to `p`. (default: None)
     """
 
     param_norm = 0.0
@@ -1636,6 +1639,7 @@ def optimizer_update_32bit(
     optim_func(
         get_ptr(g),
         get_ptr(p),
+        get_ptr(return_updates),
         get_ptr(state1),
         get_ptr(state2),
         get_ptr(unorm_vec),
@@ -1658,25 +1662,26 @@ def optimizer_update_32bit(
 
 def optimizer_update_8bit(
     optimizer_name: str,
-    g: Tensor,
-    p: Tensor,
-    state1: Tensor,
+    g: torch.Tensor,
+    p: torch.Tensor,
+    state1: torch.Tensor,
     state2: Optional[torch.Tensor],
     beta1: float,
     beta2: float,
     eps: float,
     step: int,
     lr: float,
-    qmap1: Tensor,
+    qmap1: torch.Tensor,
     qmap2: Optional[torch.Tensor],
-    max1: Tensor,
+    max1: torch.Tensor,
     max2: Optional[torch.Tensor],
-    new_max1: Tensor,
+    new_max1: torch.Tensor,
     new_max2: Optional[torch.Tensor],
     weight_decay: float = 0.0,
     gnorm_scale: float = 1.0,
     unorm_vec: Optional[torch.Tensor] = None,
     max_unorm: float = 0.0,
+    return_updates: Optional[torch.Tensor] = None,
 ) -> None:
     """
     Performs an inplace Adam update.
@@ -1726,6 +1731,8 @@ def optimizer_update_8bit(
         The tensor for the update norm.
     max_unorm : float
         The maximum update norm relative to the weight norm.
+    return_updates: Optional[torch.Tensor]
+        When provided, updates are written to this tensor and not applied directly to `p`. (default: None)
     """
 
     param_norm = 0.0
@@ -1738,6 +1745,7 @@ def optimizer_update_8bit(
         str2optimizer8bit[optimizer_name][0](
             get_ptr(p),
             get_ptr(g),
+            get_ptr(return_updates),
             get_ptr(state1),
             get_ptr(state2),
             get_ptr(unorm_vec),
@@ -1762,6 +1770,7 @@ def optimizer_update_8bit(
         str2optimizer8bit[optimizer_name][1](
             get_ptr(p),
             get_ptr(g),
+            get_ptr(return_updates),
             get_ptr(state1),
             get_ptr(state2),
             get_ptr(unorm_vec),
@@ -1809,6 +1818,7 @@ def optimizer_update_8bit_blockwise(
     weight_decay: float = 0.0,
     gnorm_scale: float = 1.0,
     skip_zeros=False,
+    return_updates: Optional[torch.Tensor] = None,
 ) -> None:
     optim_func = None
     prev_device = pre_call(g.device)
@@ -1835,6 +1845,7 @@ def optimizer_update_8bit_blockwise(
     optim_func(
         get_ptr(p),
         get_ptr(g),
+        get_ptr(return_updates),
         get_ptr(state1),
         get_ptr(state2),
         ct.c_float(beta1),
diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py
index 07174c38d..05252246b 100644
--- a/bitsandbytes/optim/__init__.py
+++ b/bitsandbytes/optim/__init__.py
@@ -9,6 +9,7 @@
     AdamW,
     AdamW8bit,
     AdamW32bit,
+    GaLoreAdamW8bit,
     PagedAdamW,
     PagedAdamW8bit,
     PagedAdamW32bit,
diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py
index 4bf3f6436..a3c6402da 100644
--- a/bitsandbytes/optim/adamw.py
+++ b/bitsandbytes/optim/adamw.py
@@ -2,7 +2,17 @@
 #
 # This source code is licensed under the MIT license found in the
 # LICENSE file in the root directory of this source tree.
-from bitsandbytes.optim.optimizer import Optimizer2State
+import torch
+
+from bitsandbytes.optim.optimizer import GaLoreWrappedParameter, Optimizer2State
+
+_galore_available = False
+try:
+    from galore_torch.galore_projector import GaLoreProjector
+
+    _galore_available = True
+except ImportError:
+    pass
 
 
 class AdamW(Optimizer2State):
@@ -127,6 +137,117 @@ def __init__(
         )
 
 
+class GaLoreAdamW8bit(Optimizer2State):
+    def __init__(
+        self,
+        params,
+        lr=1e-3,
+        betas=(0.9, 0.999),
+        eps=1e-8,
+        weight_decay=1e-2,
+        amsgrad=False,
+        optim_bits=8,
+        args=None,
+        min_8bit_size=4096,
+        percentile_clipping=100,
+        block_wise=True,
+        is_paged=False,
+    ):
+        if not _galore_available:
+            raise RuntimeError("The galore_torch package must be installed to use GaLoreAdamW8bit.")
+        super().__init__(
+            "adam",
+            params,
+            lr,
+            betas,
+            eps,
+            weight_decay,
+            optim_bits,
+            args,
+            min_8bit_size,
+            percentile_clipping,
+            block_wise,
+            is_paged=is_paged,
+        )
+
+    @torch.no_grad()
+    def step(self, closure=None):
+        """Performs a single optimization step.
+
+        Arguments:
+            closure (callable, optional): A closure that reevaluates the model
+                and returns the loss.
+        """
+        loss = None
+        if closure is not None:
+            with torch.enable_grad():
+                loss = closure()
+
+        overflows = []
+
+        if not self.initialized:
+            self.check_overrides()
+            self.to_gpu()  # needed for fairseq pure fp16 training
+            self.initialized = True
+
+        # if self.is_paged: self.page_mng.prefetch_all()
+        for gindex, group in enumerate(self.param_groups):
+            for pindex, p in enumerate(group["params"]):
+                if p.grad is None:
+                    continue
+                state = self.state[p]
+
+                if "step" not in state:
+                    state["step"] = 0
+
+                # GaLore Projection
+                if "rank" in group:
+                    if "projector" not in state:
+                        state["projector"] = GaLoreProjector(
+                            group["rank"],
+                            update_proj_gap=group["update_proj_gap"],
+                            scale=group["scale"],
+                            proj_type=group["proj_type"],
+                        )
+
+                    grad = state["projector"].project(p.grad, state["step"])
+
+                    # suboptimal implementation
+                    # p.saved_data = p.data.clone()
+                    # p.data = grad.clone().to(p.data.dtype).to(p.data.device)
+                    # p.data.zero_()
+                    # p.grad = grad
+                    lor_update = torch.zeros_like(
+                        grad, dtype=p.data.dtype, device=p.data.device, requires_grad=grad.requires_grad
+                    )
+
+                if "state1" not in state:
+                    self.init_state(group, p, gindex, pindex)
+
+                self.prefetch_state(p)
+
+                if "rank" in group:
+                    galore_p = GaLoreWrappedParameter(p=p, grad=grad)
+                    self.update_step(group, galore_p, gindex, pindex, return_updates=lor_update)
+
+                    # GaLore Projection Back
+                    p.data.add_(state["projector"].project_back(lor_update))
+
+                    if "weight_decay" in group and group["weight_decay"] > 0:
+                        p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"])
+                else:
+                    self.update_step(group, p, gindex, pindex)
+
+                torch.cuda.synchronize()
+
+        if self.is_paged:
+            # all paged operation are asynchronous, we need
+            # to sync to make sure all tensors are in the right state
+            torch.cuda.synchronize()
+
+        return loss
+
+
 class AdamW32bit(Optimizer2State):
     def __init__(
         self,
diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py
index 03e0e01d7..2c0f77295 100644
--- a/bitsandbytes/optim/optimizer.py
+++ b/bitsandbytes/optim/optimizer.py
@@ -4,8 +4,9 @@
 # LICENSE file in the root directory of this source tree.
 from collections import abc as container_abcs, defaultdict
 from copy import deepcopy
+from dataclasses import dataclass
 from itertools import chain
-from typing import Optional
+from typing import Any, Dict, Optional, Union
 
 import torch
 
@@ -18,6 +19,12 @@ def __init__(self, initial_data):
             setattr(self, key, initial_data[key])
 
 
+@dataclass
+class GaLoreWrappedParameter:
+    p: torch.Tensor
+    grad: torch.Tensor
+
+
 class GlobalOptimManager:
     """
     A global optimizer manager for enabling custom optimizer configs.
@@ -320,7 +327,7 @@ def get_config(self, gindex, pindex, group):
     def init_state(self, group, p, gindex, pindex):
         raise NotImplementedError("init_state method needs to be overridden")
 
-    def update_step(self, group, p, gindex, pindex):
+    def update_step(self, group, p, gindex, pindex, return_updates):
         raise NotImplementedError("The update_step method needs to be overridden")
 
     def get_state_buffer(self, p, dtype=torch.float32):
@@ -494,13 +501,25 @@ def init_state(self, group, p, gindex, pindex):
             state["unorm_vec"] = torch.zeros((1,), device=p.device)
 
     @torch.no_grad()
-    def update_step(self, group, p, gindex, pindex):
-        # avoid update error from non-contiguous memory layout
-        p.data = p.data.contiguous()
-        p.grad = p.grad.contiguous()
+    def update_step(
+        self,
+        group: Dict[str, Any],
+        p: Union[torch.Tensor, GaLoreWrappedParameter],
+        gindex: int,
+        pindex: int,
+        return_updates: Optional[torch.Tensor] = None,
+    ):
+        if isinstance(p, GaLoreWrappedParameter):
+            # Unwrap for GaLore
+            param_to_optimize = p.p
+        else:
+            param_to_optimize = p
 
-        state = self.state[p]
-        grad = p.grad
+        state = self.state[param_to_optimize]
+
+        # avoid update error from non-contiguous memory layout
+        param_to_optimize.data = param_to_optimize.data.contiguous()
+        grad = p.grad.contiguous()
 
         config = self.get_config(gindex, pindex, group)
 
@@ -521,7 +540,7 @@ def update_step(self, group, p, gindex, pindex):
             F.optimizer_update_32bit(
                 self.optimizer_name,
                 grad,
-                p,
+                param_to_optimize,
                 state["state1"],
                 config["betas"][0],
                 config["eps"],
@@ -536,13 +555,14 @@ def update_step(self, group, p, gindex, pindex):
                 state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
                 max_unorm=config["max_unorm"],
                 skip_zeros=config["skip_zeros"],
+                return_updates=return_updates,
             )
 
         elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
             F.optimizer_update_8bit(
                 self.optimizer_name,
                 grad,
-                p,
+                param_to_optimize,
                 state["state1"],
                 state["state2"],
                 config["betas"][0],
@@ -560,6 +580,7 @@ def update_step(self, group, p, gindex, pindex):
                 gnorm_scale=gnorm_scale,
                 unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
                 max_unorm=config["max_unorm"],
+                return_updates=return_updates,
             )
 
             # swap maxes
@@ -569,7 +590,7 @@ def update_step(self, group, p, gindex, pindex):
             F.optimizer_update_8bit_blockwise(
                 self.optimizer_name,
                 grad,
-                p,
+                param_to_optimize,
                 state["state1"],
                 state["state2"],
                 config["betas"][0],
@@ -586,6 +607,7 @@ def update_step(self, group, p, gindex, pindex):
                 config["weight_decay"],
                 gnorm_scale=gnorm_scale,
                 skip_zeros=config["skip_zeros"],
+                return_updates=return_updates,
             )
 
 
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index 867390f2c..2ff520dd2 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -68,27 +68,6 @@ __device__ float dDequantizeFP4(unsigned char val, float absmax)
   }
 }
 
-__device__ float d2DequantizeFP4(unsigned char val)
-{
-  float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
-  if((val & 0b0110) == 0)
-  {
-    // subnormal
-    if((val & 0b0001) == 0)
-      return 0.0f;
-    else
-      return sign*0.0625f;
-  }
-  else
-  {
-    // normal
-    float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f);
-    float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f;
-
-    return sign*exponent*fraction;
-  }
-}
-
 __device__ float dDequantizeFP4Tree(unsigned char val, float absmax)
 {
   float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
@@ -165,60 +144,6 @@ __device__ unsigned char dQuantizeFP4(float x)
         return 0b0000+sign;
 }
 
-__device__ half dhDequantizeNF4(unsigned char val)
-{
-  // the values for this tree was generated by test_normal_map_tree
-  // in the file tests/test_functional.py
-  if((val & 0b1000) == 8)
-    if((val & 0b0100) == 4) // 1
-      if((val & 0b0010) == 2) // 11
-        if((val & 0b0001) == 1) // 111
-          return 1.0f;
-        else
-          return 0.7229568362236023f;
-      else
-        if((val & 0b0001) == 1) // 110
-          return 0.5626170039176941f;
-        else
-          return 0.44070982933044434f;
-    else
-      if((val & 0b0010) == 2) //10
-        if((val & 0b0001) == 1) // 101
-          return 0.33791524171829224f;
-        else
-          return 0.24611230194568634f;
-      else
-        if((val & 0b0001) == 1) // 100
-          return 0.16093020141124725f;
-        else
-          return 0.07958029955625534f;
-
-  else
-    if((val & 0b0100) == 4) // 0
-      if((val & 0b0010) == 2) //01
-        if((val & 0b0001) == 1) // 011
-          return 0.0f;
-        else
-          return -0.09105003625154495f;
-      else
-        if((val & 0b0001) == 1) // 010
-          return -0.18477343022823334f;
-        else
-          return -0.28444138169288635f;
-    else
-      if((val & 0b0010) == 2) //00
-        if((val & 0b0001) == 1) // 001
-          return -0.39491748809814453f;
-        else
-          return -0.5250730514526367f;
-      else
-        if((val & 0b0001) == 1) // 000
-          return -0.6961928009986877f;
-        else
-          return -1.0f;
-
-}
-
 __device__ float dDequantizeNF4(unsigned char val)
 {
 
@@ -872,7 +797,7 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
 
 template<typename T, int OPTIMIZER>
 __launch_bounds__(TH, 1)
-__global__ void kOptimizer32bit2State(T* g, T* p,
+__global__ void kOptimizer32bit2State(T* g, T* p, T* return_updates,
                 float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
                 const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,
                 const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
@@ -931,7 +856,7 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
       __syncthreads();
       LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items);
       __syncthreads();
-      Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items);
+      Load(temp_storage.load).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items);
 
       // Load additional state1 data for AdEMAMix
       // TODO: Make constexpr after updating min compiler
@@ -975,17 +900,22 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
 									{
 										s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
 										s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
-										p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2))));
 
-                    if(weight_decay > 0.0f)
-                        p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
+                    if (return_updates == nullptr) {
+                      p_vals[j] = (T)(((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2)))));
+
+                      if(weight_decay > 0.0f)
+                          p_vals[j] = (T)(((float)p_vals[j])*(1.0f-(lr*weight_decay)));
+                    } else {
+                      p_vals[j] = (T)(update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2))));
+                    }
 									}
                   break;
           }
       }
 
       __syncthreads();
-      Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items);
+      Store(temp_storage.store).Store(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items);
       __syncthreads();
       StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items);
       __syncthreads();
@@ -1081,7 +1011,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
 
 template<typename T, int OPTIMIZER>
 __launch_bounds__(TH, 1)
-__global__ void kOptimizer32bit1State(T *g, T *p,
+__global__ void kOptimizer32bit1State(T *g, T *p, T *return_updates,
                 float *state1, float *unorm, const float max_unorm, const float param_norm,
                 const float beta1, const float beta2, const float eps, const float weight_decay,
                 const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
@@ -1127,13 +1057,13 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
       __syncthreads();
       LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items);
       __syncthreads();
-      Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items);
+      Load(temp_storage.load).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items);
 
       # pragma unroll 4
       for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
       {
         g_vals[j] = gnorm_scale*((float)g_vals[j]);
-        if(weight_decay > 0.0f)
+        if(weight_decay > 0.0f && return_updates == nullptr)
           g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay);
       }
 
@@ -1150,26 +1080,26 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
 										else
 											s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
 
-										p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
+										p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) + update_scale*(-lr*(s1_vals[j]));
 										break;
 								case LION:
-										p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))));
+										p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))));
 										s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j]));
 										break;
 								case RMSPROP:
 										s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
-										p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
+										p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
 										break;
 								case ADAGRAD:
 										s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]);
-										p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps);
+										p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps);
 										break;
 						}
 					}
       }
 
       __syncthreads();
-      Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items);
+      Store(temp_storage.store).Store(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items);
       __syncthreads();
       StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items);
   }
@@ -1298,7 +1228,7 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c
 template<typename T, int OPTIMIZER>
 __global__ void
 __launch_bounds__(NUM_THREADS2, 1)
-kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2,
+kOptimizerStatic8bit2State(T* p, T* const g, T* return_updates, unsigned char* state1, unsigned char* state2,
                 const float *unorm, const float max_unorm, const float param_norm, \
                 const float beta1, const float beta2,
                 const float eps, const int step, const float lr,
@@ -1369,7 +1299,7 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha
         __syncthreads();
         LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0);
         __syncthreads();
-        LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items);
+        LoadT(temp_storage.loadh).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items);
 
         if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; }
 
@@ -1404,12 +1334,16 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha
         # pragma unroll 4
         for(unsigned int j = 0; j < NUM_PER_THREAD2; j++)
         {
-            p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps))))));
-            if(weight_decay > 0.0f)
-                p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay));
+            if (return_updates == nullptr) {
+              p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps))))));
+              if(weight_decay > 0.0f)
+                  p_vals[j] = (T)(update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay)));
+            } else {
+              p_vals[j] = (T)((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps)))));
+            }
         }
 
-        StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
+        StoreT(temp_storage.storeh).Store(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items);
         __syncthreads();
         StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
         __syncthreads();
@@ -1513,7 +1447,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
 template<typename T, int OPTIMIZER>
 __global__ void
 __launch_bounds__(1024, 1)
-kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
+kOptimizerStatic8bit1State(T* p, T* const g, T* return_updates, unsigned char* state1,
                 const float *unorm, const float max_unorm, const float param_norm,
                 const float beta1, const float beta2,
                 const float eps, const int step, const float lr,
@@ -1569,7 +1503,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
         __syncthreads();
         LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
         __syncthreads();
-        LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items);
+        LoadT(temp_storage.loadh).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items);
 
         if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; }
 
@@ -1579,7 +1513,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
             g_val = float(g_vals[j]);
             g_val *= gnorm_scale;
 
-            if(weight_decay > 0.0f) {
+            if(weight_decay > 0.0f && return_updates == nullptr) {
               switch(OPTIMIZER) {
 		case ADAGRAD:
                 case MOMENTUM:
@@ -1602,15 +1536,15 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
                   else
                     s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
 
-                  p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
+                  p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) + (-lr*update_scale*(s1_vals[j]));
                   break;
               case LION:
-                  p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val))));
+                  p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val))));
                   s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
                   break;
               case RMSPROP:
                   s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
-                  p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
+                  p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
                   break;
             }
 
@@ -1626,7 +1560,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
             }
         }
 
-        StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
+        StoreT(temp_storage.storeh).Store(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items);
         __syncthreads();
         StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
         __syncthreads();
@@ -1687,6 +1621,7 @@ __global__ void
 kOptimizerStatic8bit2StateBlockwise(
     T* p,
     T* __restrict__ const g,
+    T* __restrict__ return_updates,
     unsigned char* state1,
     unsigned char* state2,
     const float beta1,
@@ -1881,7 +1816,7 @@ kOptimizerStatic8bit2StateBlockwise(
         }
 
         __syncthreads();
-        LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f);
+        LoadT(temp_storage.loadh).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items, (T)0.0f);
         //  reduce: 2.67/1.69 -> 2.67/1.70
         # pragma unroll N_PER_TH
         for(unsigned int j = 0; j < N_PER_TH; j++)
@@ -1895,18 +1830,24 @@ kOptimizerStatic8bit2StateBlockwise(
                     (sqrtf(s2_vals[j]) / correction2) + eps
                   )
                 ));
+
+                if (weight_decay > 0.0f)
+                  p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
               } else {
-                p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
+                if (return_updates == nullptr) {
+                  p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
+                  if(weight_decay > 0.0f)
+                      p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
+                } else {
+                  p_vals[j] = (T)(step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))));
+                }
               }
-
-              if(weight_decay > 0.0f)
-									p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
-						}
+            }
         }
 
         //  store: 0.85/1.44 -> 2.48/1.57
         __syncthreads();
-        StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
+        StoreT(temp_storage.storeh).Store(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items);
 
         //  quantizaztion: 2.67/1.70  -> 3.4/3.3
         # pragma unroll N_PER_TH
@@ -1952,7 +1893,7 @@ kOptimizerStatic8bit2StateBlockwise(
 template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
 __launch_bounds__(256, 3)
 __global__ void
-kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1,
+kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, T* return_updates, unsigned char* state1,
                 const float beta1, const float beta2,
                 const float eps, const int step, const float lr,
                 float* __restrict__ const quantiles1,
@@ -2016,7 +1957,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
         __syncthreads();
         LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
         __syncthreads();
-        LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f);
+        LoadT(temp_storage.loadh).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items, (T)0.0f);
 
         new_local_abs_max1 = -FLT_MAX;
 
@@ -2028,7 +1969,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
             g_val *= gnorm_scale;
             if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
             {
-              if(weight_decay > 0.0f) {
+              if(weight_decay > 0.0f && return_updates == nullptr) {
                 switch(OPTIMIZER) {
                   case MOMENTUM:
                   case ADAGRAD:
@@ -2091,18 +2032,18 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
 							switch(OPTIMIZER)
 							{
 									case MOMENTUM:
-										p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
+										p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*(s1_vals[j]);
 										break;
 									case LION:
-										p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]);
+										p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - ((float)g_vals[j]);
 										break;
 									case RMSPROP:
 										g_val = g_vals[j];
-										p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
+										p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
 										break;
 									case ADAGRAD:
 										g_val = g_vals[j];
-										p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
+										p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
 										break;
 							}
 						}
@@ -3841,7 +3782,7 @@ MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
 MAKE_PreconditionOptimizer32bit1State(ADAGRAD, __nv_bfloat16)
 
 #define MAKE_Optimizer32bit1State(oname, gtype) \
-template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
+template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, gtype* return_updates, float* state1, float *unorm, const float max_unorm, const float param_norm, \
     const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
 
 MAKE_Optimizer32bit1State(MOMENTUM, half)
@@ -3870,17 +3811,17 @@ MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, float)
 MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, half)
 MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, __nv_bfloat16)
 
-template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
+template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
     const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
-template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
+template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, half* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
     const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
-template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
+template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, __nv_bfloat16* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
     const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
-template __global__ void kOptimizer32bit2State<float, ADEMAMIX>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
+template __global__ void kOptimizer32bit2State<float, ADEMAMIX>(float* g, float* p, float* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
     const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
-template __global__ void kOptimizer32bit2State<half, ADEMAMIX>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
+template __global__ void kOptimizer32bit2State<half, ADEMAMIX>(half* g, half* p, half* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
     const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
-template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
+template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>(__nv_bfloat16* g, __nv_bfloat16* p, __nv_bfloat16* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
     const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
 
 
@@ -3906,7 +3847,7 @@ MAKE_PreconditionStatic8bit1State(ADAGRAD, half)
 MAKE_PreconditionStatic8bit1State(ADAGRAD, float)
 
 #define MAKE_optimizerStatic8bit1State(oname, gtype) \
-template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1,  \
+template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, gtype* return_updates, unsigned char* state1,  \
                 const float *unorm, const float max_unorm, const float param_norm, \
                 const float beta1,  \
                 const float beta2,  \
@@ -3941,7 +3882,8 @@ MAKE_PreconditionStatic8bit2State(ADAM, half)
 MAKE_PreconditionStatic8bit2State(ADAM, float)
 
 #define MAKE_optimizerStatic8bit2State(oname, gtype) \
-template __global__ void kOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \
+template __global__ void kOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* const g, gtype* return_updates, \
+                unsigned char* state1, unsigned char* state2, \
                 const float *unorm, const float max_unorm, const float param_norm, \
                 const float beta1, const float beta2, \
                 const float eps, const int step, const float lr, \
@@ -4041,7 +3983,9 @@ template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General
 template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n);
 
 #define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
-template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \
+template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>( \
+                gtype* p, gtype* __restrict__ const g, gtype* __restrict__ return_updates, \
+                unsigned char* state1, unsigned char* state2, \
                 const float beta1, const float beta2, const float beta3, const float alpha, \
                 const float eps, const int step, const float lr, \
                 float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
@@ -4058,7 +4002,7 @@ MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 256, 1)
 
 #define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \
 template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block_size, num_per_thread>( \
-		gtype* p, gtype* __restrict__ const g, unsigned char* state1, \
+		gtype* p, gtype* __restrict__ const g, gtype* return_updates, unsigned char* state1, \
                 const float beta1, const float beta2, \
                 const float eps, const int step, const float lr, \
                 float* __restrict__ const quantiles1, \
diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh
index ec6daebe5..376639993 100644
--- a/csrc/kernels.cuh
+++ b/csrc/kernels.cuh
@@ -25,7 +25,7 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
                 const int step, const float lr, const float gnorm_scale, const int n);
 
 template<typename T, int OPTIMIZER>
-__global__ void kOptimizer32bit2State(T* g, T* p,
+__global__ void kOptimizer32bit2State(T* g, T* p, T* return_updates,
                 float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
                 const float beta1, const float beta2, const float beta3, const float alpha,
                 const float eps, const float weight_decay,
@@ -38,7 +38,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
                 const int step, const float lr, const float gnorm_scale, const int n);
 
 template<typename T, int OPTIMIZER>
-__global__ void kOptimizer32bit1State(T* g, T* p,
+__global__ void kOptimizer32bit1State(T* g, T* p, T* return_updates,
                 float* state1,  float *unorm, const float max_unorm, const float param_norm,
                 const float beta1, const float beta2, const float eps, const float weight_decay,
                 const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
@@ -57,7 +57,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
 
 template<typename T, int OPTIMIZER>
 __global__ void
-kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
+kOptimizerStatic8bit1State(T* p, T* const g, T* return_updates, unsigned char* state1,
                 const float *unorm, const float max_unorm, const float param_norm,
                 const float beta1, const float beta2,
                 const float eps, const int step, const float lr,
@@ -80,7 +80,7 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c
 
 template<typename T, int OPTIMIZER>
 __global__ void
-kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2,
+kOptimizerStatic8bit2State(T* p, T* const g, T* return_updates, unsigned char* state1, unsigned char* state2,
                 const float *unorm, const float max_unorm, const float param_norm,
                 const float beta1, const float beta2,
                 const float eps, const int step, const float lr,
@@ -89,13 +89,14 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha
                 float weight_decay, const float gnorm_scale, const int n);
 
 template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit2StateBlockwise(
-		T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2,
+                T* p, T* __restrict__ const g, T* __restrict__ return_updates,
+                unsigned char* state1, unsigned char* state2,
                 const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr,
                 float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
                 float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n);
 
 template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit1StateBlockwise(
-		T* p, T* __restrict__ const g, unsigned char* state1,
+		T* p, T* __restrict__ const g, T* return_updates, unsigned char* state1,
                 const float beta1, const float beta2,
                 const float eps, const int step, const float lr,
                 float* __restrict__ const quantiles1,
diff --git a/csrc/ops.cu b/csrc/ops.cu
index 7ca854baf..e3c99a875 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -92,7 +92,7 @@ template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsign
 
 
 
-template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
+template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, T* return_updates,
                 float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
                 const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,
                 const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n)
@@ -109,7 +109,7 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
         kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
         CUDA_CHECK_RETURN(cudaPeekAtLastError());
       }
-			kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
+			kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
       CUDA_CHECK_RETURN(cudaPeekAtLastError());
 			break;
 		case MOMENTUM:
@@ -122,12 +122,12 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
         CUDA_CHECK_RETURN(cudaPeekAtLastError());
 			}
 
-			kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
+			kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
       CUDA_CHECK_RETURN(cudaPeekAtLastError());
 			break;
     case LION:
       // in lion, the momentum update after the parameter update
-      kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
+      kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
       CUDA_CHECK_RETURN(cudaPeekAtLastError());
 
       if(max_unorm > 0.0f)
@@ -140,7 +140,7 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
 	}
 }
 
-template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
+template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, T* return_updates,
                 unsigned char* state1, unsigned char* state2,
                 float *unorm, float max_unorm, float param_norm,
                 float beta1, float beta2,
@@ -162,7 +162,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
 			CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float)));
 			kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
 			CUDA_CHECK_RETURN(cudaPeekAtLastError());
-			kOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
+			kOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
 																														quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n);
 			CUDA_CHECK_RETURN(cudaPeekAtLastError());
 		break;
@@ -172,13 +172,13 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
 			CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
 			kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
 			CUDA_CHECK_RETURN(cudaPeekAtLastError());
-			kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
+			kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
 																														quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
 			CUDA_CHECK_RETURN(cudaPeekAtLastError());
 			break;
     case LION:
       // in lion, the momentum update happens after the parameter update
-      kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
+      kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
                                                             quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
       CUDA_CHECK_RETURN(cudaPeekAtLastError());
 
@@ -199,6 +199,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
 template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
     T* p,
     T* g,
+    T* return_updates,
     unsigned char* state1,
     unsigned char* state2,
     float beta1,
@@ -226,7 +227,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
 			num_blocks = n/BLOCKSIZE_2STATE;
 			num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1;
 			kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(
-				p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr,
+				p, g, return_updates, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr,
 				quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale,
 				skip_zeros, n
 			);
@@ -238,7 +239,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
     case LION:
 			num_blocks = n/BLOCKSIZE_1STATE;
 			num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
-			kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
+			kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, return_updates, state1, beta1, beta2, eps, step, lr,
 																														quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
 			CUDA_CHECK_RETURN(cudaPeekAtLastError());
 		break;
@@ -807,7 +808,7 @@ template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char
 template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream);
 
 #define MAKE_optimizer32bit(name, gtype) \
-template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
+template void optimizer32bit<gtype, name>(gtype* g, gtype* p, gtype* return_updates, \
                 float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \
                 const float beta1, const float beta2, const float beta3, const float alpha, \
                 const float eps, const float weight_decay, \
@@ -833,7 +834,8 @@ MAKE_optimizer32bit(ADEMAMIX, __nv_bfloat16)
 MAKE_optimizer32bit(ADEMAMIX, float)
 
 #define MAKE_optimizerStatic8bit(name, gtype) \
-template void optimizerStatic8bit<gtype, name>(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
+template void optimizerStatic8bit<gtype, name>(gtype* p, gtype* g, gtype* return_updates, \
+                unsigned char* state1, unsigned char* state2, \
                 float *unorm, float max_unorm, float param_norm, \
                 float beta1, float beta2, \
                 float eps, int step, float lr,  \
@@ -855,7 +857,7 @@ MAKE_optimizerStatic8bit(ADAGRAD, float)
 
 
 #define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
-template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
+template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, gtype* return_updates, \
                 unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr,  \
                 float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \
 
diff --git a/csrc/ops.cuh b/csrc/ops.cuh
index b0ecc4622..f61de4095 100644
--- a/csrc/ops.cuh
+++ b/csrc/ops.cuh
@@ -148,12 +148,13 @@ void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t s
 template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
 template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n, cudaStream_t stream);
 
-template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
+template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, T* return_updates,
                 float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
                 float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay,
                 int step, float lr, const float gnorm_scale, bool skip_zeros, int n);
 
-template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2,
+template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, T* return_updates,
+                unsigned char* state1, unsigned char* state2,
                 float *unorm, float max_unorm, float param_norm,
                 float beta1, float beta2,
                 float eps, int step, float lr,
@@ -162,10 +163,11 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigne
                 float weight_decay,
                 const float gnorm_scale, int n);
 
-template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
+
+template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g, T* return_updates,
                 unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr,
                 float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale,
-								bool skip_zeros, int n);
+                bool skip_zeros, int n);
 
 template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n);
 
diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp
index f0ee84c29..14ceb17b8 100644
--- a/csrc/pythonInterface.cpp
+++ b/csrc/pythonInterface.cpp
@@ -50,12 +50,12 @@ MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
 
 
 #define MAKE_FUNC32(fname, oname, gtype, gbits) \
-void fname##32bit_grad_##gbits(gtype *g, gtype *p, \
+void fname##32bit_grad_##gbits(gtype *g, gtype *p, gtype *return_updates, \
                float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
                const float beta1, const float beta2, const float beta3, const float alpha, \
 			   const float eps, const float weight_decay, \
                const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \
-{ optimizer32bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
+{ optimizer32bit<gtype, oname>(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
 
 MAKE_FUNC32(momentum, MOMENTUM, float, 32)
 MAKE_FUNC32(momentum, MOMENTUM, half, 16)
@@ -75,7 +75,7 @@ MAKE_FUNC32(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
 
 
 #define MAKE_FUNC8(fname, oname, gtype, gbits) \
-void fname##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
+void fname##_static_8bit_grad_##gbits(gtype* p, gtype* g, gtype* return_updates, unsigned char* state1, unsigned char* state2, \
 								float *unorm, float max_unorm, float param_norm, \
                 float beta1, float beta2, \
                 float eps, int step, float lr,  \
@@ -83,7 +83,7 @@ void fname##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1,
                 float* max1, float* max2, float* new_max1, float* new_max2, \
                 float weight_decay, float gnorm_scale, int n) \
 {  \
-	optimizerStatic8bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
+	optimizerStatic8bit<gtype, oname>(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
 			                                  quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
 } \
 
@@ -97,10 +97,10 @@ MAKE_FUNC8(lion, LION, float, 32)
 MAKE_FUNC8(lion, LION, half, 16)
 
 #define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
-void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \
+void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, gtype* return_updates, \
                 unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \
                 float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\
-{	optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\
+{	optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, return_updates, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\
 
 MAKE_BLOCKWISE8(adam, ADAM, half, fp16)
 MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
@@ -233,12 +233,13 @@ extern "C"
   void cdequantize_blockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream); }
 
 	#define MAKE_CFUNC32(name, gtype, gbits) \
-	void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \
+	void c##name##32bit_grad_##gbits(gtype *g, gtype *p, gtype *return_updates, \
 								 float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
 								 const float beta1, const float beta2, const float beta3, const float alpha, \
 								 const float eps, const float weight_decay, \
 								 const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \
-	{ name##32bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
+	{ name##32bit_grad_##gbits(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
+
 
 	MAKE_CFUNC32(adam, float, fp32)
 	MAKE_CFUNC32(adam, half, fp16)
@@ -257,7 +258,8 @@ extern "C"
 	MAKE_CFUNC32(ademamix, __nv_bfloat16, bf16)
 
 	#define MAKE_CFUNC8(name, gtype, gbits) \
-	void c##name##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
+	void c##name##_static_8bit_grad_##gbits(gtype* p, gtype* g, gtype* return_updates, \
+				unsigned char* state1, unsigned char* state2, \
                 float *unorm, float max_unorm, float param_norm, \
                 float beta1, float beta2, \
                 float eps, int step, float lr,  \
@@ -265,7 +267,7 @@ extern "C"
                 float* max1, float* max2, float* new_max1, float* new_max2, \
                 float weight_decay, float gnorm_scale, int n) \
   {  \
-	    name##_static_8bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
+	    name##_static_8bit_grad_##gbits(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
 			                                 quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
   } \
 
@@ -279,10 +281,11 @@ extern "C"
 	MAKE_CFUNC8(lion, half, 16)
 
   #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
-  void c##fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \
+  void c##fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, gtype* return_updates, \
                 unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr,  \
                 float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \
-  {	fname##_8bit_blockwise_grad_##gbits(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
+  {	fname##_8bit_blockwise_grad_##gbits(p, g, return_updates, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
+
 
 	MAKE_CBLOCKWISE8(adam, ADAM, half, fp16)
 	MAKE_CBLOCKWISE8(adam, ADAM, float, fp32)
diff --git a/setup.py b/setup.py
index 3a1bcb574..434a2eaf4 100644
--- a/setup.py
+++ b/setup.py
@@ -14,7 +14,7 @@
 
 
 def read(fname):
-    return open(os.path.join(os.path.dirname(__file__), fname)).read()
+    return open(os.path.join(os.path.dirname(__file__), fname), encoding="utf8").read()
 
 
 # Tested with wheel v0.29.0