Implementing all the Algorithms - Day 4: ChaCha20

29.05.2022

It's been a busy week but I'm finally able to write the next post. This time I wanted to implement a pseudo-random number generator. I decided to go for ChaCha which is a cryptographically secure PRNG.

ChaCha

ChaCha is a pseudo-random number generator that can be used to create a streaming cipher. It is pretty straightforward to implement based on the rfc8439. The fun part about this one was my blatant use of unsafe, converting freely between u8 and u32 buffers using std::mem::transmute.

use super::PrnGenerator;

pub struct ChaChaGenerator<const ROUNDS: usize = 20> {
    random_bytes: [u8; 64],
    key: [u8; 32],
    nonce: [u8; 12],
    next_random_byte: u8,
    counter: u32,
}

impl<const ROUNDS: usize> ChaChaGenerator<ROUNDS> {
    /// Constructs the ChaChaGenerator by reading bytes from the random device
    pub fn from_system() -> std::io::Result<ChaChaGenerator> {
        let mut key = [0; 32];
        let mut nonce = [0; 12];
        super::get_system_random_bytes(&mut key)?;
        super::get_system_random_bytes(&mut nonce)?;

        Ok(ChaChaGenerator::from_key(key, nonce))
    }

    pub fn from_key(key: [u8; 32], nonce: [u8; 12]) -> ChaChaGenerator<ROUNDS> {
        ChaChaGenerator::<ROUNDS> {
            random_bytes: [0; 64],
            key,
            nonce,
            next_random_byte: 64,
            counter: 1,
        }
    }

    fn perform_rounds(&mut self) {
        let working_vec: &mut [u32; 16] = unsafe { std::mem::transmute(&mut self.random_bytes) };
        // use random bytes as working state
        *working_vec = init_state(&self.key, self.counter, &self.nonce);
        for _ in 0..(ROUNDS / 2) {
            chacha_round(working_vec);
        }
        // if we are on a big endian system we need to flip each of the u32
        #[cfg(target_endian = "big")]
        for i in 0..16 {
            working_vec[i] = u32::to_le(working_vec[i]);
        }

        let to_add = init_state(&self.key, self.counter, &self.nonce);
        for i in 0..16 {
            working_vec[i] = working_vec[i].wrapping_add(to_add[i]);
        }

        self.counter += 1;
    }
}

impl<const ROUNDS: usize> PrnGenerator for ChaChaGenerator<ROUNDS> {
    fn next_byte(&mut self) -> u8 {
        if self.next_random_byte >= 64 {
            self.perform_rounds();
            self.next_random_byte = 0;
        }
        let result = self.random_bytes[self.next_random_byte as usize];

        self.next_random_byte += 1;
        result
    }
}

fn init_state(key: &[u8; 32], counter: u32, nonce: &[u8; 12]) -> [u32; 16] {
    let mut state: [u32; 16] = [0; 16];

    state[0] = 0x61707865;
    state[1] = 0x3320646e;
    state[2] = 0x79622d32;
    state[3] = 0x6b206574;

    for i in 0..8 {
        state[i + 4] = u32::from_le(unsafe { *(&key[i * 4] as *const u8 as *const u32) });
    }
    state[12] = counter;
    for i in 0..3 {
        state[i + 13] = u32::from_le(unsafe { *(&nonce[i * 4] as *const u8 as *const u32) });
    }
    state
}

macro_rules! chacha_quarter_round {
    ($a: expr, $b:expr, $c:expr, $d:expr) => {
        $a = $a.wrapping_add($b);
        $d ^= $a;
        $d = $d.rotate_left(16);
        $c = $c.wrapping_add($d);
        $b ^= $c;
        $b = $b.rotate_left(12);
        $a = $a.wrapping_add($b);
        $d ^= $a;
        $d = $d.rotate_left(8);
        $c = $c.wrapping_add($d);
        $b ^= $c;
        $b = $b.rotate_left(7);
    };
}

fn chacha_round(state: &mut [u32; 16]) {
    chacha_quarter_round!(state[0], state[4], state[8], state[12]);
    chacha_quarter_round!(state[1], state[5], state[9], state[13]);
    chacha_quarter_round!(state[2], state[6], state[10], state[14]);
    chacha_quarter_round!(state[3], state[7], state[11], state[15]);
    chacha_quarter_round!(state[0], state[5], state[10], state[15]);
    chacha_quarter_round!(state[1], state[6], state[11], state[12]);
    chacha_quarter_round!(state[2], state[7], state[8], state[13]);
    chacha_quarter_round!(state[3], state[4], state[9], state[14]);
}

Another fun part was using a macro to implement the chacha_quarter_round. Originally I wanted to pass mutable references to the individual elements. Unfortunately, just as I suspected that didn't work because of the borrow checker. This is a bit unfortunate because it should be easily verifiable that the elements are distinct though I have no idea about the compiler internals in that regard or if it would be hard to implement.

So to turn the problem into a solution I converted the whole thing into a macro. It felt a lot like writing code in a dynamic programming language and on top of that, I probably made the code more efficient by effectively inlining everything.

Performance Optimizations

After doing this, while writing this article, I had the idea to potentially improve the performance further. Glancing over the assembly instructions it seemed like there were a lot of moving values from the registers to memory and back. I changed the function to use direct variables instead:

fn chacha_round(state: &mut [u32; 16]) {
    let mut state_0 = state[0];
    let mut state_1 = state[1];
    let mut state_2 = state[2];
    let mut state_3 = state[3];
    let mut state_4 = state[4];
    let mut state_5 = state[5];
    let mut state_6 = state[6];
    let mut state_7 = state[7];
    let mut state_8 = state[8];
    let mut state_9 = state[9];
    let mut state_10 = state[10];
    let mut state_11 = state[11];
    let mut state_12 = state[12];
    let mut state_13 = state[13];
    let mut state_14 = state[14];
    let mut state_15 = state[15];

    chacha_quarter_round!(state_0, state_4, state_8, state_12);
    chacha_quarter_round!(state_1, state_5, state_9, state_13);
    chacha_quarter_round!(state_2, state_6, state_10, state_14);
    chacha_quarter_round!(state_3, state_7, state_11, state_15);
    chacha_quarter_round!(state_0, state_5, state_10, state_15);
    chacha_quarter_round!(state_1, state_6, state_11, state_12);
    chacha_quarter_round!(state_2, state_7, state_8, state_13);
    chacha_quarter_round!(state_3, state_4, state_9, state_14);

    state[0] = state_0;
    state[1] = state_1;
    state[2] = state_2;
    state[3] = state_3;
    state[4] = state_4;
    state[5] = state_5;
    state[6] = state_6;
    state[7] = state_7;
    state[8] = state_8;
    state[9] = state_9;
    state[10] = state_10;
    state[11] = state_11;
    state[12] = state_12;
    state[13] = state_13;
    state[14] = state_14;
    state[15] = state_15;
}

and got some performance improvements:

ChaCha20                time:   [407.29 ns 407.43 ns 407.57 ns]                     
                        change: [-36.943% -36.861% -36.770%] (p = 0.00 < 0.05)
                        Performance has improved.

Pretty decent ones actually. Encouraged by this success I moved the code over to the perform rounds function to avoid initializing the variables multiple times and got another 10 % in exchange for a very long and unreadable function.

ChaCha20                time:   [366.49 ns 366.76 ns 367.04 ns]                     
                        change: [-10.022% -9.9279% -9.8121%] (p = 0.00 < 0.05)
                        Performance has improved.
//...
pub fn perform_rounds(&mut self) {
    let working_vec: &mut [u32; 16] = unsafe { std::mem::transmute(&mut self.random_bytes) };
    // use random bytes as working state
    *working_vec = init_state(&self.key, self.counter, &self.nonce);

    let mut state_0 = working_vec[0];
    let mut state_1 = working_vec[1];
    let mut state_2 = working_vec[2];
    let mut state_3 = working_vec[3];
    let mut state_4 = working_vec[4];
    let mut state_5 = working_vec[5];
    let mut state_6 = working_vec[6];
    let mut state_7 = working_vec[7];
    let mut state_8 = working_vec[8];
    let mut state_9 = working_vec[9];
    let mut state_10 = working_vec[10];
    let mut state_11 = working_vec[11];
    let mut state_12 = working_vec[12];
    let mut state_13 = working_vec[13];
    let mut state_14 = working_vec[14];
    let mut state_15 = working_vec[15];
    for _ in 0..(ROUNDS / 2) {
        chacha_quarter_round!(state_0, state_4, state_8, state_12);
        chacha_quarter_round!(state_1, state_5, state_9, state_13);
        chacha_quarter_round!(state_2, state_6, state_10, state_14);
        chacha_quarter_round!(state_3, state_7, state_11, state_15);
        chacha_quarter_round!(state_0, state_5, state_10, state_15);
        chacha_quarter_round!(state_1, state_6, state_11, state_12);
        chacha_quarter_round!(state_2, state_7, state_8, state_13);
        chacha_quarter_round!(state_3, state_4, state_9, state_14);
    }
    working_vec[0] = u32::to_le(state_0);
    working_vec[1] = u32::to_le(state_1);
    working_vec[2] = u32::to_le(state_2);
    working_vec[3] = u32::to_le(state_3);
    working_vec[4] = u32::to_le(state_4);
    working_vec[5] = u32::to_le(state_5);
    working_vec[6] = u32::to_le(state_6);
    working_vec[7] = u32::to_le(state_7);
    working_vec[8] = u32::to_le(state_8);
    working_vec[9] = u32::to_le(state_9);
    working_vec[10] = u32::to_le(state_10);
    working_vec[11] = u32::to_le(state_11);
    working_vec[12] = u32::to_le(state_12);
    working_vec[13] = u32::to_le(state_13);
    working_vec[14] = u32::to_le(state_14);
    working_vec[15] = u32::to_le(state_15);

    let to_add = init_state(&self.key, self.counter, &self.nonce);
    for i in 0..16 {
        working_vec[i] = working_vec[i].wrapping_add(to_add[i]);
    }

    self.counter += 1;
}
//...

I could probably find a way to make this a bit prettier and not degrade performance too much, but I think I'm going to stop for today.

Also, looking at other implementations of the chacha algorithm, it seems that to make this code really fast, I'd have to use vector instructions, which in hindsight is exactly what the algorithm is made for.