Reimplementing LINQ to Objects: Part 19 – Join

You might expect that as joining is such a fundamental concept in SQL, it would be a complex operation to both describe and implement in LINQ. Fortunately, nothing could be further from the truth…

What is it?

Join has a mere two overloads, with the familiar pattern of one taking a custom equality comparer and the other not:

public static IEnumerable<TResult> Join<TOuter, TInner, TKey, TResult>(
    this IEnumerable<TOuter> outer,
    IEnumerable<TInner> inner,
    Func<TOuter, TKey> outerKeySelector,
    Func<TInner, TKey> innerKeySelector,
    Func<TOuter, TInner, TResult> resultSelector)

public static IEnumerable<TResult> Join<TOuter, TInner, TKey, TResult>(
    this IEnumerable<TOuter> outer,
    IEnumerable<TInner> inner,
    Func<TOuter, TKey> outerKeySelector,
    Func<TInner, TKey> innerKeySelector,
    Func<TOuter, TInner, TResult> resultSelector,
    IEqualityComparer<TKey> comparer)

These are daunting signatures to start with – four type parameters, and up to six normal method parameters (including the first "this" parameter indicating an extension method). Don’t panic. It’s actually quite simple to understand each individual bit:

  • We have two sequences (outer and inner). The two sequences can have different element types (TOuter and TInner).
  • For each sequence, there’s also a key selector – a projection from a single item to the key for that item. Note that although the key type can be different from the two sequence types, there’s only one key type – both key selectors have to return the same kind of key (TKey).
  • An optional equality comparer is used to compare keys.
  • A delegate (resultSelector) is used to project a pair of items whose keys are equal (one from each sequence) to the result type (TResult)

The idea is that we look through the two input sequences for pairs which correspond to equal keys, and yield one output element for each pair. This is an equijoin operation: we can only deal with equal keys, not pairs of keys which meet some arbitrary condition. It’s also an inner join in database terms – we will only see an item from one sequence if there’s a "matching" item from the other sequence. I’ll talk about mimicking left joins when I implement GroupJoin.

The documentation gives us details of the order in which we need to return items:

Join preserves the order of the elements of outer, and for each of these elements, the order of the matching elements of inner.

For the sake of clarity, it’s probably worth including a naive implementation which at least gives the right results in the right order:

// Bad implementation, only provided for the purposes of discussion
public static IEnumerable<TResult> Join<TOuter, TInner, TKey, TResult>(
    this IEnumerable<TOuter> outer,
    IEnumerable<TInner> inner,
    Func<TOuter, TKey> outerKeySelector,
    Func<TInner, TKey> innerKeySelector,
    Func<TOuter, TInner, TResult> resultSelector,
    IEqualityComparer<TKey> comparer)
{
    // Argument validation omitted
    foreach (TOuter outerElement in outer)
    {
        TKey outerKey = outerKeySelector(outerElement);
        foreach (TInner innerElement in inner)
        {
            TKey innerKey = innerKeySelector(innerElement);
            if (comparer.Equals(outerKey, innerKey))
            {
                yield return resultSelector(outerElement, innerElement);
            }
        }
    }
}

Aside from the missing argument validation, there are two important problems with this:

  • It iterates over the inner sequence multiple times. I always advise anyone implementing a LINQ-like operator to only iterate over any input sequence once. There are sequences which are impossible to iterate over multiple times, or which may give different results each time. That’s bad news.
  • It always has a complexity of O(N * M) for N items in the inner sequence and M items in the outer sequence. Eek. Admittedly that’s always a possible complexity – two sequences which have the same key for all elements will always have that complexity – but in a typical situation we can do considerably better.

The real Join operator uses the same behaviour as Except and Intersect when it comes to how the input sequences are consumed:

  • Argument validation occurs eagerly – both sequences and all the "selector" delegates have to be non-null; the comparer argument can be null, leading to the default equality comparer for TKey being used.
  • The overall operation uses deferred execution: it doesn’t iterate over either input sequence until something starts iterating over the result sequence.
  • When MoveNext is called on the result sequence for the first time, it immediately consumes the whole of the inner sequence, buffering it.
  • The outer sequence is streamed – it’s only read one element at a time. By the time the result sequence has started yielding results from the second element of outer, it’s forgotten about the first element.

We’ve started veering towards an implementation already, so let’s think about tests.

What are we going to test?

I haven’t bothered with argument validation tests this time – even with cut and paste, the 10 tests required to completely check everything feels like overkill.

However, I have tested:

  • Joining two invalid sequences, but not using the results (i.e. testing deferred execution)
  • The way that the two sequences are consumed
  • Using a custom comparer
  • Not specifying a comparer
  • Using sequences of different types

See the source code for more details, but here’s a flavour – the final test:

[Test]
public void DifferentSourceTypes()
{
    int[] outer = { 5, 3, 7 };
    string[] inner = { "bee", "giraffe", "tiger", "badger", "ox", "cat", "dog" };

    var query = outer.Join(inner,
                           outerElement => outerElement,
                           innerElement => innerElement.Length,
                           (outerElement, innerElement) => outerElement + ":" + innerElement);
    query.AssertSequenceEqual("5:tiger", "3:bee", "3:cat", "3:dog", "7:giraffe");
}

To be honest, the tests aren’t very exciting. The implementation is remarkably simple though.

Let’s implement it!

Trivial decision 1: make the comparer-less overload call the one with the comparer.

Trivial decision 2: use the split-method technique to validate arguments eagerly but defer the rest of the operation

That leaves us with the actual implementation in an iterator block, which is all I’m going to provide the code for here. So, what are we going to do?

Well, we know that we’re going to have to read in the whole of the "inner" sequence – but let’s wait a minute before deciding how to store it. We’re then going to iterate over the "outer" sequence. For each item in the outer sequence, we need to find the key and then work out all the "inner" sequence items which match that key. Now, the idea of finding all the items in a sequence which match a particular key should sound familiar to you – that’s exactly what a lookup is for. If we build a lookup from the inner sequence as our very first step, the rest becomes easy: we can fetch the sequence of matches, then iterate over them and yield the return value of calling result selector on the pair of elements.

All of this is easier to see in code than in words, so here’s the method:

private static IEnumerable<TResult> JoinImpl<TOuter, TInner, TKey, TResult>(
    IEnumerable<TOuter> outer,
    IEnumerable<TInner> inner,
    Func<TOuter, TKey> outerKeySelector,
    Func<TInner, TKey> innerKeySelector,
    Func<TOuter, TInner, TResult> resultSelector,
    IEqualityComparer<TKey> comparer)
{
    var lookup = inner.ToLookup(innerKeySelector, comparer);
    foreach (var outerElement in outer)
    {
        var key = outerKeySelector(outerElement);
        foreach (var innerElement in lookup[key])
        {
            yield return resultSelector(outerElement, innerElement);
        }
    }
}

Personally, I think this is rather beautiful… in particular, I like the way that it uses every parameter exactly once. Everything is just set up to work nicely.

But wait, there’s more… If you look at the nested foreach loops, that should remind you of something: for each outer sequence element, we’re computing a nested sequence, then applying a delegate to each pair, and yielding the result. That’s almost exactly the definition of SelectMany! If only we had a "yield foreach" or "yield!" I’d be tempted to use an implementation like this:

private static IEnumerable<TResult> JoinImpl<TOuter, TInner, TKey, TResult>(
    IEnumerable<TOuter> outer,
    IEnumerable<TInner> inner,
    Func<TOuter, TKey> outerKeySelector,
    Func<TInner, TKey> innerKeySelector,
    Func<TOuter, TInner, TResult> resultSelector,
    IEqualityComparer<TKey> comparer)
{
    var lookup = inner.ToLookup(innerKeySelector, comparer);
    // Warning: not really valid C#
    yield foreach outer.SelectMany(outerElement => lookup[outerKeySelector(outerElement)],
                                   resultSelector);
}

Unfortunately there’s no such thing as a "yield foreach" statement. We can’t just call SelectMany and return the result directly, because then we wouldn’t be deferring execution. The best we can sensibly do is loop:

private static IEnumerable<TResult> JoinImpl<TOuter, TInner, TKey, TResult>(
    IEnumerable<TOuter> outer,
    IEnumerable<TInner> inner,
    Func<TOuter, TKey> outerKeySelector,
    Func<TInner, TKey> innerKeySelector,
    Func<TOuter, TInner, TResult> resultSelector,
    IEqualityComparer<TKey> comparer)
{
    var lookup = inner.ToLookup(innerKeySelector, comparer);
    var results = outer.SelectMany(outerElement => lookup[outerKeySelector(outerElement)],
                                   resultSelector);
    foreach (var result in results)
    {
        yield return result;
    }
}

At that stage, I think I prefer the original form – which is still pretty simple, thanks to the lookup.

While I have avoided using other operators for implementations in the past, in this case it feels so completely natural, it would be silly not to use ToLookup to implement Join.

One final point before I wrap up for the night…

Query expressions

Most of the operators I’ve been implementing recently don’t have any direct mapping in C# query expressions. Join does, however. The final test I showed before can also be written like this:

int[] outer = { 5, 3, 7 };
string[] inner = { "bee", "giraffe", "tiger", "badger", "ox", "cat", "dog" };

var query = from x in outer
            join y in inner on x equals y.Length
            select x + ":" + y;

query.AssertSequenceEqual("5:tiger", "3:bee", "3:cat", "3:dog", "7:giraffe");

The compiler will effectively generate the same code (although admittedly I’ve used shorter range variable names here – x and y instead of outerElement and innerElement respectively). In this case the resultSelector delegate it supplies is simply the final projection from the "select" clause – but if we had anything else (such as a where clause) between the join and the select, the compiler would introduce a transparent identifier to propagate the values of both x and y through the query. It looks like I haven’t blogged explicitly about transparent identifiers, although I cover them in C# in Depth. Maybe once I’ve finished actually implementing the operators, I’ll have a few more general posts on this sort of thing.

Anyway, the point is that for simple join clauses (as opposed to join…into) we’ve implemented everything we need to. Hoorah.

Conclusion

I spoiled some of the surprise around how easy Join would be to implement by mentioning it in the ToLookup post, but I’m still impressed by how neat it is. Again I should emphasize that this is due to the design of LINQ – it’s got nothing to do with my own prowess.

This won’t be the end of our use of lookups though… the other grouping constructs can all use them too. I’ll try to get them all out of the way before moving on to operators which feel a bit different…

Addendum

It turns out this wasn’t quite as simple as I’d expected. Although ToLookup and GroupBy handle null keys without blinking, Join and GroupJoin ignore them. I had to write an alternative version of ToLookup which ignores null keys while populating the lookup, and then replace the calls of "ToLookup" in the code above with calls to "ToLookupNoNullKeys". This isn’t documented anywhere, and is inconsistent with ToLookup/GroupBy. I’ve opened up a Connect issue about it, in the hope that it at least gets documented properly. (It’s too late to change the behaviour now, I suspect.)

Reimplementing LINQ to Objects: Part 18 – ToLookup

I’ve had a request to implement Join, which seems a perfectly reasonable operator to aim towards. However, it’s going to be an awful lot easier to implement if we write ToLookup first. That will also help with GroupBy and GroupJoin, later on.

In the course of this post we’ll create a certain amount of infrastructure – some of which we may want to tweak later on.

What is it?

ToLookup has four overloads, with these signatures:

public static ILookup<TKey, TSource> ToLookup<TSource, TKey>(
    this IEnumerable<TSource> source,
    Func<TSource, TKey> keySelector)

public static ILookup<TKey, TSource> ToLookup<TSource, TKey>(
    this IEnumerable<TSource> source,
    Func<TSource, TKey> keySelector,
    IEqualityComparer<TKey> comparer)

public static ILookup<TKey, TElement> ToLookup<TSource, TKey, TElement>(
    this IEnumerable<TSource> source,
    Func<TSource, TKey> keySelector,
    Func<TSource, TElement> elementSelector)

public static ILookup<TKey, TElement> ToLookup<TSource, TKey, TElement>(
    this IEnumerable<TSource> source,
    Func<TSource, TKey> keySelector,
    Func<TSource, TElement> elementSelector,
    IEqualityComparer<TKey> comparer)

Essentially these boil down to two required parameters and two optional ones:

  • The source is required, and must not be null
  • The keySelector is required, and must not be null
  • The elementSelector is optional, and defaults to an identity projection of a source element to itself. If it is specified, it must not be null.
  • The comparer is optional, and defaults to the default equality comparer for the key type. It may be null, which is equivalent to specifying the default equality comparer for the key type.

Now we can just consider the most general case – the final overload. However, in order to understand what ToLookup does, we have to know what ILookup<TKey, TElement> means – which in turn means knowing about IGrouping<TKey, TElement>:

public interface IGrouping<out TKey, out TElement> : IEnumerable<TElement>
{
    TKey Key { get; }
}

public interface ILookup<TKey, TElement> : IEnumerable<IGrouping<TKey, TElement>>
{
    int Count { get; }
    IEnumerable<TElement> this[TKey key] { get; }
    bool Contains(TKey key);
}

The generics make these interfaces seem somewhat scary at first sight, but they’re really not so bad. A grouping is simply a sequence with an associated key. This is a pretty simple concept to transfer to the real world – something like "Plays by Oscar Wilde" could be an IGrouping<string, Play> with a key of "Oscar Wilde". The key doesn’t have to be "embedded" within the element type though – it would also be reasonable to have an IGrouping<string, string> representing just the names of plays by Oscar Wilde.

A lookup is essentially a map or dictionary where each key is associated with a sequence of values instead of a single one. Note that the interface is read-only, unlike IDictionary<TKey, TValue>. As well as looking up a single sequence of values associated with a key, you can also iterate over the whole lookup in terms of groupings (instead of the key/value pair from a dictionary). There’s one other important difference between a lookup and a dictionary: if you ask a lookup for the sequence corresponding to a key which it doesn’t know about, it will return an empty sequence, rather than throwing an exception. (A key which the lookup does know about will never yield an empty sequence.)

One slightly odd point to note is that while IGrouping is covariant in TKey and TElement, ILookup is invariant in both of its type parameters. While TKey has to be invariant, it would be reasonable for TElement to be covariant – to go back to our "plays" example, an IGrouping<string, Play> could be sensibly regarded as an IGrouping<string, IWrittenWork> (with the obvious type hierarchy). However, the interface declarations above are the ones in .NET 4, so that’s what I’ve used in Edulinq.

Now that we understand the signature of ToLookup, let’s talk about what it actually does. Firstly, it uses immediate execution – the lookup returned by the method is effectively divorced from the original sequence; changes to the sequence after the method has returned won’t change the lookup. (Obviously changes to the objects within the original sequence may still be seen, depending on what the element selector does, etc.) The rest is actually fairly straightforward, when you consider what parameters we have to play with:

  • The keySelector is applied to each item in the input sequence. Keys are always compared for equality using the "comparer" parameter.
  • The elementSelector is applied to each item in the input sequence, to project the item to the value which will be returned within the lookup.

To demonstrate this a little further, here are two applications of ToLookup – one using the first overload, and one using the last. In each case we’re going to group plays by author and then display some information about them. Hopefully this will make some of the concepts I’ve described a little more concrete:

ILookup<string, Play> lookup = allPlays.ToLookup(play => play.Author);
foreach (IGrouping<string, Play> group in lookup)
{
    Console.WriteLine("Plays by {0}", group.Key);
    foreach (Play play in group)
    {
        Console.WriteLine("  {0} ({1})", play.Name, play.CopyrightDate);
    }
}

// Or…

ILookup<string, string> lookup = allPlays.ToLookup(play => play.Author,
                                                   play => play.Name,
                                                   StringComparer.OrdinalIgnoreCase);
foreach (IGrouping<string, string> group in lookup)
{
    Console.WriteLine("Plays by {0}:", group.Key);
    foreach (string playName in group)
    {
        Console.WriteLine("  {0}", playName);
    }
}
        
// And demonstrating looking up by a key:
IEnumerable<string> playsByWilde = lookup["Oscar Wilde"];
Console.WriteLine("Plays by Oscar Wilde: {0}",
                  string.Join("; ", playsByWilde));

Note that although I’ve used explicit typing for all the variables here, I would usually use implicit typing with var, just to keep the code more concise.

Now we just have the small matter of working out how to go from a key/element pair for each item, to a data structure which lets us look up sequences by key.

Before we move onto the implementation and tests, there are two aspects of ordering to consider:

  • How are the groups within the lookup ordered?
  • How are the elements within each group ordered?

The documentation for ToLookup is silent on these matters – but the docs for the closely-related GroupBy operator are more specific:

The IGrouping<TKey, TElement> objects are yielded in an order based on the order of the elements in source that produced the first key of each IGrouping<TKey, TElement>. Elements in a grouping are yielded in the order they appear in source.

Admittedly "in an order based on…" isn’t as clear as it might be, but I think it’s reasonable to make sure that our implementation yields groups such that the first group returned has the same key as the first element in the source, and so on.

What are we going to test?

I’ve actually got relatively few tests this time. I test each of the overloads, but not terribly exhaustively. There are three tests worth looking at though. The first two show the "eager" nature of the operator, and the final one demonstrates the most complex overload by grouping people into families.

[Test]
public void SourceSequenceIsReadEagerly()
{
    var source = new ThrowingEnumerable();
    Assert.Throws<InvalidOperationException>(() => source.ToLookup(x => x));
}

[Test]
public void ChangesToSourceSequenceAfterToLookupAreNotNoticed()
{
    List<string> source = new List<string> { "abc" };
    var lookup = source.ToLookup(x => x.Length);
    Assert.AreEqual(1, lookup.Count);

    // Potential new key is ignored
    source.Add("x");
    Assert.AreEqual(1, lookup.Count);

    // Potential new value for existing key is ignored
    source.Add("xyz");
    lookup[3].AssertSequenceEqual("abc");
}


[Test]
public void LookupWithComparareAndElementSelector()
{
    var people = new[] {
        new { First = "Jon", Last = "Skeet" },
        new { First = "Tom", Last = "SKEET" }, // Note upper-cased name
        new { First = "Juni", Last = "Cortez" },
        new { First = "Holly", Last = "Skeet" },
        new { First = "Abbey", Last = "Bartlet" },
        new { First = "Carmen", Last = "Cortez" },                
        new { First = "Jed", Last = "Bartlet" }
    };

    var lookup = people.ToLookup(p => p.Last, p => p.First, StringComparer.OrdinalIgnoreCase);

    lookup["Skeet"].AssertSequenceEqual("Jon", "Tom", "Holly");
    lookup["Cortez"].AssertSequenceEqual("Juni", "Carmen");
    // The key comparer is used for lookups too
    lookup["BARTLET"].AssertSequenceEqual("Abbey", "Jed");

    lookup.Select(x => x.Key).AssertSequenceEqual("Skeet", "Cortez", "Bartlet");
}

I’m sure there are plenty of other tests I could have come up with. lf you want to see any ones in particular (either as an edit to this post if they already exist in source control, or as entirely new tests), please leave a comment.

EDIT: It turned out there were a few important tests missing… see the addendum.

Let’s implement it!

There’s quite a lot of code involved in implementing this – but it should be reusable later on, potentially with a few tweaks. Let’s get the first three overloads out of the way to start with:

public static ILookup<TKey, TSource> ToLookup<TSource, TKey>(
    this IEnumerable<TSource> source,
    Func<TSource, TKey> keySelector)
{
    return source.ToLookup(keySelector, element => element, EqualityComparer<TKey>.Default);
}

public static ILookup<TKey, TSource> ToLookup<TSource, TKey>(
    this IEnumerable<TSource> source,
    Func<TSource, TKey> keySelector,
    IEqualityComparer<TKey> comparer)
{
    return source.ToLookup(keySelector, element => element, comparer);
}

public static ILookup<TKey, TElement> ToLookup<TSource, TKey, TElement>(
    this IEnumerable<TSource> source,
    Func<TSource, TKey> keySelector,
    Func<TSource, TElement> elementSelector)
{
    return source.ToLookup(keySelector, elementSelector, EqualityComparer<TKey>.Default);
}

Now we can just worry about the final one. Before we implement that though, we’re definitely going to need an implementation of IGrouping. Let’s come up with a really simple one to start with:

internal sealed class Grouping<TKey, TElement> : IGrouping<TKey, TElement>
{
    private readonly TKey key;
    private readonly IEnumerable<TElement> elements;

    internal Grouping(TKey key, IEnumerable<TElement> elements)
    {
        this.key = key;
        this.elements = elements;
    }

    public TKey Key { get { return key; } }

    public IEnumerator<TElement> GetEnumerator()
    {
        return elements.GetEnumerator();
    }

    IEnumerator IEnumerable.GetEnumerator()
    {
        return GetEnumerator();
    }
}

How do we feel about this? Well, groupings should be immutable. We have no guarantee that the "values" sequence won’t be changed after creation – but it’s an internal class, so we’ve just got to be careful about how we use it. We also have to be careful that we don’t use a mutable sequence type which allows a caller to get from the iterator (the IEnumerator<TElement> returned by GetEnumerator) back to the sequence and then mutate it. In our case we’re actually going to use List<T> to provide the sequences, and while List<T>.Enumerator is a public type, it doesn’t expose the underlying list. Of course a caller could use reflection to mess with things, but we’re not going to try to protect against that.

Okay, so now we can pair a sequence with a key… but we still need to implement ILookup. This is where there are multiple options. We want our lookup to be immutable, but there are various degrees of immutability we could use:

  • Mutable internally, immutable in public API
  • Mutable privately, immutable to the internal API
  • Totally immutable, even within the class itself

The first option is the simplest to implement, and it’s what I’ve gone for at the moment. I’ve created a Lookup class which is allows a key/element pair to be added to it from within the Edulinq assembly. It uses a Dictionary<TKey, List<TElement>> to map the keys to sequences efficiently, and a List<TKey> to remember the order in which we first saw the keys. Here’s the complete implementation:

internal sealed class Lookup<TKey, TElement> : ILookup<TKey, TElement>
{
    private readonly Dictionary<TKey, List<TElement>> map;
    private readonly List<TKey> keys;

    internal Lookup(IEqualityComparer<TKey> comparer)
    {
        map = new Dictionary<TKey, List<TElement>>(comparer);
        keys = new List<TKey>();
    }

    internal void Add(TKey key, TElement element)
    {
        List<TElement> list;
        if (!map.TryGetValue(key, out list))
        {
            list = new List<TElement>();
            map[key] = list;
            keys.Add(key);
        }
        list.Add(element);
    }

    public int Count
    {
        get { return map.Count; }
    }

    public IEnumerable<TElement> this[TKey key]
    {
        get
        {
            List<TElement> list;
            if (!map.TryGetValue(key, out list))
            {
                return Enumerable.Empty<TElement>();
            }
            return list.Select(x => x);
        }
    }

    public bool Contains(TKey key)
    {
        return map.ContainsKey(key);
    }

    public IEnumerator<IGrouping<TKey, TElement>> GetEnumerator()
    {
        return keys.Select(key => new Grouping<TKey, TElement>(key, map[key]))
                   .GetEnumerator();
    }
        
    IEnumerator IEnumerable.GetEnumerator()
    {
        return GetEnumerator();
    }
}

All of this is really quite straightforward. Note that we provide an equality comparer to the constructor, which is then passed onto the dictionary – that’s the only thing that needs to know how to compare keys.

There are only two points of interest, really:

  • In the indexer, we don’t return the list itself – that would allow the caller to mutate it, after casting it to List<TElement>. Instead, we just call Select with an identity projection, as a simple way of inserting a sort of "buffering" layer between the list and the caller. There are other ways of doing this, of course – including implementing the indexer with an iterator block.
  • In GetEnumerator, we’re retaining the key order by using our list of keys and performing a lookup on each key.

We’re currently creating the new Grouping objects lazily – which will lead to fewer of them being created if the caller doesn’t actually iterate over the lookup, but more of them if the caller iterates over it several times. Again, there are alternatives here – but without any good information about where to make the trade-off, I’ve just gone for the simplest code which works for the moment.

One last thing to note about Lookup – I’ve left it internal. In .NET, it’s actually public – but the only way of getting at an instance of it is to call ToLookup and then cast the result. I see no particular reason to make it public, so I haven’t.

Now we’re finally ready to implement the last ToLookup overload – and it becomes pretty much trivial:

public static ILookup<TKey, TElement> ToLookup<TSource, TKey, TElement>(
    this IEnumerable<TSource> source,
    Func<TSource, TKey> keySelector,
    Func<TSource, TElement> elementSelector,
    IEqualityComparer<TKey> comparer)
{
    if (source == null)
    {
        throw new ArgumentNullException("source");
    }
    if (keySelector == null)
    {
        throw new ArgumentNullException("keySelector");
    }
    if (elementSelector == null)
    {
        throw new ArgumentNullException("elementSelector");
    }

    Lookup<TKey, TElement> lookup = new Lookup<TKey, TElement>(comparer ?? EqualityComparer<TKey>.Default);

    foreach (TSource item in source)
    {
        TKey key = keySelector(item);
        TElement element = elementSelector(item);
        lookup.Add(key, element);
    }
    return lookup;
}

When you look past the argument validation, we’re just creating a Lookup, populating it, and then returning it. Simple.

Thread safety

Something I haven’t addressed anywhere so far is thead safety. In particular, although all of this is nicely immutable when viewed from a single thread, I have a nasty feeling that theoretically, if the return value of our implementation of ToLookup was exposed to another thread, it could potentially observe the internal mutations we’re making here, as we’re not doing anything special in terms of the memory model here.

I’m basically scared by lock-free programming these days, unless I’m using building blocks provided by someone else. While investigating the exact guarantees offered here would be interesting, I don’t think it would really help with our understanding of LINQ as a whole. I’m therefore declaring thread safety to be out of scope for Edulinq in general :)

Conclusion

So that’s ToLookup. Two new interfaces, two new classes, all for one new operator… so far. We can reuse almost all of this in Join though, which will make it very simple to implement. Stay tuned…

Addendum

It turns out that I missed something quite important: ToLookup has to handle null keys, as do various other LINQ operators (GroupBy etc). We’re currently using a Dictionary<TKey, TValue> to organize the groups… and that doesn’t support null keys. Oops.

So, first steps: write some tests proving that it fails as we expect it to. Fetch by a null key of a lookup. Include a null key in the source of a lookup. Use GroupBy, GroupJoin and Join with a null key. Watch it go bang. That’s the easy part…

Now, we can do all the special-casing in Lookup itself – but it gets ugly. Our Lookup code was pretty simple before; it seems a shame to spoil it with checks everywhere. What we really need is a dictionary which does support null keys. Step forward NullKeyFriendlyDictionary, a new internal class in Edulinq. Now you might expect this to implement IDictionary<TKey, TValue>, but it turns out that’s a pain in the neck. We hardly use any of the members of Dictionary – TryGetValue, the indexer, ContainsKey, and the Count property. That’s it! So those are the only members I’ve implemented.

The class contains a Dictionary<TKey, TValue> to delegate most requests to, and it just handles null keys itself. Here’s a quick sample:

internal sealed class NullKeyFriendlyDictionary<TKey, TValue>
{
    private readonly Dictionary<TKey, TValue> map;
    private bool haveNullKey = false;
    private TValue valueForNullKey;

    internal NullKeyFriendlyDictionary(IEqualityComparer<TKey> comparer)
    {
        map = new Dictionary<TKey, TValue>(comparer);
    }

    internal bool TryGetValue(TKey key, out TValue value)
    {
        if (key == null)
        {
            // This will be default(TValue) if haveNullKey is false,
            // which is what we want.
            value = valueForNullKey;
            return haveNullKey;
        }
        return map.TryGetValue(key, out value);
    }

    // etc
}

There’s one potential flaw here, that I can think of: if you provide an IEqualityComparer<TKey> which treats some non-null key as equal to a null key, we won’t spot that. If your source then contains both those keys, we’ll end up splitting them into two groups instead of keeping them together. I’m not too worried about that – and I suspect there are all kinds of ways that could cause problems elsewhere anyway.

With this in place, and Lookup adjusted to use NullKeyFriendlyDictionary instead of just Dictionary, all the tests pass. Hooray!

At the same time as implementing this, I’ve tidied up Grouping itself – it now implements IList<T> itself, in a way which is immutable to the outside world. The Lookup now contains groups directly, and can return them very easily. The code is generally tidier, and anything using a group can take advantage of the optimizations applied to IList<T> – particularly the Count() operator, which is often applied to groups.

Reimplementing LINQ to Objects: Part 17 – Except

I’m really pleased with this one. Six minutes ago (at the time of writing this), I tweeted about the Intersect blog post. I then started writing the tests and implementation for Except… and I’m now done.

The tests are cut/paste/replace with a few tweaks – but it’s the implementation that I’m most pleased with. You’ll see what I mean later – it’s beautiful.

What is it?

Except is our final set operator, with the customary two overloads:

public static IEnumerable<TSource> Except<TSource>(
    this IEnumerable<TSource> first,
    IEnumerable<TSource> second)

public static IEnumerable<TSource> Except<TSource>(
    this IEnumerable<TSource> first,
    IEnumerable<TSource> second,
    IEqualityComparer<TSource> comparer)

The result of calling Except is the elements of the first sequence which are not in the second sequence.

Just for completeness, here’s a quick summary of the behaviour:

  • The first overload uses the default equality comparer for TSource, and the second overload does if you pass in "null" as the comparer, too.
  • Neither "first" nor "second" can be null
  • The method does use deferred execution
  • The result sequence only contains distinct elements; even if the first sequence contains duplicates, the result sequence won’t

This time, the documentation doesn’t say how "first" and "second" will be used, other than to describe the result as a set difference. In practice though, it’s exactly the same as Intersect: when we ask for the first result, the "second" input sequence is fully evaluated, then the "first" input sequence is streamed.

What are we going to test?

I literally cut and paste the tests for Intersect, did a find/replace on Intersect/Except, and then looked at the data in each test. In particular, I made sure that there were potential duplicates in the "first" input sequence which would have to be removed in the result sequence. I also tweaked the details of the data in the final tests shown before (which proved the way in which the two sequences were read) but the main thrust of the tests are the same.

Nothing to see here, move on…

Let’s implement it!

I’m not even going to bother showing the comparer-free overload this time. It just calls the other overload, as you’d expect. Likewise the argument validation part of the implementation is tedious. Let’s focus on the part which does the work. First, we’ll think back to Distinct and Intersect:

  • In Distinct, we started with an empty set and populated it as we went along, making sure we never returned anything already in the set
  • In Intersect, we started with a populated set (from the second input sequence) and removed elements from it as we went along, only returning elements where an equal element was previously in the set

Except is simply a cross between the two: from Distinct we keep the idea of a "banned" set of elements that we add to; from Intersect we take the idea of starting off with a set populated from the second input element. Here’s the implementation:

private static IEnumerable<TSource> ExceptImpl<TSource>(
    IEnumerable<TSource> first,
    IEnumerable<TSource> second,
    IEqualityComparer<TSource> comparer)
{
    HashSet<TSource> bannedElements = new HashSet<TSource>(second, comparer);
    foreach (TSource item in first)
    {
        if (bannedElements.Add(item))
        {
            yield return item;
        }
    }
}

The only differences between this and Intersect are:

  • The name of the method (ExceptImpl instead of IntersectImpl)
  • The name of the local variable holding the set (bannedElements instead of potentialElements)
  • The method called in the loop (Add instead of Remove)

Isn’t that just wonderful? Perhaps it shouldn’t make me quite as happy as it does, but hey…

Conclusion

That concludes the set operators – and indeed my blogging for the night. It’s unsurprising that all of the set operators have used a set implementation internally… but I’ve been quite surprised at just how simple the implementations all were. Again, the joy of LINQ resides in the ability for such simple operators to be combined in useful ways.

Reimplementing LINQ to Objects: Part 16 – Intersect (and build fiddling)

Okay, this is more like it – after the dullness of Union, Intersect has a new pattern to offer… and one which we’ll come across repeatedly.

First, however, I should explain some more changes I’ve made to the solution structure…

Building the test assembly

I’ve just had an irritating time sorting out something I thought I’d fixed this afternoon. Fed up of accidentally testing against the wrong implementation, my two project configurations ("Normal LINQ" and "Edulinq implementation") now target different libraries from the test project: only the "Normal LINQ" configuration refers to System.Core, and only the "Edulinq implementation" configuration refers to the Edulinq project itself. Or so I thought. Unfortunately, msbuild automatically adds System.Core in unless you’re careful. I had to add this into the "Edulinq implementation" property group part of my project file to avoid accidentally pulling in System.Core:

<AddAdditionalExplicitAssemblyReferences>false</AddAdditionalExplicitAssemblyReferences>

Unfortunately, at that point all the extension methods I’d written within the tests project – and the references to HashSet<T> – failed. I should have noticed them not failing before, and been suspicious. Hey ho.

Now I’m aware that you can add your own version of ExtensionAttribute, but I believe that can become a problem if at execution time you do actually load System.Core… which I will end up doing, as the Edulinq assembly itself references it (for HashSet<T> aside from anything else).

There may well be multiple solutions to this problem, but I’ve come up with one which appears to work:

  • Add an Edulinq.TestSupport project, and refer to that from both configurations of Edulinq.Tests
  • Make Edulinq.TestSupport refer to System.Core so it can use whatever collections it likes, as well as extension methods
  • Put all the non-test classes (those containing extension methods, and the weird and wonderful collections) into the Edulinq.TestSupport project.
  • Add a DummyClass to Edulinq.TestSupport in the namespace System.Linq, so that using directives within Edulinq.Tests don’t need to be conditionalized

All seems to be working now – and finally, if I try to refer to an extension method I haven’t implemented in Edulinq yet, it will fail to compile instead of silently using the System.Core one. Phew. Now, back to Intersect…

What is it?

Intersect has a now-familiar pair of overloads:

public static IEnumerable<TSource> Intersect<TSource>(
    this IEnumerable<TSource> first,
    IEnumerable<TSource> second)

public static IEnumerable<TSource> Intersect<TSource>(
    this IEnumerable<TSource> first,
    IEnumerable<TSource> second,
    IEqualityComparer<TSource> comparer)

Fairly obviously, it computes the intersection of two sequences: the elements that occur in both "first" and "second". Points to note:

  • Again, the first overload uses the default equality comparer for TSource, and the second overload does if you pass in "null" as the comparer, too.
  • Neither "first" nor "second" can be null
  • The method does use deferred execution – but in a way which may seem unusual at first sight
  • The result sequence only contains distinct elements – even if an element occurs multiple times in both input sequences, it will only occur once in the result

Now for the interesting bit – exactly when the two sequences are used. MSDN claims this:

When the object returned by this method is enumerated, Intersect enumerates first, collecting all distinct elements of that sequence. It then enumerates second, marking those elements that occur in both sequences. Finally, the marked elements are yielded in the order in which they were collected.

This is demonstrably incorrect. Indeed, I have a test which would fail under LINQ to Objects if this were the case. What actually happens is this:

  • Nothing happens until the first result element is requested. I know I’ve said so already, but it’s worth repeating.
  • As soon as the first element of the result is requested, the whole of the "second" input sequence is read, as well as the first (and possibly more) elements of the "first" input sequence – enough to return the first result, basically.
  • Results are read from the "first" input sequence as and when they are required. Only elements which were originally in the "second" input sequence and haven’t already been yielded are returned.

We’ll see this pattern of "greedily read the second sequence, but stream the first sequence" again when we look at Join… but let’s get back to the tests.

What are we going to test?

Obviously I’ve got simple tests including:

  • Argument validation
  • Elements which occur multiple times in each sequence
  • The overload without a comparer specified
  • Specifying a null comparer
  • Specifying a "custom" comparer (I’m using the case-insensitive string comparer again; news at 11)

However, the interesting tests are the ones which show how the sequences are actually consumed. Here they are:

[Test]
public void NoSequencesUsedBeforeIteration()
{
    var first = new ThrowingEnumerable();
    var second = new ThrowingEnumerable();
    // No exceptions!
    var query = first.Union(second);
    // Still no exceptions… we’re not calling MoveNext.
    using (var iterator = query.GetEnumerator())
    {
    }
}

[Test]
public void SecondSequenceReadFullyOnFirstResultIteration()
{
    int[] first = { 1 };
    var secondQuery = new[] { 10, 2, 0 }.Select(x => 10 / x);

    var query = first.Intersect(secondQuery);
    using (var iterator = query.GetEnumerator())
    {
        Assert.Throws<DivideByZeroException>(() => iterator.MoveNext());
    }
}

[Test]
public void FirstSequenceOnlyReadAsResultsAreRead()
{
    var firstQuery = new[] { 10, 2, 0, 2 }.Select(x => 10 / x);
    int[] second = { 1 };

    var query = firstQuery.Intersect(second);
    using (var iterator = query.GetEnumerator())
    {
        // We can get the first value with no problems
        Assert.IsTrue(iterator.MoveNext());
        Assert.AreEqual(1, iterator.Current);

        // Getting at the *second* value of the result sequence requires
        // reading from the first input sequence until the "bad" division
        Assert.Throws<DivideByZeroException>(() => iterator.MoveNext());
    }
}

The first test just proves that execution really is deferred. That’s straightforward.

The second test proves that the "second" input sequence is completely read as soon as we ask for our first result. If the operator had really read "first" and then started reading "second", it could have yielded the first result (1) without throwing an exception… but it didn’t.

The third test proves that the "first" input sequence isn’t read in its entirety before we start returning results. We manage to read the first result with no problems – it’s only when we ask for the second result that we get an exception.

Let’s implement it!

I have chosen to implement the same behaviour as LINQ to Objects rather than the behaviour described by MSDN, because it makes more sense to me. In general, the "first" sequence is given more importance than the "second" sequence in LINQ: it’s generally the one which is streamed, with the second sequence being buffered if necessary.

Let’s start off with the comparer-free overload – as before, it will just call the other one:

public static IEnumerable<TSource> Intersect<TSource>(
    this IEnumerable<TSource> first,
    IEnumerable<TSource> second)
{
    return Intersect(first, second, EqualityComparer<TSource>.Default);
}

Now for the more interesting part. Obviously we’ll have an argument-validating method, but what should we do in the guts? I find the duality between this and Distinct fascinating: there, we started with an empty set of elements, and tried to add each source element to it, yielding the element if we successfully added it (meaning it wasn’t there before).

This time, we’ve effectively got a limited set of elements which we can yield, but we only want to yield each of them once – so as we see items, we can remove them from the set, yielding only if that operation was successful. The initial set is formed from the "second" input sequence, and then we just iterate over the "first" input sequence, removing and yielding appropriately:

public static IEnumerable<TSource> Intersect<TSource>(
    this IEnumerable<TSource> first,
    IEnumerable<TSource> second,
    IEqualityComparer<TSource> comparer)
{           
    if (first == null)
    {
        throw new ArgumentNullException("first");
    }
    if (second == null)
    {
        throw new ArgumentNullException("second");
    }
    return IntersectImpl(first, second, comparer ?? EqualityComparer<TSource>.Default);
}

private static IEnumerable<TSource> IntersectImpl<TSource>(
    IEnumerable<TSource> first,
    IEnumerable<TSource> second,
    IEqualityComparer<TSource> comparer)
{
    HashSet<TSource> potentialElements = new HashSet<TSource>(second, comparer);
    foreach (TSource item in first)
    {
        if (potentialElements.Remove(item))
        {
            yield return item;
        }
    }
}

Next time we’ll see a cross between the two techniques.

Ta-da – it all works as expected. I don’t know whether this is how Intersect really works in LINQ to Objects, but I expect it does something remarkably similar.

Conclusion

After suffering from some bugs earlier today where my implementation didn’t follow the documentation, it’s nice to find some documentation which doesn’t follow the real implementation :)

Seriously though, there’s an efficiency point to be noted here. If you have two sequences, one long and one short, then it’s much more efficient (mostly in terms of space) to use longSequence.Intersect(shortSequence) than shortSequence.Intersect(longSequence). The latter will require the whole of the long sequence to be in memory at once.

Next up – and I might just manage it tonight – our final set operator, Except.

Reimplementing LINQ to Objects: Part 15 – Union

I’m continuing the set-based operators with Union. I may even have a couple of hours tonight – possibly enough to finish off Intersect and Except as well… let’s see.

What is it?

Union is another extension method with two overloads; one taking an equality comparer and one just using the default:

public static IEnumerable<TSource> Union<TSource>(
    this IEnumerable<TSource> first,
    IEnumerable<TSource> second)

public static IEnumerable<TSource> Union<TSource>(
    this IEnumerable<TSource> first,
    IEnumerable<TSource> second,
    IEqualityComparer<TSource> comparer)

This "two overloads, one taking an equality comparer" pattern is familiar from Distinct; we’ll see that and Intersect and Except do exactly the same thing.

Simply put, Union returns the union of the two sequences – all items that are in either input sequence. The result sequence has no duplicates in even if one of the input sequences contains duplicates. (I’m using the term "duplicate" here to mean an element which is equal to another according to the equality comparer we’re using in the operator.)

Characteristics:

  • Union uses deferred execution: argument validation is basically all that happens when the method is first called; it only starts iterating over the input sequences when you iterate over the result sequence
  • Neither first nor second can be null; the comparer argument can be null, in which case the default equality comparer is used
  • The input sequences are only read as and when they’re needed; to return the first result element, only the first input element is read

It’s worth noting that the documentation for Union specifies a lot more than the Distinct documentation does:

When the object returned by this method is enumerated, Union enumerates first and second in that order and yields each element that has not already been yielded.

To me, that actually looks like a guarantee of the rules I proposed for Distinct. In particular, it’s guaranteeing that the implementation iterates over "first" before "second", and if it’s yielding elements as it goes, that guarantees that distinct elements will retain their order from the original input sequences. Whether others would read it in the same way or not, I can’t say… input welcome.

What are we going to test?

I’ve written quite a few tests for Union – possibly more than we really need, but they demonstrate a few points of usage. The tests are:

  • Arguments are validated eagerly
  • Finding the union of two sequences without specifying a comparer; both inputs have duplicate elements, and there’s one element in both
  • The same test as above but explicitly specifying null as the comparer, to force the default to be used
  • The same test as above but using a case-insensitive string comparer
  • Taking the union of an empty sequence with a non-empty one
  • Taking the union of a non-empty sequence with an empty one
  • Taking the union of two empty sequences
  • Proving that the first sequence isn’t used until we start iterating over the result sequence (using ThrowingEnumerable)
  • Proving that the second sequence isn’t used until we’ve exhausted the first

No new collections or comparers needed this time though – it’s all pretty straightforward. I haven’t written any tests for null elements this time – I’m convinced enough by what I saw when implementing Distinct to believe they won’t be a problem.

Let’s implement it!

First things first: we can absolutely implement the simpler overload in terms of the more complex one, and I’ll do the same for Except and Intercept. Here’s the Union method:

public static IEnumerable<TSource> Union<TSource>(
    this IEnumerable<TSource> first,
    IEnumerable<TSource> second)
{
    return Union(first, second, EqualityComparer<TSource>.Default);
}

So how do we implement the more complex overload? Well, I’ve basically been a bit disappointed by Union in terms of its conceptual weight. It doesn’t really give us anything that the obvious combination of Concat and Distinct doesn’t – so let’s implement it that way first:

public static IEnumerable<TSource> Union<TSource>(
    this IEnumerable<TSource> first,
    IEnumerable<TSource> second,
    IEqualityComparer<TSource> comparer)
{
    return first.Concat(second).Distinct(comparer);
}

The argument validation can be implemented by Concat with no problems here – the parameters have the same name, so any exceptions thrown will be fine in every way.

That’s how I think about Union, but as I’ve mentioned before, I’d rather not actually see Concat and Distinct showing up in the stack trace – so here’s the fuller implementation.

public static IEnumerable<TSource> Union<TSource>(
    this IEnumerable<TSource> first,
    IEnumerable<TSource> second,
    IEqualityComparer<TSource> comparer)
{
    if (first == null)
    {
        throw new ArgumentNullException("first");
    }
    if (second == null)
    {
        throw new ArgumentNullException("second");
    }
    return UnionImpl(first, second, comparer ?? EqualityComparer<TSource>.Default);
}

private static IEnumerable<TSource> UnionImpl<TSource>(
    IEnumerable<TSource> first,
    IEnumerable<TSource> second,
    IEqualityComparer<TSource> comparer)
{
    HashSet<TSource> seenElements = new HashSet<TSource>(comparer);
    foreach (TSource item in first)
    {
        if (seenElements.Add(item))
        {
            yield return item;
        }
    }
    foreach (TSource item in second)
    {
        if (seenElements.Add(item))
        {
            yield return item;
        }
    }
}

That feels like an absurd waste of code when we can achieve the same result so simply, admittedly. This is the first time my resolve against implementing one operator in terms of completely different ones has wavered. Just looking at it in black and white (so to speak), I’m close to going over the edge…

Conclusion

Union was a disappointingly bland operator in my view. (Maybe I should start awarding operators marks out of ten for being interesting, challenging etc.) It doesn’t feel like it’s really earned its place in LINQ, as calls to Concat/Distinct can replace it so easily. Admittedly as I’ve mentioned in several other places, a lot of operators can be implemented in terms of each other – but rarely quite so simply.

Still, I think Intersect and Except should be more interesting.

Reimplementing LINQ to Objects: Part 14 – Distinct

I’m going to implement the set-based operators next, starting with Distinct.

What is it?

Distinct has two overloads, which differ only in the simplest possible way:

public static IEnumerable<TSource> Distinct<TSource>(
    this IEnumerable<TSource> source)

public static IEnumerable<TSource> Distinct<TSource>(
    this IEnumerable<TSource> source,
    IEqualityComparer<TSource> comparer)

The point of the operator is straightforward: the result sequence contains the same items as the input sequence, but with the duplicates removed – so an input of { 0, 1, 3, 1, 5 } will give a result sequence of { 0, 1, 3, 5 } – the second occurrence of 1 is ignored.

This time I’ve checked and double-checked the documentation – and in this case it really is appropriate to think of the first overload as just a simplified version of the second. If you don’t specify an equality comparer, the default comparer for the type will be used. The same will happen if you pass in a null reference for the comparer. The default equality comparer for a type can be obtained with the handy EqualityComparer<T>.Default property.

Just to recap, an equality comparer (represented by the IEqualityComparer<T> interface) is able to do two things:

  • Determine the hash code for a single item of type T
  • Compare any two items of type T for equality

It doesn’t have to give any sort of ordering – that’s what IComparer<T> is for, although that doesn’t have the ability to provide a hash code.

One interesting point about IEqualityComparer<T> is that the GetHashCode() method is meant to throw an exception if it’s provided with a null argument, but in practice the EqualityComparer<T>.Default implementations appear not to. This leads to an interesting question about Distinct: how should it handle null elements? It’s not documented either way, but in reality both the LINQ to Objects implementation and the simplest way of implementing it ourselves simply throws a NullReferenceException if you use a not-null-safe comparer and have a null element present. Note that the default equality comparer for any type (EqualityComparer<T>.Default) does cope with nulls.

There are other undocumented aspects of Distinct, too. Both the ordering of the result sequence and the choice of which exact element is returned when there are equal options are unspecified. In the case of ordering, it’s explicitly unspecified. From the documentation: "The Distinct method returns an unordered sequence that contains no duplicate values." However, there’s a natural approach which answers both of these questions. Distinct is specified to use deferred execution (so it won’t look at the input sequence until you start reading from the output sequence) but it also streams the results to some extent: to return the first element in the result sequence, it only needs to read the first element from the input sequence. Some other operators (such as OrderBy) have to read all their data before yielding any results.

When you implement Distinct in a way which only reads as much data as it has to, the answer to the ordering and element choice is easy:

  • The result sequence is in the same order as the input sequence
  • When there are multiple equal elements, it is the one which occurs earliest in the input sequence which is returned as part of the result sequence.

Remember that it’s perfectly possible to have elements which are considered equal under a particular comparer, but are still clearly different when looked at another way. The simplest example of this is case-insensitive string equality. Taking the above rules into account, the distinct sequence returned for { "ABC", "abc", "xyz" } with a case-insensitive comparer is { "ABC", "xyz" }.

What are we going to test?

All of the above :)

All the tests use sequences of strings for clarity, but I’m using four different comparers:

  • The default string comparer (which is a case-sensitive ordinal comparer)
  • The case-insensitive ordinal comparer
  • A comparer which uses object identity (so will treat two equal but distinct strings as different)
  • A comparer which explicitly doesn’t try to cope with null values

The tests assume that the undocumented aspects listed above are implemented with the rules that I’ve given. This means they’re over-sensitive, in that an implementation of Distinct which matches all the documented behaviour but returns elements in a different order would fail the tests. This highlights an interesting aspect of unit testing in general… what exactly are we trying to test? I can think of three options in our case:

  • Just the documented behaviour: anything conforming to that, however oddly, should pass
  • The LINQ to Objects behaviour: the framework implementation should pass all our tests, and then our implementation should as well
  • Our implementation’s known (designed) behaviour: we can specify that our implementation will follow particular rules above and beyond the documented contracts

In production projects, these different options are valid in different circumstances, depending on exactly what you’re trying to do. At the moment, I don’t have any known differences in behaviour between LINQ to Objects and Edulinq, although that may well change later in terms of optimizations.

None of the tests themselves are particularly interesting – although I find it interesting that I had to implement a deliberately fragile (but conformant) implementation of IEqualityComparer<T> in order to test Distinct fully.

Let’s implement it!

I’m absolutely confident in implementing the overload that doesn’t take a custom comparer using the one that does. We have two options for how to specify the custom comparer in the delegating call though – we could pass null or EqualityComparer<T>.Default, as the two are explicitly defined to behave the same way in the second overload. I’ve chosen to pass in EqualityComparer<T>.Default just for the sake of clarity – it means that anyone reading the first method doesn’t need to check the behavior of the second to understand what it will do.

We need to use the "private iterator block method" approach again, so that the arguments can be evaluated eagerly but still let the result sequence use deferred execution. The real work method uses HashSet<T> to keep track of all the elements we’ve already returned – it takes an IEqualityComparer<T> in its constructor, and the Add method adds an element to the set if there isn’t already an equal one, and returns whether or not it really had to add anything. All we have to do is iterate over the input sequence, call Add, and yield the item as part of the result sequence if Add returned true. Simple!

public static IEnumerable<TSource> Distinct<TSource>(
    this IEnumerable<TSource> source)
{
    return source.Distinct(EqualityComparer<TSource>.Default);
}

public static IEnumerable<TSource> Distinct<TSource>(
    this IEnumerable<TSource> source,
    IEqualityComparer<TSource> comparer)
{
    if (source == null
    {
        throw new ArgumentNullException("source");
    }
    return DistinctImpl(source, comparer ?? EqualityComparer<TSource>.Default);
}

private static IEnumerable<TSource> DistinctImpl<TSource>(
    IEnumerable<TSource> source,
    IEqualityComparer<TSource> comparer)
{
    HashSet<TSource> seenElements = new HashSet<TSource>(comparer);
    foreach (TSource item in source)
    {
        if (seenElements.Add(item))
        {
            yield return item;
        }
    }
}

So what about the behaviour with nulls? Well, it seems that HashSet<T> just handles that automatically, if the comparer it uses does. So long as the comparer returns the same hash code each time it’s passed null, and considers null and null to be equal, it can be present in the sequence. Without HashSet<T>, we’d have had a much uglier implementation – especially as Dictionary<TKey, TValue> doesn’t allow null keys.

Conclusion

I’m frankly bothered by the lack of specificity in the documentation for Distinct. Should you rely on the ordering rules that I’ve given here? I think that in reality, you’re reasonably safe to rely on it – it’s the natural result of the most obvious implementation, after all. I wouldn’t rely on the same results when using a different LINQ provider, mind you – when fetching the results back from a database, for example, I wouldn’t be at all surprised to see the ordering change. And of course, the fact that t documentation explicitly states that the result is unordered should act as a bit of a deterrent from relying on this.

We’ll have to make similar decisions for the other set-based operators: Union, Intersect and Except. And yes, they’re very likely to use HashSet<T> too…

Reimplementing LINQ to Objects: Part 13 – Aggregate

EDIT: I’ve had to edit this quite a bit now that a second bug was discovered… basically my implementation of the first overload was completely broken :(

Last night’s tweet asking for suggestions around which operator to implement next resulted in a win for Aggregate, so here we go.

What is it?

Aggregate has three overloads, effectively allow two defaults:

public static TSource Aggregate<TSource>(
    this IEnumerable<TSource> source,
    Func<TSource, TSource, TSource> func)

public static TAccumulate Aggregate<TSource, TAccumulate>(
    this IEnumerable<TSource> source,
    TAccumulate seed,
    Func<TAccumulate, TSource, TAccumulate> func)

public static TResult Aggregate<TSource, TAccumulate, TResult>(
    this IEnumerable<TSource> source,
    TAccumulate seed,
    Func<TAccumulate, TSource, TAccumulate> func,
    Func<TAccumulate, TResult> resultSelector)

Aggregate is an extension method using immediate execution, returning a single result. The generalised behaviour is as follows:

  • Start off with a seed. For the first overload, this defaults to the first value of the input sequence. The seed is used as the first accumulator value.
  • For each item in the list, apply the aggregation function, which takes the current accumulator value and the newly found item, and returns a new accumulator value.
  • Once the sequence has been exhausted, optionally apply a final projection to obtain a result. If no projection has been specified, we can imagine that the identity function has been provided.

The signatures make all of this look a bit more complicated because of the various type parameters involved. You can consider all the overloads as dealing with three different types, even though the first two actually have fewer type parameters:

  • TSource is the element type of the sequence, always.
  • TAccumulate is the type of the accumulator – and thus the seed. For the first overload where no seed is provided, TAccumulate is effectively the same as TSource.
  • TResult is the return type when there’s a final projection involved. For the first two overloads, TResult is effectively the same as TAccumulate (again, think of a default "identity projection" as being used when nothing else is specified)

In the first overload, which uses the first input element as the seed, an InvalidOperationException is thrown if the input sequence is empty.

What are we going to test?

Obviously the argument validation is reasonably simple to test – source, func and resultSelector can’t be null. But there are two different approaches to testing the "success" cases.

We could work out exactly when each delegate should be called and with what values – effectively mock every step of the iteration. This would be a bit of a pain, but a very robust way of proceeding.

The alternative approach is just to take some sample data and aggregation function, work out what the result should be, and assert that result. If the result is sufficiently unlikely to be achieved by chance, this is probably good enough – and it’s a lot simpler to implement. Here’s a sample from the most complicated test, where we have a seed and a final projection:

[Test]
public void SeededAggregationWithResultSelector()
{
    int[] source = { 1, 4, 5 };
    int seed = 5;
    Func<int, int, int> func = (current, value) => current * 2 + value;
    Func<int, string> resultSelector = result => result.ToInvariantString();
    // First iteration: 5 * 2 + 1 = 11
    // Second iteration: 11 * 2 + 4 = 26
    // Third iteration: 26 * 2 + 5 = 57
    // Result projection: 57.ToString() = "57"
    Assert.AreEqual("57", source.Aggregate(seed, func, resultSelector));
}

Now admittedly I’m not testing this to the absolute full – I’m using the same types for TSource and TAccumulate – but frankly it gives me plenty of confidence that the implementation is correct.

EDIT: My result selector now calls ToInvariantString. It used to just call ToString, but as I’ve now been persuaded that there are some cultures where that wouldn’t give us the right results, I’ve implemented an extension method which effectively means that x.ToInvariantString() is equivalent to x.ToString(CultureInfo.InvariantCulture) – so we don’t need to worry about cultures with different numeric representations etc.

Just for the sake of completeness (I’ve convinced myself to improve the code while writing this blog post), here’s an example which sums integers, but results in a long – so it copes with a result which is bigger than Int32.MaxValue. I haven’t bothered with a final projection though:

[Test]
public void DifferentSourceAndAccumulatorTypes()
{
    int largeValue = 2000000000;
    int[] source = { largeValue, largeValue, largeValue };
    long sum = source.Aggregate(0L, (acc, value) => acc + value);
    Assert.AreEqual(6000000000L, sum);
    // Just to prove we haven’t missed off a zero…
    Assert.IsTrue(sum > int.MaxValue); 
}

Since I first wrote this post, I’ve also added tests for empty sequences (where the first overload should throw an exception) and a test which relies on the first overload using the first element of the sequence as the seed, rather than the default value of the input sequence’s element type.

Okay, enough about the testing… what about the real code?

Let’s implement it!

I’m still feeling my way around when it’s a good idea to implement one method by using another, but at the moment my gut feeling is that it’s okay to do so when:

  • You’re implementing one operator by reusing another overload of the same operator; in other words, no unexpected operators will end up in the stack trace of callers
  • There are no significant performance penalties for doing so
  • The observed behaviour is exactly the same – including argument validation
  • The code ends up being simpler to understand (obviously)

Contrary to an earlier version of this post, the first overload can’t be implemented in terms of the second or third ones, because of its behaviour regarding the seed and empty sequences.

public static TSource Aggregate<TSource>(
    this IEnumerable<TSource> source,
    Func<TSource, TSource, TSource> func)
{
    if (source == null)
    {
        throw new ArgumentNullException("source");
    }
    if (func == null)
    {
        throw new ArgumentNullException("func");
    }

    using (IEnumerator<TSource> iterator = source.GetEnumerator())
    {
        if (!iterator.MoveNext())
        {
            throw new InvalidOperationException("Source sequence was empty");
        }
        TSource current = iterator.Current;
        while (iterator.MoveNext())
        {
            current = func(current, iterator.Current);
        }
        return current;
    }
}

It still makes sense to share an implementation for the second and third overloads though. There’s a choice around whether to implement the second operator in terms of the third (giving it an identity projection) or to implement the third operator in terms of the second (by just calling the second overload and then applying a projection). Obviously applying an unnecessary identity projection has a performance penalty in itself – but it’s a tiny penalty. So which is more readable? I’m in two minds about this. I like code where various methods call one other "central" method where all the real work occurs (suggesting implementing the second overload using the third) but equally I suspect I really think about aggregation in terms of getting the final value of the accumulator… with just a twist in the third overload, of an extra projection. I guess it depends on whether you think of the final projection as part of the general form or an "extra" step.

For the moment, I’ve gone with the "keep all logic in one place" approach:

public static TAccumulate Aggregate<TSource, TAccumulate>(
    this IEnumerable<TSource> source,
    TAccumulate seed,
    Func<TAccumulate, TSource, TAccumulate> func)
{
    return source.Aggregate(seed, func, x => x);
}

public static TResult Aggregate<TSource, TAccumulate, TResult>(
    this IEnumerable<TSource> source,
    TAccumulate seed,
    Func<TAccumulate, TSource, TAccumulate> func,
    Func<TAccumulate, TResult> resultSelector)
{
    if (source == null)
    {
        throw new ArgumentNullException("source");
    }
    if (func == null)
    {
        throw new ArgumentNullException("func");
    }
    if (resultSelector == null)
    {
        throw new ArgumentNullException("resultSelector");
    }
    TAccumulate current = seed;
    foreach (TSource item in source)
    {
        current = func(current, item);
    }
    return resultSelector(current);
}

The bulk of the "real work" method is argument validation – the actual iteration is almost painfully simple.

Conclusion

The moral of today’s story is to read the documentation carefully – sometimes there’s unexpected behaviour to implement. I still don’t really know why this difference in behaviour exists… it feels to me as if the first overload really should behave like the second one, just with a default initial seed. EDIT: it seems that you need to read it really carefully. You know, every word of it. Otherwise you could make an embarrassing goof in a public blog post. <sigh>

The second moral should really be about the use of Aggregate – it’s a very generalized operator, and you can implement any number of other operators (Sum, Max, Min, Average etc) using it. In some ways it’s the scalar equivalent of SelectMany, just in terms of its diversity. Maybe I’ll show some later operators implemented using Aggregate…

Next up, there have been requests for some of the set-based operators – Distinct, Union, etc – so I’ll probably look at those soon.

Reimplementing LINQ to Objects: Part 12 – DefaultIfEmpty

After the masses of code required for all the permutations of First/Last/etc, DefaultIfEmpty is a bit of a relief.

What is it?

Even this simple operator has two overloads:

public static IEnumerable<TSource> DefaultIfEmpty<TSource>(
    this IEnumerable<TSource> source)

public static IEnumerable<TSource> DefaultIfEmpty<TSource>(
    this IEnumerable<TSource> source,
    TSource defaultValue)

The behaviour is very simple to describe:

  • If the input sequence is empty, the result sequence has a single element in it, the default value. This is default(TSource) for the overload without an extra parameter, or the specified value otherwise.
  • If the input sequence isn’t empty, the result sequence is the same as the input sequence
  • The source argument must not be null, and this is validated eagerly
  • The result sequence itself uses deferred execution – the input sequence isn’t read at all until the result sequence is read
  • The input sequence is streamed; any values read are yielded immediately; no buffering is used

Dead easy.

What are we going to test?

Despite being relatively late in the day, I decided to test for argument validation – and a good job too, as my first attempt failed to split the implementation into an argument validation method and an iterator block method for the real work! It just shows how easy it is to slip up.

Other than that, I can really only see four cases worth testing:

  • No default value specified, empty input sequence
  • Default value specified, empty input sequence
  • No default value specified, non-empty input sequence
  • Default value specified, non-empty input sequence

So I have tests for all of those, and that’s it. I don’t have anything testing for streaming, lazy evaluation etc.

Let’s implement it!

Despite my reluctance to implement one operator in terms of another elsewhere, this felt like such an obvious case, I figured it would make sense just this once. I even applied DRY to the argument validation aspect. Here’s the implementation in all its brief glory:

public static IEnumerable<TSource> DefaultIfEmpty<TSource>(
    this IEnumerable<TSource> source)
{
    // This will perform an appropriate test for source being null first.
    return source.DefaultIfEmpty(default(TSource));
}

public static IEnumerable<TSource> DefaultIfEmpty<TSource>(
    this IEnumerable<TSource> source,
    TSource defaultValue)
{
    if (source == null)
    {
        throw new ArgumentNullException("source");
    }
    return DefaultIfEmptyImpl(source, defaultValue);
}

private static IEnumerable<TSource> DefaultIfEmptyImpl<TSource>(
    IEnumerable<TSource> source,
    TSource defaultValue)
{
    bool foundAny = false;
    foreach (TSource item in source)
    {
        yield return item;
        foundAny = true;
    }
    if (!foundAny)
    {
        yield return defaultValue;
    }
}

Of course, now that I’ve said how easy it was, someone will find a bug…

Aside from the use of default(TSource) to call the more complex overload from the simpler one, the only aspect of any interest is the bottom method. It irks me slightly that we’re assigning "true" to "foundAny" on every iteration… but the alternative is fairly unpleasant:

private static IEnumerable<TSource> DefaultIfEmptyImpl<TSource>(
    IEnumerable<TSource> source,
    TSource defaultValue)
{
    using (IEnumerator<TSource> iterator = source.GetEnumerator())
    {
        if (!iterator.MoveNext())
        {
            yield return defaultValue;
            yield break; // Like a "return"
        }
        yield return iterator.Current;
        while (iterator.MoveNext())
        {
            yield return iterator.Current;
        }
    }
}

This may be slightly more efficient, but it feels a little clumsier. We could get rid of the "yield break" by putting the remainder of the method in an "else" block, but I’m not dead keen on that, either. We could use a do/while loop instead of a simple while loop – that would at least remove the repetition of "yield return iterator.Current" but I’m not really a fan of do/while loops. I use them sufficiently rarely that they cause me more mental effort to read than I really like.

If any readers have suggestions which are significantly nicer than either of the above implementations, I’d be interested to hear them. This feels a little inelegant at the moment. It’s far from a readability disaster – it’s just not quite neat.

Conclusion

Aside from the slight annoyance at the final lack of elegance, there’s not much of interest here. However, we could now implement FirstOrDefault/LastOrDefault/SingleOrDefault using DefaultIfEmpty. For example, here’s an implementation of FirstOrDefault:

public static TSource FirstOrDefault<TSource>(
    this IEnumerable<TSource> source)
{
    return source.DefaultIfEmpty().First();
}

public static TSource FirstOrDefault<TSource>(
    this IEnumerable<TSource> source,
    Func<TSource, bool> predicate)
{
    // Can’t just use source.DefaultIfEmpty().First(predicate)
    return source.Where(predicate).DefaultIfEmpty().First();
}

Note the comment in the predicated version – the defaulting has to be the very last step after we’ve applied the predicate… otherwise if we pass in an empty sequence and a predicate which doesn’t match with default(TSource), we’ll get an exception instead of the default value. The other two …OrDefault operators could be implemented in the same way, of course. (I haven’t done so yet, but the above code is in source control.)

I’m currently unsure what I’ll implement next. I’ll see whether I get inspired by any particular method in the morning.

Reimplementing LINQ to Objects: Part 11 – First/Single/Last and the …OrDefault versions

Today I’ve implemented six operators, each with two overloads. At first I expected the implementations to be very similar, but they’ve all turned out slightly differently…

What are they?

We’ve got three axes of permutation here: {First, Last, Single}, {with/without OrDefault }, {with/without a predicate}. That gives us these twelve signatures:

public static TSource First<TSource>(
    this IEnumerable<TSource> source)

public static TSource First<TSource>(
    this IEnumerable<TSource> source,
    Func<TSource, bool> predicate)

public static TSource FirstOrDefault<TSource>(
    this IEnumerable<TSource> source)

public static TSource FirstOrDefault<TSource>(
    this IEnumerable<TSource> source,
    Func<TSource, bool> predicate)

public static TSource Last<TSource>(
    this IEnumerable<TSource> source)

public static TSource Last<TSource>(
    this IEnumerable<TSource> source,
    Func<TSource, bool> predicate)

public static TSource LastOrDefault<TSource>(
    this IEnumerable<TSource> source)

public static TSource LastOrDefault<TSource>(
    this IEnumerable<TSource> source,
    Func<TSource, bool> predicate)

public static TSource Single<TSource>(
    this IEnumerable<TSource> source)

public static TSource Single<TSource>(
    this IEnumerable<TSource> source,
    Func<TSource, bool> predicate)

public static TSource SingleOrDefault<TSource>(
    this IEnumerable<TSource> source)

public static TSource SingleOrDefault<TSource>(
    this IEnumerable<TSource> source,
    Func<TSource, bool> predicate)

The shared behaviour is as follows:

  • They’re all extension methods with a single generic type argument
  • They’re all implemented with immediate execution
  • They all validate that their parameters are non-null
  • The overloads with a predicate are equivalent to calling source.Where(predicate).SameOperator() – in other words, they just add a filter before applying the operator.

With those rules applied, we simply need to consider three possibilities for each operator: what happens if the source sequence is empty, contains a single element, or contains multiple elements. (This is after applying the filter if a predicate is specified, of course.) We can draw the results up in a simple table:

Operator Empty sequence Single element Multiple elements
First Throws exception Returns element Returns first element
FirstOrDefault Returns default(TSource) Returns element Returns first element
Last Throws exception Returns element Returns last element
LastOrDefault Returns default(TSource) Returns element Returns last element
Single Throws exception Returns element Throws exception
SingleOrDefault Returns default(TSource) Returns element Throws exception

As you can see, for an input sequence with a single element, the results are remarkably uniform :) Likewise, for an empty input sequence, any operator without "OrDefault" throws an exception (InvalidOperationException, in fact) and any operator with "OrDefault" returns the default value for the element type (null for reference types, 0 for int etc). The operators really differ if the (potentially filtered) input sequence contains multiple elements – First and Last do the obvious thing, and Single throws an exception. It’s worth noting that SingleOrDefault also throws an exception – it’s not like it’s saying, "If the sequence is a single element, return it – otherwise return the default value." If you want an operator which handles multiple elements, you should be using First or Last, with the "OrDefault" version if the sequence can legitimately have no elements. Note that if you do use an "OrDefault" operator, the result is exactly the same for an empty input sequence as for an input sequence containing exactly one element which is the default value. (I’ll be looking at the DefaultIfEmpty operator next.)

Now we know what the operators do, let’s test them.

What are we going to test?

This morning I tweeted that I had written 72 tests before writing any implementation. In fact I ended up with 80, for reasons we’ll come to in a minute. For each operator, I tested the following 12 cases:

  • Null source (predicate-less overload)
  • Null source (predicated overload)
  • Null predicate
  • Empty sequence, no predicate
  • Empty sequence, with a predicate
  • Single element sequence, no predicate
  • Single element sequence where the element matched the predicate
  • Single element sequence where the element didn’t match the predicate
  • Multiple element sequence, no predicate
  • Multiple element
  • Multiple element sequence where one element matched the predicate
  • Multiple element sequence where multiple elements matched the predicate

These were pretty much cut and paste jobs – I used the same data for each test against each operator, and just changed the expected results.

There are two extra tests for each of First and FirstOrDefault, and two for each of Last and LastOrDefault:

  • First/FirstOrDefault should return as soon as they’ve seen the first element, when there’s no predicate; they shouldn’t iterate over the rest of the sequence
  • First/FirstOrDefault should return as soon as they’ve seen the first matching element, when there is a predicate
  • Last/LastOrDefault are optimized for the case where the source implements IList<T> and there’s no predicate: it uses Count and the indexer to access the final element
  • Last/LastOrDefault is not optimized for the case where the source implements IList<T> but there is a predicate: it iterates through the entire sequence

The last two tests involved writing a new collection called NonEnumerableList which implements IList<T> by delegating everything to a backing List<T>, except for GetEnumerator() (both the generic and nongeneric forms) which simply throws a NotSupportedException. That should be a handy one for testing optimizations in the future. I’ll discuss the optimization for Last when we get there.

Let’s implement them!

These operators were more interesting to implement than I’d expect, so I’m actually going to show all twelve methods. It was rarely just a matter of cutting and pasting, other than for the argument validation.

Of course, if we chose to implement the predicated versions using "Where" and the non-predicated form, and the "OrDefault" versions by using "DefaultIfEmpty" followed by the non-defaulting version, we would only have had the three non-predicated, non-defaulting versions to deal with… but as I’ve said before, there are some virtues to implementing each operator separately.

For the sake of avoiding fluff, I’ve removed the argument validation from each method – but obviously it’s there in the real code. Let’s start with First:

public static TSource First<TSource>(
    this IEnumerable<TSource> source)
{
    // Argument validation elided
    using (IEnumerator<TSource> iterator = source.GetEnumerator())
    {
        if (iterator.MoveNext())
        {
            return iterator.Current;
        }
        throw new InvalidOperationException("Sequence was empty");
    }
}

public static TSource First<TSource>(
    this IEnumerable<TSource> source,
    Func<TSource, bool> predicate)
{
    // Argument validation elided
    foreach (TSource item in source)
    {
        if (predicate(item))
        {
            return item;
        }
    }
    throw new InvalidOperationException("No items matched the predicate");
}

These look surprisingly different – and that was actually a deliberate decision. I could easily have implemented the predicate-less version with a foreach loop as well, just returning unconditionally from its body. However, I chose to emphasize the fact that we’re not looping in First: we simply move to the first element if we can, and return it or throw an exception. There’s no hint that we might ever call MoveNext again. In the predicated form, of course, we have to keep looping until we find a matching value – only throwing the exception when we’ve exhausted all possibilities.

Now let’s see how it looks when we return a default for empty sequences:

public static TSource FirstOrDefault<TSource>(
    this IEnumerable<TSource> source)
{
    // Argument validation elided
    using (IEnumerator<TSource> iterator = source.GetEnumerator())
    {
        return iterator.MoveNext() ? iterator.Current : default(TSource);
    }
}

public static TSource FirstOrDefault<TSource>(
    this IEnumerable<TSource> source,
    Func<TSource, bool> predicate)
{
    // Argument validation elided
    foreach (TSource item in source)
    {
        if (predicate(item))
        {
            return item;
        }
    }
    return default(TSource);
}

Here the predicated form looks very similar to First, but the predicate-less one is slightly different: instead of using an if block (which would be a perfectly valid approach, of course) I’ve used the conditional operator. We’re going to return something, whether we manage to move to the first element or not. Arguably it would be nice if the conditional operator allowed the second or third operands to be "throw" expressions, taking the overall type of the expression from the other result operand… but it’s no great hardship.

Next up we’ll implement Single, which is actually closer to First than Last is, in some ways:

public static TSource Single<TSource>(
    this IEnumerable<TSource> source)
{
    // Argument validation elided
    using (IEnumerator<TSource> iterator = source.GetEnumerator())
    {
        if (!iterator.MoveNext())
        {
            throw new InvalidOperationException("Sequence was empty");
        }
        TSource ret = iterator.Current;
        if (iterator.MoveNext())
        {
            throw new InvalidOperationException("Sequence contained multiple elements");
        }
        return ret;
    }
}

public static TSource Single<TSource>(
    this IEnumerable<TSource> source,
    Func<TSource, bool> predicate)
{
    // Argument validation elided
    TSource ret = default(TSource);
    bool foundAny = false;
    foreach (TSource item in source)
    {
        if (predicate(item))
        {
            if (foundAny)
            {
                throw new InvalidOperationException("Sequence contained multiple matching elements");
            }
            foundAny = true;
            ret = item;
        }
    }
    if (!foundAny)
    {
        throw new InvalidOperationException("No items matched the predicate");
    }
    return ret;
}

This is already significantly more complex than First. The predicate-less version starts off the same way, but if we manage to move to the first element, we have to remember that value (as we’re hoping to return it) and then try to move to the second element. This time, if the move succeeds, we have to throw an exception – otherwise we can return our saved value.

The predicated version is even hairier. We still need to remember the first matching value we find, but this time we’re looping – so we need to keep track of whether we’ve already seen a matching value or not. If we see a second match, we have to throw an exception… and we also have to throw an exception if we reach the end without finding any matches at all. Note that although we assign an initial value of default(TSource) to ret, we’ll never reach a return statement without assigning a value to it. However, the rules around definite assignment aren’t smart enough to cope with this, so we need to provide a "dummy" value to start with… and default(TSource) is really the only value available. There is an alternative approach without using a foreach statement, where you loop until you find the first match, assign it to a local variable which is only declared at that point, followed by a second loop ensuring that there aren’t any other matches. I personally think that’s a bit more complex, which is why I’ve just used the foreach here.

The difference when we implement SingleOrDefault isn’t quite as pronounced this time though:

public static TSource SingleOrDefault<TSource>(
    this IEnumerable<TSource> source)
{
    // Argument validation elided
    using (IEnumerator<TSource> iterator = source.GetEnumerator())
    {
        if (!iterator.MoveNext())
        {
            return default(TSource);
        }
        TSource ret = iterator.Current;
        if (iterator.MoveNext())
        {
            throw new InvalidOperationException("Sequence contained multiple elements");
        }
        return ret;
    }
}

public static TSource SingleOrDefault<TSource>(
    this IEnumerable<TSource> source,
    Func<TSource, bool> predicate)
{
    // Argument validation elided
    TSource ret = default(TSource);
    bool foundAny = false;
    foreach (TSource item in source)
    {
        if (predicate(item))
        {
            if (foundAny)
            {
                throw new InvalidOperationException("Sequence contained multiple matching elements");
            }
            foundAny = true;
            ret = item;
        }
    }
    return ret;
}

This time we’ve just replaced a "throw" statement in the predicate-less method with a "return" statement, and removed the test for no matches being found in the predicated method. Here our assignment of default(TSource) to ret really works in our favour – if we don’t end up assigning anything else to it, we’ve already got the right return value!

Next up is Last:

public static TSource Last<TSource>(
    this IEnumerable<TSource> source)
{
    // Argument validation elided
    IList<TSource> list = source as IList<TSource>;
    if (list != null)
    {
        if (list.Count == 0)
        {
            throw new InvalidOperationException("Sequence was empty");
        }
        return list[list.Count – 1];
    }

    using (IEnumerator<TSource> iterator = source.GetEnumerator())
    {
        if (!iterator.MoveNext())
        {
            throw new InvalidOperationException("Sequence was empty");
        }
        TSource last = iterator.Current;
        while (iterator.MoveNext())
        {
            last = iterator.Current;
        }
        return last;
    }
}

public static TSource Last<TSource>(
    this IEnumerable<TSource> source,
    Func<TSource, bool> predicate)
{
    // Argument validation elided
    bool foundAny = false;
    TSource last = default(TSource);
    foreach (TSource item in source)
    {
        if (predicate(item))
        {
            foundAny = true;
            last = item;
        }
    }
    if (!foundAny)
    {
        throw new InvalidOperationException("No items matched the predicate");
    }
    return last;
}

Let’s start off with the optimization at the beginning of the predicate-less method. If we find out we’re dealing with a list, we can simply fetch the count, and then either throw an exception or return the element at the final index. As an extra bit of optimization I could store the count in a local variable – but I’m assuming that the count of an IList<T> is cheap to compute. If there are significant objections to that assumption, I’m happy to change it :) Note that this is another situation where I’m assuming that anything implementing IList<T> will only hold at most Int32.MaxValue items – otherwise the optimization will fail.

If we don’t follow the optimized path, we simply iterate over the sequence, updating a local variable with the last-known element on every single iteration. This time I’ve avoided the foreach loop for no particularly good reason – we could easily have had a foundAny variable which was just set to "true" on every iteration, and then tested at the end. In fact, that’s exactly the pattern the predicated method takes. Admittedly that decision is forced upon us to some extent – we can’t just move once and then take the first value as the "first last-known element", because it might not match the predicate.

There’s no optimization for the predicated form of Last. This follows LINQ to Objects, but I don’t honestly know the reason for it there. We could easily iterate backwards from the end of the sequence using the indexer on each iteration. One possible reason which makes a certain amount of sense is that when there’s a predicate, that predicate could throw an exception for some values – and if we just skipped to the end if the collection implements IList<T>, that would be an observable difference. I’d be interested to know whether or not that is the reason – if anyone has any inside information which they can share, I’ll update this post.

From here, we only have one more operator to implement – LastOrDefault:

public static TSource LastOrDefault<TSource>(
    this IEnumerable<TSource> source)
{
    // Argument validation elided
    IList<TSource> list = source as IList<TSource>;
    if (list != null)
    {
        return list.Count == 0 ? default(TSource) : list[list.Count – 1];
    }

    TSource last = default(TSource);
    foreach (TSource item in source)
    {
        last = item;
    }
    return last;
}

public static TSource LastOrDefault<TSource>(
    this IEnumerable<TSource> source,
    Func<TSource, bool> predicate)
{
    // Argument validation elided
    TSource last = default(TSource);
    foreach (TSource item in source)
    {
        if (predicate(item))
        {
            last = item;
        }
    }
    return last;
}

This time, aside from the optimization, the predicated and non-predicated forms look very similar… more so than for any of the other operators. In each case, we start with a return value of default(TSource), and iterate over the whole sequence, updating it – only doing so when it matches the predicate, if we’ve got one.

Conclusion

This was a longer post than I anticipated when I got up this morning, but I hope the slight differences between implementations – and the mystery of the unoptimized predicated "Last/LastOrDefault" operators have made it worth slogging through.

As a contrast – and because I’ve already mentioned it in this post – I’ll implement DefaultIfEmpty next. I reckon I can still do that this evening, if I hurry…

Addendum

It turns out I was missing some tests for Single and SingleOrDefault: what should they do if evaluating the sequence fully throws an exception? It turns out that in LINQ to Objects, the overloads without a predicate throw InvalidOperationException as soon as they see a second element, but the overloads with a predicate keep iterating even when they’ve seen a second element matching a predicate. This seems ludicrously inconsistent to me – I’ve opened a Connect issue about it; we’ll see what happens.

Reimplementing LINQ to Objects: Part 10 – Any and All

Another day, another blog post. I should emphasize that this rate of posting is likely to be short-lived… although if I get into the habit of writing a post on the morning commute when I go back to work after the Christmas holidays, I could keep ploughing through until we’re done.

Anyway, today we have a pair of operators: Any and All.

What are they?

"Any" has two overloads; there’s only one for "All":

public static bool Any<TSource>(
    this IEnumerable<TSource> source)

public static bool Any<TSource>(
    this IEnumerable<TSource> source,
    Func<TSource, bool> predicate)

public static bool All<TSource>(
    this IEnumerable<TSource> source,
    Func<TSource, bool> predicate)

The names really are fairly self-explanatory:

  • "Any" without a predicate returns whether there are any elements in the input sequence
  • "Any" with a predicate returns whether any elements in the input sequence match the predicate
  • "All" returns whether all the elements in the input sequence match the given predicate

Both operators use immediate execution – they don’t return until they’ve got the answer, basically.

Importantly, "All" has to read through the entire input sequence to return true, but can return as soon as it’s found a non-matching element; "Any" can return true as soon as it’s found a matching element, but has to iterate over the entire input sequence in order to return false. This gives rise to one very simple LINQ performance tip: it’s almost never a good idea to use a query like

// Don’t use this
if (query.Count() != 0)

That has to iterate over *all* the results in the query… when you really only care whether or not there are any results. So use "Any" instead:

// Use this instead
if (query.Any())

If this is part of a bigger LINQ to SQL query, it may not make a difference – but in LINQ to Objects it can certainly be a huge boon.

Anyway, let’s get on to testing the three methods…

What are we going to test?

Feeling virtuous tonight, I’ve even tested argument validation again… although it’s easy to get that right here, as we’re using immediate execution.

Beyond that, I’ve tested a few scenarios:

  • An empty sequence will return false with Any, but true with All. (Whatever the predicate is for All, there are no elements which fail it.)
  • A sequence with any elements at all will make the predicate-less Any return true.
  • If all the elements don’t match the predicate, both Any and All return false.
  • If some elements match the predicate, Any will return true but All will return false.
  • If all elements match the predicate, All will return true.

Those are all straightforward, so I won’t give the code. One final test is interesting though: we prove that Any returns as soon as it’s got the result by giving it a query which will throw an exception if it’s iterated over completely. The easiest way of doing this is to start out with a sequence of integers including 0, and then use Select with a projection which divides some constant value by each element. In this test case, I’ve given it a value which will match the predicate before the value which will cause the exception to be thrown:

[Test]
public void SequenceIsNotEvaluatedAfterFirstMatch()
{
    int[] src = { 10, 2, 0, 3 };
    var query = src.Select(x => 10 / x);
    // This will finish at the second element (x = 2, so 10/x = 5)
    // It won’t evaluate 10/0, which would throw an exception
    Assert.IsTrue(query.Any(y => y > 2));
}

There’s an equivalent test for All, where a non-matching element occurs before the exceptional one.

So, with all the tests written, let’s get on with the interesting bit:

Let’s implement them!

The first thing to note is that all of these could be implemented in terms of either Any-with-a-predicate or All. For example, given All, we could implement Any as:

public static bool Any<TSource>(
    this IEnumerable<TSource> source)
{
    return source.Any(x => true);
}

public static bool Any<TSource>(
    this IEnumerable<TSource> source,
    Func<TSource, bool> predicate)
{
    if (predicate == null)
    {
        throw new ArgumentNullException("predicate");
    }
    return !source.All(x => !predicate(x));
}

It’s simplest to implement the predicate-less Any in terms of the predicated one – using a predicate which returns true for any element means that Any will return true for any element at all, which is what we want.

The inversions in the call to All take a minute to get your head round, but it’s basically De Morgan’s law in LINQ form: we effectively invert the predicate to find out if all of the elements don’t match the original predicate… then return the inverse. Due to the inversion, this still returns early in all the appropriate situations, too.

While we could do that, I’ve actually preferred a straightforward implementation of all of the separate methods:

public static bool Any<TSource>(
    this IEnumerable<TSource> source)
{
    if (source == null)
    {
        throw new ArgumentNullException("source");
    }
            
    using (IEnumerator<TSource> iterator = source.GetEnumerator())
    {
        return iterator.MoveNext();
    }
}

public static bool Any<TSource>(
    this IEnumerable<TSource> source,
    Func<TSource, bool> predicate)
{
    if (source == null)
    {
        throw new ArgumentNullException("source");
    }
    if (predicate == null)
    {
        throw new ArgumentNullException("predicate");
    }

    foreach (TSource item in source)
    {
        if (predicate(item))
        {
            return true;
        }
    }
    return false;
}


public static bool All<TSource>(
    this IEnumerable<TSource> source,
    Func<TSource, bool> predicate)
{
    if (source == null)
    {
        throw new ArgumentNullException("source");
    }
    if (predicate == null)
    {
        throw new ArgumentNullException("predicate");
    }

    foreach (TSource item in source)
    {
        if (!predicate(item))
        {
            return false;
        }
    }
    return true;
}

Aside from anything else, this makes it obvious where the "early out" comes in each case – and also means that any stack traces generated are rather easier to understand. It would be quite odd from a client developer’s point of view to call Any but see All in the stack trace, or vice versa.

One interesting point to note is that I don’t actually use a foreach loop in Any – although I could, of course. Instead, I just get the iterator and then return whether the very first call to MoveNext indicates that there are any elements. I like the fact that reading this method it’s obvious (at least to me) that we really couldn’t care less what the value of the first element is – because we never ask for it.

Conclusion

Probably the most important lesson here is the advice to use Any (without a predicate) instead of Count when you can. The rest was pretty simple – although it’s always fun to see one operator implemented in terms of another.

So, what next? Possibly Single/SingleOrDefault/First/FirstOrDefault/Last/LastOrDefault. I might as well do them all together – partly as they’re so similar, and partly to emphasize the differences which do exist.