Line data Source code
1 : #region Copyright
2 : // // -----------------------------------------------------------------------
3 : // // <copyright company="cdmdotnet Limited">
4 : // // Copyright cdmdotnet Limited. All rights reserved.
5 : // // </copyright>
6 : // // -----------------------------------------------------------------------
7 : #endregion
8 :
9 : using System;
10 : using System.Collections.Generic;
11 : using Cqrs.Domain.Exceptions;
12 :
13 : namespace Cqrs.Domain
14 : {
15 : /// <summary>
16 : /// This is a Unit of Work. This shouldn't normally be used as a singleton.
17 : /// </summary>
18 : public class UnitOfWork<TAuthenticationToken> : IUnitOfWork<TAuthenticationToken>
19 1 : {
20 : private IRepository<TAuthenticationToken> Repository { get; set; }
21 :
22 : private Dictionary<Guid, IAggregateDescriptor<TAuthenticationToken>> TrackedAggregates { get; set; }
23 :
24 0 : public UnitOfWork(IRepository<TAuthenticationToken> repository)
25 : {
26 : if(repository == null)
27 : throw new ArgumentNullException("repository");
28 :
29 : Repository = repository;
30 : TrackedAggregates = new Dictionary<Guid, IAggregateDescriptor<TAuthenticationToken>>();
31 : }
32 :
33 : /// <summary>
34 : /// Add an item into the <see cref="IUnitOfWork{TAuthenticationToken}"/> ready to be committed.
35 : /// </summary>
36 1 : public void Add<TAggregateRoot>(TAggregateRoot aggregate)
37 : where TAggregateRoot : IAggregateRoot<TAuthenticationToken>
38 : {
39 : if (!IsTracked(aggregate.Id))
40 : {
41 : var aggregateDescriptor = new AggregateDescriptor<TAggregateRoot, TAuthenticationToken>
42 : {
43 : Aggregate = aggregate,
44 : Version = aggregate.Version
45 : };
46 : TrackedAggregates.Add(aggregate.Id, aggregateDescriptor);
47 : }
48 : else if (((TrackedAggregates[aggregate.Id]).Aggregate) != (IAggregateRoot<TAuthenticationToken>)aggregate)
49 : throw new ConcurrencyException(aggregate.Id);
50 : }
51 :
52 : /// <summary>
53 : /// Get an item from the <see cref="IUnitOfWork{TAuthenticationToken}"/> if it has already been loaded or get it from the <see cref="IRepository{TAuthenticationToken}"/>.
54 : /// </summary>
55 1 : public TAggregateRoot Get<TAggregateRoot>(Guid id, int? expectedVersion = null)
56 : where TAggregateRoot : IAggregateRoot<TAuthenticationToken>
57 : {
58 : if(IsTracked(id))
59 : {
60 : var trackedAggregate = (TAggregateRoot)TrackedAggregates[id].Aggregate;
61 : if (expectedVersion != null && trackedAggregate.Version != expectedVersion)
62 : throw new ConcurrencyException(trackedAggregate.Id);
63 : return trackedAggregate;
64 : }
65 :
66 : var aggregate = Repository.Get<TAggregateRoot>(id);
67 : if (expectedVersion != null && aggregate.Version != expectedVersion)
68 : throw new ConcurrencyException(id);
69 : Add(aggregate);
70 :
71 : return aggregate;
72 : }
73 :
74 : private bool IsTracked(Guid id)
75 : {
76 : return TrackedAggregates.ContainsKey(id);
77 : }
78 :
79 : /// <summary>
80 : /// Commit any changed <see cref="AggregateRoot{TAuthenticationToken}"/> added to this <see cref="IUnitOfWork{TAuthenticationToken}"/> via <see cref="Add{T}"/>
81 : /// into the <see cref="IRepository{TAuthenticationToken}"/>
82 : /// </summary>
83 1 : public void Commit()
84 : {
85 : foreach (IAggregateDescriptor<TAuthenticationToken> descriptor in TrackedAggregates.Values)
86 : {
87 : Repository.Save(descriptor.Aggregate, descriptor.Version);
88 : }
89 : TrackedAggregates.Clear();
90 : }
91 : }
92 : }
|