Custom .NET Data Provider

The .NET framework offers a provider model for ADO.NET. It is possible to register ADO.NET providers that can be used to dynamically create connections and other kinds of objects:

var myProvider = DbProviderFactories.GetFactory("MyProvider");

 

using (var con = myProvider.CreateConnection())

using (var cmd = con.CreateCommand())

{

    con.ConnectionString = "My connection string";

    con.Open();

 

    //do something with cmd

}

The provider name was also mandatory for connection strings specified in configuration files:

<connectionStrings>

    <add connectionString="My connection string" name="MyConnection" providerName="MyProvider"/>

</connectionStrings>

var provider = DbProvider0Factories.GetFactory(ConfigurationManager.ConnectionStrings["MyConnection"].ProviderName);

Before .NET 4.5, this was controlled through the DbProviderFactories section of Machine.config, and it was possible to register our own in a local Web/App.config file. Presently, the built-in providers – SqlClientFactory, OdbcFactory, OleDbFactory and OracleClientFactory – are no longer configured through a file, but are set automatically by the framework, therefore, cannot be easily changed. The registered providers can be inspected through the GetFactoryClasses() method of DbProviderFactories:

var providers = DbProviderFactories.GetFactories();

Using some reflection, however, it is possible to switch a built-in provider for a custom one. Why would you want to do that? Well, for example, to add custom profiling and logging to a connection or command.

The most important base classes in the ADO.NET model are DbConnection and DbCommand. DbCommand can be obtained through an existing DbConnection or from a DbProviderFactory. So, if we want to intercept DbCommand. we need to create custom DbConnection and DbCommand classes:

public class WrapperCommand : DbCommand

{

    private static readonly PropertyInfo canRaiseEventsProp = typeof(Component).GetProperty("CanRaiseEvents", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.GetProperty);

    private static readonly MethodInfo disposeMethod = typeof(Component).GetMethod("Dispose", BindingFlags.NonPublic | BindingFlags.Instance);

    private static readonly MethodInfo getServiceMethod = typeof(Component).GetMethod("GetService", BindingFlags.NonPublic | BindingFlags.Instance);

 

    private readonly DbCommand original;

 

    public WrapperCommand(DbCommand original)

    {

        this.original = original;

    }

 

    public override void Prepare()

    {

        this.original.Prepare();

    }

 

    public override string CommandText

    {

        get { return this.original.CommandText; }

        set { this.original.CommandText = value; }

    }

 

    public override int CommandTimeout

    {

        get { return this.original.CommandTimeout; }

        set { this.original.CommandTimeout = value; }

    }

 

    public override CommandType CommandType

    {

        get { return this.original.CommandType; }

        set { this.original.CommandType = value; }

    }

 

    public override UpdateRowSource UpdatedRowSource

    {

        get { return this.original.UpdatedRowSource; }

        set { this.original.UpdatedRowSource = value; }

    }

 

    protected override DbConnection DbConnection

    {

        get { return this.original.Connection; }

        set { this.original.Connection = value; }

    }

 

    protected override DbParameterCollection DbParameterCollection

    {

        get { return this.original.Parameters; }

    }

 

    protected override DbTransaction DbTransaction

    {

        get { return this.original.Transaction; }

        set { this.original.Transaction = value; }

    }

 

    public override bool DesignTimeVisible

    {

        get { return this.original.DesignTimeVisible; }

        set { this.original.DesignTimeVisible = value; }

    }

 

    public override void Cancel()

    {

        this.original.Cancel();

    }

 

    protected override DbParameter CreateDbParameter()

    {

        return this.original.CreateParameter();

    }

 

    protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior)

    {

        return this.original.ExecuteReader(behavior);

    }

 

    public override int ExecuteNonQuery()

    {

        return this.original.ExecuteNonQuery();

    }

 

    public override object ExecuteScalar()

    {

        return this.original.ExecuteScalar();

    }

 

    protected override bool CanRaiseEvents

    {

        get

        {

            return (bool)canRaiseEventsProp.GetValue(this.original, null);

        }

    }

 

    public override ObjRef CreateObjRef(Type requestedType)

    {

        return this.original.CreateObjRef(requestedType);

    }

 

    protected override void Dispose(bool disposing)

    {

        disposeMethod.Invoke(this.original, new object[] { disposing });

    }

 

    protected override Task<DbDataReader> ExecuteDbDataReaderAsync(CommandBehavior behavior, CancellationToken cancellationToken)

    {

        return this.original.ExecuteReaderAsync(behavior, cancellationToken);

    }

 

    public override Task<int> ExecuteNonQueryAsync(CancellationToken cancellationToken)

    {

        return this.original.ExecuteNonQueryAsync(cancellationToken);

    }

 

    public override Task<object> ExecuteScalarAsync(CancellationToken cancellationToken)

    {

        return this.original.ExecuteScalarAsync(cancellationToken);

    }

 

    protected override object GetService(Type service)

    {

        return getServiceMethod.Invoke(this.original, new object[] { service });

    }

 

    public override object InitializeLifetimeService()

    {

        return this.original.InitializeLifetimeService();

    }

 

    public override ISite Site

    {

        get { return this.original.Site; }

        set { this.original.Site = value; }

    }

}

 

public class WrapperConnection : DbConnection

{

    private static readonly MethodInfo disposeMethod = typeof(Component).GetMethod("Dispose", BindingFlags.NonPublic | BindingFlags.Instance);

    private static readonly MethodInfo getServiceMethod = typeof(Component).GetMethod("GetService", BindingFlags.NonPublic | BindingFlags.Instance);

    private static readonly PropertyInfo canRaiseEventsProp = typeof(Component).GetProperty("CanRaiseEvents", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.GetProperty);

    private static readonly MethodInfo onStateChangeMethod = typeof(DbConnection).GetMethod("OnStateChange", BindingFlags.NonPublic | BindingFlags.Instance);

 

    private readonly DbConnection original;

 

    public WrapperConnection(DbConnection original)

    {

        this.original = original;

    }

 

    protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel)

    {

        return this.original.BeginTransaction(isolationLevel);

    }

 

    public override void Close()

    {

        this.original.Close();

    }

 

    public override void ChangeDatabase(string databaseName)

    {

        this.original.ChangeDatabase(databaseName);

    }

 

    public override void Open()

    {

        this.original.Open();

    }

 

    public override string ConnectionString

    {

        get { return this.original.ConnectionString; }

        set { this.original.ConnectionString = value; }

    }

 

    public override string Database

    {

        get { return this.original.Database; }

    }

 

    public override ConnectionState State

    {

        get { return this.original.State; }

    }

 

    public override string DataSource

    {

        get { return this.original.DataSource; }

    }

 

    public override string ServerVersion

    {

        get { return this.original.ServerVersion; }

    }

 

    protected override DbCommand CreateDbCommand()

    {

        return new WrapperCommand(this.original.CreateCommand());

    }

 

    public override ISite Site

    {

        get { return this.original.Site; }

        set { this.original.Site = value; }

    }

 

    protected override DbProviderFactory DbProviderFactory

    {

        get { return WrapperProviderFactory.Instance; }

    }

 

    public override object InitializeLifetimeService()

    {

        return this.original.InitializeLifetimeService();

    }

 

    protected override void Dispose(bool disposing)

    {

        disposeMethod.Invoke(this.original, new object[] { disposing });

    }

 

    protected override object GetService(Type service)

    {

        return getServiceMethod.Invoke(this.original, new object[] { service });

    }

 

    protected override bool CanRaiseEvents

    {

        get { return (bool)canRaiseEventsProp.GetValue(this.original, null); }

    }

 

    public override ObjRef CreateObjRef(Type requestedType)

    {

        return this.original.CreateObjRef(requestedType);

    }

 

    public override int ConnectionTimeout

    {

        get { return this.original.ConnectionTimeout; }

    }

 

    public override void EnlistTransaction(System.Transactions.Transaction transaction)

    {

        this.original.EnlistTransaction(transaction);

    }

 

    public override DataTable GetSchema()

    {

        return this.original.GetSchema();

    }

 

    public override DataTable GetSchema(string collectionName)

    {

        return this.original.GetSchema(collectionName);

    }

 

    public override DataTable GetSchema(string collectionName, string[] restrictionValues)

    {

        return this.original.GetSchema(collectionName, restrictionValues);

    }

 

    protected override void OnStateChange(StateChangeEventArgs stateChange)

    {

        onStateChangeMethod.Invoke(this.original, new object[] { stateChange });

    }

 

    public override Task OpenAsync(CancellationToken cancellationToken)

    {

        return this.original.OpenAsync(cancellationToken);

    }

 

    public override event StateChangeEventHandler StateChange

    {

        add { this.original.StateChange += value; }

        remove { this.original.StateChange -= value; }

    }

}

Our implementations need to take preexisting objects, to which they will delegate the actual work. Some reflection is needed to access non-public members. In order to make some use of it, we could add some virtual methods or events, for example, to add custom code before and after SQL commands are executed.

Next we need some provider factory that can be registered and replace the default functionality of some provider, here’s a possible implementation:

public sealed class WrapperProviderFactory : DbProviderFactory

{

    private static readonly FieldInfo providerTableField = typeof(DbProviderFactories).GetField("_providerTable", BindingFlags.NonPublic | BindingFlags.Static);

    private static readonly FieldInfo instanceField = typeof(WrapperProviderFactory).GetField("Instance", BindingFlags.Public | BindingFlags.Static);

 

    public static readonly WrapperProviderFactory Instance;

 

    private readonly DbProviderFactory original;

 

    private WrapperProviderFactory(DbProviderFactory original)

    {

        this.original = original;

 

        DbProviderFactories.GetFactoryClasses();

 

        var dt = providerTableField.GetValue(null) as DataTable;

        var row = dt.Rows.Find(original.GetType().Namespace);

        dt.Columns["AssemblyQualifiedName"].ReadOnly = false;

        row["AssemblyQualifiedName"] = typeof(WrapperProviderFactory).AssemblyQualifiedName;

        dt.Columns["AssemblyQualifiedName"].ReadOnly = true;

    }

 

    public static void Initialize(DbProviderFactory original)

    {

        instanceField.SetValue(null, new WrapperProviderFactory(original));

    }

 

    public override DbConnection CreateConnection()

    {

        return new WrapperConnection(this.original.CreateConnection());

    }

 

    public override DbCommand CreateCommand()

    {

        return new WrapperCommand(this.original.CreateCommand());

    }

 

    //other method overrides

}

The requirements for a custom DbProviderFactory are simple:

If you are curious, the first call to GetFactoryClasses() is required so that .NET can setup internally its structures and populate them with the default providers.

Our delegation first has the framework create the list of providers and then modify it so as to replace the provider that we want. We will need to initialize our provider by passing it an existing provider (Initialize) before doing something else:

WrapperProviderFactory.Initialize(SqlClientFactory.Instance);

And that’s it. As soon as you call this, your wrapper implementation becomes registered, and whenever the named provider is retrieved, .NET will return this implementation.