Skip to content

Commit c17e6e6

Browse files
penpornktensorflower-gardener
authored andcommitted
[xla:cpu] Add FFI custom call thunk runtime support to PJRT CPU client.
Also add a benchmark that uses PJRT CPU Client. PiperOrigin-RevId: 647916282
1 parent 2c51dd3 commit c17e6e6

File tree

5 files changed

+99
-5
lines changed

5 files changed

+99
-5
lines changed

third_party/xla/xla/pjrt/cpu/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ cc_library(
163163
"//xla/pjrt:semaphore",
164164
"//xla/pjrt:transpose",
165165
"//xla/pjrt:utils",
166-
"//xla/pjrt/distributed:key_value_store_interface",
167166
"//xla/service:buffer_assignment",
168167
"//xla/service:compiler",
169168
"//xla/service:computation_placer_hdr",

third_party/xla/xla/pjrt/cpu/cpu_client.cc

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,10 +1588,18 @@ absl::StatusOr<PjRtLoadedExecutable::Result> TfrtCpuExecutable::ExecuteHelper(
15881588
cpu::Thunk::CollectiveExecuteParams collective_params,
15891589
cpu::Thunk::CollectiveExecuteParams::Create(&run_options));
15901590

1591+
// TODO(penporn): Consolidate with other thunk parameter set up calls.
1592+
TF_ASSIGN_OR_RETURN(
1593+
cpu::Thunk::CustomCallExecuteParams custom_call_execute_params,
1594+
cpu::Thunk::CustomCallExecuteParams::Create(&run_options));
1595+
15911596
cpu::Thunk::ExecuteParams execute_params = {
1592-
&cpu_executable->host_kernels(), &allocations,
1597+
&cpu_executable->host_kernels(),
1598+
&allocations,
15931599
cpu::runtime::GetXfeedManager(run_options.device_ordinal()),
1594-
run_options.intra_op_thread_pool(), &collective_params};
1600+
run_options.intra_op_thread_pool(),
1601+
&collective_params,
1602+
&custom_call_execute_params};
15951603

15961604
auto execute_event = cpu_executable->thunks().Execute(
15971605
execute_params, [&](cpu::ThunkExecutor::Task task) {
@@ -1714,11 +1722,18 @@ absl::StatusOr<PjRtLoadedExecutable::Result> TfrtCpuExecutable::ExecuteHelper(
17141722
collective_params =
17151723
cpu::Thunk::CollectiveExecuteParams::Create(&run_options);
17161724

1725+
absl::StatusOr<cpu::Thunk::CustomCallExecuteParams>
1726+
custom_call_params =
1727+
cpu::Thunk::CustomCallExecuteParams::Create(&run_options);
1728+
17171729
if (collective_params.ok()) {
17181730
cpu::Thunk::ExecuteParams execute_params = {
1719-
&cpu_executable->host_kernels(), &allocations,
1731+
&cpu_executable->host_kernels(),
1732+
&allocations,
17201733
cpu::runtime::GetXfeedManager(run_options.device_ordinal()),
1721-
run_options.intra_op_thread_pool(), &*collective_params};
1734+
run_options.intra_op_thread_pool(),
1735+
&*collective_params,
1736+
&*custom_call_params};
17221737

17231738
auto execute_event = cpu_executable->thunks().Execute(
17241739
execute_params, [&](cpu::ThunkExecutor::Task task) {

third_party/xla/xla/service/cpu/benchmarks/BUILD

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,23 @@ xla_cc_test(
148148
],
149149
)
150150

151+
xla_cc_test(
152+
name = "custom_call_benchmark_test",
153+
srcs = ["custom_call_benchmark_test.cc"],
154+
deps = [
155+
":hlo_benchmark_runner",
156+
"//xla/ffi",
157+
"//xla/ffi:ffi_api",
158+
"//xla/tests:hlo_test_base",
159+
"//xla/tests:test_macros_header",
160+
"@com_google_absl//absl/status",
161+
"@com_google_absl//absl/types:span",
162+
"@local_tsl//tsl/platform:logging",
163+
"@local_tsl//tsl/platform:test_benchmark",
164+
"@local_tsl//tsl/platform:test_main",
165+
],
166+
)
167+
151168
xla_cc_test(
152169
name = "gather_benchmark_test",
153170
srcs = ["gather_benchmark_test.cc"],
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/* Copyright 2024 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <string_view>
17+
18+
#include "absl/status/status.h"
19+
#include "absl/types/span.h"
20+
#include "xla/ffi/ffi.h"
21+
#include "xla/ffi/ffi_api.h"
22+
#include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h"
23+
#include "tsl/platform/logging.h"
24+
#include "tsl/platform/test_benchmark.h"
25+
26+
namespace xla::cpu {
27+
namespace {
28+
29+
static absl::Status Minimal(
30+
ffi::Result<ffi::BufferR0<PrimitiveType::F32>> unused) {
31+
return absl::OkStatus();
32+
}
33+
34+
XLA_FFI_DEFINE_HANDLER(
35+
kMinimal, Minimal,
36+
ffi::Ffi::Bind()
37+
.Ret<ffi::BufferR0<PrimitiveType::F32>>()); // Unused out buffer
38+
39+
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_bm$$minimal", "Host",
40+
kMinimal);
41+
42+
static void BM_CustomCall_Minimal(benchmark::State& state) {
43+
const char* kModuleStr = R"(
44+
HloModule module
45+
46+
ENTRY custom_call {
47+
ROOT custom-call = f32[] custom-call(),
48+
custom_call_target="__xla_bm$$minimal",
49+
api_version=API_VERSION_TYPED_FFI
50+
}
51+
)";
52+
CHECK_OK(RunHloBenchmark(state, kModuleStr, /*args=*/{},
53+
/*replacements=*/{}));
54+
state.SetItemsProcessed(state.iterations());
55+
}
56+
57+
BENCHMARK(BM_CustomCall_Minimal)->MeasureProcessCPUTime();
58+
59+
} // namespace
60+
} // namespace xla::cpu

third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> CustomCallThunk::CallTypedFFI(
9090
"No registered implementation for FFI custom call to %s for Host",
9191
target_name_);
9292
}
93+
if (params.custom_call_params == nullptr) {
94+
return Internal("CustomCallExecuteParams cannot be nullptr.");
95+
}
9396

9497
// Build the FFI call frame.
9598
ffi::CallFrameBuilder builder(

0 commit comments

Comments
 (0)