Skip to content

Commit 5c17111

Browse files
authored
[Fix][Runtime][RPC] Fix remote tensor handle cleanup for RPC return values (#19410)
This PR fixes RPC tensor cleanup for tensors returned from remote calls. When a remote function returns a `Tensor`, the RPC protocol sends both: - the remote backing data pointer - the remote tensor object handle used for deletion Previously, `TensorFromRemoteOpaqueHandle` stored only the data pointer and called `FreeHandle(space_.data)` during local tensor destruction. That is incorrect: `FreeHandle` is meant for remote object handles, not raw data-space pointers. This could lead to invalid cleanup behavior and crashes during teardown in RPC workflows, including the cross-compilation + RPC tutorial scenario reported in #18923. This change: - stores the remote tensor object handle in `RemoteSpace` - calls `FreeHandle(remote_tensor_handle)` during tensor destruction - keeps cleanup fault-tolerant if the remote connection is already closed
1 parent e0e9315 commit 5c17111

5 files changed

Lines changed: 150 additions & 12 deletions

File tree

python/tvm/rpc/testing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import numpy as np
2222

2323
import tvm
24+
import tvm.testing
2425

2526

2627
# RPC test functions to be registered for unit-tests purposes

src/runtime/rpc/rpc_module.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,23 @@ Tensor TensorFromRemoteOpaqueHandle(std::shared_ptr<RPCSession> sess, void* hand
6262
// the pointer to the remote space is passed in as the data pointer
6363
tensor->data = &(space_);
6464
}
65-
void FreeData(DLTensor* tensor) { space_.sess->FreeHandle(space_.data); }
65+
void FreeData(DLTensor* tensor) {
66+
if (space_.object_handle != nullptr) {
67+
try {
68+
space_.sess->FreeHandle(space_.object_handle);
69+
} catch (const Error& e) {
70+
// fault tolerance to remote close
71+
}
72+
}
73+
}
6674

6775
private:
6876
RemoteSpace space_;
6977
};
7078
RemoteSpace space;
7179
space.sess = sess;
7280
space.data = handle;
81+
space.object_handle = remote_tensor_handle;
7382
ffi::Shape shape(template_tensor->shape, template_tensor->shape + template_tensor->ndim);
7483
return Tensor::FromNDAlloc(RemoteSpaceAlloc(space), shape, template_tensor->dtype, dev);
7584
}

src/runtime/rpc/rpc_session.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,14 @@ struct RemoteSpace {
281281
void* data;
282282
/*! \brief Reference to the underlying RPC session. */
283283
std::shared_ptr<RPCSession> sess;
284+
/*!
285+
* \brief The remote Tensor object handle, if this RemoteSpace wraps a returned Tensor.
286+
*
287+
* Returned RPC Tensors carry both the backing data pointer and a Tensor object handle. The
288+
* object handle must be released with FreeHandle so the remote side can correctly decrement the
289+
* Tensor refcount and free the backing storage when it is no longer shared.
290+
*/
291+
void* object_handle{nullptr};
284292
};
285293

286294
/*!
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include <gtest/gtest.h>
21+
#include <tvm/runtime/device_api.h>
22+
#include <tvm/runtime/logging.h>
23+
#include <tvm/runtime/tensor.h>
24+
25+
#include <memory>
26+
#include <string>
27+
28+
#include "../../../src/runtime/rpc/rpc_session.h"
29+
30+
namespace tvm {
31+
namespace runtime {
32+
33+
Tensor TensorFromRemoteOpaqueHandle(std::shared_ptr<RPCSession> sess, void* handle,
34+
DLTensor* template_tensor, Device dev,
35+
void* remote_tensor_handle);
36+
37+
namespace {
38+
39+
class RecordingRPCSession final : public RPCSession {
40+
public:
41+
PackedFuncHandle GetFunction(const std::string& name) final { return nullptr; }
42+
43+
void CallFunc(PackedFuncHandle func, ffi::PackedArgs args,
44+
const FEncodeReturn& fencode_return) final {}
45+
46+
void CopyToRemote(void* local_from_bytes, DLTensor* remote_to, uint64_t nbytes) final {}
47+
48+
void CopyFromRemote(DLTensor* remote_from, void* local_to_bytes, uint64_t nbytes) final {}
49+
50+
void FreeHandle(void* handle) final {
51+
++free_handle_calls;
52+
last_freed_handle = handle;
53+
if (throw_on_free) {
54+
TVM_FFI_THROW(InternalError) << "simulated remote close";
55+
}
56+
}
57+
58+
DeviceAPI* GetDeviceAPI(Device dev, bool allow_missing = false) final { return nullptr; }
59+
60+
bool IsLocalSession() const final { return false; }
61+
62+
int free_handle_calls{0};
63+
void* last_freed_handle{nullptr};
64+
bool throw_on_free{false};
65+
};
66+
67+
DLTensor MakeTemplateTensor() {
68+
static int64_t shape[1] = {4};
69+
DLTensor tensor{};
70+
tensor.data = nullptr;
71+
tensor.device = Device{kDLCPU, 0};
72+
tensor.ndim = 1;
73+
tensor.dtype = DataType::Float(32);
74+
tensor.shape = shape;
75+
tensor.strides = nullptr;
76+
tensor.byte_offset = 0;
77+
return tensor;
78+
}
79+
80+
Device MakeRemoteDevice(const std::shared_ptr<RPCSession>& sess) {
81+
return AddRPCSessionMask(Device{kDLCPU, 0}, sess->table_index());
82+
}
83+
84+
} // namespace
85+
86+
TEST(RPCTensorTest, ReturnedTensorFreesRemoteTensorHandle) {
87+
auto sess = std::make_shared<RecordingRPCSession>();
88+
DLTensor template_tensor = MakeTemplateTensor();
89+
void* data_handle = reinterpret_cast<void*>(0x1234);
90+
void* tensor_handle = reinterpret_cast<void*>(0x5678);
91+
92+
{
93+
auto tensor = TensorFromRemoteOpaqueHandle(sess, data_handle, &template_tensor,
94+
MakeRemoteDevice(sess), tensor_handle);
95+
EXPECT_NE(tensor.defined(), false);
96+
}
97+
98+
EXPECT_EQ(sess->free_handle_calls, 1);
99+
EXPECT_EQ(sess->last_freed_handle, tensor_handle);
100+
EXPECT_NE(sess->last_freed_handle, data_handle);
101+
}
102+
103+
TEST(RPCTensorTest, ReturnedTensorDestructorIgnoresFreeHandleErrors) {
104+
auto sess = std::make_shared<RecordingRPCSession>();
105+
sess->throw_on_free = true;
106+
DLTensor template_tensor = MakeTemplateTensor();
107+
void* data_handle = reinterpret_cast<void*>(0x1234);
108+
void* tensor_handle = reinterpret_cast<void*>(0x5678);
109+
110+
EXPECT_NO_THROW({
111+
auto tensor = TensorFromRemoteOpaqueHandle(sess, data_handle, &template_tensor,
112+
MakeRemoteDevice(sess), tensor_handle);
113+
});
114+
EXPECT_EQ(sess->free_handle_calls, 1);
115+
EXPECT_EQ(sess->last_freed_handle, tensor_handle);
116+
}
117+
118+
} // namespace runtime
119+
} // namespace tvm

tests/python/runtime/test_runtime_rpc.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import sys
2323
import tempfile
2424
import time
25+
import gc
2526

2627
import numpy as np
2728
import pytest
@@ -386,22 +387,22 @@ def check_error_handling():
386387

387388
@tvm.testing.requires_rpc
388389
def test_rpc_return_tensor():
389-
# start server
390-
server = rpc.Server(key="x1")
391-
client = rpc.connect("127.0.0.1", server.port, key="x1")
392-
393-
m = client.get_function("rpc.test.remote_return_nd")
394-
get_arr = m("get_arr")
395-
ref_count = m("ref_count")
396-
get_elem = m("get_elem")
397-
get_arr_elem = m("get_arr_elem")
398-
399-
# array test
400390
def run_arr_test():
391+
server = rpc.Server(key="x1")
392+
client = rpc.connect("127.0.0.1", server.port, key="x1")
393+
m = client.get_function("rpc.test.remote_return_nd")
394+
get_arr = m("get_arr")
395+
get_elem = m("get_elem")
396+
get_arr_elem = m("get_arr_elem")
397+
401398
arr = get_arr()
402399
assert get_elem(0) == 0.0
403400
assert get_arr_elem(arr, 0) == 0.0
404401

402+
del arr
403+
gc.collect()
404+
assert get_elem(0) == 0.0
405+
405406
run_arr_test()
406407

407408

0 commit comments

Comments
 (0)