PlaidCTF Compression

PlaidCTF 2013 had a level called "Compression". Here's the provided code for this level:

import os
import struct
import SocketServer
import zlib
from Crypto.Cipher import AES
from Crypto.Util import Counter
# Not the real keys!
ENCRYPT_KEY = '0000000000000000000000000000000000000000000000000000000000000000'.decode('hex')
# Determine this key.
# Character set: lowercase letters and underscore
def encrypt(data, ctr):
    aes =, AES.MODE_CTR, counter=ctr)
    return aes.encrypt(zlib.compress(data))
class ProblemHandler(SocketServer.StreamRequestHandler):
    def handle(self):
        nonce = os.urandom(8)
        ctr =, prefix=nonce)
        while True:
            data =
            if not data:
                length = struct.unpack('I', data)[0]
                if length > (1<<20):
                data =
                data += PROBLEM_KEY
                ciphertext = encrypt(data, ctr)
                self.wfile.write(struct.pack('I', len(ciphertext)))
class ReusableTCPServer(SocketServer.ForkingMixIn, SocketServer.TCPServer):
    allow_reuse_address = True
if __name__ == '__main__':
    HOST = ''
    PORT = 4433
    SocketServer.TCPServer.allow_reuse_address = True
    server = ReusableTCPServer((HOST, PORT), ProblemHandler)

So there's a few interesting things of note here:

My first thoughts were to attack CTR mode: reuse of the IV in a CTR mode cipher is quite fatal. But given that the nonce is of reasonably high entropy (64-bit) and our input length limit prevents wrapping the counter (if we could even send that much data, which we can't) it becomes quite obvious we won't be getting a reused IV anytime soon.

So I start thinking about the fact that AES (or any block cipher) in CTR mode is really a stream cipher -- they encrypt the counter with the key, to produce a keystream, then XOR with the plaintext to get the ciphertext. In particular, pycrypto guarantees that len(input) == len(output). Given the name of the level (Compression) I start thinking about approaches to get information out of the ciphertext length.

At this point, it's worth revisiting the design of the DEFLATE algorithm (used by the zlib.compress call in the program). DEFLATE uses a combination of Huffman coding and LZ77/LZ78-style duplicate string elimination. In this context, I believe the duplicate string elimination plays the bigger role -- this part takes repeated sections of the input and, for the 2nd and later instance, includes a pointer back to first instance that is shorter than the original string. For our purposes, that means if we provide input that contains substrings of the unknown key, we will get a shorter response than if our string is completely different than the flag. To test my theory, I fired off 27 tries to the server, each containing one of the valid ([a-z_]) characters repeated 3 times: all responses, save one, were the same length (30 bytes). Only the repeated 'c' string came back at 29 bytes. This led me to believe that 'c' probably started the flag. (If it only needed to be in the flag, more than one character would likely have returned a different length.)

I put together a script to go through character by character and look for lengths that were shorter from the rest. During my first couple of runs, it would frequently get stuck until I hit upon the idea of putting the test string multiple times, increasing the likelihood that duplicate string elimination would use the entire thing. Eventually, I had a few candidate flags, but from glancing at them, it was clear what the answer was...

import struct
import socket
import sys
import collections
HOST = 'ip.add.res.s'
PORT = 4433
def try_val(val):
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  sock.connect((HOST, PORT))
  nonce = sock.recv(8)
  sock.send(struct.pack('I', len(val)))
  data = sock.recv(4)
  recv_len = struct.unpack('I', data)[0]
  data = sock.recv(recv_len)
  return (nonce, data)
def get_candidate_len(can):
  nonce, data = try_val(can)
  return len(data)
def try_layer(prefix):
  if len(prefix) == 20:
    print "Found candidate %s!" % prefix
  candidates = 'abcdefghijklmnopqrstuvwxyz_'
  print "Trying %s" % prefix
  samples = {}
  for c in candidates:
    val = prefix + c
    samples[val] = get_candidate_len(val*2 if len(val)>9 else val*5)
  m = mode(samples.values())
  for k in samples:
    if samples[k] == m:
def mode(l):
  c = collections.Counter(l)
  return c.most_common(1)[0][0]