RSA Oracle - Getting the flag by using chosen ciphertext attack

1k views Asked by At

I am trying to solve a simple RSA CTF challenge, but I am facing problems that go beyond the theory behind the attack (or at least I guess so). Basically, I have an oracle at disposal that will first print the encrypted flag and then encrypt and decrypt whatever I want (except for the decryption of the flag). The idea for the attack is to encrypt 2*encrypted_flag and then decrypt the given cipher. By dividing the obtained number by 2 I should then get the flag. Am I missing something, below you can find both the oracle's code and my exploit.

The attack idea is a chosen ciphertext attack and I "stole" the equations from this video: https://www.youtube.com/watch?v=ZjYzrn8M3w4&ab_channel=BillBuchananOBE.

Oracle's code:

#!/usr/bin/env python3

import signal
from binascii import hexlify
from Crypto.PublicKey import RSA
from Crypto.Util.number import *
from random import randint
from secret import FLAG
import string

TIMEOUT = 300 # 5 minutes time-out

def menu():
    print()
    print('Choice:')
    print('  [0] Exit')
    print('  [1] Encrypt')
    print('  [2] Decrypt')
    print('')
    return input('> ')

def encrypt(m):
    return pow(m, rsa.e, rsa.n)

def decrypt(c):
    return pow(c, rsa.d, rsa.n)

rsa = RSA.generate(1024)
flag_encrypted = pow(bytes_to_long(FLAG.encode()), rsa.e, rsa.n)
used = [bytes_to_long(FLAG.encode())]

def handle():
  print("================================================================================")
  print("=                      RSA Encryption & Decryption oracle                      =")
  print("=                                Find the flag!                                =")
  print("================================================================================")
  print("")
  print("Encrypted flag:", flag_encrypted)

  while True:
    choice = menu()

    # Exit
    if choice == '0':
      print("Goodbye!")
      break

    # Encrypt
    elif choice == '1':
      m = int(input('\nPlaintext > ').strip())
      print('\nEncrypted: ' + str(encrypt(m)))

    # Decrypt
    elif choice == '2':
      c = int(input('\nCiphertext > ').strip())

      if c == flag_encrypted:
        print("Wait. That's illegal.")
      else:
        for no in used:
          if m % no == 0:
            print("Wait. That's illegal.")
            break
        else:
          print('\nDecrypted: ' + str(m))

    # Invalid
    else:
      print('bye!')
      break

if __name__ == "__main__":
    signal.alarm(TIMEOUT)
    handle()

My current approach:

from Crypto.Util.number import *
from math import gcd
import gmpy2
import sys
#sys.set_int_max_str_digits(0)
r = remote('oracle.challs.cyberchallenge.it', 9041)
r.recvuntil(b'Encrypted flag: ')
encrypted_flag = int(r.recvline().strip().decode())
e = 65537

# Let's first gather the ciphertext of the new num
"""
Here's another hint: suppose I encrypt 2. The oracle will give me back c2= pow(2, 65537, rsa.n). Now I can also compute 2**65537 as an integer. We know that 2**65537 - c2 is divisible by N. So we can try to factor 2**65537 - c2 using, say, the elliptic curve method (ECM). If we are incredibly lucky, 2**65537 - c2 = N * (bunch of relatively small primes), and after ECM finds all the small factors we'll be left with N. But, suppose, instead, that I also encrypt 3, so I get c3 = pow(3, 65537, rsa.n). And maybe even c5 = pow(5, 65537, rsa.n) How can I combine these to find rsa.n
"""

public_exponent = 65537
numbers = [2,3,4,5,6]
numbers_bytes = [b'\x02',b'\x03',b'\x04',b'\x05',b'\x06']
ciphers = []
diffs = []
for i in range(4):
    r.recvuntil(b'>')
    r.sendline(b'1')
    r.recvuntil(b'Plaintext > ')
    r.sendline(str(bytes_to_long(numbers_bytes[i])))
    r.recvuntil(b'Encrypted: ')
    cipher = int(r.recvline().strip().decode())
    ciphers.append(cipher)
    diffs.append(gmpy2.sub(pow(numbers[i], public_exponent),cipher))

print(diffs)
common_factor = None
for diff in diffs:
    if common_factor is None:
        common_factor = diff
    else:
        common_factor = gmpy2.gcd(common_factor, diff)
print(common_factor) 
#let's check whether the common factor is N
print(ciphers[0] == pow(bytes_to_long(b'\x02'), public_exponent, common_factor))
# We have found N if True
# To trick the decryption method just sum to the original ciphertext N once
print(common_factor)
encrypted_flag += int(common_factor)
r.recvuntil(b'>')
r.sendline(b'2')
r.recvuntil(b'Ciphertext > ')
r.sendline(str(encrypted_flag))
r.recvuntil('Decrypted: ')
flag = int(r.recvline().decode())
print(long_to_bytes(flag))
0

There are 0 answers