Skip to content
34 changes: 22 additions & 12 deletions sycl/source/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,22 +209,32 @@ ur_native_handle_t device::getNative() const { return impl->getNative(); }

bool device::has(aspect Aspect) const { return impl->has(Aspect); }

template <detail::UrApiKind ApiKind>
static void p2pAccessHelper(const device &self, const device &peer,
ur_device_handle_t Device, ur_device_handle_t Peer,
detail::adapter_impl &Adapter,
const char *errorMsg) {
if (Device == Peer)
return;

if (peer.get_platform() != self.get_platform())
throw exception(errc::invalid, errorMsg);

Adapter.call<ApiKind>(Device, Peer);
}

void device::ext_oneapi_enable_peer_access(const device &peer) {
ur_device_handle_t Device = impl->getHandleRef();
ur_device_handle_t Peer = peer.impl->getHandleRef();
if (Device != Peer) {
detail::adapter_impl &Adapter = impl->getAdapter();
Adapter.call<detail::UrApiKind::urUsmP2PEnablePeerAccessExp>(Device, Peer);
}
p2pAccessHelper<detail::UrApiKind::urUsmP2PEnablePeerAccessExp>(
*this, peer, impl->getHandleRef(), peer.impl->getHandleRef(),
impl->getAdapter(),
"Cannot enable peer access between different platforms");
}

void device::ext_oneapi_disable_peer_access(const device &peer) {
ur_device_handle_t Device = impl->getHandleRef();
ur_device_handle_t Peer = peer.impl->getHandleRef();
if (Device != Peer) {
detail::adapter_impl &Adapter = impl->getAdapter();
Adapter.call<detail::UrApiKind::urUsmP2PDisablePeerAccessExp>(Device, Peer);
}
p2pAccessHelper<detail::UrApiKind::urUsmP2PDisablePeerAccessExp>(
*this, peer, impl->getHandleRef(), peer.impl->getHandleRef(),
impl->getAdapter(),
"Cannot disable peer access between different platforms");
}

bool device::ext_oneapi_can_access_peer(const device &peer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ if(UR_BUILD_ADAPTER_L0_V2)
${CMAKE_CURRENT_SOURCE_DIR}/helpers/kernel_helpers.cpp
${CMAKE_CURRENT_SOURCE_DIR}/helpers/memory_helpers.cpp
${CMAKE_CURRENT_SOURCE_DIR}/helpers/mutable_helpers.cpp
${CMAKE_CURRENT_SOURCE_DIR}/usm_p2p.cpp
${CMAKE_CURRENT_SOURCE_DIR}/virtual_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../../ur/ur.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sampler.hpp
Expand Down Expand Up @@ -194,6 +193,7 @@ if(UR_BUILD_ADAPTER_L0_V2)
${CMAKE_CURRENT_SOURCE_DIR}/v2/queue_immediate_in_order.cpp
${CMAKE_CURRENT_SOURCE_DIR}/v2/queue_immediate_out_of_order.cpp
${CMAKE_CURRENT_SOURCE_DIR}/v2/usm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/v2/usm_p2p.cpp
)
install_ur_library(ur_adapter_level_zero_v2)

Expand Down
2 changes: 2 additions & 0 deletions unified-runtime/source/adapters/level_zero/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ ur_result_t urContextCreate(

Context->initialize();
*RetContext = reinterpret_cast<ur_context_handle_t>(Context);
// TODO: delete below 'if' when memory isolation in the context is
// implemented in the driver
if (IndirectAccessTrackingEnabled) {
std::scoped_lock<ur_shared_mutex> Lock(Platform->ContextsMutex);
Platform->Contexts.push_back(*RetContext);
Expand Down
21 changes: 21 additions & 0 deletions unified-runtime/source/adapters/level_zero/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2353,3 +2353,24 @@ void ZeUSMImportExtension::doZeUSMRelease(ze_driver_handle_t DriverHandle,
void *HostPtr) {
ZE_CALL_NOCHECK(zexDriverReleaseImportedPointer, (DriverHandle, HostPtr));
}

std::ostream &operator<<(std::ostream &os,
ur_device_handle_t_ const &device_handle) {
if (device_handle.Id.has_value()) {
return os << device_handle.Id.value();
}
return os << "NONE";
}

std::ostream &operator<<(std::ostream &os,
ur_device_handle_t_::PeerStatus peer_status) {
switch (peer_status) {
case ur_device_handle_t_::PeerStatus::DISABLED:
return os << "DISABLED";
case ur_device_handle_t_::PeerStatus::ENABLED:
return os << "ENABLED";
case ur_device_handle_t_::PeerStatus::NO_CONNECTION:
return os << "NO_CONNECTION";
}
return os << "UNKNOWN";
}
14 changes: 13 additions & 1 deletion unified-runtime/source/adapters/level_zero/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,17 +261,29 @@ struct ur_device_handle_t_ : ur_object {
std::unordered_map<ur_exp_image_native_handle_t, ze_image_handle_t>
ZeOffsetToImageHandleMap;

// unique ephemeral identifer of the device in the adapter
// Devices which user enabled p2p access by
// urUsmP2P(Enable|Disable)PeerAccessExp. Devices are indexed by device id.
enum class PeerStatus : char { ENABLED, DISABLED, NO_CONNECTION };
std::vector<PeerStatus>
peers; // info if our device can access given peer device allocations

// unique ephemeral identifier of the device in the adapter
std::optional<DeviceId> Id;

ur::RefCount RefCount;
};

std::ostream &operator<<(std::ostream &os,
ur_device_handle_t_ const &device_handle);
std::ostream &operator<<(std::ostream &os,
ur_device_handle_t_::PeerStatus peer_status);

// Collects a flat vector of unique devices for USM memory pool creation.
// Traverses the input devices and their sub-devices, ensuring each Level Zero
// device handle appears only once in the result.
inline std::vector<ur_device_handle_t> CollectDevicesForUsmPoolCreation(
const std::vector<ur_device_handle_t> &Devices) {

std::vector<ur_device_handle_t> DevicesAndSubDevices;
std::unordered_set<ze_device_handle_t> Seen;

Expand Down
43 changes: 40 additions & 3 deletions unified-runtime/source/adapters/level_zero/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -746,9 +746,9 @@ ur_platform_handle_t_::getDeviceFromNativeHandle(ze_device_handle_t ZeDevice) {
std::shared_lock<ur_shared_mutex> Lock(URDevicesCacheMutex);
auto it = std::find_if(URDevicesCache.begin(), URDevicesCache.end(),
[&](std::unique_ptr<ur_device_handle_t_> &D) {
return D.get()->ZeDevice == ZeDevice &&
(D.get()->RootDevice == nullptr ||
D.get()->RootDevice->RootDevice == nullptr);
return D->ZeDevice == ZeDevice &&
(D->RootDevice == nullptr ||
D->RootDevice->RootDevice == nullptr);
});
if (it != URDevicesCache.end()) {
return (*it).get();
Expand Down Expand Up @@ -914,6 +914,43 @@ ur_result_t ur_platform_handle_t_::populateDeviceCacheIfNeeded() {
ZeDeviceSynchronizeSupported = Supported;
}

for (auto &dev : URDevicesCache) {
dev->peers = std::vector<ur_device_handle_t_::PeerStatus>(
URDevicesCache.size(), ur_device_handle_t_::PeerStatus::NO_CONNECTION);

for (size_t peerId = 0; peerId < URDevicesCache.size(); ++peerId) {
if (peerId == dev->Id.value())
continue;

ZeStruct<ze_device_p2p_properties_t> p2pProperties;
ZE2UR_CALL(
zeDeviceGetP2PProperties,
(dev->ZeDevice, URDevicesCache[peerId]->ZeDevice, &p2pProperties));
if (!(p2pProperties.flags & ZE_DEVICE_P2P_PROPERTY_FLAG_ACCESS)) {
UR_LOG(INFO,
"p2p access to memory of dev:{} from dev:{} not possible due to "
"lack of p2p property",
peerId, dev->Id.value());
continue;
}

ze_bool_t p2p;
ZE2UR_CALL(zeDeviceCanAccessPeer,
(dev->ZeDevice, URDevicesCache[peerId]->ZeDevice, &p2p));
if (!p2p) {
UR_LOG(INFO,
"p2p access to memory of dev:{} from dev:{} not possible due to "
"no connection",
peerId, dev->Id.value());
continue;
}

UR_LOG(INFO, "p2p access to memory of dev:{} from dev:{} can be enabled",
peerId, dev->Id.value());
dev->peers[peerId] = ur_device_handle_t_::PeerStatus::DISABLED;
}
}

return UR_RESULT_SUCCESS;
}

Expand Down
9 changes: 4 additions & 5 deletions unified-runtime/source/adapters/level_zero/platform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,10 @@ struct ur_platform_handle_t_ : ur::handle_base<ur::level_zero::ddi_getter>,
uint32_t VersionMinor,
uint32_t VersionBuild);

// Keep track of all contexts in the platform. This is needed to manage
// a lifetime of memory allocations in each context when there are kernels
// with indirect access.
// TODO: should be deleted when memory isolation in the context is implemented
// in the driver.
// Keep track of all contexts in the platform. In v1 L0 this is needed to
// manage a lifetime of memory allocations in each context when there are
// kernels with indirect access. In v2 it is used during
// ext_oneapi_enable_peer_access and ext_oneapi_disable_peer_access calls.
std::list<ur_context_handle_t> Contexts;
ur_shared_mutex ContextsMutex;

Expand Down
18 changes: 12 additions & 6 deletions unified-runtime/source/adapters/level_zero/usm_p2p.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,23 @@

namespace ur::level_zero {

ur_result_t urUsmP2PEnablePeerAccessExp(ur_device_handle_t /*commandDevice*/,
ur_device_handle_t /*peerDevice*/) {
ur_result_t urUsmP2PEnablePeerAccessExp(ur_device_handle_t commandDevice,
ur_device_handle_t peerDevice) {

// L0 has peer devices enabled by default
UR_LOG(INFO,
"ignored enabling peer access from {} to memory of {}, because P2P is "
"always enabled in Level Zero V1 adapter",
(void *)commandDevice, (void *)peerDevice);
return UR_RESULT_SUCCESS;
}

ur_result_t urUsmP2PDisablePeerAccessExp(ur_device_handle_t /*commandDevice*/,
ur_device_handle_t /*peerDevice*/) {
ur_result_t urUsmP2PDisablePeerAccessExp(ur_device_handle_t commandDevice,
ur_device_handle_t peerDevice) {

// L0 has peer devices enabled by default
UR_LOG(INFO,
"ignored disabling peer access from {} to memory of {}, because P2P "
"is always enabled in Level Zero V1 adapter",
(void *)commandDevice, (void *)peerDevice);
return UR_RESULT_SUCCESS;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,39 @@ ur_result_t ur_command_list_manager::appendUSMMemcpy(
wait_list_view &waitListView, ur_event_handle_t phEvent) {
TRACK_SCOPE_LATENCY("ur_command_list_manager::appendUSMMemcpy");

// Check P2P access: if the source pointer is device memory on a different
// device AND the destination is also device memory (not host/shared), verify
// that peer access has been enabled. Copies to host memory always succeed
// regardless of P2P state.
ZeStruct<ze_memory_allocation_properties_t> dstMemProps;
ze_device_handle_t dstZeDevice = nullptr;
auto zeDstResult = ZE_CALL_NOCHECK(
zeMemGetAllocProperties,
(hContext.get()->getZeHandle(), pDst, &dstMemProps, &dstZeDevice));
if (zeDstResult == ZE_RESULT_SUCCESS &&
dstMemProps.type == ZE_MEMORY_TYPE_DEVICE) {
ZeStruct<ze_memory_allocation_properties_t> srcMemProps;
ze_device_handle_t srcZeDevice = nullptr;
auto zeSrcResult = ZE_CALL_NOCHECK(
zeMemGetAllocProperties,
(hContext.get()->getZeHandle(), pSrc, &srcMemProps, &srcZeDevice));
if (zeSrcResult == ZE_RESULT_SUCCESS &&
srcMemProps.type == ZE_MEMORY_TYPE_DEVICE && srcZeDevice &&
srcZeDevice != hDevice.get()->ZeDevice) {
auto *srcDevice =
hContext.get()->getPlatform()->getDeviceFromNativeHandle(srcZeDevice);
if (srcDevice && srcDevice->Id.has_value() &&
hDevice.get()->Id.has_value() &&
hDevice.get()->Id.value() < srcDevice->peers.size()) {
std::scoped_lock<ur_shared_mutex> lock(srcDevice->Mutex);
if (srcDevice->peers[hDevice.get()->Id.value()] !=
ur_device_handle_t_::PeerStatus::ENABLED) {
return UR_RESULT_ERROR_INVALID_OPERATION;
}
}
}
}

auto zeSignalEvent = getSignalEvent(phEvent, UR_COMMAND_USM_MEMCPY);
auto [pWaitEvents, numWaitEvents, _] = waitListView;

Expand Down
Loading
Loading