Skip to content

[OpenMP][OMPT] Fix device identifier collision during callbacks #65595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions openmp/libomptarget/include/omptargetplugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ int32_t __tgt_rtl_data_notify_mapped(int32_t ID, void *HstPtr, int64_t Size);
// host address \p HstPtr and \p Size bytes.
int32_t __tgt_rtl_data_notify_unmapped(int32_t ID, void *HstPtr);

// Set the global device identifier offset, such that the plugin may determine a
// unique device number.
int32_t __tgt_rtl_set_device_offset(int32_t DeviceIdOffset);

#ifdef __cplusplus
}
#endif
Expand Down
2 changes: 2 additions & 0 deletions openmp/libomptarget/include/rtl.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ struct RTLInfoTy {
typedef int32_t(data_unlock_ty)(int32_t, void *);
typedef int32_t(data_notify_mapped_ty)(int32_t, void *, int64_t);
typedef int32_t(data_notify_unmapped_ty)(int32_t, void *);
typedef int32_t(set_device_offset_ty)(int32_t);
typedef int32_t(activate_record_replay_ty)(int32_t, uint64_t, bool, bool);

int32_t Idx = -1; // RTL index, index is the number of devices
Expand Down Expand Up @@ -125,6 +126,7 @@ struct RTLInfoTy {
data_unlock_ty *data_unlock = nullptr;
data_notify_mapped_ty *data_notify_mapped = nullptr;
data_notify_unmapped_ty *data_notify_unmapped = nullptr;
set_device_offset_ty *set_device_offset = nullptr;
activate_record_replay_ty *activate_record_replay = nullptr;

// Are there images associated with this RTL.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,8 @@ Error GenericDeviceTy::init(GenericPluginTy &Plugin) {
bool ExpectedStatus = false;
if (OmptInitialized.compare_exchange_strong(ExpectedStatus, true))
performOmptCallback(device_initialize,
/* device_num */ DeviceId,
/* device_num */ DeviceId +
Plugin.getGlobalDeviceIdOffset(),
/* type */ getComputeUnitKind().c_str(),
/* device */ reinterpret_cast<ompt_device_t *>(this),
/* lookup */ ompt::lookupCallbackByName,
Expand Down Expand Up @@ -587,7 +588,7 @@ Error GenericDeviceTy::init(GenericPluginTy &Plugin) {
return Plugin::success();
}

Error GenericDeviceTy::deinit() {
Error GenericDeviceTy::deinit(GenericPluginTy &Plugin) {
// Delete the memory manager before deinitializing the device. Otherwise,
// we may delete device allocations after the device is deinitialized.
if (MemoryManager)
Expand All @@ -605,7 +606,9 @@ Error GenericDeviceTy::deinit() {
if (ompt::Initialized) {
bool ExpectedStatus = true;
if (OmptInitialized.compare_exchange_strong(ExpectedStatus, false))
performOmptCallback(device_finalize, /* device_num */ DeviceId);
performOmptCallback(device_finalize,
/* device_num */ DeviceId +
Plugin.getGlobalDeviceIdOffset());
}
#endif

Expand Down Expand Up @@ -656,7 +659,8 @@ GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
size_t Bytes =
getPtrDiff(InputTgtImage->ImageEnd, InputTgtImage->ImageStart);
performOmptCallback(device_load,
/* device_num */ DeviceId,
/* device_num */ DeviceId +
Plugin.getGlobalDeviceIdOffset(),
/* FileName */ nullptr,
/* File Offset */ 0,
/* VmaInFile */ nullptr,
Expand Down Expand Up @@ -1362,7 +1366,7 @@ Error GenericPluginTy::deinitDevice(int32_t DeviceId) {
return Plugin::success();

// Deinitialize the device and release its resources.
if (auto Err = Devices[DeviceId]->deinit())
if (auto Err = Devices[DeviceId]->deinit(*this))
return Err;

// Delete the device and invalidate its reference.
Expand Down Expand Up @@ -1815,6 +1819,12 @@ int32_t __tgt_rtl_init_device_info(int32_t DeviceId,
return OFFLOAD_SUCCESS;
}

int32_t __tgt_rtl_set_device_offset(int32_t DeviceIdOffset) {
Plugin::get().setGlobalDeviceIdOffset(DeviceIdOffset);

return OFFLOAD_SUCCESS;
}

#ifdef __cplusplus
}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
/// Deinitialize the device and free all its resources. After this call, the
/// device is no longer considered ready, so no queries or modifications are
/// allowed.
Error deinit();
Error deinit(GenericPluginTy &Plugin);
virtual Error deinitImpl() = 0;

/// Load the binary image into the device and return the target table.
Expand Down Expand Up @@ -946,6 +946,14 @@ struct GenericPluginTy {
/// Get the number of active devices.
int32_t getNumDevices() const { return NumDevices; }

/// Get the plugin-specific device identifier offset.
int32_t getGlobalDeviceIdOffset() const { return GlobalDeviceIdOffset; }

/// Set the plugin-specific device identifier offset.
void setGlobalDeviceIdOffset(int32_t Offset) {
GlobalDeviceIdOffset = Offset;
}

/// Get the ELF code to recognize the binary image of this plugin.
virtual uint16_t getMagicElfBits() const = 0;

Expand Down Expand Up @@ -1010,6 +1018,10 @@ struct GenericPluginTy {
/// Number of devices available for the plugin.
int32_t NumDevices = 0;

/// Offset which when added to a DeviceId will yield a unique, user-observable
/// device identifier.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be reworded in the sense that makes it clear that this is (mainly/only?) in scenarios with multiple different supported devices?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid point, I'm not pushing that naming or description.

Could also be name the corresponding variable PluginDeviceStartIndex (or similar); "global" was my go-to as a globally / unique DeviceId we wanted to determine.
Or was your initial thought that a clarification in the comment would suffice?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid point, I'm not pushing that naming or description.

Could also be name the corresponding variable PluginDeviceStartIndex (or similar); "global" was my go-to as a globally / unique DeviceId we wanted to determine. Or was your initial thought that a clarification in the comment would suffice?

I like PluginDeviceStartIndex because that partially corresponds to the variable name where it originates from (libomptarget rtl.cpp).

Copy link
Contributor Author

@mhalk mhalk Sep 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are in the GenericPluginTy namespace, I'll drop the Plugin prefix and suggest: DeviceIdStartIndex.
Also go with a more verbose description as JP suggested -- so it should be more clear when this is actually useful.

int32_t GlobalDeviceIdOffset = 0;

/// Array of pointers to the devices. Initially, they are all set to nullptr.
/// Once a device is initialized, the pointer is stored in the position given
/// by its device id. A position with nullptr means that the corresponding
Expand Down
8 changes: 4 additions & 4 deletions openmp/libomptarget/src/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ void *DeviceTy::allocData(int64_t Size, void *HstPtr, int32_t Kind) {
void *TargetPtr = nullptr;
OMPT_IF_BUILT(InterfaceRAII TargetDataAllocRAII(
RegionInterface.getCallbacks<ompt_target_data_alloc>(),
RTLDeviceID, HstPtr, &TargetPtr, Size,
DeviceID, HstPtr, &TargetPtr, Size,
/* CodePtr */ OMPT_GET_RETURN_ADDRESS(0));)

TargetPtr = RTL->data_alloc(RTLDeviceID, Size, HstPtr, Kind);
Expand All @@ -594,7 +594,7 @@ int32_t DeviceTy::deleteData(void *TgtAllocBegin, int32_t Kind) {
/// RAII to establish tool anchors before and after data deletion
OMPT_IF_BUILT(InterfaceRAII TargetDataDeleteRAII(
RegionInterface.getCallbacks<ompt_target_data_delete>(),
RTLDeviceID, TgtAllocBegin,
DeviceID, TgtAllocBegin,
/* CodePtr */ OMPT_GET_RETURN_ADDRESS(0));)

return RTL->data_delete(RTLDeviceID, TgtAllocBegin, Kind);
Expand Down Expand Up @@ -632,7 +632,7 @@ int32_t DeviceTy::submitData(void *TgtPtrBegin, void *HstPtrBegin, int64_t Size,
OMPT_IF_BUILT(
InterfaceRAII TargetDataSubmitRAII(
RegionInterface.getCallbacks<ompt_target_data_transfer_to_device>(),
RTLDeviceID, TgtPtrBegin, HstPtrBegin, Size,
DeviceID, TgtPtrBegin, HstPtrBegin, Size,
/* CodePtr */ OMPT_GET_RETURN_ADDRESS(0));)

if (!AsyncInfo || !RTL->data_submit_async || !RTL->synchronize)
Expand Down Expand Up @@ -660,7 +660,7 @@ int32_t DeviceTy::retrieveData(void *HstPtrBegin, void *TgtPtrBegin,
OMPT_IF_BUILT(
InterfaceRAII TargetDataRetrieveRAII(
RegionInterface.getCallbacks<ompt_target_data_transfer_from_device>(),
RTLDeviceID, HstPtrBegin, TgtPtrBegin, Size,
DeviceID, HstPtrBegin, TgtPtrBegin, Size,
/* CodePtr */ OMPT_GET_RETURN_ADDRESS(0));)

if (!RTL->data_retrieve_async || !RTL->synchronize)
Expand Down
6 changes: 6 additions & 0 deletions openmp/libomptarget/src/rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ bool RTLsTy::attemptLoadRTL(const std::string &RTLName, RTLInfoTy &RTL) {
DynLibrary->getAddressOfSymbol("__tgt_rtl_data_notify_mapped");
*((void **)&RTL.data_notify_unmapped) =
DynLibrary->getAddressOfSymbol("__tgt_rtl_data_notify_unmapped");
*((void **)&RTL.set_device_offset) =
DynLibrary->getAddressOfSymbol("__tgt_rtl_set_device_offset");

// Record Replay RTL
*((void **)&RTL.activate_record_replay) =
Expand Down Expand Up @@ -424,6 +426,10 @@ void RTLsTy::initRTLonce(RTLInfoTy &R) {
R.IsUsed = true;
UsedRTLs.push_back(&R);

// If possible, set the device identifier offset
if (R.set_device_offset)
R.set_device_offset(Start);

DP("RTL " DPxMOD " has index %d!\n", DPxPTR(R.LibraryHandler.get()), R.Idx);
}
}
Expand Down