|           Line data    Source code 
       1             : #region Copyright
       2             : // // -----------------------------------------------------------------------
       3             : // // <copyright company="cdmdotnet Limited">
       4             : // //   Copyright cdmdotnet Limited. All rights reserved.
       5             : // // </copyright>
       6             : // // -----------------------------------------------------------------------
       7             : #endregion
       8             : 
       9             : using System;
      10             : using System.Collections.Generic;
      11             : using Cqrs.Domain.Exceptions;
      12             : 
      13             : namespace Cqrs.Domain
      14             : {
      15             :         /// <summary>
      16             :         /// This is a Unit of Work. This shouldn't normally be used as a singleton.
      17             :         /// </summary>
      18             :         public class UnitOfWork<TAuthenticationToken> : IUnitOfWork<TAuthenticationToken>
      19           1 :         {
      20             :                 private IRepository<TAuthenticationToken> Repository { get; set; }
      21             : 
      22             :                 private Dictionary<Guid, IAggregateDescriptor<TAuthenticationToken>> TrackedAggregates { get; set; }
      23             : 
      24           0 :                 public UnitOfWork(IRepository<TAuthenticationToken> repository)
      25             :                 {
      26             :                         if(repository == null)
      27             :                                 throw new ArgumentNullException("repository");
      28             : 
      29             :                         Repository = repository;
      30             :                         TrackedAggregates = new Dictionary<Guid, IAggregateDescriptor<TAuthenticationToken>>();
      31             :                 }
      32             : 
      33             :                 /// <summary>
      34             :                 /// Add an item into the <see cref="IUnitOfWork{TAuthenticationToken}"/> ready to be committed.
      35             :                 /// </summary>
      36           1 :                 public void Add<TAggregateRoot>(TAggregateRoot aggregate)
      37             :                         where TAggregateRoot : IAggregateRoot<TAuthenticationToken>
      38             :                 {
      39             :                         if (!IsTracked(aggregate.Id))
      40             :                         {
      41             :                                 var aggregateDescriptor = new AggregateDescriptor<TAggregateRoot, TAuthenticationToken>
      42             :                                 {
      43             :                                         Aggregate = aggregate,
      44             :                                         Version = aggregate.Version
      45             :                                 };
      46             :                                 TrackedAggregates.Add(aggregate.Id, aggregateDescriptor);
      47             :                         }
      48             :                         else if (((TrackedAggregates[aggregate.Id]).Aggregate) != (IAggregateRoot<TAuthenticationToken>)aggregate)
      49             :                                 throw new ConcurrencyException(aggregate.Id);
      50             :                 }
      51             : 
      52             :                 /// <summary>
      53             :                 /// Get an item from the <see cref="IUnitOfWork{TAuthenticationToken}"/> if it has already been loaded or get it from the <see cref="IRepository{TAuthenticationToken}"/>.
      54             :                 /// </summary>
      55           1 :                 public TAggregateRoot Get<TAggregateRoot>(Guid id, int? expectedVersion = null)
      56             :                         where TAggregateRoot : IAggregateRoot<TAuthenticationToken>
      57             :                 {
      58             :                         if(IsTracked(id))
      59             :                         {
      60             :                                 var trackedAggregate = (TAggregateRoot)TrackedAggregates[id].Aggregate;
      61             :                                 if (expectedVersion != null && trackedAggregate.Version != expectedVersion)
      62             :                                         throw new ConcurrencyException(trackedAggregate.Id);
      63             :                                 return trackedAggregate;
      64             :                         }
      65             : 
      66             :                         var aggregate = Repository.Get<TAggregateRoot>(id);
      67             :                         if (expectedVersion != null && aggregate.Version != expectedVersion)
      68             :                                 throw new ConcurrencyException(id);
      69             :                         Add(aggregate);
      70             : 
      71             :                         return aggregate;
      72             :                 }
      73             : 
      74             :                 private bool IsTracked(Guid id)
      75             :                 {
      76             :                         return TrackedAggregates.ContainsKey(id);
      77             :                 }
      78             : 
      79             :                 /// <summary>
      80             :                 /// Commit any changed <see cref="AggregateRoot{TAuthenticationToken}"/> added to this <see cref="IUnitOfWork{TAuthenticationToken}"/> via <see cref="Add{T}"/>
      81             :                 /// into the <see cref="IRepository{TAuthenticationToken}"/>
      82             :                 /// </summary>
      83           1 :                 public void Commit()
      84             :                 {
      85             :                         foreach (IAggregateDescriptor<TAuthenticationToken> descriptor in TrackedAggregates.Values)
      86             :                         {
      87             :                                 Repository.Save(descriptor.Aggregate, descriptor.Version);
      88             :                         }
      89             :                         TrackedAggregates.Clear();
      90             :                 }
      91             :         }
      92             : }
 |