Skip to content

[OpenMP][OMPT] Fix device identifier collision during callbacks #65595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions openmp/libomptarget/include/omptargetplugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ int32_t __tgt_rtl_data_notify_mapped(int32_t ID, void *HstPtr, int64_t Size);
// host address \p HstPtr and \p Size bytes.
int32_t __tgt_rtl_data_notify_unmapped(int32_t ID, void *HstPtr);

// Set the global device identifier offset, such that the plugin may determine a
// unique device number.
int32_t __tgt_rtl_set_device_offset(int32_t DeviceIdOffset);

#ifdef __cplusplus
}
#endif
Expand Down
2 changes: 2 additions & 0 deletions openmp/libomptarget/include/rtl.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ struct RTLInfoTy {
typedef int32_t(data_unlock_ty)(int32_t, void *);
typedef int32_t(data_notify_mapped_ty)(int32_t, void *, int64_t);
typedef int32_t(data_notify_unmapped_ty)(int32_t, void *);
typedef int32_t(set_device_offset_ty)(int32_t);
typedef int32_t(activate_record_replay_ty)(int32_t, uint64_t, bool, bool);

int32_t Idx = -1; // RTL index, index is the number of devices
Expand Down Expand Up @@ -125,6 +126,7 @@ struct RTLInfoTy {
data_unlock_ty *data_unlock = nullptr;
data_notify_mapped_ty *data_notify_mapped = nullptr;
data_notify_unmapped_ty *data_notify_unmapped = nullptr;
set_device_offset_ty *set_device_offset = nullptr;
activate_record_replay_ty *activate_record_replay = nullptr;

// Are there images associated with this RTL.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,8 @@ Error GenericDeviceTy::init(GenericPluginTy &Plugin) {
bool ExpectedStatus = false;
if (OmptInitialized.compare_exchange_strong(ExpectedStatus, true))
performOmptCallback(device_initialize,
/* device_num */ DeviceId,
/* device_num */ DeviceId +
Plugin.getDeviceIdStartIndex(),
/* type */ getComputeUnitKind().c_str(),
/* device */ reinterpret_cast<ompt_device_t *>(this),
/* lookup */ ompt::lookupCallbackByName,
Expand Down Expand Up @@ -587,7 +588,7 @@ Error GenericDeviceTy::init(GenericPluginTy &Plugin) {
return Plugin::success();
}

Error GenericDeviceTy::deinit() {
Error GenericDeviceTy::deinit(GenericPluginTy &Plugin) {
// Delete the memory manager before deinitializing the device. Otherwise,
// we may delete device allocations after the device is deinitialized.
if (MemoryManager)
Expand All @@ -605,7 +606,9 @@ Error GenericDeviceTy::deinit() {
if (ompt::Initialized) {
bool ExpectedStatus = true;
if (OmptInitialized.compare_exchange_strong(ExpectedStatus, false))
performOmptCallback(device_finalize, /* device_num */ DeviceId);
performOmptCallback(device_finalize,
/* device_num */ DeviceId +
Plugin.getDeviceIdStartIndex());
}
#endif

Expand Down Expand Up @@ -656,7 +659,8 @@ GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
size_t Bytes =
getPtrDiff(InputTgtImage->ImageEnd, InputTgtImage->ImageStart);
performOmptCallback(device_load,
/* device_num */ DeviceId,
/* device_num */ DeviceId +
Plugin.getDeviceIdStartIndex(),
/* FileName */ nullptr,
/* File Offset */ 0,
/* VmaInFile */ nullptr,
Expand Down Expand Up @@ -1362,7 +1366,7 @@ Error GenericPluginTy::deinitDevice(int32_t DeviceId) {
return Plugin::success();

// Deinitialize the device and release its resources.
if (auto Err = Devices[DeviceId]->deinit())
if (auto Err = Devices[DeviceId]->deinit(*this))
return Err;

// Delete the device and invalidate its reference.
Expand Down Expand Up @@ -1815,6 +1819,12 @@ int32_t __tgt_rtl_init_device_info(int32_t DeviceId,
return OFFLOAD_SUCCESS;
}

int32_t __tgt_rtl_set_device_offset(int32_t DeviceIdOffset) {
Plugin::get().setDeviceIdStartIndex(DeviceIdOffset);

return OFFLOAD_SUCCESS;
}

#ifdef __cplusplus
}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
/// Deinitialize the device and free all its resources. After this call, the
/// device is no longer considered ready, so no queries or modifications are
/// allowed.
Error deinit();
Error deinit(GenericPluginTy &Plugin);
virtual Error deinitImpl() = 0;

/// Load the binary image into the device and return the target table.
Expand Down Expand Up @@ -946,6 +946,12 @@ struct GenericPluginTy {
/// Get the number of active devices.
int32_t getNumDevices() const { return NumDevices; }

/// Get the plugin-specific device identifier offset.
int32_t getDeviceIdStartIndex() const { return DeviceIdStartIndex; }

/// Set the plugin-specific device identifier offset.
void setDeviceIdStartIndex(int32_t Offset) { DeviceIdStartIndex = Offset; }

/// Get the ELF code to recognize the binary image of this plugin.
virtual uint16_t getMagicElfBits() const = 0;

Expand Down Expand Up @@ -1010,6 +1016,11 @@ struct GenericPluginTy {
/// Number of devices available for the plugin.
int32_t NumDevices = 0;

/// Index offset, which when added to a DeviceId, will yield a unique
/// user-observable device identifier. This is especially important when
/// DeviceIds of multiple plugins / RTLs need to be distinguishable.
int32_t DeviceIdStartIndex = 0;

/// Array of pointers to the devices. Initially, they are all set to nullptr.
/// Once a device is initialized, the pointer is stored in the position given
/// by its device id. A position with nullptr means that the corresponding
Expand Down
8 changes: 4 additions & 4 deletions openmp/libomptarget/src/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ void *DeviceTy::allocData(int64_t Size, void *HstPtr, int32_t Kind) {
void *TargetPtr = nullptr;
OMPT_IF_BUILT(InterfaceRAII TargetDataAllocRAII(
RegionInterface.getCallbacks<ompt_target_data_alloc>(),
RTLDeviceID, HstPtr, &TargetPtr, Size,
DeviceID, HstPtr, &TargetPtr, Size,
/* CodePtr */ OMPT_GET_RETURN_ADDRESS(0));)

TargetPtr = RTL->data_alloc(RTLDeviceID, Size, HstPtr, Kind);
Expand All @@ -594,7 +594,7 @@ int32_t DeviceTy::deleteData(void *TgtAllocBegin, int32_t Kind) {
/// RAII to establish tool anchors before and after data deletion
OMPT_IF_BUILT(InterfaceRAII TargetDataDeleteRAII(
RegionInterface.getCallbacks<ompt_target_data_delete>(),
RTLDeviceID, TgtAllocBegin,
DeviceID, TgtAllocBegin,
/* CodePtr */ OMPT_GET_RETURN_ADDRESS(0));)

return RTL->data_delete(RTLDeviceID, TgtAllocBegin, Kind);
Expand Down Expand Up @@ -632,7 +632,7 @@ int32_t DeviceTy::submitData(void *TgtPtrBegin, void *HstPtrBegin, int64_t Size,
OMPT_IF_BUILT(
InterfaceRAII TargetDataSubmitRAII(
RegionInterface.getCallbacks<ompt_target_data_transfer_to_device>(),
RTLDeviceID, TgtPtrBegin, HstPtrBegin, Size,
DeviceID, TgtPtrBegin, HstPtrBegin, Size,
/* CodePtr */ OMPT_GET_RETURN_ADDRESS(0));)

if (!AsyncInfo || !RTL->data_submit_async || !RTL->synchronize)
Expand Down Expand Up @@ -660,7 +660,7 @@ int32_t DeviceTy::retrieveData(void *HstPtrBegin, void *TgtPtrBegin,
OMPT_IF_BUILT(
InterfaceRAII TargetDataRetrieveRAII(
RegionInterface.getCallbacks<ompt_target_data_transfer_from_device>(),
RTLDeviceID, HstPtrBegin, TgtPtrBegin, Size,
DeviceID, HstPtrBegin, TgtPtrBegin, Size,
/* CodePtr */ OMPT_GET_RETURN_ADDRESS(0));)

if (!RTL->data_retrieve_async || !RTL->synchronize)
Expand Down
6 changes: 6 additions & 0 deletions openmp/libomptarget/src/rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ bool RTLsTy::attemptLoadRTL(const std::string &RTLName, RTLInfoTy &RTL) {
DynLibrary->getAddressOfSymbol("__tgt_rtl_data_notify_mapped");
*((void **)&RTL.data_notify_unmapped) =
DynLibrary->getAddressOfSymbol("__tgt_rtl_data_notify_unmapped");
*((void **)&RTL.set_device_offset) =
DynLibrary->getAddressOfSymbol("__tgt_rtl_set_device_offset");

// Record Replay RTL
*((void **)&RTL.activate_record_replay) =
Expand Down Expand Up @@ -424,6 +426,10 @@ void RTLsTy::initRTLonce(RTLInfoTy &R) {
R.IsUsed = true;
UsedRTLs.push_back(&R);

// If possible, set the device identifier offset
if (R.set_device_offset)
R.set_device_offset(Start);

DP("RTL " DPxMOD " has index %d!\n", DPxPTR(R.LibraryHandler.get()), R.Idx);
}
}
Expand Down