密码学课设1--简单spn实现

in this article

概述

本次密码学课设使用新的oj系统,要求学生自行实现密码学课本上的spn加解密算法。

[!NOTE] 这次课设允许使用的语言是c/cpp,要求使用内存不超过16mb,使用时间不超过2000ms。
测试分为七个,据推断,数据量分别为1,10,100,1k,10k,100k,1m条。

思路

开始之前

首先,我们需要明确的是:根据书本例子的要求,密码长度为32位,明文长度为16位,这几个数据类型会多次用到,所以我们在开始前,定义它们的宏:

typedef unsigned int UINT32;
typedef unsigned short UINT16;
typedef unsigned char UINT8;

密钥编排算法

样例中密钥编排算法较为简单,轮密钥是原始密钥每次左移四位后取前16位,我们可以用位运算轻松实现这个算法:

UINT16* getKeys(UINT32 key, UINT16* keys) {
    for (size_t i = 0; i < 5; i++)
    {
        keys[i] = key >> (16 - 4 * i);
    }
    return keys;
}

s盒代换

在我的设计中,s盒的定义是一个数组,直接通过地址进行映射。这样转换速度快,定义方便。代换和逆代换都可以使用一个函数。

UINT16 sTransform(UINT16 input, const UINT16* trans) {
    UINT16 fin = 0x0000;
    auto a = input >> 12;
    fin |= trans[a] << 12;
    a = (input & 0x0f00) >> 8;
    fin |= trans[a] << 8;
    a = (input & 0x00f0) >> 4;
    fin |= trans[a] << 4;
    a = input & 0x000f;
    fin |= trans[a] << 0;
    return fin;
}

p盒置换

这部分也不难,p盒定义同样由数组给出。

UINT16 pTransForm(UINT16 input, const UINT16* trans) {
    UINT16 fin = 0x0000;
    for (size_t i = 0; i < 16; i++)
    {
        auto pos = trans[i];
        auto mask = getBitSingle(16 - (i));
        auto bit = input & mask;
        int j = pos - 1 - i;
        if (j > 0)
        {
            fin |= bit >> j;
        }
        else {
            fin |= bit << -j;
        }
    }
    return fin;
}

加密和解密

然后就是将s盒和p盒组成spn网络了,这部分照书画瓢就可以,只需注意最后一个循环没有p置换:

UINT16 encrypt(UINT16 x, const UINT16* keys) {

    for (size_t i = 0; i < 4; i++)
    {
        x ^= keys[i];

        x = sTransform(x, s);
        if (i == 3)
        {
            break;
        }
        x = pTransForm(x, p);
    }
    x ^= keys[4];
    return x;
}

解密则是加密的逆过程,要注意传入的s盒的定义是原来s盒的逆:

UINT16 decrypt(UINT16 x, const UINT16* keys) {
    x ^= keys[4];
    x = sTransform(x, sReverse);
    for (size_t i = 3; i > 0; i--)
    {
        x ^= keys[i];
        x = pTransForm(x, p);
        x = sTransform(x, sReverse);
    }
    x ^= keys[0];
    return x;
}

main函数代码

按照要求完成功能即可

int main()
{
    UINT16 x = 0x0;
    UINT32 key = 0x0;
    UINT16 keys[5] = { 0,0,0,0,0 };
    UINT16 y;
    UINT16 mask;
    UINT16 temp;
    char skey[20];
    char sx[10];
    int num = 0;
    scanf("%d", &num);
    getchar();
    for (size_t i = 0; i < num; i++)
    {
        key = readKey();
        x = readx();
        getKeys(key, keys);
        y = encrypt(x, keys);
        mask = 1;
        temp = (!(y & mask)) | (y & (~mask));
        x = decrypt(temp, keys);
        printf("%04x %04x\n", y, x);
    }
}

[!IMPORTANT] oj系统的第七个测试的数据量极大,如果使用cincout会导致oj出现RE。因此务必使用c的IO完成输入和输出功能。

[!IMPORTANT] main中作为循环计数器的变量不能太小,因为oj的测试数据量非常大。如果采用shortunsigned short类型会导致测试的时候溢出。

[!TIP] 尽管oj系统的第七次测试无法通过含有cincout的代码,实际使用中使用cincout的代码无疑更加优雅,普遍,甚至效率更高。因此这里附上main函数采用cincout进行IO的写法:

int main()
{
    UINT16 x = 0x0;
    UINT32 key = 0x0;
    UINT16* keys = (UINT16*)malloc(5 * sizeof(UINT16));
    UINT16 y;
    UINT16 mask;
    UINT16 temp;
    int num = 0;
    cin >> num;
    for (size_t i = 0; i < num; i++)
    {
        cin >> hex >> key;
        cin >> hex >> x;
        getKeys(key, keys);
        y = encrypt(key, x, keys);
        mask = getBitSingle(1);
        temp = (!(y & mask)) | (y & (~mask));
        x = decrypt(key, temp, keys);
        cout << setw(4) << setfill('0') << hex << y << " "
            << setw(4) << setfill('0') << hex << x << endl;
    }
    free(keys);
}

源码

#include <iostream>
#pragma warning (disable:4996)
typedef unsigned int UINT32;
typedef unsigned short UINT16;
typedef unsigned char UINT8;

const UINT16 sReverse[] = { 14,3,4,8,1,12,10,15,7,13,9,6,11,2,0,5 };
const UINT16 p[] = { 1,5,9,13,2,6,10,14,3,7,11,15,4,8,12,16 };
const UINT16 s[] = { 14,4,13,1,2,15,11,8,3,10,6,12,5,9,0,7 };


UINT16 getBitSingle(int i) {
    UINT16 temp = 1;
    temp <<= (i - 1);
    return temp;
}
UINT32 readKey()
{
    UINT8 k = 0x0;
    char c;
    UINT32 key = 0x0;
    while ((c = getchar()) != '\n')
    {
        if (c==' ')
        {
            break;
        }
        if (c >= '0' && c <= '9')
            k = c - '0';
        else if (c >= 'a' && c <= 'f')
            k = c - 87;
        else if (c >= 'A' && c <= 'F')
            k = c - 55;
        key = key << 4;
        key = key | k;
    }
    return key;
}
UINT16 readx()
{
    UINT8 k = 0x0;
    char c;
    UINT16 x = 0x0;
    while ((c = getchar()) != '\n')
    {
        if (c == ' ')
        {
            break;
        }
        if (c >= '0' && c <= '9')
            k = c - '0';
        else if (c >= 'a' && c <= 'f')
            k = c - 87;
        else if (c >= 'A' && c <= 'F')
            k = c - 55;
        x = x << 4;
        x = x | k;
    }
    return x;
}

UINT16 sTransform(UINT16 input, const UINT16* trans) {
    UINT16 fin = 0x0000;
    auto a = input >> 12;
    fin |= trans[a] << 12;
    a = (input & 0x0f00) >> 8;
    fin |= trans[a] << 8;
    a = (input & 0x00f0) >> 4;
    fin |= trans[a] << 4;
    a = input & 0x000f;
    fin |= trans[a] << 0;
    return fin;
}
UINT16 pTransForm(UINT16 input, const UINT16* trans) {
    UINT16 fin = 0x0000;
    for (size_t i = 0; i < 16; i++)
    {
        auto pos = trans[i];
        auto mask = getBitSingle(16 - (i));
        auto bit = input & mask;
        int j = pos - 1 - i;
        if (j > 0)
        {
            fin |= bit >> j;
        }
        else {
            fin |= bit << -j;
        }
    }
    return fin;
}

//密匙编排算法
UINT16* getKeys(UINT32 key, UINT16* keys) {
    for (size_t i = 0; i < 5; i++)
    {
        keys[i] = key >> (16 - 4 * i);
    }
    return keys;
}
UINT16 encrypt(UINT16 x, const UINT16* keys) {

    for (size_t i = 0; i < 4; i++)
    {
        x ^= keys[i];

        x = sTransform(x, s);
        if (i == 3)
        {
            break;
        }
        x = pTransForm(x, p);
    }
    x ^= keys[4];
    return x;
}
UINT16 decrypt(UINT16 x, const UINT16* keys) {
    x ^= keys[4];
    x = sTransform(x, sReverse);
    for (size_t i = 3; i > 0; i--)
    {
        x ^= keys[i];
        x = pTransForm(x, p);
        x = sTransform(x, sReverse);
    }
    x ^= keys[0];
    return x;
}

int main()
{
    UINT16 x = 0x0;
    UINT32 key = 0x0;
    UINT16 keys[5] = { 0,0,0,0,0 };
    UINT16 y;
    UINT16 mask;
    UINT16 temp;
    char skey[20];
    char sx[10];
    int num = 0;
    scanf("%d", &num);
    getchar();
    for (size_t i = 0; i < num; i++)
    {
        key = readKey();
        x = readx();
        getKeys(key, keys);
        y = encrypt(x, keys);
        mask = 1;
        temp = (!(y & mask)) | (y & (~mask));
        x = decrypt(temp, keys);
        printf("%04x %04x\n", y, x);
    }
}

本文章使用limfx的vsocde插件快速发布