diff --git a/unified-runtime/source/adapters/level_zero/v2/queue_batched.cpp b/unified-runtime/source/adapters/level_zero/v2/queue_batched.cpp index 3ed9132d007a8..4c98d749281cc 100644 --- a/unified-runtime/source/adapters/level_zero/v2/queue_batched.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/queue_batched.cpp @@ -125,6 +125,7 @@ ur_queue_batched_t::renewBatchUnlocked(locked &batchLocked) { if (batchLocked->isLimitOfUsedCommandListsReached()) { return queueFinishUnlocked(batchLocked); } else { + UR_CALL(batchLocked->enqueueCurrentBatchUnlocked()); return batchLocked->renewRegularUnlocked(getNewRegularCmdList()); } } @@ -157,7 +158,7 @@ ur_queue_batched_t::onEventWaitListUse(ur_event_generation_t batch_generation) { auto batchLocked = currentCmdLists.lock(); if (batchLocked->isCurrentGeneration(batch_generation)) { - return queueFlushUnlocked(batchLocked); + return renewBatchUnlocked(batchLocked); } else { return UR_RESULT_SUCCESS; } @@ -228,7 +229,9 @@ ur_result_t batch_manager::batchFinish() { UR_CALL(activeBatch.releaseSubmittedKernels()); if (!isActiveBatchEmpty()) { - // Should have been enqueued as part of queueFinishUnlocked + // The active batch was already submitted to the immediate command list + // by queueFinishUnlocked. Reset it here so it is ready to record new + // commands. TRACK_SCOPE_LATENCY("ur_queue_batched_t::resetRegCmdlist"); ZE2UR_CALL(zeCommandListReset, (activeBatch.getZeCommandList())); @@ -432,7 +435,7 @@ ur_result_t ur_queue_batched_t::enqueueUSMFreeExp( createEventIfRequestedRegular(phEvent, lockedBatch->getCurrentGeneration()))); - return queueFlushUnlocked(lockedBatch); + return renewBatchUnlocked(lockedBatch); } ur_result_t ur_queue_batched_t::enqueueMemBufferMap( @@ -634,7 +637,7 @@ ur_result_t ur_queue_batched_t::enqueueEventsWaitWithBarrier( phEvent, lockedBatch->getCurrentGeneration()))); } - return queueFlushUnlocked(lockedBatch); + return renewBatchUnlocked(lockedBatch); } ur_result_t @@ -652,7 +655,7 @@ ur_queue_batched_t::enqueueEventsWait(uint32_t numEventsInWaitList, waitListView, createEventIfRequestedRegular( phEvent, lockedBatch->getCurrentGeneration()))); - return queueFlushUnlocked(lockedBatch); + return renewBatchUnlocked(lockedBatch); } ur_result_t ur_queue_batched_t::enqueueMemBufferCopy( @@ -818,7 +821,7 @@ ur_result_t ur_queue_batched_t::enqueueUSMDeviceAllocExp( lockedBatch->getCurrentGeneration()), UR_USM_TYPE_DEVICE)); - return queueFlushUnlocked(lockedBatch); + return renewBatchUnlocked(lockedBatch); } ur_result_t ur_queue_batched_t::enqueueUSMSharedAllocExp( @@ -840,7 +843,7 @@ ur_result_t ur_queue_batched_t::enqueueUSMSharedAllocExp( lockedBatch->getCurrentGeneration()), UR_USM_TYPE_SHARED)); - return queueFlushUnlocked(lockedBatch); + return renewBatchUnlocked(lockedBatch); } ur_result_t ur_queue_batched_t::enqueueUSMHostAllocExp( @@ -861,7 +864,7 @@ ur_result_t ur_queue_batched_t::enqueueUSMHostAllocExp( lockedBatch->getCurrentGeneration()), UR_USM_TYPE_HOST)); - return queueFlushUnlocked(lockedBatch); + return renewBatchUnlocked(lockedBatch); } ur_result_t ur_queue_batched_t::bindlessImagesImageCopyExp( @@ -969,7 +972,7 @@ ur_result_t ur_queue_batched_t::enqueueCommandBufferExp( // command buffer batch (also a regular list) to preserve the order of // operations if (!lockedBatch->isActiveBatchEmpty()) { - UR_CALL(queueFlushUnlocked(lockedBatch)); + UR_CALL(renewBatchUnlocked(lockedBatch)); } // Regular lists cannot be appended to other regular lists for execution, only @@ -1075,20 +1078,13 @@ ur_queue_batched_t::queueGetNativeHandle(ur_queue_native_desc_t * /*pDesc*/, return UR_RESULT_SUCCESS; } -ur_result_t -ur_queue_batched_t::queueFlushUnlocked(locked &batchLocked) { - UR_CALL(batchLocked->enqueueCurrentBatchUnlocked()); - - return renewBatchUnlocked(batchLocked); -} - ur_result_t ur_queue_batched_t::queueFlush() { auto batchLocked = currentCmdLists.lock(); if (batchLocked->isActiveBatchEmpty()) { return UR_RESULT_SUCCESS; } else { - return queueFlushUnlocked(batchLocked); + return renewBatchUnlocked(batchLocked); } } diff --git a/unified-runtime/source/adapters/level_zero/v2/queue_batched.hpp b/unified-runtime/source/adapters/level_zero/v2/queue_batched.hpp index df1adab5bc838..0f1d0932f227b 100644 --- a/unified-runtime/source/adapters/level_zero/v2/queue_batched.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/queue_batched.hpp @@ -199,8 +199,6 @@ struct ur_queue_batched_t : ur_object, ur_queue_t_ { ur_result_t queueFinishUnlocked(locked &batchLocked); - ur_result_t queueFlushUnlocked(locked &batchLocked); - ur_result_t markIssuedCommandInBatch(locked &batchLocked); public: diff --git a/unified-runtime/test/adapters/level_zero/enqueue_alloc.cpp b/unified-runtime/test/adapters/level_zero/enqueue_alloc.cpp index c322a14768186..b2c1128b770ae 100644 --- a/unified-runtime/test/adapters/level_zero/enqueue_alloc.cpp +++ b/unified-runtime/test/adapters/level_zero/enqueue_alloc.cpp @@ -73,16 +73,19 @@ std::ostream &operator<<(std::ostream &os, } struct urL0EnqueueAllocMultiQueueSameDeviceTest - : uur::urContextTestWithParam { + : uur::urContextTestWithParam< + uur::MultiQueueParam> { void SetUp() override { UUR_RETURN_ON_FATAL_FAILURE(urContextTestWithParam::SetUp()); - auto param = std::get<1>(this->GetParam()); + const auto ¶m = getAllocParam(); + + ur_queue_properties_t props = {UR_STRUCTURE_TYPE_QUEUE_PROPERTIES, nullptr, + getQueueFlags()}; queues.reserve(param.numQueues); for (size_t i = 0; i < param.numQueues; i++) { ur_queue_handle_t queue = nullptr; - ASSERT_SUCCESS(urQueueCreate(context, device, 0, &queue)); - SKIP_IF_BATCHED_QUEUE(queue); + ASSERT_SUCCESS(urQueueCreate(context, device, &props, &queue)); queues.push_back(queue); } } @@ -95,6 +98,20 @@ struct urL0EnqueueAllocMultiQueueSameDeviceTest UUR_RETURN_ON_FATAL_FAILURE(urContextTestWithParam::TearDown()); } + const uur::MultiQueueParam & + getMultiQueueParam() const { + return uur::urContextTestWithParam< + uur::MultiQueueParam>::getParam(); + } + + const EnqueueAllocMultiQueueTestParam &getAllocParam() const { + return std::get<0>(this->getMultiQueueParam()); + } + + ur_queue_flag_t getQueueFlags() const { + return std::get<1>(this->getMultiQueueParam()); + } + std::vector queues; }; @@ -322,7 +339,7 @@ TEST_P(urL0EnqueueAllocTest, SuccessWithKernelRepeat) { ValidateEnqueueFree(ptr2); } -UUR_DEVICE_TEST_SUITE_WITH_PARAM( +UUR_MULTI_QUEUE_TYPE_TEST_SUITE_WITH_PARAM( urL0EnqueueAllocMultiQueueSameDeviceTest, ::testing::ValuesIn({ EnqueueAllocMultiQueueTestParam{1024, 256, 8, urEnqueueUSMHostAllocExp, @@ -334,20 +351,16 @@ UUR_DEVICE_TEST_SUITE_WITH_PARAM( urEnqueueUSMDeviceAllocExp, uur::GetDeviceUSMDeviceSupport}, }), - uur::deviceTestWithParamPrinter); + uur::deviceTestWithParamPrinterMulti); TEST_P(urL0EnqueueAllocMultiQueueSameDeviceTest, SuccessMt) { - const size_t allocSize = std::get<1>(this->GetParam()).allocSize; - const size_t numQueues = std::get<1>(this->GetParam()).numQueues; - const size_t iterations = std::get<1>(this->GetParam()).iterations; + const size_t allocSize = getAllocParam().allocSize; + const size_t numQueues = getAllocParam().numQueues; + const size_t iterations = getAllocParam().iterations; const auto enqueueUSMAllocFunc = - std::get<1>(this->GetParam()).funcParams.enqueueUSMAllocFunc; + getAllocParam().funcParams.enqueueUSMAllocFunc; const auto checkUSMSupportFunc = - std::get<1>(this->GetParam()).funcParams.checkUSMSupportFunc; - - if (numQueues > 0) { - SKIP_IF_BATCHED_QUEUE(queues[0]); - } + getAllocParam().funcParams.checkUSMSupportFunc; ur_device_usm_access_capability_flags_t USMSupport = 0; ASSERT_SUCCESS(checkUSMSupportFunc(device, USMSupport)); @@ -394,11 +407,11 @@ TEST_P(urL0EnqueueAllocMultiQueueSameDeviceTest, SuccessMt) { TEST_P(urL0EnqueueAllocMultiQueueSameDeviceTest, SuccessReuse) { GTEST_SKIP() << "Multi queue reuse is not supported."; - const size_t allocSize = std::get<1>(this->GetParam()).allocSize; + const size_t allocSize = getAllocParam().allocSize; const auto enqueueUSMAllocFunc = - std::get<1>(this->GetParam()).funcParams.enqueueUSMAllocFunc; + getAllocParam().funcParams.enqueueUSMAllocFunc; const auto checkUSMSupportFunc = - std::get<1>(this->GetParam()).funcParams.checkUSMSupportFunc; + getAllocParam().funcParams.checkUSMSupportFunc; ur_device_usm_access_capability_flags_t USMSupport = 0; ASSERT_SUCCESS(checkUSMSupportFunc(device, USMSupport)); @@ -457,12 +470,12 @@ TEST_P(urL0EnqueueAllocMultiQueueSameDeviceTest, SuccessReuse) { } TEST_P(urL0EnqueueAllocMultiQueueSameDeviceTest, SuccessDependantMt) { - const size_t allocSize = std::get<1>(this->GetParam()).allocSize; - const size_t iterations = std::get<1>(this->GetParam()).iterations; + const size_t allocSize = getAllocParam().allocSize; + const size_t iterations = getAllocParam().iterations; const auto enqueueUSMAllocFunc = - std::get<1>(this->GetParam()).funcParams.enqueueUSMAllocFunc; + getAllocParam().funcParams.enqueueUSMAllocFunc; const auto checkUSMSupportFunc = - std::get<1>(this->GetParam()).funcParams.checkUSMSupportFunc; + getAllocParam().funcParams.checkUSMSupportFunc; ur_device_usm_access_capability_flags_t USMSupport = 0; ASSERT_SUCCESS(checkUSMSupportFunc(device, USMSupport));