I want to learn Python, so I ported a C# Enigma implementation of mine. It got UnitTests and is running.
I'm looking for a review, telling me, where I don't know best practices, where I break naming conventions, where I look like a C# programmer writing Python:-)
Thanks & have fun Harry
https://github.com/HaraldLeitner/Enigma BusinessLogic.py
from Enums import Mode
class BusinessLogic:
def __init__(self, rolls):
self._rolls = rolls
self._rolls_reverse = rolls.copy()
self._rolls_reverse.reverse()
def transform_file(self, infile, outfile, mode):
buffer_size = 65536
in_file = open(infile, 'rb', buffer_size)
out_file = open(outfile, 'wb', buffer_size)
buffer = bytearray(in_file.read(buffer_size))
while len(buffer):
self.transform_buffer(buffer, mode)
out_file.write(buffer)
buffer = bytearray(in_file.read(buffer_size))
in_file.close()
out_file.close()
def transform_buffer(self, buffer, mode):
if mode == Mode.ENC:
for i in range(len(buffer)):
for roll in self._rolls:
roll.encrypt(buffer, i)
self.roll_on()
if mode == Mode.DEC:
for i in range(len(buffer)):
for roll in self._rolls_reverse:
roll.decrypt(buffer, i)
self.roll_on()
def roll_on(self):
for roll in self._rolls:
if not roll.roll_on():
break
Enums.py
from enum import Enum
class Mode(Enum):
ENC = 0
DEC = 1
Roll.py
class Roll:
def __init__(self, transitions, turn_over_indices):
self._transitions = transitions
self._turn_over_indices = turn_over_indices
self._re_transitions = bytearray(256)
for x in self._transitions:
self._re_transitions[self._transitions[x]] = x
self._position = 0
def check_input(self, turnover_indices_count):
if len(self._transitions) != 256:
raise ValueError("Wrong Transition length ")
for i in range(256):
found = 0
for j in self._transitions:
if self._transitions[j] == i:
found = 1
continue
if not found:
raise ValueError("Transitions not 1-1 complete")
if len(self._turn_over_indices) != turnover_indices_count:
raise ValueError("Wrong TurnOverIndices length ")
for i in range(len(self._turn_over_indices) - 1):
if self._turn_over_indices[i] == self._turn_over_indices[i + 1]:
raise ValueError("turn_over_indices has doubles")
def encrypt(self, buffer, index):
buffer[index] = self._transitions[(buffer[index] + self._position) & 0xff]
def decrypt(self, buffer, index):
buffer[index] = (self._re_transitions[int(buffer[index])] - self._position) & 0xff
def roll_on(self):
self._position = (self._position + 1) & 0xff
return self._turn_over_indices.count(self._position)
Enigma.ini
[DEFAULT]
TransitionCount = 53
Main.py
from configparser import ConfigParser
import os
import sys
from random import random, randbytes, randint
from time import sleep
from BusinessLogic import BusinessLogic
from Enums import Mode
from Roll import Roll
class Program:
def __init__(self, transition_count=0):
self._mode = None
self._rolls = []
self._inputFilename = None
self._keyFilename = None
self._transitionCount = None
config = ConfigParser()
config.read("enigma.ini")
if transition_count < 1:
self._transitionCount = config.getint("DEFAULT", "TransitionCount")
else:
self._transitionCount = transition_count
def main(self):
if len(sys.argv) != 4:
print("Generate key with 'keygen x key.file' where x > 3 is the number of rolls.")
print("Encrypt a file with 'enc a.txt key.file'")
print("Decrypt a file with 'dec a.txt key.file'")
exit(1)
self.run_main(sys.argv[1], sys.argv[2], sys.argv[3])
def run_main(self, arg1, arg2, arg3):
self._keyFilename = arg3
if arg1 == "keygen":
self.keygen(int(arg2))
return
self._inputFilename = arg2
if arg1.lower() == 'enc':
self._mode = Mode.ENC
elif arg1 == 'dec':
self._mode = Mode.DEC
else:
raise Exception("Undefined Encryption Mode.")
self.create_rolls()
BusinessLogic(self._rolls).transform_file(self._inputFilename, self._inputFilename + '.' + self._mode.name,
self._mode)
def keygen(self, roll_count):
if roll_count < 4:
raise Exception("Not enough rolls.")
if os.path.exists(self._keyFilename):
os.remove(self._keyFilename)
key = bytearray()
for i in range(roll_count):
transform = bytearray(256)
for j in range(256):
transform[j] = j
while not self.is_twisted(transform):
for j in range(256):
rand1 = randint(0, 255)
rand2 = randint(0, 255)
temp = transform[rand1]
transform[rand1] = transform[rand2]
transform[rand2] = temp
key += transform
transitions = bytearray()
while len(transitions) < self._transitionCount:
rand = randint(0, 255)
if not transitions.count(rand):
transitions.append(rand)
key += transitions
file = open(self._keyFilename, 'wb')
file.write(key)
file.close()
print("Keys generated.")
sleep(1)
def is_twisted(self, trans):
for i in range(256):
if trans[i] == i:
return 0
return 1
def create_rolls(self):
roll_key_length = 256 + self._transitionCount
file = open(self._keyFilename, 'rb')
key = file.read()
file.close()
if len(key) % roll_key_length:
raise Exception('Invalid key_size')
roll_count = int(len(key) / roll_key_length)
for rollNumber in range(roll_count):
self._rolls.append(Roll(key[rollNumber * roll_key_length: rollNumber * roll_key_length + 256],
key[
rollNumber * roll_key_length + 256: rollNumber * roll_key_length + 256 + self._transitionCount]))
for roll in self._rolls:
roll.check_input(self._transitionCount)
if __name__ == '__main__':
Program().main()
UnitTest.py
import os
import unittest
from BusinessLogic import BusinessLogic
from Enums import Mode
from Roll import Roll
from main import Program
class MyTestCase(unittest.TestCase):
def setUp(self):
self.trans_linear = bytearray(256) # here every char is mapped to itself
self.trans_linear_invert = bytearray(256) # match the first to the last etc
self.trans_shift_1 = bytearray(256) # 'a' is mapped to 'b' etc
self.trans_shift_2 = bytearray(256) # 'a' is mapped to 'c' etc
self.businesslogic_encode = None
self.businesslogic_decode = None
self.encrypted_message = bytearray()
self.decrypted_message = bytearray()
self.init_test_rolls()
def init_test_rolls(self):
for i in range(256):
self.trans_linear[i] = i
self.trans_linear_invert[i] = 255 - i
self.trans_shift_1[i] = (i + 1) % 256
self.trans_shift_2[i] = (i + 2) % 256
def init_business_logic(self, transitions, turnovers):
rolls_encrypt = []
rolls_decrypt = []
for index in range(len(transitions)):
rolls_encrypt.append(Roll(transitions[index], turnovers[index]))
rolls_decrypt.append(Roll(transitions[index], turnovers[index]))
self.businesslogic_encode = BusinessLogic(rolls_encrypt)
self.businesslogic_decode = BusinessLogic(rolls_decrypt)
def crypt(self, msg):
self.encrypted_message = bytearray(msg)
self.businesslogic_encode.transform_buffer(self.encrypted_message, Mode.ENC)
self.decrypted_message = bytearray(self.encrypted_message)
self.businesslogic_decode.transform_buffer(self.decrypted_message, Mode.DEC)
def test_one_byte_one_roll_linear(self):
for i in range(256):
self.init_business_logic([self.trans_linear], [[0]])
self.crypt([i])
self.assertEqual(i, self.encrypted_message[0])
self.assertEqual(i, self.decrypted_message[0])
def test_one_byte_one_roll_shift_one(self):
for i in range(256):
self.init_business_logic([self.trans_shift_1], [[0]])
self.crypt([i])
self.assertEqual((i + 1) % 256, self.encrypted_message[0])
self.assertEqual(i, self.decrypted_message[0])
def test_one_byte_one_roll_shift_two(self):
for i in range(256):
self.init_business_logic([self.trans_shift_2], [[0]])
self.crypt([i])
self.assertEqual((i + 2) % 256, self.encrypted_message[0])
self.assertEqual(i, self.decrypted_message[0])
def test_two_byte_one_roll_linear(self):
for i in range(256):
self.init_business_logic([self.trans_linear], [[0]])
self.crypt([i, (i + 1) % 256])
self.assertEqual(i, self.encrypted_message[0])
self.assertEqual((i + 2) % 256, self.encrypted_message[1])
self.assertEqual(i, self.decrypted_message[0])
self.assertEqual((i + 1) % 256, self.decrypted_message[1])
def test_two_byte_one_roll_shift1(self):
for i in range(256):
self.init_business_logic([self.trans_shift_1], [[0]])
self.crypt([i, (i + 1) % 256])
self.assertEqual((i + 1) % 256, self.encrypted_message[0])
self.assertEqual((i + 3) % 256, self.encrypted_message[1])
self.assertEqual(i, self.decrypted_message[0])
self.assertEqual((i + 1) % 256, self.decrypted_message[1])
def test_two_byte_one_roll_invert(self):
for i in range(256):
self.init_business_logic([self.trans_linear_invert], [[0]])
self.crypt([i, i])
self.assertEqual(255 - i, self.encrypted_message[0])
self.assertEqual((256 + 255 - i - 1) & 0xff, self.encrypted_message[1])
self.assertEqual(i, self.decrypted_message[0])
self.assertEqual(i, self.decrypted_message[1])
def test_two_byte_two_roll_linear(self):
for i in range(256):
self.init_business_logic([self.trans_linear, self.trans_linear], [[0], [0]])
self.crypt([i, (i + 1) & 0xff])
self.assertEqual(i, self.encrypted_message[0])
self.assertEqual((i + 2) & 0xff, self.encrypted_message[1])
self.assertEqual(i, self.decrypted_message[0])
self.assertEqual((i + 1) & 0xff, self.decrypted_message[1])
def test_two_byte_two_roll_shift1(self):
for i in range(256):
self.init_business_logic([self.trans_shift_1, self.trans_shift_1], [[0], [0]])
self.crypt([i, (i + 1) & 0xff])
self.assertEqual((i + 2) & 0xff, self.encrypted_message[0])
self.assertEqual((i + 4) & 0xff, self.encrypted_message[1])
self.assertEqual(i, self.decrypted_message[0])
self.assertEqual((i + 1) & 0xff, self.decrypted_message[1])
def test_two_byte_two_roll_shift2(self):
self.init_business_logic([self.trans_shift_2, self.trans_shift_2], [[0], [0]])
self.crypt([7, 107])
self.assertEqual(11, self.encrypted_message[0])
self.assertEqual(112, self.encrypted_message[1])
self.assertEqual(7, self.decrypted_message[0])
self.assertEqual(107, self.decrypted_message[1])
def test_two_byte_two_roll_invert(self):
for i in range(256):
self.init_business_logic([self.trans_linear_invert, self.trans_linear_invert], [[0], [0]])
self.crypt([i, (i + 1) & 0xff])
self.assertEqual(i, self.encrypted_message[0])
self.assertEqual((i + 2) & 0xff, self.encrypted_message[1])
self.assertEqual(i, self.decrypted_message[0])
self.assertEqual((i + 1) & 0xff, self.decrypted_message[1])
def test_three_byte_two_roll_turnover(self):
for i in range(256):
self.init_business_logic([self.trans_linear, self.trans_linear], [range(256), range(256)])
self.crypt([i, (i + 1) & 0xff, (i + 2) & 0xff])
self.assertEqual(i, self.encrypted_message[0])
self.assertEqual((i + 3) & 0xff, self.encrypted_message[1])
self.assertEqual((i + 6) & 0xff, self.encrypted_message[2])
self.assertEqual(i, self.decrypted_message[0])
self.assertEqual((i + 1) & 0xff, self.decrypted_message[1])
self.assertEqual((i + 2) & 0xff, self.decrypted_message[2])
def test_three_byte_two_different_roll_turnover(self):
self.init_business_logic([self.trans_linear, self.trans_shift_1], [range(4), range(4)])
self.crypt([7, 107])
self.assertEqual(8, self.encrypted_message[0])
self.assertEqual(110, self.encrypted_message[1])
self.assertEqual(7, self.decrypted_message[0])
self.assertEqual(107, self.decrypted_message[1])
def test_three_byte_two_different_roll_turnover3(self):
self.init_business_logic([self.trans_linear, self.trans_linear_invert], [range(4), range(4)])
self.crypt([7, 107])
self.assertEqual(248, self.encrypted_message[0])
self.assertEqual(146, self.encrypted_message[1])
self.assertEqual(7, self.decrypted_message[0])
self.assertEqual(107, self.decrypted_message[1])
def test_real_live(self):
msg_size = 65536
msg = bytearray(msg_size)
for i in range(msg_size):
msg[i] = i & 0xff
self.init_business_logic([self.trans_linear, self.trans_linear_invert, self.trans_shift_1, self.trans_shift_2],
[[0, 22, 44, 100], [11, 44, 122, 200], [33, 77, 99, 222], [55, 67, 79, 240]])
self.crypt(msg)
for i in range(msg_size):
self.assertEqual(msg[i], self.decrypted_message[i])
def test_integration(self):
key_file_name = "any.key"
msg_file_name = "msg.file"
msg_size = 65536 * 5
msg = bytearray(msg_size)
for i in range(msg_size):
msg[i] = i & 0xff
if os.path.exists(msg_file_name):
os.remove(msg_file_name)
file = open(msg_file_name, 'wb')
file.write(msg)
file.close()
program = Program(55)
program.run_main("keygen", 5, key_file_name)
program.run_main("enc", msg_file_name, key_file_name)
program2 = Program(55)
program2.run_main("dec", msg_file_name + ".enc", key_file_name)
file = open(msg_file_name + ".enc.dec", 'rb')
decypted = file.read()
file.close()
for i in range(msg_size):
self.assertEqual(msg[i], decypted[i])
self.assertEqual(msg_size, len(decypted))
if __name__ == '__main__':
unittest.main()
```