Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Deserialize torch.Tensors to the correct device #50134

Closed
stephanie-wang opened this issue Jan 29, 2025 · 21 comments
Closed

[core] Deserialize torch.Tensors to the correct device #50134

stephanie-wang opened this issue Jan 29, 2025 · 21 comments
Assignees
Labels
core Issues that should be addressed in Ray Core enhancement Request for new feature and/or capability gpu GPU related issues P0 Issues that should be fixed in short order

Comments

@stephanie-wang
Copy link
Contributor

Description

Ray currently serializes torch.Tensors to the object store then deserializes using torch's default deserialization method. This can result in deserialization to the wrong device. Ideally, on deserialization, we should place the tensor directly on the correct device. Currently we do this in Ray Compiled Graphs but we could also support it for all Ray programs (although we cannot eliminate unnecessary copies).

Some questions to consider:

  • We need a way to disable the behavior, both globally and for individual tensors. If a task/actor returns a tensor whose device doesn't match its default device, we can disable the custom deserializer but we probably need another way to override the behavior as well. For example, if a CPU actor passes a CPU tensor to a GPU actor and we want the tensor to remain on the CPU.
  • We need a way to set the default torch.device context in Ray. In Compiled Graphs, this is currently done in the ChannelContext.

Example:

@ray.remote(num_gpus=1)
class Actor:
  def alloc(self):
    return torch.randn(1000, device="cuda")
  def read(self, tensor):
    assert tensor.device

# GPU 0
a = Actor.remote()
# GPU 1
b = Actor.remote()

t = a.alloc.remote()
# Return GPU 0.
ray.get(a.read.remote(t))
# Return GPU 1.
ray.get(b.read.remote(t))
# Driver has no GPUs. Return CPU.
ray.get(a.read.remote(t)).device

Use case

No response

@stephanie-wang stephanie-wang added enhancement Request for new feature and/or capability triage Needs triage (eg: priority, bug/not-bug, and owning component) core Issues that should be addressed in Ray Core labels Jan 29, 2025
@jjyao jjyao added P0 Issues that should be fixed in short order gpu GPU related issues and removed triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Jan 30, 2025
@edoakes
Copy link
Collaborator

edoakes commented Jan 30, 2025

@stephanie-wang can you explain more about your thinking re: "we cannot eliminate unnecessary copies" (you've thought about this a lot more than me).

One direction I was thinking is to use the storage._shared_cuda() API to get a CUDA IPC handle for the tensor, then pass it to the downstream actor and reconstruct the tensor pointing at the same device memory. The tricky part is keeping the reference to the tensor alive in the process that produces it, but we should be able to plug into the existing reference counting implementation for that.

@edoakes
Copy link
Collaborator

edoakes commented Jan 30, 2025

I think we should also introduce a warning when this happens as it's likely unexpected to the user. Something like "serializing GPU tensor foobar, a CPU copy will be made, to avoid this you can do ..."

@sven1977
Copy link
Contributor

Awesome! Thanks for opening this issue @stephanie-wang .

I think we should also introduce a warning when this happens as it's likely unexpected to the user. Something like "serializing GPU tensor foobar, a CPU copy will be made, to avoid this you can do ..."

Yes, I think the warning makes sense. When I tried it the first time with just one GPU, I thought the tensor would NOT be copied to CPU in between, b/c it magically came right out from the object store on the correct device. That's why my intuition was that direct GPU-tensor handover from actor to actor (who share the same GPU) was already implemented.

@stephanie-wang
Copy link
Contributor Author

The tricky part is keeping the reference to the tensor alive in the process that produces it, but we should be able to plug into the existing reference counting implementation for that.

That's basically the idea behind the proposed GPU support for Ray Core API :) But I want to avoid doing this without the associated API changes because it brings up questions of how we should manage the GPU data on the sending actor:

  • The user may modify the tensor after returning it from the task. In contrast, the object store approach will serialize an immutable copy before returning control to the actor. Not much we can do to get around this even with the new API but at least requiring the user to annotate functions would make them more aware of the problem.
  • We need to buffer GPU data on the sending and receiving actors. If we buffer too much data, we will run out of GPU memory.

I think it would be better to have a dumb, kind of slow, but fully reliable approach for the normal Ray Core API, and then we can improve on it with the GPU-native API. Anyway, it is probably a good idea to have both options for the future since the latter may take some time to stabilize.

@edoakes
Copy link
Collaborator

edoakes commented Jan 31, 2025

Agree with all of the above. I'd propose let's:

  1. Fix this with a minimal change for "dumb but correct"
  2. Add the IPC solution to the scope of the GPU<>GPU API proposal
  3. Once (2) is supported, add a warning message for (1) that points at the docs on how to do it.

@anmscale
Copy link
Contributor

anmscale commented Jan 31, 2025

Hey @stephanie-wang @edoakes, I am trying to understand the issue here. I ran the code (slightly modified) on a node with 4xA10G. I would like to clarify a few things below.

import ray
import torch
@ray.remote(num_gpus=1)
class Actor:
    def alloc(self, device):
        return torch.randn(1000, device=device)

    def read(self, tensor):
        # changed to print
        print(tensor.device)

# GPU 0
A = Actor.remote()
# GPU 1
B = Actor.remote()
t = A.alloc.remote('cuda')
ray.get(A.read.remote(t))

Output:
(Actor pid=49498) cuda:0

Explanation-1:

  1. Object created at actor A
  2. reference t to object sent to driver
  3. reference t sent to object A
  4. A can access the object which doesn't trigger any tensor transfer
ray.get(B.read.remote(t))

Output:
(Actor pid=49497) cuda:0

Explanation-2:

  1. reference t sent to object B
  2. B cannot access the object because it's not on GPU-1 memory, which triggers object transfer
  3. Object is sent over object store, and deserialized on GPU-1 memory
  4. One caveat here is that cuda:0 refers to GPU-1 because that's what this actors call it (a question on that below)
ray.get(t).device

Output:
device(type='cuda', index=0)

Explanation-3:
Object is retrieved on driver, but somehow it's still referring to the tensor on GPU-0 (unexpected) while this should have triggered transfer to CPU because the driver doesn't have any GPU.

Questions

  1. Can you please verify the steps I describe above?
  2. Can you please clarify the expected behavior vs. the current one?
  3. How can I make sure the second cuda:0 is actually referring to GPU-1 (B's gpu)

@stephanie-wang
Copy link
Contributor Author

When torch serializes the tensor, it will store the data plus the torch device, which in this case will be "cuda:0". Also, the 0 refers to an index into CUDA_VISIBLE_DEVICES, not the physical GPU 0. When an actor requests num_gpus, Ray will set the CUDA_VISIBLE_DEVICES env variable to a list of one or more physical GPU indices. So if B is allocated physical GPU 1, "cuda:0" will actually refer to GPU 1.

The driver has no physical GPUs allocated from Ray's perspective, and it will just have the default GPU visibility, which is all GPUs. Therefore, from its perspective, "cuda:0" will refer to GPU 0.

Hey @stephanie-wang @edoakes, I am trying to understand the issue here. I ran the code (slightly modified) on a node with 4xA10G. I would like to clarify a few things below.

import ray
import torch
@ray.remote(num_gpus=1)
class Actor:
def alloc(self, device):
return torch.randn(1000, device=device)

def read(self, tensor):
    # changed to print
    print(tensor.device)

GPU 0

A = Actor.remote()

GPU 1

B = Actor.remote()

t = A.alloc.remote('cuda')
ray.get(A.read.remote(t))

Output: (Actor pid=49498) cuda:0

Explanation-1:

1. Object created at actor `A`

2. reference `t` to object sent to driver

3. reference `t` sent to object `A`

4. `A` can access the object which doesn't trigger any tensor transfer

Not exactly. A copy will be stored in the object store. So there is a GPU -> CPU -> GPU copy happening here.

ray.get(B.read.remote(t))

Output: (Actor pid=49497) cuda:0

Explanation-2:

1. reference `t` sent to object `B`

2. `B` cannot access the object because it's not on GPU-1 memory, which triggers object transfer

That's not exactly right, the object transfer will get triggered no matter what. Because right now we always just serialize the tensor into the object store.

3. Object is sent over object store, and deserialized on GPU-1 memory

4. One caveat here is that `cuda:0` refers to GPU-1 because that's what this actors call it (a question on that below)

ray.get(t).device

Output: device(type='cuda', index=0)

Explanation-3: Object is retrieved on driver, but somehow it's still referring to the tensor on GPU-0 (unexpected) while this should have triggered transfer to CPU because the driver doesn't have any GPU.

Questions

1. Can you please verify the steps I describe above?

2. Can you please clarify the expected behavior vs. the current one?

3. How can I make sure the second `cuda:0` is actually referring to GPU-1 (`B`'s gpu)

You can check CUDA_VISIBLE_DEVICES on the actor.

@anmscale
Copy link
Contributor

Thanks for the detailed response @stephanie-wang! Can we scope out exactly what needs to be fixed in this issue?

@edoakes
Copy link
Collaborator

edoakes commented Jan 31, 2025

Thanks for the detailed response @stephanie-wang! Can we scope out exactly what needs to be fixed in this issue?

For the minimal fix, we want to modify the behavior when returning a tensor to another actor/driver on the same node, we'll remap the device accordingly.

As @stephanie-wang mentioned, we should have a way to turn this behavior off globally (an internal feature flag) as well as for an individual tensor/actor call. As I mentioned above we also should likely log some kind of useful warning message when it happens because it will cause a performance penalty that can be fixed in the future using the GPU<>GPU API.

@anmscale
Copy link
Contributor

anmscale commented Feb 3, 2025

I investigated the current behavior, and noticed an issue where a GPU-tensor is read by a CPU-only worker, in which case the default torch.Tensor serialization fails because it tries to reconstruct the tensor on a GPU device. In the following code I modify the serialization torch tensors to solve this issue, and I think this should be implemented in Ray's serialization code.

import ray
import torch
import io

class TensorData:
    def __init__(self, tensor: torch.Tensor):
        self.tensor = tensor
        self.device = str(tensor.device)

    def __reduce_ex__(self, protocol):
        buffer = io.BytesIO()
        torch.save(self.tensor, buffer)
        return (self.__class__._rebuild,
                (buffer.getvalue(), self.device))

    @staticmethod
    def _rebuild(tensor_bytes: bytes, original_device: str):
        buffer = io.BytesIO(tensor_bytes)
        
        # Default to CPU if no CUDA available
        if not torch.cuda.is_available() and "cuda" in original_device:
            print(f"Warning: {original_device} requested but CUDA is unavailable. Using CPU instead.")
            target_device = "cpu"
        else:
           target_device = original_device
        
        tensor = torch.load(buffer, map_location=target_device)
        return TensorData(tensor.to(target_device))

@ray.remote(num_gpus=1)
class Actor:
    def __init__(self):
        self.device = f"cuda" if ray.get_gpu_ids() else "cpu"
        
    def alloc(self):
        tensor_data = TensorData(torch.randn(1000, device=self.device))
        print(f"Allocated TensorData Device: {tensor_data.device}, ID: {id(tensor_data)}")
        return tensor_data

    def read(self, tensor_data: TensorData):
        return f"Read TensorData Device: {tensor_data.device}, ID: {id(tensor_data)}"

ray.init()

a = Actor.remote()  # GPU 0
b = Actor.remote()  # GPU 1
c = Actor.options(num_gpus=0).remote()  # CPU

t_cuda_0 = a.alloc.remote()
t_cuda_1 = b.alloc.remote()

print("1.", ray.get(a.read.remote(t_cuda_0)))  # Should be on GPU 0
print("2.", ray.get(b.read.remote(t_cuda_0)))  # Should move to GPU 1
print("3.", ray.get(c.read.remote(t_cuda_0)))  # Should move to CPU if on a CPU-only node
print("4.", ray.get(a.read.remote(t_cuda_1)))  # Should move to GPU 0

# Driver has no GPU, so it should be on CPU
tensor_from_driver = ray.get(t_cuda_0)
print(f"5. Tensor on driver - Device: {tensor_from_driver.device}, ID: {id(tensor_from_driver)}")

Note here that I run the code above on a CPU-only headnode (with a worker node with 2 GPUs).

Output:

(Actor pid=30322, ip=10.0.0.193) Allocated TensorData Device: cuda:0, ID: 125660037049648
1. Read TensorData Device: cuda:0, ID: 125660037300960
(Actor pid=30321, ip=10.0.0.193) Allocated TensorData Device: cuda:0, ID: 133275103329584
2. Read TensorData Device: cuda:0, ID: 133275052724960
3. Read TensorData Device: cpu, ID: 132562613214416
4. Read TensorData Device: cuda:0, ID: 125595202840800
Warning: cuda:0 requested but CUDA is unavailable. Using CPU instead.
5. Tensor on driver - Device: cpu, ID: 131242725584272
(Actor pid=30323, ip=10.0.0.193) Warning: cuda:0 requested but CUDA is unavailable. Using CPU instead.

Questions

  1. Do you agree this the desired behavior for GPU-to-CPU?
  2. Above when a tensor created on one GPU then read by another one, we silently transfer the tensor to cuda:0. This works because each worker sees its device as cuda:0, but do we want to warn the user this implicit transfer is happening?

@stephanie-wang
Copy link
Contributor Author

Yes, please also take a look at how this is currently done in Ray Compiled Graphs: code. Ideally we should reuse this code to keep the codepaths unified.

@anmscale
Copy link
Contributor

anmscale commented Feb 4, 2025

There's a main difference between Ray Compiled Graphs and Ray Core: in RCG we assume a 1:1 mapping between and actor and a GPU, which means that we don't have an issue where a tensor on cuda:1 in A is transferred to B with only cuda:0.
I wrote the code below to handle all cases I can think of: GPU-to-GPU and GPU-to-CPU.

import ray
import torch
import io
from ray.util.placement_group import placement_group


class TensorData:
    def __init__(self, tensor: torch.Tensor):
        self.tensor = tensor
        self.device = str(tensor.device)

    def __reduce_ex__(self, protocol):
        buffer = io.BytesIO()
        torch.save(self.tensor, buffer)
        return (self.__class__._rebuild, (buffer.getvalue(), self.device))

    @staticmethod
    def _rebuild(tensor_bytes: bytes, original_device: str):
        buffer = io.BytesIO(tensor_bytes)
        node_id_short = ray.get_runtime_context().get_node_id()[:8]
        gpu_ids = [str(id) for id in ray.get_gpu_ids()]
        
        if torch.cuda.is_available():
            device_id = original_device.split(":")[1]
            if device_id in gpu_ids:
                target_device = original_device
            else:
                print(f"({node_id_short}) Warning: {original_device} requested but is not available. Using cuda:0 instead.")
                target_device = "cuda:0"
        else:
            if "cuda" in original_device:
                print(
                    f"({node_id_short}) Warning: {original_device} requested but CUDA is unavailable. Using CPU instead."
                )
            target_device = "cpu"
        
        tensor = torch.load(buffer, map_location=target_device)
        return TensorData(tensor.to(target_device))


@ray.remote
class Actor:
    def __init__(self):
        self.gpu_ids = [str(id) for id in ray.get_gpu_ids()]
        self.node_id = ray.get_runtime_context().get_node_id()
        # Take first 8 characters of the node_id
        print(f"({self.node_id[:8]}) GPU IDs: {self.gpu_ids}")

    def alloc(self, device: str):
        if device == "cuda":
            device = f"cuda:{self.gpu_ids[0]}"
        elif "cuda" in device:
            assert device.split(":")[1] in self.gpu_ids, f"{device.split(':')[1]} not in {self.gpu_ids}"
        else:
            device = "cpu"

        tensor_data = TensorData(torch.randn(1000, device=device))
        print(
            f"({self.node_id[:8]}) Alloc Tensor ID: {id(tensor_data)}, on Device: {tensor_data.device}"
        )
        return tensor_data

    def read(self, tensor_data: TensorData):
        return f"({self.node_id[:8]}) Read Tensor ID: {id(tensor_data)}, on Device: {tensor_data.device}"


ray.init()

pg1 = placement_group([{"GPU": 4, "CPU": 1}])
pg2 = placement_group([{"GPU": 4, "CPU": 1}])
ray.get([pg1.ready(), pg2.ready()])

a = Actor.options(num_gpus=2, placement_group=pg1).remote()
b = Actor.options(num_gpus=1, placement_group=pg2).remote()

a_cuda_0 = a.alloc.remote('cuda:0')
a_cuda_1 = a.alloc.remote('cuda:1')
b_cuda_0 = b.alloc.remote('cuda:0')


print("1.", ray.get(b.read.remote(a_cuda_0)))  # Should move to b_cuda:0
print("2.", ray.get(a.read.remote(b_cuda_0)))  # Should move to a_cuda:0
print("3.", ray.get(b.read.remote(a_cuda_1)))  # Should move to b_cuda:0 but give a warning

# Driver has no GPU, so it should be on CPU
a_cpu_1= ray.get(a_cuda_1)
print(f"4. Driver Tensor ID: {id(a_cpu_1)}, on Device: {a_cpu_1.device}")

ray.shutdown()

Output

(base) ray@ip-10-0-3-222:~/default$ RAY_DEDUP_LOGS=0  python test.py 
2025-02-04 09:27:24,727 INFO worker.py:1654 -- Connecting to existing Ray cluster at address: 10.0.3.222:6379...
2025-02-04 09:27:24,739 INFO worker.py:1832 -- Connected to Ray cluster. View the dashboard at https://session-81vks6vwzgp8sdg337ezjukqsi.i.anyscaleuserdata.com 
2025-02-04 09:27:25,440 INFO packaging.py:366 -- Pushing file package 'gcs://_ray_pkg_10053df94a2b068712ca39b8cf390add78a6206a.zip' (252.08MiB) to Ray cluster...
2025-02-04 09:27:26,946 INFO packaging.py:379 -- Successfully pushed file package 'gcs://_ray_pkg_10053df94a2b068712ca39b8cf390add78a6206a.zip'.
(Actor pid=106783, ip=10.0.19.50) (2b9b19b6) GPU IDs: ['0', '1']
(Actor pid=104577, ip=10.0.5.145) (dde1cbf0) GPU IDs: ['0']
(Actor pid=104577, ip=10.0.5.145) (dde1cbf0) Alloc Tensor ID: 135402358035648, on Device: cuda:0
1. (dde1cbf0) Read Tensor ID: 135468579499168, on Device: cuda:0
(Actor pid=106783, ip=10.0.19.50) (2b9b19b6) Alloc Tensor ID: 134773861504288, on Device: cuda:0
2. (2b9b19b6) Read Tensor ID: 134777303309472, on Device: cuda:0
3. (dde1cbf0) Read Tensor ID: 135402339162576, on Device: cuda:0
(741a44f2) Warning: cuda:1 requested but CUDA is unavailable. Using CPU instead.
4. Driver Tensor ID: 125431708100112, on Device: cpu
(Actor pid=106783, ip=10.0.19.50) (2b9b19b6) Alloc Tensor ID: 134777303141536, on Device: cuda:1

That being said, the RCG code seems to handle different data types and converts tensors to numpy, so potentially this solves the zero-copy issue we have with torch tensors @edoakes?

@anmscale
Copy link
Contributor

anmscale commented Feb 10, 2025

@stephanie-wang The following code reuses the compiled graph's serialization/deserialization logic:

from ray.util import register_serializer
from ray.util.placement_group import placement_group
from ray.experimental.channel.serialization_context import _SerializationContext

# Define a custom serializer that only relies on the built-in methods.
def my_tensor_serializer(tensor: torch.Tensor):
    ctx = _SerializationContext.get_current()
    return ctx.serialize_numpy(tensor)

def my_tensor_deserializer(serialized):
    ctx = _SerializationContext.get_current()
    return ctx.deserialize_numpy(serialized)

# Register the custom serializer for torch.Tensor.
register_serializer(
    torch.Tensor,
    serializer=my_tensor_serializer,
    deserializer=my_tensor_deserializer,
)

However, I don't think it handles all cases for GPU tensors. Let me summarize below what theses are:

  1. Source tensor is on CPU ⇒ Target tensor will be on CPU
  2. If source tensor is on GPU but target device doesn’t have GPU ⇒ give a warning to the user and deserialize target tensor on CPU (current code throws an exception)
  3. If source tensor is on GPU and target device has a GPU:
    a) If target device doesn’t have the same cuda device id ⇒ give a warning and deserialize target tensor on cuda:0 (current code throws an exception)
    b) Otherwise, deserialize target tensor on the same device id as the source (e.g. cuda:1 on cuda:1)

Do you agree we should print a warning instead of throwing an exception?

@edoakes
Copy link
Collaborator

edoakes commented Feb 10, 2025

Do you agree we should print a warning instead of throwing an exception?

I don't think we have a choice here -- we have to print a warning rather than raise, else it would be a breaking behavior change.

We could decide to add a warning that the behavior will change and then change it to raise an exception in the future. However I think for the base Ray API it should be OK to do the working-but-slow thing. We can always revisit it in the future.

@edoakes
Copy link
Collaborator

edoakes commented Feb 10, 2025

I want to point out one other thing -- from the looks of it, ctx.serialize_numpy(tensor) will serialize as a numpy array to enable zero-copy deserialization. We cannot drop-in replace the serializer for torch tensors to do this because this would also be a breaking behavior change.

IMO we should separately add another warning here for torch tensors that are returned on CPU (maybe over the Ray inlined object size threshold?) that tells users they aren't zero-copy deserialized by default and gives them an API to do that easily (like ray.util.ZeroCopyTorchTensor(t) or similar). I was already planning to ask @israbbani to add this in the near future.

@stephanie-wang
Copy link
Contributor Author

Yes, probably we want to add a zero_copy flag to the deserialization function.

But I am a bit confused why the current code throws an exception. AFAIK, the code in the snippet is supposed to deserialize the tensor to the correct local device, so I don't think it should throw an exception in this case:

If source tensor is on GPU but target device doesn’t have GPU ⇒ give a warning to the user and deserialize target tensor on CPU (current code throws an exception)

@anmscale
Copy link
Contributor

anmscale commented Feb 10, 2025

But I am a bit confused why the current code throws an exception. AFAIK, the code in the snippet is supposed to deserialize the tensor to the correct local device, so I don't think it should throw an exception in this case:

Let me clarify, currently Ray tries to deserialize on the exact same device, and will throw a runtime error if it doesn't match. For example, in your original code snippet (which I slight modified here) only the GPU-to-CPU case fails with an error. The GPU-to-GPU works because both of these devices are called cuda:0 on each actor respectively. See comments below:

@ray.remote(num_gpus=1)
class Actor:
  def alloc(self):
    return torch.randn(1000, device="cuda")
  def read(self, tensor):
    return tensor.device  # return instead of assert

# GPU 0
a = Actor.remote()
# GPU 1
b = Actor.remote()

t = a.alloc.remote()
# Return GPU 0 (cuda:0)
print(ray.get(a.read.remote(t)))    
# Return GPU 1 (cuda:0)
print(ray.get(b.read.remote(t)))
# Driver has no GPUs. 
# raises RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False.
print(ray.get(t).device)

In my code snippet, I test another case where 1 actor has 2 GPUs and the other has only 1 GPU. So if we try to transfer a tensor located on cuda:1 to an actor with 1 GPU it will raise the following exception:

RuntimeError: Attempting to deserialize object on CUDA device 1 but torch.cuda.device_count() is 1.

PS: I am running the code on a CPU-only headnode, and a cluster with 2 worker nodes each with GPUs

@anmscale
Copy link
Contributor

I don't think we have a choice here -- we have to print a warning rather than raise, else it would be a breaking behavior change.

Hmm, not really. We currently raise an error. The change would be to print a warning instead of raising an error.

@edoakes
Copy link
Collaborator

edoakes commented Feb 10, 2025

I don't think we have a choice here -- we have to print a warning rather than raise, else it would be a breaking behavior change.

Hmm, not really. We currently raise an error. The change would be to print a warning instead of raising an error.

We raise an error in the compiled graphs path but not in the regular Ray API, right? (unless I'm misunderstanding)

@anmscale
Copy link
Contributor

I don't think we have a choice here -- we have to print a warning rather than raise, else it would be a breaking behavior change.

Hmm, not really. We currently raise an error. The change would be to print a warning instead of raising an error.

We raise an error in the compiled graphs path but not in the regular Ray API, right? (unless I'm misunderstanding)

I meant in Ray API. Please run the lastest code snippet on a CPU headnode and at least 1 GPU worker node.

@edoakes
Copy link
Collaborator

edoakes commented Feb 10, 2025

Ok I understand now from the latest code sample. I am good w/ either:

  1. no longer erroring and now printing a warning and implicitly copying to CPU instead.
  2. still erroring, but improving the error message and telling people to explicitly use .to("cpu") (or ) when returning the tensor.

On principle I prefer (2), but (1) is more in line with the existing implicit GPU->CPU->GPU copying behavior, so would suggest we go with that.

edoakes added a commit that referenced this issue Mar 9, 2025
Changing the default tensor serialization in compiled graphs. Also added
a comprehensive set of unit tests covering cases for torch.Tensor
serialization in both Ray core and compiled graphs.

## Related issue number

Related to issues:
  - #50134
  - #50452
Also related to #47742

---------

Signed-off-by: Amjad Almahairi <anm@anyscale.com>
Co-authored-by: Edward Oakes <ed.nmi.oakes@gmail.com>
park12sj pushed a commit to park12sj/ray that referenced this issue Mar 18, 2025
…project#50778)

Changing the default tensor serialization in compiled graphs. Also added
a comprehensive set of unit tests covering cases for torch.Tensor
serialization in both Ray core and compiled graphs.

## Related issue number

Related to issues:
  - ray-project#50134
  - ray-project#50452
Also related to ray-project#47742

---------

Signed-off-by: Amjad Almahairi <anm@anyscale.com>
Co-authored-by: Edward Oakes <ed.nmi.oakes@gmail.com>
jaychia pushed a commit to jaychia/ray that referenced this issue Mar 19, 2025
…project#50778)

Changing the default tensor serialization in compiled graphs. Also added
a comprehensive set of unit tests covering cases for torch.Tensor
serialization in both Ray core and compiled graphs.

## Related issue number

Related to issues:
  - ray-project#50134
  - ray-project#50452
Also related to ray-project#47742

---------

Signed-off-by: Amjad Almahairi <anm@anyscale.com>
Co-authored-by: Edward Oakes <ed.nmi.oakes@gmail.com>
Signed-off-by: Jay Chia <17691182+jaychia@users.noreply.github.com>
jaychia pushed a commit to jaychia/ray that referenced this issue Mar 19, 2025
…project#50778)

Changing the default tensor serialization in compiled graphs. Also added
a comprehensive set of unit tests covering cases for torch.Tensor
serialization in both Ray core and compiled graphs.

## Related issue number

Related to issues:
  - ray-project#50134
  - ray-project#50452
Also related to ray-project#47742

---------

Signed-off-by: Amjad Almahairi <anm@anyscale.com>
Co-authored-by: Edward Oakes <ed.nmi.oakes@gmail.com>
Signed-off-by: Jay Chia <17691182+jaychia@users.noreply.github.com>
Drice1999 pushed a commit to Drice1999/ray that referenced this issue Mar 23, 2025
…project#50778)

Changing the default tensor serialization in compiled graphs. Also added
a comprehensive set of unit tests covering cases for torch.Tensor
serialization in both Ray core and compiled graphs.

## Related issue number

Related to issues:
  - ray-project#50134
  - ray-project#50452
Also related to ray-project#47742

---------

Signed-off-by: Amjad Almahairi <anm@anyscale.com>
Co-authored-by: Edward Oakes <ed.nmi.oakes@gmail.com>
dhakshin32 pushed a commit to dhakshin32/ray that referenced this issue Mar 27, 2025
…project#50778)

Changing the default tensor serialization in compiled graphs. Also added
a comprehensive set of unit tests covering cases for torch.Tensor
serialization in both Ray core and compiled graphs.

## Related issue number

Related to issues:
  - ray-project#50134
  - ray-project#50452
Also related to ray-project#47742

---------

Signed-off-by: Amjad Almahairi <anm@anyscale.com>
Co-authored-by: Edward Oakes <ed.nmi.oakes@gmail.com>
Signed-off-by: Dhakshin Suriakannu <d_suriakannu@apple.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
core Issues that should be addressed in Ray Core enhancement Request for new feature and/or capability gpu GPU related issues P0 Issues that should be fixed in short order
Projects
None yet
Development

No branches or pull requests

5 participants