diff options
Diffstat (limited to 'src/Driver/EncryptedIoQueue.c')
-rw-r--r-- | src/Driver/EncryptedIoQueue.c | 166 |
1 files changed, 147 insertions, 19 deletions
diff --git a/src/Driver/EncryptedIoQueue.c b/src/Driver/EncryptedIoQueue.c index 8c2e8a41..91399c47 100644 --- a/src/Driver/EncryptedIoQueue.c +++ b/src/Driver/EncryptedIoQueue.c @@ -266,32 +266,106 @@ UpdateBuffer( return updated; } +static VOID CompleteIrpWorkItemRoutine(PDEVICE_OBJECT DeviceObject, PVOID Context) +{ + PCOMPLETE_IRP_WORK_ITEM workItem = (PCOMPLETE_IRP_WORK_ITEM)Context; + EncryptedIoQueueItem* item = (EncryptedIoQueueItem * ) workItem->Item; + EncryptedIoQueue* queue = item->Queue; + KIRQL oldIrql; + UNREFERENCED_PARAMETER(DeviceObject); + + __try + { + // Complete the IRP + TCCompleteDiskIrp(workItem->Irp, workItem->Status, workItem->Information); + + item->Status = workItem->Status; + OnItemCompleted(item, FALSE); // Do not free item here; it will be freed below + } + __finally + { + // If no active work items remain, signal the event + if (InterlockedDecrement(&queue->ActiveWorkItems) == 0) + { + KeSetEvent(&queue->NoActiveWorkItemsEvent, IO_NO_INCREMENT, FALSE); + } + + // Return the work item to the free list + KeAcquireSpinLock(&queue->WorkItemLock, &oldIrql); + InsertTailList(&queue->FreeWorkItemsList, &workItem->ListEntry); + KeReleaseSpinLock(&queue->WorkItemLock, oldIrql); + + // Release the semaphore to signal that a work item is available + KeReleaseSemaphore(&queue->WorkItemSemaphore, IO_NO_INCREMENT, 1, FALSE); + + // Free the item + ReleasePoolBuffer(queue, item); + } +} -static VOID CompletionThreadProc (PVOID threadArg) +// Handles the completion of the original IRP. +static VOID HandleCompleteOriginalIrp(EncryptedIoQueue* queue, EncryptedIoRequest* request) { - EncryptedIoQueue *queue = (EncryptedIoQueue *) threadArg; + NTSTATUS status = KeWaitForSingleObject(&queue->WorkItemSemaphore, Executive, KernelMode, FALSE, NULL); + if (queue->ThreadExitRequested) + return; + + if (!NT_SUCCESS(status)) + { + // Handle wait failure: we call the completion routine directly. + // This is not ideal since it can cause deadlock that we are trying to fix but it is better than losing the IRP. + CompleteOriginalIrp(request->Item, STATUS_INSUFFICIENT_RESOURCES, 0); + } + else + { + // Obtain a work item from the free list. + KIRQL oldIrql; + KeAcquireSpinLock(&queue->WorkItemLock, &oldIrql); + PLIST_ENTRY freeEntry = RemoveHeadList(&queue->FreeWorkItemsList); + KeReleaseSpinLock(&queue->WorkItemLock, oldIrql); + + PCOMPLETE_IRP_WORK_ITEM workItem = CONTAINING_RECORD(freeEntry, COMPLETE_IRP_WORK_ITEM, ListEntry); + + // Increment ActiveWorkItems. + InterlockedIncrement(&queue->ActiveWorkItems); + KeResetEvent(&queue->NoActiveWorkItemsEvent); + + // Prepare the work item. + workItem->Irp = request->Item->OriginalIrp; + workItem->Status = request->Item->Status; + workItem->Information = NT_SUCCESS(request->Item->Status) ? request->Item->OriginalLength : 0; + workItem->Item = request->Item; + + // Queue the work item. + IoQueueWorkItem(workItem->WorkItem, CompleteIrpWorkItemRoutine, DelayedWorkQueue, workItem); + } +} + +static VOID CompletionThreadProc(PVOID threadArg) +{ + EncryptedIoQueue* queue = (EncryptedIoQueue*)threadArg; PLIST_ENTRY listEntry; - EncryptedIoRequest *request; + EncryptedIoRequest* request; UINT64_STRUCT dataUnit; if (IsEncryptionThreadPoolRunning()) - KeSetPriorityThread (KeGetCurrentThread(), LOW_REALTIME_PRIORITY); + KeSetPriorityThread(KeGetCurrentThread(), LOW_REALTIME_PRIORITY); while (!queue->ThreadExitRequested) { - if (!NT_SUCCESS (KeWaitForSingleObject (&queue->CompletionThreadQueueNotEmptyEvent, Executive, KernelMode, FALSE, NULL))) + if (!NT_SUCCESS(KeWaitForSingleObject(&queue->CompletionThreadQueueNotEmptyEvent, Executive, KernelMode, FALSE, NULL))) continue; if (queue->ThreadExitRequested) break; - while ((listEntry = ExInterlockedRemoveHeadList (&queue->CompletionThreadQueue, &queue->CompletionThreadQueueLock))) + while ((listEntry = ExInterlockedRemoveHeadList(&queue->CompletionThreadQueue, &queue->CompletionThreadQueueLock))) { - request = CONTAINING_RECORD (listEntry, EncryptedIoRequest, CompletionListEntry); + request = CONTAINING_RECORD(listEntry, EncryptedIoRequest, CompletionListEntry); - if (request->EncryptedLength > 0 && NT_SUCCESS (request->Item->Status)) + if (request->EncryptedLength > 0 && NT_SUCCESS(request->Item->Status)) { - ASSERT (request->EncryptedOffset + request->EncryptedLength <= request->Offset.QuadPart + request->Length); + ASSERT(request->EncryptedOffset + request->EncryptedLength <= request->Offset.QuadPart + request->Length); dataUnit.Value = (request->Offset.QuadPart + request->EncryptedOffset) / ENCRYPTION_DATA_UNIT_SIZE; if (queue->CryptoInfo->bPartitionInInactiveSysEncScope) @@ -299,7 +373,7 @@ static VOID CompletionThreadProc (PVOID threadArg) else if (queue->RemapEncryptedArea) dataUnit.Value += queue->RemappedAreaDataUnitOffset; - DecryptDataUnits (request->Data + request->EncryptedOffset, &dataUnit, request->EncryptedLength / ENCRYPTION_DATA_UNIT_SIZE, queue->CryptoInfo); + DecryptDataUnits(request->Data + request->EncryptedOffset, &dataUnit, request->EncryptedLength / ENCRYPTION_DATA_UNIT_SIZE, queue->CryptoInfo); } // Dump("Read sector %lld count %d\n", request->Offset.QuadPart >> 9, request->Length >> 9); // Update subst sectors @@ -309,15 +383,14 @@ static VOID CompletionThreadProc (PVOID threadArg) if (request->CompleteOriginalIrp) { - CompleteOriginalIrp (request->Item, request->Item->Status, - NT_SUCCESS (request->Item->Status) ? request->Item->OriginalLength : 0); + HandleCompleteOriginalIrp(queue, request); } - ReleasePoolBuffer (queue, request); + ReleasePoolBuffer(queue, request); } } - PsTerminateSystemThread (STATUS_SUCCESS); + PsTerminateSystemThread(STATUS_SUCCESS); } @@ -471,8 +544,7 @@ static VOID IoThreadProc (PVOID threadArg) if (request->CompleteOriginalIrp) { - CompleteOriginalIrp (request->Item, request->Item->Status, - NT_SUCCESS (request->Item->Status) ? request->Item->OriginalLength : 0); + HandleCompleteOriginalIrp(queue, request); } ReleasePoolBuffer (queue, request); @@ -640,7 +712,7 @@ static VOID MainThreadProc (PVOID threadArg) { UINT64_STRUCT dataUnit; - dataBuffer = (PUCHAR) MmGetSystemAddressForMdlSafe (irp->MdlAddress, (HighPagePriority | ExDefaultMdlProtection)); + dataBuffer = (PUCHAR) MmGetSystemAddressForMdlSafe (irp->MdlAddress, (HighPagePriority | MdlMappingNoExecute)); if (!dataBuffer) { TCfree (buffer); @@ -760,7 +832,7 @@ static VOID MainThreadProc (PVOID threadArg) continue; } - dataBuffer = (PUCHAR) MmGetSystemAddressForMdlSafe (irp->MdlAddress, (HighPagePriority | ExDefaultMdlProtection)); + dataBuffer = (PUCHAR) MmGetSystemAddressForMdlSafe (irp->MdlAddress, (HighPagePriority | MdlMappingNoExecute)); if (dataBuffer == NULL) { @@ -972,7 +1044,7 @@ NTSTATUS EncryptedIoQueueStart (EncryptedIoQueue *queue) { NTSTATUS status; EncryptedIoQueueBuffer *buffer; - int i, preallocatedIoRequestCount, preallocatedItemCount, fragmentSize; + int i, j, preallocatedIoRequestCount, preallocatedItemCount, fragmentSize; preallocatedIoRequestCount = EncryptionIoRequestCount; preallocatedItemCount = EncryptionItemCount; @@ -1076,6 +1148,41 @@ retry_preallocated: buffer->InUse = FALSE; } + // Initialize the free work item list + InitializeListHead(&queue->FreeWorkItemsList); + KeInitializeSemaphore(&queue->WorkItemSemaphore, EncryptionMaxWorkItems, EncryptionMaxWorkItems); + KeInitializeSpinLock(&queue->WorkItemLock); + + queue->MaxWorkItems = EncryptionMaxWorkItems; + queue->WorkItemPool = (PCOMPLETE_IRP_WORK_ITEM)TCalloc(sizeof(COMPLETE_IRP_WORK_ITEM) * queue->MaxWorkItems); + if (!queue->WorkItemPool) + { + goto noMemory; + } + + // Allocate and initialize work items + for (i = 0; i < (int) queue->MaxWorkItems; ++i) + { + queue->WorkItemPool[i].WorkItem = IoAllocateWorkItem(queue->DeviceObject); + if (!queue->WorkItemPool[i].WorkItem) + { + // Handle allocation failure + // Free previously allocated work items + for (j = 0; j < i; ++j) + { + IoFreeWorkItem(queue->WorkItemPool[j].WorkItem); + } + TCfree(queue->WorkItemPool); + goto noMemory; + } + + // Insert the work item into the free list + ExInterlockedInsertTailList(&queue->FreeWorkItemsList, &queue->WorkItemPool[i].ListEntry, &queue->WorkItemLock); + } + + queue->ActiveWorkItems = 0; + KeInitializeEvent(&queue->NoActiveWorkItemsEvent, NotificationEvent, FALSE); + // Main thread InitializeListHead (&queue->MainThreadQueue); KeInitializeSpinLock (&queue->MainThreadQueueLock); @@ -1158,6 +1265,27 @@ NTSTATUS EncryptedIoQueueStop (EncryptedIoQueue *queue) TCStopThread (queue->IoThread, &queue->IoThreadQueueNotEmptyEvent); TCStopThread (queue->CompletionThread, &queue->CompletionThreadQueueNotEmptyEvent); + // Wait for active work items to complete + KeResetEvent(&queue->NoActiveWorkItemsEvent); + Dump("Queue stopping active work items=%d\n", queue->ActiveWorkItems); + while (InterlockedCompareExchange(&queue->ActiveWorkItems, 0, 0) > 0) + { + KeWaitForSingleObject(&queue->NoActiveWorkItemsEvent, Executive, KernelMode, FALSE, NULL); + // reset the event again in case multiple work items are completing + KeResetEvent(&queue->NoActiveWorkItemsEvent); + } + + // Free pre-allocated work items + for (ULONG i = 0; i < queue->MaxWorkItems; ++i) + { + if (queue->WorkItemPool[i].WorkItem) + { + IoFreeWorkItem(queue->WorkItemPool[i].WorkItem); + queue->WorkItemPool[i].WorkItem = NULL; + } + } + TCfree(queue->WorkItemPool); + TCfree (queue->FragmentBufferA); TCfree (queue->FragmentBufferB); TCfree (queue->ReadAheadBuffer); |