There's a need to have a mechanism in place that will process messages received from the network concurrently. However, only X number of messages can be allowed to be processed concurrently and there's a restriction: Similar messages must be processed sequentially. For simplicity's sake, similarity can be identified by the Tuple<T1, T2> or Integer TId generic parameter.
Consumption is straightforward: Throw any number of work items at SemaphoreWorkQueue<TId, TState> and it will do the processing. A single unit test is provided below for clarification.
Where can this be improved for performance and are there any pitfalls?
using System;
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;
namespace CodeReview.StackExchange.Com
{
/// <summary>
/// Executes actions concurrently limited by the semaphore count, but does not honor concurrency for actions with similar identifiers.
/// </summary>
public class SemaphoreWorkQueue<TId, TState> : IDisposable
{
private SemaphoreSlim _semaphore;
private Task _workerTask;
private ConcurrentDictionary<TId, TaskInfo> _runningTasks;
private ConcurrentQueue<TaskInfo> _waitQueue;
private CancellationToken _cancellationToken;
private volatile bool _disengaged;
/// <summary>
/// Initializes a new instance with an initial count of 10 concurrent actions and CancellationToken.None.
/// </summary>
public SemaphoreWorkQueue()
: this(10, CancellationToken.None)
{
}
/// <summary>
/// Initializes a new instance with desired number of concurrent actions and CancellationToken.None.
/// </summary>
public SemaphoreWorkQueue(int concurrentTaskCount)
: this(concurrentTaskCount, CancellationToken.None)
{
}
/// <summary>
/// Initializes a new instance with desired number of concurrent actions and cancellation token.
/// </summary>
public SemaphoreWorkQueue(int concurrentTaskCount, CancellationToken cancellationToken)
{
if (concurrentTaskCount < 1)
throw new ArgumentException("Parameter concurrentTaskCount cannot be less then zero.", "concurrentTaskCount");
_cancellationToken = cancellationToken;
_semaphore = new SemaphoreSlim(concurrentTaskCount);
_waitQueue = new ConcurrentQueue<TaskInfo>();
_runningTasks = new ConcurrentDictionary<TId, TaskInfo>();
_workerTask = Task.Factory.StartNew(async () => await this.Engage(), cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default);
}
/// <summary>
/// Queues an action with a specified identifier.
/// </summary>
public void EnqueueWork(TId id, Action<TState> action)
{
this.EnqueueWork(id, action, default(TState), _cancellationToken);
}
/// <summary>
/// Queues an action with a specified identifier and state.
/// </summary>
public void EnqueueWork(TId id, Action<TState> action, TState state)
{
this.EnqueueWork(id, action, state, _cancellationToken);
}
/// <summary>
/// Queues an action with a specified identifier and cancellation token.
/// </summary>
public void EnqueueWork(TId id, Action<TState> action, CancellationToken cancellationToken)
{
this.EnqueueWork(id, action, default(TState), cancellationToken);
}
/// <summary>
/// Queues an action with a specified identifier, state and cancellation token.
/// </summary>
public void EnqueueWork(TId id, Action<TState> action, TState state, CancellationToken cancellationToken)
{
if (id == null)
throw new ArgumentNullException("id", "Parameter id is required");
if (action == null)
throw new ArgumentNullException("action", "Parameter action is required");
var task = new TaskInfo();
task.Id = id;
task.State = state;
task.Action = action;
task.CancellationToken = cancellationToken;
_waitQueue.Enqueue(task);
}
private async Task<bool> Engage()
{
while (!_disengaged && !_cancellationToken.IsCancellationRequested)
{
TaskInfo waitingTask;
if (_waitQueue.TryDequeue(out waitingTask))
{
if (waitingTask.IsCancellationRequested)
break;
if (_runningTasks.ContainsKey(waitingTask.Id))
{
_waitQueue.Enqueue(waitingTask);
}
else
{
_runningTasks.TryAdd(waitingTask.Id, waitingTask);
await _semaphore.WaitAsync(); //decrease semaphore count
if (_disengaged)
break;
StartTask(waitingTask);
}
}
await Task.Delay(100);
}
return await Task.FromResult<bool>(_disengaged);
}
private void StartTask(TaskInfo taskInfo)
{
Task.Run(() => {
try
{
taskInfo.Action(taskInfo.State);
}
finally
{
_runningTasks.TryRemove(taskInfo.Id, out taskInfo);
_semaphore.Release(); //increase semaphore count and remove task
}
});
}
private class TaskInfo
{
public TId Id;
public TState State;
public Action<TState> Action;
public CancellationToken CancellationToken;
public bool IsCancellationRequested
{
get { return CancellationToken.IsCancellationRequested; }
}
}
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
private void Dispose(bool disposing)
{
if (disposing)
{
_disengaged = true;
_semaphore.Dispose();
_runningTasks.Clear();
TaskInfo task;
while (_waitQueue.TryDequeue(out task)) ;
}
}
}
}
And the unit test:
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
namespace CodeReview.StackExchange.Com
{
[TestClass]
public class SemaphoreWorkQueueTests
{
[TestMethod]
public void EnqueueWork_Can_Process_Tasks_Concurrently_No_Overlap()
{
var actions = new List<Action<int>>();
var rnd = new Random();
var executedTasks = new ConcurrentBag<ActionInfo>();
var maxTaskCount = 128;
var concurrentTaskCount = 8;
var delayMilliseconds = 1000;
using (var semaphoreQueue = new SemaphoreWorkQueue<int, int>(concurrentTaskCount))
{
for (var i = 0; i < concurrentTaskCount; i++) //create X actions
{
actions.Add(new Action<int>((int n) => {
var info = new ActionInfo();
info.Id = n;
info.StartTime = DateTime.Now.TimeOfDay;
Debug.WriteLine("I'm action #{0} @ {1} on Thread {2}", n, info.StartTime, Thread.CurrentThread.ManagedThreadId);
Task.Delay(delayMilliseconds).Wait();
info.EndTime = DateTime.Now.TimeOfDay;
executedTasks.Add(info);
}));
}
for (var i = 0; i < maxTaskCount; i++) //enqueue actions and start processing them randomly
{
var next = rnd.Next(0, concurrentTaskCount);
semaphoreQueue.EnqueueWork(next, actions[next], next);
}
Task.Run(async () => { //wait for SemaphoreWorkQueue to process all tasks
while (executedTasks.Count < maxTaskCount)
await Task.Delay(delayMilliseconds);
}).Wait(TimeSpan.FromMilliseconds(maxTaskCount * delayMilliseconds));
}
var groupById = executedTasks.GroupBy(x => x.Id); //group by similar identifiers and sort by StartTime
foreach (var item in groupById)
{
if (item.Count() > 1)
{
var group = item.ToList();
group.Sort((ActionInfo a, ActionInfo b) => { return a.StartTime.CompareTo(b.StartTime); });
for (var i = 1; i < group.Count; i++)
{
var first = group[i - 1];
var second = group[i];
/* sorted group will look like below
*
* StartTime 11:00:00, EndTime 11:00:01 <--first
* StartTime 11:00:02, EndTime 11:00:03 <--second's StartTime should be greater than first's
* StartTime 11:00:04, EndTime 11:00:05
* StartTime 11:00:06, EndTime 11:00:07
*/
if (second.StartTime < first.EndTime)
Assert.Fail("Task with same Id started while other task with same Id was already running.");
}
}
}
}
[DebuggerDisplay("{StartTime} {EndTime}")]
private class ActionInfo
{
public int Id { get; set; }
public TimeSpan StartTime { get; set; }
public TimeSpan EndTime { get; set; }
}
}
}