苟利国家生死以,岂因祸福避趋之

Panda Home

Base64 编码学习笔记( Java 实现)

发布于 # 聊聊技术
标签: # Base64 # Java
Base64 编码学习笔记( Java 实现)
Photo by Markus Spiske on Unsplash

但凡从事码工这一行,多多少少会遇到 Base64 编码这个概念,因为我们总要接触互联网,而 Base64 编码诞生的目的就是为了让二进制数据能够在只支持文本的媒介上传输,比如说在网络上传输一张图片或者一段音频。而 Base64 本身是一种无损编码转换规则,同时编码后的内容与原始内容差别非常大,所以很多时候大家在网上留联系方式的时候也喜欢用 Base64 转换一下,既能把信息传达给网友,又避免了充斥在网络上的各种机器人的骚扰,比如这位朋友的自我介绍,甚至还贴心地给出了完整的解码命令行。

理论

那么 Base64 的编码规则是怎样的?因为在计算机的世界中,所有的内容都以字节数组( Byte Array )的形式呈现,所以 Base64 的输入数据自然是字节数组,每一个字节有 8 位 bit , Base64 以每三个字节为一组,然后均分成四份,每一份有 6 个 bit ,而这 6 个 bit ,恰好就能对应到 64 个指定的 ASCII 字符上去。转换规则如下(抄自维基百科

IndexBinaryCharIndexBinaryCharIndexBinaryCharIndexBinaryChar
0000000A16010000Q32100000g48110000w
1000001B17010001R33100001h49110001x
2000010C18010010S34100010i50110010y
3000011D19010011T35100011j51110011z
4000100E20010100U36100100k521101000
5000101F21010101V37100101l531101011
6000110G22010110W38100110m541101102
7000111H23010111X39100111n551101113
8001000I24011000Y40101000o561110004
9001001J25011001Z41101001p571110015
10001010K26011010a42101010q581110106
11001011L27011011b43101011r591110117
12001100M28011100c44101100s601111008
13001101N29011101d45101101t611111019
14001110O30011110e46101110u62111110+
15001111P31011111f47101111v63111111/

那么如果输入的字节长度不是 3 的整数倍,最后的一个或两个字节岂不是就无法应用上述规则了吗?所以 Base64 规定,对于末尾的空位,用等号 = 补齐,因此转换而成的 Base64 编码的长度总是 4 的整数倍。

简而言之,给一段二进制数据进行 Base64 编码时,需要以下三步

  1. 将每三个字节归并到一组,末尾不足三位的留出来单独处理
  2. 将每一组 24 个 bit 分成四份,每份 6 个 bit ,从上表中找到相应的 ASCII 字符放到该位置
  3. 将末尾的一位或两位字节分别用零补齐到 12 位 bit 或 18 位 bit ,从表中找到相应的字符填空,末尾用等号 = 补齐四位

维基上也给出了不同情况下的转换示例

解码的过程就是编码的逆过程,同样地,也可以用三步来概括

  1. 将 Base64 字符串每四个分为一组,从上表中找到每个字符对应的 6 位二进制码,拼在一起成 24 位 bit 串
  2. 将这 24 位 bit 串均分为三份,每部分 8 个 bit 作为一个字节,直接放到解码结果相应的位置
  3. 最后的四位字符,拿掉末尾所有的等号,根据末尾等号的个数(一位还是两位)判断需要从末尾拿走几个零,最后解码为两位或一位字节

实践

为了证明自己会写 Java ,闲暇时用 Java + Maven 简单写了一个 Base64 的编码和解码方法。为了节约篇幅,这里省掉了类的定义以及依赖的引入。

编码

private final static byte[] encodeMap = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/".getBytes();
public static byte[] encode(byte[] plainBytes) {
    if (plainBytes.length == 0) {
        return "".getBytes();
    }
    int encodedLength = (plainBytes.length % 3 == 0) ? plainBytes.length/3 * 4 : (plainBytes.length/3 + 1) * 4;
    byte[] encodedBytes = new byte[encodedLength];
    int i = 0, j = 0;
    while (i < plainBytes.length / 3 * 3) {
        int value = plainBytes[i] << 16 | plainBytes[i+1] << 8 | plainBytes[i+2];
        encodedBytes[j] = encodeMap[value>>18&0x3f];
        encodedBytes[j+1] = encodeMap[value>>12&0x3f];
        encodedBytes[j+2] = encodeMap[value>>6&0x3f];
        encodedBytes[j+3] = encodeMap[value&0x3f];
        i += 3;
        j += 4;
    }
    int remains = plainBytes.length - i;
    if (remains > 0) {
        int value = plainBytes[i] << 16;
        if (remains == 2) {
            value |= plainBytes[i+1] << 8;
        }
        encodedBytes[j] = encodeMap[value>>18&0x3f];
        encodedBytes[j+1] = encodeMap[value>>12&0x3f];
        if (remains == 1) {
            encodedBytes[j+2] = '=';
            encodedBytes[j+3] = '=';
        } else if (remains == 2) {
            encodedBytes[j+2] = encodeMap[value>>6&0x3f];
            encodedBytes[j+3] = '=';
        }
    }
    return encodedBytes;
}

解码

// decodeMap 的初始化需借助上面的 encodeMap ,其实质上是一个 ASCII 字符到它在 encodeMap 中位置的映射
private final static Map<Byte, Integer> decodeMap = new HashMap<Byte, Integer>();
static {
    for (int i = 0; i < encodeMap.length; i++) {
        decodeMap.put(encodeMap[i], i);
    }
}
public static byte[] decode(byte[] encodedBytes) {
    if (encodedBytes.length == 0) {
        return "".getBytes();
    }
    int decodedLength = (encodedBytes.length - 4) / 4 * 3;
    if (encodedBytes[encodedBytes.length-1] == '=' && encodedBytes[encodedBytes.length-2] == '=') {
        decodedLength += 1;
    } else if (encodedBytes[encodedBytes.length-1] == '=') {
        decodedLength += 2;
    } else {
        decodedLength += 3;
    }
    byte[] decodedBytes = new byte[decodedLength];
    int i = 0, j = 0;
    while (i < encodedBytes.length - 4) {
        int value = decodeMap.get(encodedBytes[i])<<18 | decodeMap.get(encodedBytes[i+1])<<12 | decodeMap.get(encodedBytes[i+2])<<6 | decodeMap.get(encodedBytes[i+3]);
        decodedBytes[j] = (byte)(value>>16&0xff);
        decodedBytes[j+1] = (byte)(value>>8&0xff);
        decodedBytes[j+2] = (byte)(value&0xff);
        i += 4;
        j += 3;
    }
    if (decodedLength - j == 1) {
        int value = decodeMap.get(encodedBytes[i])<<18 | decodeMap.get(encodedBytes[i+1])<<12;
        decodedBytes[j] = (byte)(value>>16&0xff);
    } else if (decodedLength - j == 2) {
        int value = decodeMap.get(encodedBytes[i])<<18 | decodeMap.get(encodedBytes[i+1])<<12 | decodeMap.get(encodedBytes[i+2])<<6;
        decodedBytes[j] = (byte)(value>>16&0xff);
        decodedBytes[j+1] = (byte)(value>>8&0xff);
    } else {
        int value = decodeMap.get(encodedBytes[i])<<18 | decodeMap.get(encodedBytes[i+1])<<12 | decodeMap.get(encodedBytes[i+2])<<6 | decodeMap.get(encodedBytes[i+3]);
        decodedBytes[j] = (byte)(value>>16&0xff);
        decodedBytes[j+1] = (byte)(value>>8&0xff);
        decodedBytes[j+2] = (byte)(value&0xff);
    }
    return decodedBytes;
}