AES Encryption in C#

Have you ever wanted to encrypt some sensitive data? Then you have probably came across various articles about AES (Advanced Encryption Standard). As of August 2019, AES is still the recommended algorithm to use so let’s look at how you can use it.

Intro and a little bit of theory

First rule of Fight Club cryptography is not to try to invent your own (unless you have the math skills, of course) but to instead use some battle tested standard like AES.

In a lot of cases, examples of how to use AES are incomplete or even severely reducing the level of security the algorithm provides. After spending some time researching the topic, going through docs and RFCs, I think I have a better understanding of how it should be used and what should be avoided. So in this article I will start with a (bad) basic example and go through a series of steps while gradually improving it. The final version of the code is at the bottom of the article if you just want to grab it.

Sometimes AES and Rijndael get used interchangeably. For a more extensive description of what is the difference between them I recommend reading through the wikipedia article on AES but to summarize it - Rijndael is the underlying algorithm and AES is just prescribing what parameters should be used. These are

Wrong, bad, just don’t …

I was hesitant to even put it here but we have to start somewhere. Below is an example of how NOT to do it.

public static class SymmetricEncryptor
{
    // don't use this

    static string password = "very strong password 123412;,[p;[; 172634812";

    public static byte[] EncryptString(string toEncrypt)
    {
        var key = GetKey(password);

        using (var aes = Aes.Create())
        using (var encryptor = aes.CreateEncryptor(key, key))
        {
            var plainText = Encoding.UTF8.GetBytes(toEncrypt);
            return encryptor.TransformFinalBlock(plainText, 0, plainText.Length);
        }
    }

    public static string DecryptToString(byte[] encryptedData)
    {
        var key = GetKey(password);

        using (var aes = Aes.Create())
        using (var encryptor = aes.CreateDecryptor(key, key))
        {
            var decryptedBytes = encryptor
                .TransformFinalBlock(encryptedData, 0, encryptedData.Length);
            return Encoding.UTF8.GetString(decryptedBytes);
        }
    }

    // converts password to 128 bit hash
    private static byte[] GetKey(string password)
    {
        var keyBytes = Encoding.UTF8.GetBytes(password);
        using (var md5 = MD5.Create())
        {
            return md5.ComputeHash(keyBytes);
        }
    }
}

… and how you could then use it:

class Program
{
    static void Main(string[] args)
    {
        var textToEncrypt = "something you want to hide";
        Console.WriteLine("original text: {0}{1}{0}", 
            Environment.NewLine, textToEncrypt);

        var encryptedData = SymmetricEncryptor.EncryptString(textToEncrypt);
        Console.WriteLine("encrypted data:{0}{1}{0}", 
            Environment.NewLine, Convert.ToBase64String(encryptedData));

        var decryptedText = SymmetricEncryptor.DecryptToString(encryptedData);
        Console.WriteLine("decrypted text:{0}{1}{0}", 
            Environment.NewLine, decryptedText);
    }
}

which gives:

original text:
something you want to hide

encrypted data:
zXnS9f+LqO6myn2BxxniMUmfzzU82d74GA35CwpgNqw=

decrypted text:
something you want to hide

At first glance this doesn’t look bad, but from a security perspective there are multiple issues. So let’s get to fixing!

Hard coded password

With tools like Ildasm.exe or dotPeek, it’s very easy to decompile the binaries and see the code … and the password. Not to mention it also stays in your source control history. So we will need to pass it as a parameter. You will have to load it from some external source - config file, environment variable, Azure Key Vault, etc.

public static class SymmetricEncryptor
{
    public static byte[] EncryptString(string toEncrypt, string password)
    {
        // ...
    }

    public static string DecryptToString(byte[] encryptedData, string password)
    {
        // ...
    }

    // ...
}

Introduce a little randomness

An issue with the current solution is that it always returns the same result (same sequence of bytes) when given your password together with the same data to encrypt. Because of that it is easier for the attacker to guess your password. There is a parameter to initialize the algorithm, intuitively named Initialization Vector (IV), which solves this problem. The IV must be of the same size as is the block size.

public static class SymmetricEncryptor
{
    private const int AesBlockByteSize = 128 / 8;
    private static readonly RandomNumberGenerator Random = RandomNumberGenerator.Create();

    public static byte[] EncryptString(string toEncrypt, string password)
    {
        var key = GetKey(password);

        using (var aes = Aes.Create())
        {
            var iv = GenerateRandomBytes(AesBlockByteSize);

            using (var encryptor = aes.CreateEncryptor(key, iv))
            {
                var plainText = Encoding.UTF8.GetBytes(toEncrypt);
                var cipherText = encryptor
                    .TransformFinalBlock(plainText, 0, plainText.Length);

                var result = new byte[iv.Length + cipherText.Length];
                iv.CopyTo(result, 0);
                cipherText.CopyTo(result, iv.Length);

                return result;
            }
        }
    }

    public static string DecryptToString(byte[] encryptedData, string password)
    {
        var key = GetKey(password);

        using (var aes = Aes.Create())
        {
            var iv = encryptedData.Take(AesBlockByteSize).ToArray();
            var cipherText = encryptedData.Skip(AesBlockByteSize).ToArray();

            using (var encryptor = aes.CreateDecryptor(key, iv))
            {
                var decryptedBytes = encryptor
                    .TransformFinalBlock(cipherText, 0, cipherText.Length);
                return Encoding.UTF8.GetString(decryptedBytes);
            }
        }
    }

    private static byte[] GetKey(string password)
    {
        var keyBytes = Encoding.UTF8.GetBytes(password);
        using (var md5 = MD5.Create())
        {
            return md5.ComputeHash(keyBytes);
        }
    }

    private static byte[] GenerateRandomBytes(int numberOfBytes)
    {
        var randomBytes = new byte[numberOfBytes];
        Random.GetBytes(randomBytes);
        return randomBytes;
    }
}

When you run the program now, it still works, but the produced sequence of bytes representing your encrypted data is different every time you run the program. This change introduces a new problem though - how to handle the new randomly generated IV, more specifically where to store it? You have probably noticed I’m simply prepending the IV to the encrypted data. That might seem strange or even scary, but an IV is not considered as something secret so it is not a problem from security perspective.

Better key handling

To make sure our password is usable as a key for AES we are currently simply hashing it with MD5. That stretches shorter or longer password into a key of exactly the size we need it to be. Even though the entropy of our chosen password will not increase we can still strengthen our resulting key against bruteforce and dictionary attacks. Algorithms for this purpose belong to a category named password-based key derivation functions. Examples of these are PBKDF2, scrypt, Argon2 and others. While Argon2 is the new modern alternative to PBKDF2, its use might be prohibitive in some applications for performance reasons and its availability throughout platforms is also not as universal.

As of 2019 PBKDF2 can still offer good resistance against attacks if configured well. When you open the specification of PBKDF2 described in RFC 2898, you can see the options and arguments for the function:

So lets see how you can use it:

public static class SymmetricEncryptor
{
    private const int AesBlockByteSize = 128 / 8;

    private const int PasswordSaltByteSize = 128 / 8;
    private const int PasswordByteSize = 256 / 8;
    private const int PasswordIterationCount = 100_000;

    private static readonly Encoding StringEncoding = Encoding.UTF8;
    private static readonly RandomNumberGenerator Random = RandomNumberGenerator.Create();

    public static byte[] EncryptString(string toEncrypt, string password)
    {
        using (var aes = Aes.Create())
        {
            var keySalt = GenerateRandomBytes(PasswordSaltByteSize);
            var key = GetKey(password, keySalt);
            var iv = GenerateRandomBytes(AesBlockByteSize);

            using (var encryptor = aes.CreateEncryptor(key, iv))
            {
                var plainText = StringEncoding.GetBytes(toEncrypt);
                var cipherText = encryptor
                    .TransformFinalBlock(plainText, 0, plainText.Length);

                var result = MergeArrays(keySalt, iv, cipherText);
                return result;
            }
        }
    }

    public static string DecryptToString(byte[] encryptedData, string password)
    {
        using (var aes = Aes.Create())
        {
            var keySalt = encryptedData.Take(PasswordSaltByteSize).ToArray();
            var key = GetKey(password, keySalt);
            var iv = encryptedData
                .Skip(PasswordSaltByteSize).Take(AesBlockByteSize).ToArray();
            var cipherText = encryptedData
                .Skip(PasswordSaltByteSize + AesBlockByteSize).ToArray();

            using (var encryptor = aes.CreateDecryptor(key, iv))
            {
                var decryptedBytes = encryptor
                    .TransformFinalBlock(cipherText, 0, cipherText.Length);
                return StringEncoding.GetString(decryptedBytes);
            }
        }
    }

    private static byte[] GetKey(string password, byte[] passwordSalt)
    {
        var keyBytes = StringEncoding.GetBytes(password);

        using (var derivator = new Rfc2898DeriveBytes(
            keyBytes, passwordSalt, 
            PasswordIterationCount, HashAlgorithmName.SHA256))
        {
            return derivator.GetBytes(PasswordByteSize);
        }
    }

    private static byte[] GenerateRandomBytes(int numberOfBytes)
    {
        var randomBytes = new byte[numberOfBytes];
        Random.GetBytes(randomBytes);
        return randomBytes;
    }

    private static byte[] MergeArrays(params byte[][] arrays)
    {
        var merged = new byte[arrays.Sum(a => a.Length)];
        var mergeIndex = 0;
        for (int i = 0; i < arrays.GetLength(0); i++)
        {
            arrays[i].CopyTo(merged, mergeIndex);
            mergeIndex += arrays[i].Length;
        }

        return merged;
    }
}

As you can see I’m now using the Rfc2898DeriveBytes class in the GetKey method. The keyBytes and passwordSalt are self explanatory, but the next two parameters deserve some comments.

PasswordIterationCount says how many times you want run the pseudorandom function; higher number -> longer computation time -> harder to guess. If your application is running on a relatively modern CPU I would suggest to start with 100k and decrease it only if the computation time is unacceptable. 1k is the default but that is too little for todays hardware, especially when taking into consideration tools like hashcat running on machines with multiple GPUs.

I chose SHA256 as a pseudorandom function (PRF) because I’m deriving 256 bits for the key. I could use the default SHA-1 but that one produces only 160 bits, which isn’t technically a problem for PBKDF2, but it is less optimal. When you ask PBKDF2 for more bits than the PRF can produce, PBKDF2 will do another round of the specified number of iterations and another round and another … until the sum will be greater or equal to your required number of bits. Since the rounds can be run in parallel, security improvement is not that much higher when deriving longer keys with PRFs with smaller outputs hence it’s better to choose PRF with longer output.

Signing

One additional security feature we need to add is signing. Right now we have no way to verify whether some data is missing or some extra data has been injected. With signing we will make sure our message hasn’t been altered in any way.

To implement this we will use Message authentication code (MAC), more specifically HMAC-SHA256 which belongs to a MAC subcategory called hash-based message authentication code.

In a simplified way the idea behind HMAC is to take your encrypted message, hash it and then hash it again with an authentication key. It’s important to note that this key must be different from the one used for encrypting your message.

using System;
using System.Linq;
using System.Security.Cryptography;
using System.Text;

public static class SymmetricEncryptor
{
    private const int AesBlockByteSize = 128 / 8;

    private const int PasswordSaltByteSize = 128 / 8;
    private const int PasswordByteSize = 256 / 8;
    private const int PasswordIterationCount = 100_000;

    private const int SignatureByteSize = 256 / 8;

    private const int MinimumEncryptedMessageByteSize =
        PasswordSaltByteSize + // auth salt
        PasswordSaltByteSize + // key salt
        AesBlockByteSize + // IV
        AesBlockByteSize + // cipher text min length
        SignatureByteSize; // signature tag

    private static readonly Encoding StringEncoding = Encoding.UTF8;
    private static readonly RandomNumberGenerator Random = RandomNumberGenerator.Create();

    public static byte[] EncryptString(string toEncrypt, string password)
    {
        using (var aes = Aes.Create())
        {
            // encrypt
            var keySalt = GenerateRandomBytes(PasswordSaltByteSize);
            var key = GetKey(password, keySalt);
            var iv = GenerateRandomBytes(AesBlockByteSize);

            byte[] cipherText;
            using (var encryptor = aes.CreateEncryptor(key, iv))
            {
                var plainText = StringEncoding.GetBytes(toEncrypt);
                cipherText = encryptor
                    .TransformFinalBlock(plainText, 0, plainText.Length);
            }

            // sign
            var authKeySalt = GenerateRandomBytes(PasswordSaltByteSize);
            var authKey = GetKey(password, authKeySalt);
            var result = MergeArrays(
                additionalCapacity: SignatureByteSize, 
                authKeySalt, keySalt, iv, cipherText);

            using (var hmac = new HMACSHA256(authKey))
            {
                var payloadToSignLength = result.Length - SignatureByteSize;
                var signatureTag = hmac.ComputeHash(result, 0, payloadToSignLength);
                signatureTag.CopyTo(result, payloadToSignLength);
            }

            return result;
        }
    }

    public static string DecryptToString(byte[] encryptedData, string password)
    {
        if (encryptedData is null 
            || encryptedData.Length < MinimumEncryptedMessageByteSize)
        {
            throw new ArgumentException("Invalid length of encrypted data");
        }

        var authKeySalt = encryptedData
            .AsSpan(0, PasswordSaltByteSize).ToArray();
        var keySalt = encryptedData
            .AsSpan(PasswordSaltByteSize, PasswordSaltByteSize).ToArray();
        var iv = encryptedData
            .AsSpan(2 * PasswordSaltByteSize, AesBlockByteSize).ToArray();
        var signatureTag = encryptedData
            .AsSpan(encryptedData.Length - SignatureByteSize, SignatureByteSize).ToArray();

        var cipherTextIndex = authKeySalt.Length + keySalt.Length + iv.Length;
        var cipherTextLength = 
            encryptedData.Length - cipherTextIndex - signatureTag.Length;

        var authKey = GetKey(password, authKeySalt);
        var key = GetKey(password, keySalt);

        // verify signature
        using (var hmac = new HMACSHA256(authKey))
        {
            var payloadToSignLength = encryptedData.Length - SignatureByteSize;
            var signatureTagExpected = hmac
                .ComputeHash(encryptedData, 0, payloadToSignLength);

            // constant time checking to prevent timing attacks
            var signatureVerificationResult = 0;
            for (int i = 0; i < signatureTag.Length; i++)
            {
                signatureVerificationResult |= signatureTag[i] ^ signatureTagExpected[i];
            }

            if (signatureVerificationResult != 0)
            {
                throw new CryptographicException("Invalid signature");
            }
        }

        // decrypt
        using (var aes = Aes.Create())
        {
            using (var encryptor = aes.CreateDecryptor(key, iv))
            {
                var decryptedBytes = encryptor
                    .TransformFinalBlock(encryptedData, cipherTextIndex, cipherTextLength);
                return StringEncoding.GetString(decryptedBytes);
            }
        }
    }

    private static byte[] GetKey(string password, byte[] passwordSalt)
    {
        var keyBytes = StringEncoding.GetBytes(password);

        using (var derivator = new Rfc2898DeriveBytes(
            keyBytes, passwordSalt, 
            PasswordIterationCount, HashAlgorithmName.SHA256))
        {
            return derivator.GetBytes(PasswordByteSize);
        }
    }

    private static byte[] GenerateRandomBytes(int numberOfBytes)
    {
        var randomBytes = new byte[numberOfBytes];
        Random.GetBytes(randomBytes);
        return randomBytes;
    }

    private static byte[] MergeArrays(int additionalCapacity = 0, params byte[][] arrays)
    {
        var merged = new byte[arrays.Sum(a => a.Length) + additionalCapacity];
        var mergeIndex = 0;
        for (int i = 0; i < arrays.GetLength(0); i++)
        {
            arrays[i].CopyTo(merged, mergeIndex);
            mergeIndex += arrays[i].Length;
        }

        return merged;
    }
}

There is not much to say about the implementation - we derive a new key, create a new HMACSHA256 instance, compute the hash and then finally add it to the result.

One little thing maybe worth mentioning is the decryption part which handles verification of the signature. There we are comparing all the bytes of the signature without any branching. So no matter what data we get, the verification should take a constant amount of time thus mitigating any attempt to timing attack.

Additional options

We are nearly at the end but I want to briefly touch few additional things. Aes.Create() is the recommended way to get an instance of the best available implementation of the Aes abstract class and that also gives you good defaults, but I still prefer to be explicit, so I have added small helper function which I can then reuse.

private static Aes CreateAes()
{
    var aes = Aes.Create();
    aes.Mode = CipherMode.CBC;
    aes.Padding = PaddingMode.PKCS7;
    return aes;
}

Important! Always make sure you are using CBC mode over ECB, since ECB has serious security issues.

Alternatives

An alternative to this approach of encrypting and then signing is to use something like AES-GCM which belongs to a category of authenticated encryption algorithms and has that functionality included. Unfortunately implementation in CoreFX is not yet there but it’s coming with .NET Core 3 (you can take a look there).

If you are looking for something ready right now check out the well known Bouncy Castle library which also provides many more cryptographic algorithms.

Complete code

As promised at the start of the post here is the final version of the code. To make it easier to see the changes, in other words make shorter, I have intentionally omitted error handling and argument verification and kept only the most important checks. To make it more robust and production ready, some additional validations will have to be added.

using System;
using System.Linq;
using System.Security.Cryptography;
using System.Text;

public static class SymmetricEncryptor
{
    private const int AesBlockByteSize = 128 / 8;

    private const int PasswordSaltByteSize = 128 / 8;
    private const int PasswordByteSize = 256 / 8;
    private const int PasswordIterationCount = 100_000;

    private const int SignatureByteSize = 256 / 8;

    private const int MinimumEncryptedMessageByteSize =
        PasswordSaltByteSize + // auth salt
        PasswordSaltByteSize + // key salt
        AesBlockByteSize + // IV
        AesBlockByteSize + // cipher text min length
        SignatureByteSize; // signature tag

    private static readonly Encoding StringEncoding = Encoding.UTF8;
    private static readonly RandomNumberGenerator Random = RandomNumberGenerator.Create();

    public static byte[] EncryptString(string toEncrypt, string password)
    {
        // encrypt
        var keySalt = GenerateRandomBytes(PasswordSaltByteSize);
        var key = GetKey(password, keySalt);
        var iv = GenerateRandomBytes(AesBlockByteSize);

        byte[] cipherText;
        using (var aes = CreateAes())
        using (var encryptor = aes.CreateEncryptor(key, iv))
        {
            var plainText = StringEncoding.GetBytes(toEncrypt);
            cipherText = encryptor
                .TransformFinalBlock(plainText, 0, plainText.Length);
        }

        // sign
        var authKeySalt = GenerateRandomBytes(PasswordSaltByteSize);
        var authKey = GetKey(password, authKeySalt);

        var result = MergeArrays(
            additionalCapacity: SignatureByteSize,
            authKeySalt, keySalt, iv, cipherText);

        using (var hmac = new HMACSHA256(authKey))
        {
            var payloadToSignLength = result.Length - SignatureByteSize;
            var signatureTag = hmac.ComputeHash(result, 0, payloadToSignLength);
            signatureTag.CopyTo(result, payloadToSignLength);
        }

        return result;
    }

    public static string DecryptToString(byte[] encryptedData, string password)
    {
        if (encryptedData is null
            || encryptedData.Length < MinimumEncryptedMessageByteSize)
        {
            throw new ArgumentException("Invalid length of encrypted data");
        }

        var authKeySalt = encryptedData
            .AsSpan(0, PasswordSaltByteSize).ToArray();
        var keySalt = encryptedData
            .AsSpan(PasswordSaltByteSize, PasswordSaltByteSize).ToArray();
        var iv = encryptedData
            .AsSpan(2 * PasswordSaltByteSize, AesBlockByteSize).ToArray();
        var signatureTag = encryptedData
            .AsSpan(encryptedData.Length - SignatureByteSize, SignatureByteSize).ToArray();

        var cipherTextIndex = authKeySalt.Length + keySalt.Length + iv.Length;
        var cipherTextLength =
            encryptedData.Length - cipherTextIndex - signatureTag.Length;

        var authKey = GetKey(password, authKeySalt);
        var key = GetKey(password, keySalt);

        // verify signature
        using (var hmac = new HMACSHA256(authKey))
        {
            var payloadToSignLength = encryptedData.Length - SignatureByteSize;
            var signatureTagExpected = hmac
                .ComputeHash(encryptedData, 0, payloadToSignLength);

            // constant time checking to prevent timing attacks
            var signatureVerificationResult = 0;
            for (int i = 0; i < signatureTag.Length; i++)
            {
                signatureVerificationResult |= signatureTag[i] ^ signatureTagExpected[i];
            }

            if (signatureVerificationResult != 0)
            {
                throw new CryptographicException("Invalid signature");
            }
        }

        // decrypt
        using (var aes = CreateAes())
        {
            using (var encryptor = aes.CreateDecryptor(key, iv))
            {
                var decryptedBytes = encryptor
                    .TransformFinalBlock(encryptedData, cipherTextIndex, cipherTextLength);
                return StringEncoding.GetString(decryptedBytes);
            }
        }
    }

    private static Aes CreateAes()
    {
        var aes = Aes.Create();
        aes.Mode = CipherMode.CBC;
        aes.Padding = PaddingMode.PKCS7;
        return aes;
    }

    private static byte[] GetKey(string password, byte[] passwordSalt)
    {
        var keyBytes = StringEncoding.GetBytes(password);

        using (var derivator = new Rfc2898DeriveBytes(
            keyBytes, passwordSalt, 
            PasswordIterationCount, HashAlgorithmName.SHA256))
        {
            return derivator.GetBytes(PasswordByteSize);
        }
    }

    private static byte[] GenerateRandomBytes(int numberOfBytes)
    {
        var randomBytes = new byte[numberOfBytes];
        Random.GetBytes(randomBytes);
        return randomBytes;
    }

    private static byte[] MergeArrays(int additionalCapacity = 0, params byte[][] arrays)
    {
        var merged = new byte[arrays.Sum(a => a.Length) + additionalCapacity];
        var mergeIndex = 0;
        for (int i = 0; i < arrays.GetLength(0); i++)
        {
            arrays[i].CopyTo(merged, mergeIndex);
            mergeIndex += arrays[i].Length;
        }

        return merged;
    }
}

And remember …

Keep it secret, keep it safe

Keep it secret, keep it safe!

Comments

comments powered by Disqus