Bringing some clarity to UnitOfWork and Repository

It’s amazing what one can learn when they no longer have real work in front of them.

A while back, there was a thread on how best to use the UnitOfWork pattern in conjunction with the Repository pattern. I won’t link the thread though as it was almost entirely wrong!

There are two reasons most people missunderstand the UnitOfWork pattern (myself included).

The first is that by choosing to use it, you also choose to separate the concerns of reads and writes. Having used repositories that included insert, update and delete methods in them, gives rise to confusion. That is, until you realise they are redundant given the new role of the repository.

The other reason people become confused is that they assume that the UnitOfWork pattern will be wrapping a database transaction, and that is not always the case. Mediums such as serialized objects or xml files, do not subscribed to such notions.

The key is realizing two things: we are now entertaining both a read and a write model, and transactionalized or not, we still need a way to cache writes for batch operations later.

So what I’ve done is created two simple scenarios that better explain how it all fits together. The first is for NHibernate and the second is for EF 4.1.

NHibernate

namespace NHUnitOfWorkSample
{
    // bare bones entity class
    public abstract class Entity<TId>
        where TId : struct
    {
        public virtual TId Id { get; protected set; }
    }
    // with most commonly used identity
    public abstract class Entity : Entity<int> { }
    // base bones repository interface
    public interface IRepository<TEntity, TId>
        where TEntity : Entity<TId>
        where TId : struct
    {
        IEnumerable<TEntity> Select();
        TEntity SelectById(TId id);
    }
    // with most commonly used identity
    public interface IRepository<TEntity> : IRepository<TEntity, int>
        where TEntity : Entity
    {
    }
    // bare bones repository implementation
    public class Repository<TEntity, TId> : IRepository<TEntity, TId>
        where TEntity : Entity<TId>
        where TId : struct
    {
        protected readonly ISession session;
        public Repository(ISessionFactory sessionFactory)
        {
            session = sessionFactory.GetCurrentSession();
        }
        
        public IEnumerable<TEntity> Select()
        {
            return session.CreateCriteria<TEntity>().List<TEntity>();
        }
        public TEntity SelectById(TId id)
        {
            return session.Get<TEntity>(id);
        }
    
    }
    // with most commonly used identity
    public class Repository<TEntity> : Repository<TEntity, int>, IRepository<TEntity>
        where TEntity : Entity
    {
    }
    // unit of work factory interface
    public interface IUnitOfWorkFactory
    {
        IUnitOfWork Create(IsolationLevel isolationLevel = IsolationLevel.Unspecified);
    }
    // unit of work factory implementation
    public class UnitOfWorkFactory : IUnitOfWorkFactory
    {
        private ISessionFactory sessionFactory;
        public UnitOfWorkFactory(ISessionFactory sessionFactory)
        {
            this.sessionFactory = sessionFactory;
        }
        public IUnitOfWork Create()
        {
            return new UnitOfWork(sessionFactory, isolationLevel);
        }
    }
    // unit of work interface
    public interface IUnitOfWork : IDisposable
    {
        void MarkNew(object entity);
        void MarkDirty(object entity);
        void MarkRemoved(object entity);
        void Commit();
        void Rollback();
    }
    // unit of work implementation
    public class UnitOfWork : IUnitOfWork
    {
        private ISession session;
        private ITransaction transaction;
        public UnitOfWork(ISessionFactory sessionFactory, IsolationLevel isolationLevel = IsolationLevel.Unspecified)
        {
            session = sessionFactory.GetCurrentSession();
            transaction = session.BeginTransaction(isolationLevel);
        }
        public void MarkNew(object entity)
        {
            session.Save(entity);
        }
        public void MarkDirty(object entity)
        {
            session.Update(entity);
        }
        public void MarkRemoved(object entity)
        {
            session.Delete(entity);
        }
        public void Commit()
        {
            transaction.Commit();
        }
        public void Rollback()
        {
            transaction.Rollback();
        }
        public void Dispose()
        {
            transaction.Dispose();
        }
    }

    // demo class
    public class Demo
    {
        public Demo(IUnitOfWorkFactory unitOfWorkFactory, int fooId)
        {
            
            using (var worker = unitOfWorkFactory.Create())
            {
                try
                {
                    Foo foo = fooRepository.SelectById(fooId);
                    foo.Bar = "woot";
                    worker.MarkDirty(foo);
                    worker.Commit();
                }
                catch
                {
                    worker.Rollback();
                }
            }
        }
    }
}

Entity Framework 4.1

namespace EFUnitOfWorkSample
{
    // bare bones entity class
    public abstract class Entity<TId>
        where TId : struct
    {
        public virtual TId Id { get; protected set; }
    }
    // with most commonly used identity
    public abstract class Entity : Entity<int> { }
    // base bones repository interface
    public interface IRepository<TEntity, TId>
        where TEntity : Entity<TId>
        where TId : struct
    {
        IEnumerable<TEntity> Select();
        TEntity SelectById(TId id);
    }
    // with most commonly used identity
    public interface IRepository<TEntity> : IRepository<TEntity, int>
        where TEntity : Entity
    {
    }
    // bare bones repository implementation
    public class Repository<TEntity, TId> : IRepository<TEntity, TId>
        where TEntity : Entity<TId>
        where TId : struct
    {
        protected readonly DbContext context;
        public Repository(IContextFactory contextFactory)
        {
            context = contextFactory.GetCurrentContext();
        }
        
        public IEnumerable<TEntity> Select()
        {
            return context.Set<TEntity>();
        }
        public TEntity SelectById(TId id)
        {
            return context.Set<TEntity>().SingleOrDefault(x => Equals(x.Id, id));
        }
    
    }
    // with most commonly used identity
    public class Repository<TEntity> : Repository<TEntity, int>, IRepository<TEntity>
        where TEntity : Entity
    {
    }
    // unit of work factory interface
    public interface IUnitOfWorkFactory
    {
        IUnitOfWork Create();
    }
    // unit of work factory implementation
    public class UnitOfWorkFactory : IUnitOfWorkFactory
    {
        private IContextFactory contextFactory;
        public UnitOfWorkFactory(IContextFactory contextFactory)
        {
            this.contextFactory = contextFactory;
        }
        public IUnitOfWork Create()
        {
            return new UnitOfWork(contextFactory);
        }
    }
    // unit of work interface
    public interface IUnitOfWork : IDisposable
    {
        void MarkNew(object entity);
        void MarkDirty(object entity);
        void MarkRemoved(object entity);
        void Commit();
        void Rollback();
    }
    // unit of work implementation
    public class UnitOfWork : IUnitOfWork
    {
        private DbContext context;
        private List<object> markedNew = new List<object>();
        private List<object> markedDirty = new List<object>();
        private List<object> markedRemoved = new List<object>();
        public UnitOfWork(IContextFactory contextFactory)
        {
            context = contextFactory.GetCurrentContext();
        }
        public void MarkNew(object entity)
        {
            markedNew.Add(entity);
        }
        public void MarkDirty(object entity)
        {
            markedDirty.Add(entity);
        }
        public void MarkRemoved(object entity)
        {
            markedRemoved.Add(entity);
        }
        public void Commit()
        {
            using (var transaction = new TransactionScope())
            {
                foreach (var entity in markedNew)
                    context.Entry(entity).State = EntityState.Added;
                foreach (var entity in markedDirty)
                    context.Entry(entity).State = EntityState.Modified;
                foreach (var entity in markedRemoved)
                    context.Entry(entity).State = EntityState.Deleted;
                context.SaveChanges();
                transaction.Complete();
            }
        }
        public void Rollback()
        {
            markedNew.Clear();
            markedDirty.Clear();
            markedRemoved.Clear();
        }
        public void Dispose()
        {
            // nothing to do but exit silently
        }
    }

    // demo class
    public class Demo
    {
        public Demo(IUnitOfWorkFactory unitOfWorkFactory, int fooId)
        {
            
            using (var worker = unitOfWorkFactory.Create())
            {
                try
                {
                    Foo foo = fooRepository.SelectById(fooId);
                    foo.Bar = "woot";
                    worker.MarkDirty(foo);
                    worker.Commit();
                }
                catch
                {
                    worker.Rollback();
                }
            }
        }
    }
}

The are only a few changes between the two, but they both show how to use these patterns together.

Oh, and if you use EF, you’ll probably want to look at my ContextFactory classes as well, since they weren’t included above.

namespace Sample
{
    public interface IContextFactory
    {
        DbContext GetCurrentContext();
    
    }
    public class ContextFactory : IContextFactory
    {
        
        public DbContext GetCurrentContext()
        {
            DbContext context = HttpContext.Current.Items["CurrentContext"] as DbContext;
            if (context == null)
            {
                DbProviderInfo providerInfo = new DbProviderInfo("System.Data.SqlClient", "2008");
                DbModelBuilder modelBuilder = new DbModelBuilder();
                modelBuilder.Configurations.Add(new FooMap());
                modelBuilder.Configurations.Add(new BarMap());
                DbModel model = modelBuilder.Build(providerInfo);
                context = new DbContext(ConfigurationManager.AppSettings["ConnectionString"], model.Compile());
                HttpContext.Current.Items["CurrentContext"] = context;
            }
            return context;
        }
    
    }
}

It’s very easy to implement and you no longer have to derive a dedicated DbContext class for your application.

I hope this made things a bit clearer. Enjoy!