import os
import random
import rsa
from math import ceil
from base64 import b64encode, b64decode
from Crypto.Cipher import AES
class RSAWrapper:
def __init__(self, keygen=True, keylength=1024, localpem=False) -> None:
self.keylength = keylength
self.chunk = keylength//8-11
if keygen:
pubkey_pem, privkey_pem = 'pubkey.pem', 'privkey.pem'
if not localpem:
usr_ssh_home = os.path.join(os.path.expanduser('~'), '.ssh')
pubkey_pem = os.path.join(usr_ssh_home, pubkey_pem)
privkey_pem = os.path.join(usr_ssh_home, privkey_pem)
if os.path.exists(pubkey_pem) and os.path.exists(privkey_pem):
with open(pubkey_pem, 'rb') as f:
self.pubkey = rsa.PublicKey.load_pkcs1(f.read())
with open(privkey_pem, 'rb') as f:
self.privkey = rsa.PrivateKey.load_pkcs1(f.read())
else:
pubkey, privkey = rsa.newkeys(keylength)
with open(pubkey_pem, 'wb') as f:
f.write(pubkey.save_pkcs1())
with open(privkey_pem, 'wb') as f:
f.write(privkey.save_pkcs1())
self.pubkey = pubkey
self.privkey = privkey
def get_pubkey(self) -> str:
return str(self.pubkey.save_pkcs1(), 'utf-8')
def load_pubkey(self, pubkey) -> None:
self.pubkey = rsa.PublicKey.load_pkcs1(pubkey)
def encrypt(self, message) -> str:
chunk = self.chunk
divide = ceil(len(message)/float(chunk))
cryptolalia = b''
for i in range(divide):
cryptolalia += rsa.encrypt(message[i *
chunk:(i+1)*chunk].encode(), self.pubkey)
return str(b64encode(cryptolalia), 'utf-8')
def decrypt(self, cryptolalia) -> str:
cryptolalia = b64decode(cryptolalia)
chunk = self.keylength//8
divide = len(cryptolalia)//chunk
message = ''
for i in range(divide):
message += rsa.decrypt(cryptolalia[i *
chunk:(i+1)*chunk], self.privkey).decode()
return message
def sign(self, message) -> str:
signature = rsa.sign(message.encode(), self.privkey, 'SHA-256')
return str(b64encode(signature), 'utf-8')
def verify(self, message, signature):
signature = b64decode(signature)
return rsa.verify(message.encode(), signature, self.pubkey)
class AESWrapper:
alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ/=+'
def __init__(self, keygen=True, keylength=16) -> None:
self.keylength = keylength
if keygen:
key = ''
for _ in range(keylength):
key += random.choice(self.alphabet)
self.key = key
self.cryptor = AES.new(key, AES.MODE_ECB)
def get_key(self) -> str:
return self.key
def load_key(self, key) -> None:
self.key = key
self.keylength = len(key)
self.cryptor = AES.new(key, AES.MODE_ECB)
def encrypt(self, message) -> str:
chunk = self.keylength
divide = ceil(len(message)/float(chunk))
r = len(message) % chunk
message += '' if r == 0 else ' '*(chunk-r)
cryptolalia = b''
for i in range(divide):
cryptolalia += self.cryptor.encrypt(
message[i*chunk:(i+1)*chunk].encode())
return str(b64encode(cryptolalia), 'utf-8')
def decrypt(self, cryptolalia) -> str:
cryptolalia = b64decode(cryptolalia)
chunk = self.keylength
divide = len(cryptolalia)//chunk
message = ''
for i in range(divide):
message += self.cryptor.decrypt(
cryptolalia[i*chunk:(i+1)*chunk]).decode()
return message.rstrip()
if __name__ == '__main__':
import time
print('RSA Test', end='\n\n')
myrsa = RSAWrapper()
text = 'my rsa pubkey is\n' + myrsa.get_pubkey()
print('text:')
print(text, end='\n\n')
encrypt_start = time.time()
cryptolalia = myrsa.encrypt(text)
sign_start = time.time()
signature = myrsa.sign(text)
decrypt_start = time.time()
decrypt_text = myrsa.decrypt(cryptolalia)
decrypt_end = time.time()
print('cryptolalia:', sign_start-encrypt_start)
print(cryptolalia, end='\n\n')
print('signature:', decrypt_start-sign_start)
print(signature, end='\n\n')
print('decrypt_text:', decrypt_end-decrypt_start)
print(decrypt_text, end='\n\n')
myrsa.verify(decrypt_text, signature)
print('AES Test', end='\n\n')
myaes = AESWrapper()
encrypt_start = time.time()
cryptolalia = myaes.encrypt(text)
decrypt_start = time.time()
decrypt_text = myaes.decrypt(cryptolalia)
decrypt_end = time.time()
print('cryptolalia:', decrypt_start-encrypt_start)
print(cryptolalia, end='\n\n')
print('decrypt_text:', decrypt_end-decrypt_start)
print(decrypt_text, end='\n\n')