How to include recursion / table valued functions in LINQ To Entities queries with EF4? v2

I posted a first version on this topic few days ago.

However, the Regex pattern wasn’t perfect because it didn’t support every case (join on table valued functions for example was not supported).

Moreover, with ExecuteStoreQuery method, properties and columns must have the same name. With EF4, if columns are calculated their names are C1, C2, etc. which are probably not (I hope so [:)]) your properties names.

So I write a new version for this:

public static class ObjectQueryExtension
{
    public static IEnumerable<T> UseTableValuedFunction<T>(this ObjectQuery<T> query, ref Func<SqlDataReader, T> tFactory, IEnumerable<KeyValuePair<string, int>> nbParamsPerFunction)
        where T : new()
    {
        string sql = query.ToTraceString();
        sql = sql.Replace("@p__linq__", "@p");
        foreach (KeyValuePair<string, int> function in nbParamsPerFunction)
        {
            string newSql = sql;
            do
            {
                sql = newSql;
                newSql = new Regex(string.Format(@"(\[?{0}\]?)(\s+AS\s+(\[(\w+)\]))\s*(.*)", function.Key), RegexOptions.Singleline | RegexOptions.IgnoreCase).
                    Replace(sql, m =>
                    {
                        string end = m.Groups[5].Value;
                        string alias = m.Groups[4].Value;
                        List<string> functionParams = new List<string>();
                        int nbParam = function.Value;
                        string newEnd = Regex.Replace(end, string.Format(@"WHERE\s+((\(*([\w\s\[\]].)*\[{0}\][\w\s\[\].]+\s*=\s*([\w\s\[\]@.]+)\)*(\s*(AND|OR)\s*)?)+)", alias), m2 =>
                            {
                                string condition = m2.Value;
                                while (nbParam-- != 0)
                                {
                                    Match match = Regex.Match(condition, string.Format(@"\(*[\w\s\[\].]*\[{0}\][\w\s\[\].]+\s*=\s*([\w\s\[\]@.]+)\)*(\s+AND\s+)?", alias));
                                    functionParams.Add(match.Groups[1].Value);
                                    condition = condition.Substring(condition.IndexOf(match.Value) + match.Value.Length);
                                }
                                string parenthesis = new string(Enumerable.Repeat(')', m2.Value.Where(c => c == ')').Count() - m2.Value.Where(c => c == '(').Count()).ToArray());
                                return ((condition = condition.TrimStart()).Length == 0 && !(parenthesis.Length == 0 && Regex.IsMatch(m2.Value, @"(AND|OR)\s*$"))) ?
                                    parenthesis :
                                    string.Concat("WHERE ", condition);
                            });
                        return string.Format("{0}({1}){2} {3}",
                            m.Groups[1].Value,
                            functionParams.Aggregate((p1, p2) => string.Concat(p1, ", ", p2)),
                            m.Groups[2].Value,
                            newEnd);
                    });
            } while (sql != newSql);
            sql = newSql;
        }
        if (tFactory == null)
        {
            ParameterExpression parameterExpression = Expression.Parameter(typeof(SqlDataReader));
            var bindings = ((MemberInitExpression)((LambdaExpression)((UnaryExpression)((MethodCallExpression)(((IQueryable<T>)query).Expression)).Arguments.Last()).Operand).Body).Bindings;
            tFactory = Expression.Lambda<Func<SqlDataReader, T>>(
                Expression.MemberInit(
                    Expression.New(
                        typeof(T)
                    ),
                    bindings.Select((binding, index) =>
                        Expression.Bind(
                            binding.Member,
                            Expression.Convert(
                                Expression.Condition(
                                    Expression.Call(
                                        parameterExpression,
                                        typeof(SqlDataReader).GetMethod("IsDBNull"),
                                        Expression.Subtract(
                                            Expression.Add(
                                                Expression.Constant(index),
                                                Expression.MakeMemberAccess(
                                                    parameterExpression,
                                                    typeof(SqlDataReader).GetProperty("FieldCount")
                                                )
                                            ),
                                            Expression.Constant(bindings.Count)
                                        )
                                    ),
                                    Expression.Constant(null),
                                    Expression.Call(
                                        parameterExpression,
                                        typeof(SqlDataReader).GetMethod("GetValue"),
                                        Expression.Subtract(
                                            Expression.Add(
                                                Expression.Constant(index),
                                                Expression.MakeMemberAccess(
                                                    parameterExpression,
                                                    typeof(SqlDataReader).GetProperty("FieldCount")
                                                )
                                            ),
                                            Expression.Constant(bindings.Count)
                                        )
                                    )
                                ),
                                ((PropertyInfo)binding.Member).PropertyType)
                            )
                        ).ToArray()
                    ),
                    parameterExpression
                ).Compile();
        }
        return LoopOnSqlDataReader(query, sql, tFactory);
    }



    private static IEnumerable<T> LoopOnSqlDataReader<T>(ObjectQuery<T> query, string sql, Func<SqlDataReader, T> tFactory)
    {
        using (SqlConnection connection = new SqlConnection(((EntityConnection)query.Context.Connection).StoreConnection.ConnectionString))
        {
            if (connection.State != System.Data.ConnectionState.Open)
                connection.Open();
            SqlCommand command = connection.CreateCommand();
            command.CommandText = sql;
            command.Parameters.AddRange(query.Parameters.Select((p, index) => new SqlParameter(string.Format("@p{0}", index), p.Value)).ToArray());
            using (command)
            {
                SqlDataReader dataReader = command.ExecuteReader();
                while (dataReader.Read())
                    yield return tFactory(dataReader);
            }
        }
    }
    public static IEnumerable<T> UseTableValuedFunction<T>(this IQueryable<T> query, ref Func<SqlDataReader, T> tFactory, IEnumerable<KeyValuePair<string, int>> nbParamsPerFunction)
        where T : new()
    {
        return ((ObjectQuery<T>)query).UseTableValuedFunction(ref tFactory, nbParamsPerFunction);
    }
    public static IEnumerable<T> UseTableValuedFunction<T>(this IQueryable<T> query, ref Func<SqlDataReader, T> tFactory, params KeyValuePair<string, int>[] nbParamsPerFunction)
        where T : new()
    {
        return UseTableValuedFunction<T>(query, ref tFactory, (IEnumerable<KeyValuePair<string, int>>)nbParamsPerFunction);
    }
}


Now if you want to track changes on your entity, you just have to attach it to the context.



Hope that helps

This entry was posted in 12253, 7671, 7674. Bookmark the permalink.

One Response to How to include recursion / table valued functions in LINQ To Entities queries with EF4? v2

  1. R. S. says:

    I tried to implement the program which is given in blog(Both June and July post). In Version1, i am getting error ExecuteStoreQuery() method not supported. In Version2 of this article, how can i call method UseTableValuedFunction(). Can you please tell the implementation of ‘method call statement’ or Complete Code as it was given in Version1(June post).

Leave a Reply

Your email address will not be published. Required fields are marked *

You may use these HTML tags and attributes: <a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <strike> <strong>