Entity Framework 4 and MutiResultSet, execute multiple queries on only one DB access

In the EFExtensions project, we have an illustration of the use of MutiResultset with EF which is a good idea.

However, this is used only with Stored Procedure.

Moreover, contrary to what is done on it, I don’t want to specify the mapping when I already have it on my EDM.

So what I want is to be able to execute multiple L2E queries with only one DB access.

I realized a Proof Of Concept for it.

/// <remarks>Very basic on this version</remarks>
public static class Materializer<T>
    where T : class, new()
{
    private static Func<DbDataReader, T> _tFactory;

    /// <remarks>This work only for entity types only and basic scenarios. It is a Proof Of Concept</remarks>
    public static Func<DbDataReader, T> GetFactory(ObjectQuery<T> query)
    {
        ObjectContext context = query.Context;
        EntityType entityType = context.GetEntityType<T>();
        if (entityType == null)
            throw new NotImplementedException();
        if (_tFactory != null)
            return _tFactory;
        string connectionString = context.Connection.ConnectionString;
        if (connectionString.StartsWith("name="))
            connectionString = ConfigurationManager.ConnectionStrings[connectionString.Substring(5)].ConnectionString;
        string mslFile = Regex.Match(connectionString, @"\w+.msl").Value;
        Stream stream = Assembly.GetExecutingAssembly().GetManifestResourceStream(mslFile);
        XDocument msl = XDocument.Load(stream);
        XElement mappingXElement = msl.Descendants().FirstOrDefault(xe =>
            {
                var attribute = xe.Attribute("TypeName");
                return attribute != null && attribute.Value == entityType.FullName;
            });
        ParameterExpression parameterExpression = Expression.Parameter(typeof(DbDataReader));
        int index = 0;
        Expression<Func<DbDataReader, T>> tFactory = Expression.Lambda<Func<DbDataReader, T>>(
            Expression.MemberInit(
                Expression.New(typeof(T)),
                entityType.Properties.Select(p =>
                    {
                        XElement propMappingXElement = mappingXElement.Descendants().First(xe =>
                            xe.Name == XName.Get("ScalarProperty", mappingXElement.Name.NamespaceName) && xe.Attribute("Name").Value == p.Name);
                        PropertyInfo property = typeof(T).GetProperty(p.Name);
                        Expression getValue = Expression.Call(
                                            parameterExpression,
                                            typeof(DbDataReader).GetMethod("GetValue"),
                                            Expression.Constant(index)
                                        );
                        MemberAssignment value =
                            Expression.Bind(property,
                                Expression.Convert(
                                    (property.PropertyType.IsValueType && !property.PropertyType.IsGenericType ?
                                        getValue :
                                        Expression.Condition(
                                            Expression.Call(
                                                parameterExpression,
                                                typeof(DbDataReader).GetMethod("IsDBNull"),
                                                Expression.Constant(index)
                                            ),
                                            Expression.Constant(null),
                                            getValue
                                        )
                                    ),
                                    property.PropertyType
                                )
                            );
                        index++;
                        return value;
                    })), parameterExpression);
        _tFactory = tFactory.Compile();
        return _tFactory;
    }

    public static T Materialize(ObjectQuery<T> query, DbDataReader dbDataReader, Func<DbDataReader, T> tFactory = null)
    {
        if (tFactory == null)
            tFactory = GetFactory(query);
        return tFactory(dbDataReader);
    }
}


public static class ObjectContextExrension
{
    public static EntityType GetEntityType<T>(this ObjectContext context) where T : class
    {
        var ocspace = context.MetadataWorkspace.GetItems(System.Data.Metadata.Edm.DataSpace.OCSpace).FirstOrDefault(item =>
            item.ToString().StartsWith(string.Format("{0}:", typeof(T).FullName))).ToString();
        if (ocspace == null)
            return null;
        return context.MetadataWorkspace.GetItems(DataSpace.CSpace).OfType<EntityType>().First(et => et.FullName == ocspace.Substring(ocspace.IndexOf(":") + 1));
    }

    public static EntityKey GetEntityKey<T>(this ObjectContext context, T entity)
        where T : class
    {
        EntityType entityType = context.GetEntityType<T>();
        return GetEntityKey(context, entity, entityType);
    }

    public static EntityKey GetEntityKey<T>(this ObjectContext context, T entity, EntityType entityType)
        where T : class
    {
        List<EntitySet> entitySets = context.MetadataWorkspace.GetItems(DataSpace.CSpace).OfType<EntityContainer>().First().BaseEntitySets.OfType<EntitySet>().ToList();
        EntityType loopEntityType = entityType;
        EntitySet entitySet;
        while ((entitySet = entitySets.FirstOrDefault(es => es.ElementType == loopEntityType)) == null && loopEntityType.BaseType != null)
            loopEntityType = (EntityType)loopEntityType.BaseType;
        if (entitySet == null)
            throw new InvalidOperationException();
        return new System.Data.EntityKey(
            string.Format("{0}.{1}", context.DefaultContainerName, entitySet),
            entityType.KeyMembers.Select(member =>
                new KeyValuePair<string, object>(member.Name, typeof(T).GetProperty(member.Name).GetValue(entity, null))));
    }

    public static T Attach<T>(this ObjectContext context, T item, MergeOption mergeOption = MergeOption.AppendOnly)
        where T : class
    {
        if (mergeOption == MergeOption.NoTracking)
            return item;
        EntityKey entityKey = context.GetEntityKey(item);
        ObjectStateEntry ose;
        context.ObjectStateManager.TryGetObjectStateEntry(entityKey, out ose);
        if (ose == null)
        {
            context.AttachTo(entityKey.EntitySetName, item);
            return item;
        }
        T entity = (T)ose.Entity;
        switch (mergeOption)
        {
            case MergeOption.OverwriteChanges:
                ose.ApplyCurrentValues(item);
                ose.AcceptChanges();
                break;
            case MergeOption.PreserveChanges:
                EntityType entityType = context.GetEntityType<T>();
                if (!context.ContextOptions.UseLegacyPreserveChangesBehavior)
                {
                    ose.ApplyOriginalValues(item);
                    List<string> modifiedProperties = ose.GetModifiedProperties().ToList();
                    ose.ApplyOriginalValues(item);
                    foreach (string propertyName in entityType.Properties.Except(entityType.KeyMembers).Select(p => p.Name).Except(modifiedProperties))
                    {
                        PropertyInfo property = typeof(T).GetProperty(propertyName);
                        property.SetValue(entity, property.GetValue(item, null), null);
                    }
                }
                ose.AcceptChanges();
                ose.ApplyOriginalValues(item);
                foreach (string propertyName in (from p in entityType.Properties.Except(entityType.KeyMembers)
                                                    where ose.OriginalValues[p.Name] != ose.CurrentValues[p.Name]
                                                    select p.Name).ToList())
                    ose.SetModifiedProperty(propertyName);
                break;
        }
        return entity;
    }

    public static IEnumerable<KeyValuePair<IQueryable, object>> Execute<T1, T2>(this ObjectContext context, IQueryable<T1> query1,
        IQueryable<T2> query2, MergeOption mergeOption = MergeOption.AppendOnly)
        where T1 : class, new()
        where T2 : class, new()
    {
        return Execute(new[]
        {
            new Tuple<IQueryable, Func<DbDataReader, object>, Func<object, MergeOption, object>>
                (
                    query1,
                    dr => (object)Materializer<T1>.Materialize((ObjectQuery<T1>)query1, dr),
                    (entity, attachMergeOption) => context.Attach<T1>((T1)entity, attachMergeOption)
                ),
            new Tuple<IQueryable, Func<DbDataReader, object>, Func<object, MergeOption, object>>
                (
                    query2,
                    dr => (object)Materializer<T2>.Materialize((ObjectQuery<T2>)query2, dr),
                    (entity, attachMergeOption) => context.Attach<T2>((T2)entity, attachMergeOption)
                )
        });
    }

    public static IEnumerable<KeyValuePair<IQueryable, object>> Execute<T1, T2, T3>(this ObjectContext context, IQueryable<T1> query1,
        IQueryable<T2> query2, IQueryable<T3> query3, MergeOption mergeOption = MergeOption.AppendOnly)
        where T1 : class, new()
        where T2 : class, new()
        where T3 : class, new()
    {
        return Execute(
            new[] {
            new Tuple<IQueryable, Func<DbDataReader, object>, Func<object, MergeOption, object>>
                (
                    query1,
                    dr => (object)Materializer<T1>.Materialize((ObjectQuery<T1>)query1, dr),
                    (entity, attachMergeOption) => context.Attach<T1>((T1)entity, attachMergeOption)
                ),
            new Tuple<IQueryable, Func<DbDataReader, object>, Func<object, MergeOption, object>>
                (
                    query2,
                    dr => (object)Materializer<T2>.Materialize((ObjectQuery<T2>)query2, dr),
                    (entity, attachMergeOption) => context.Attach<T2>((T2)entity, attachMergeOption)
                ),
            new Tuple<IQueryable, Func<DbDataReader, object>, Func<object, MergeOption, object>>
                (
                    query3,
                    dr => (object)Materializer<T3>.Materialize((ObjectQuery<T3>)query3, dr),
                    (entity, attachMergeOption) => context.Attach<T3>((T3)entity, attachMergeOption)
                )
        });
    }

    private static IEnumerable<KeyValuePair<IQueryable, object>> Execute(IEnumerable<Tuple<IQueryable, Func<DbDataReader, object>,
        Func<object, MergeOption, object>>> subQueries, MergeOption mergeOption = MergeOption.AppendOnly)
    {
        if (!subQueries.Any())
            throw new ArgumentException();

        List<DbParameter> parameters = new List<DbParameter>();
        ObjectContext context = null;
        StringBuilder sql = new StringBuilder();
        int index = 0;
        foreach (ObjectQuery subQuery in subQueries.Select(sq => sq.Item1))
        {
            if (subQuery == null)
                throw new ArgumentException();
            if (context == null)
                context = subQuery.Context;
            else if (context != subQuery.Context)
                throw new ArgumentException();

            string subSql = subQuery.ToTraceString();
            for (int subIndex = 0; subIndex < subQuery.Parameters.Count; subIndex++)
                subSql = subSql.Replace(string.Format("@p__linq__{0}", subIndex), string.Format("@p{0}", index++));
            sql.Append(subSql);
            sql.Append(";");
            sql.Append(Environment.NewLine);
        }

        EntityConnection entityConnection = (EntityConnection)context.Connection;
        DbConnection storeConnection = entityConnection.StoreConnection;
        if (storeConnection.State != ConnectionState.Open)
            storeConnection.Open();
        using (DbCommand command = storeConnection.CreateCommand())
        {
            command.CommandText = sql.ToString();
            command.Parameters.AddRange(parameters.ToArray());
            using (DbDataReader dataReader = command.ExecuteReader())
            {
                var enumerator = subQueries.GetEnumerator();
                do
                {
                    enumerator.MoveNext();
                    while (dataReader.Read())
                        yield return new KeyValuePair<IQueryable, object>(enumerator.Current.Item1,
                            enumerator.Current.Item3(enumerator.Current.Item2(dataReader), mergeOption));
                } while (dataReader.NextResult());
            }
        }
        storeConnection.Close();
    }
}


Hope that helps.

This entry was posted in 7671, 7674, 9104. Bookmark the permalink.

Leave a Reply

Your email address will not be published. Required fields are marked *

You may use these HTML tags and attributes: <a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <strike> <strong>