Part III: RSA in Python - Preventing Side Channel Attacks

Part III: RSA in Python - Preventing Side Channel Attacks
Photo by Mauro Sbicego / Unsplash

In case you haven't read former 2 articles on this topic, below are the links:

Part I: Implementing RSA in Python from Scratch
I’m sure many programmers, particularly web developers have heard of the RSA cryptography system. RSA is an asymmetric cryptography system, meaning that one key is used for encryption and the other for decryption. I’ve seen a lot of articles explaining the general principles of asymmetric cryptograp…
Part II: Implementing RSA in Python from Scratch
In the last article RSA key generation and integer encryption were explained and implemented. This is good for demonstrating how the algorithm works, but it is not really usable if you want to exchange encrypted messages with someone. To be usable, it needs big random prime generation and text encry…

Side-channel attacks

The first 2 articles focused mainly on the idea and implementation of RSA. As such they left out a relevant topic in modern encryption. Side-channel attacks take advantage of data we'd usually brush off as gibberish, such as hardware sounds, electromagnetic waves and timing. Timing attacks are most common because they are the easiest to perform and can be performed on long-distance computers such as online servers.

Timing attacks

To be able to perform a timing attack, the attacker must have the power to choose which message will be decrypted by the target device. This often is the case, especially in case of an online server, where the attacker doesn't even have to be near the computer to steal sensitive information from it.

Idea behind timing attacks

The attack was first concieved by Paul Kocher and is based on the following:

Many systems will speed up the modulo operation `a % b` by checking whether `a` is less than `b`. This also means that if `a` is bigger than `b`, the operation will take more time. How can this be exploited? Let's look at the following exponentiation with modulo implementation.

def modpow(b, e, n):
	# find length of e in bits
	tst = 1
	siz = 0
	while e >= tst:
		tst <<= 1
		siz += 1
	siz -= 1
	# calculate the result
	r = 1
	for i in range(siz, -1, -1):
		r = (r * r) % n
		if (e >> i) & 1:
        		r = (r * b) % n
	return r

After the operation at the position of the first non-zero bit (`d₀`), this value (stored in `r`) will be squared and, if the second bit of `d` (`d₁`) equals 1, it will be multiplied by `m`. This means that we can discover the value of `d₁` if we can craft 2 messages `a` and `b` such that on one a modulo operation needs to be calculated and on the other it can be avoided, we can find out whether this operation takes place or not.

These `a` and `b` need to be such that `a2 < N < a3` and `b3 < N`. Encrpytion of message `a` should, on average, take longer than encryption of message `b` if `d₁` equals 1.

Now that we know bit `d₁` we know the temporary value that will be squared and possibly multiplied by `m` depending on the value of the second bit `d₂`. In other words, we repeat the same process but this time we choose `a` and `b` such that `temp_val(a)2 < N < temp_val(a)2 * a` and `temp_val(b)2 * b < N`, `temp_val(x)` being the mentioned temporary value for a given `x`.

Implementation

Initial `a` and `b` are relatively easy to find with a binary search approximation method, but every next one is not so easy because we have `temp_val` which depends on `a` and `b` and computing it involves modulo (`temp_val(x) = xcurrently_known_bits % N`) so instead the program will take a 1 000 000 randomly chosen numbers between 1 and N and out of those pick the ones that satisfy conditions for `a` and `b`.

Then the program will calculate average times for `a` and `b` to decrypt and then check if `a > b + f` to know whether the current bit is 1. `f` can be determined empirically.

After 1024 repetitions all bits will be found and all that is left to do is to determine the number of leading zeros. This can be done by simply shifting the 1024-bit result to the right and checking it's validity by choosing a random message and exponentiating it with the product of encryption key and the result until the correct value for `d` has been found.

How to protect against timing attacks?

There is a commonly used method called RSA binding. Given a message to be decrypted `m`:

  • Pick a random `r`
  • Calculate the value `x = C * re`
  • Calculate `xd * r-1`

Using this approach, attacker doesn't know which value will be given to the fast exponentiation function, which is a prerequisite for this attack to occur.

Conclusion

It seems that even though the mathematical background is strong and steady, side-channels can make various implementations insecure and unreliable. When writing software that will deal with sensitive data, one should always consider the possibility of side-channels being exploited.