nghttp2_hd: Simplify huffman encode

This commit is contained in:
Tatsuhiro Tsujikawa 2013-12-19 22:54:34 +09:00
parent 0af0bd4362
commit 7a9eca1f7d
4 changed files with 560 additions and 579 deletions

View File

@ -26,6 +26,7 @@
#include <string.h> #include <string.h>
#include <assert.h> #include <assert.h>
#include <stdio.h>
#include "nghttp2_hd.h" #include "nghttp2_hd.h"
@ -44,14 +45,12 @@ extern const int16_t res_huff_decode_table[][256];
static uint8_t get_prefix_byte(const uint8_t *in, size_t len, size_t bitoff) static uint8_t get_prefix_byte(const uint8_t *in, size_t len, size_t bitoff)
{ {
uint8_t b; uint8_t b;
size_t bitleft;
if(bitoff == 0) { if(bitoff == 0) {
return *in; return *in;
} }
bitleft = 8 - bitoff; b = *in << bitoff;
b = (*in & ((1 << bitleft) - 1)) << bitoff;
if(len > 1) { if(len > 1) {
b |= *(in + 1) >> bitleft; b |= *(in + 1) >> (8 - bitoff);
} }
return b; return b;
} }
@ -90,57 +89,33 @@ static int huff_decode(const uint8_t *in, size_t len, size_t bitoff,
} }
return rv; return rv;
} }
/* /*
* Returns next LSB aligned |nbits| bits from huffman symbol |sym|, * Encodes huffman code |sym| into |*dest_ptr|, whose least |rembits|
* starting |codebitoff| bit offset (from beginning of code sequence, * bits are not filled yet. The |rembits| must be in range [1, 8],
* so it could be more than 8). * inclusive. At the end of the process, the |*dest_ptr| is updated
* and points where next output should be placed. The number of
* unfilled bits in the pointed location is returned.
*/ */
static uint8_t huff_get_lsb_aligned(const nghttp2_huff_sym *sym, static size_t huff_encode_sym(uint8_t **dest_ptr, size_t rembits,
size_t codebitoff,
size_t nbits)
{
size_t codeidx = codebitoff / 8;
uint8_t a = sym->code[codeidx];
size_t localbitoff = codebitoff & 0x7;
size_t bitleft = 8 - localbitoff;
if(bitleft >= nbits) {
return (a >> (bitleft - nbits)) & ((1 << nbits) - 1);
} else {
size_t right = nbits - bitleft;
a &= ((1 << bitleft) - 1);
a <<= right;
if((sym->nbits + 7) / 8 > codeidx + 1) {
a |= sym->code[codeidx + 1] >> (8 - right);
}
return a;
}
}
/*
* Encodes huffman code |sym| into |*dest_ptr|,starting |bitoff|
* offset. The |bitoff| must be strictly less than 8. At the end of
* the process, the |*dest_ptr| is updated and points where next
* output should be placed. The bit offset of the pointed location is
* returned.
*/
static size_t huff_encode_sym(uint8_t **dest_ptr, size_t bitoff,
const nghttp2_huff_sym *sym) const nghttp2_huff_sym *sym)
{ {
size_t b = 0; size_t nbits = sym->nbits;
if(bitoff == 0) **dest_ptr = 0; for(;;) {
**dest_ptr |= huff_get_lsb_aligned(sym, b, 8 - bitoff); if(rembits > nbits) {
b += 8 - bitoff; **dest_ptr |= sym->code << (rembits - nbits);
++*dest_ptr; rembits -= nbits;
for(; b < sym->nbits; b += 8, ++*dest_ptr) { break;
**dest_ptr = huff_get_lsb_aligned(sym, b, 8); }
**dest_ptr |= sym->code >> (nbits - rembits);
++*dest_ptr;
nbits -= rembits;
rembits = 8;
if(nbits == 0) {
break;
}
**dest_ptr = 0;
} }
bitoff = 8 - (b - sym->nbits); return rembits;
if(bitoff > 0) {
--*dest_ptr;
}
return bitoff;
} }
size_t nghttp2_hd_huff_encode_count(const uint8_t *src, size_t len, size_t nghttp2_hd_huff_encode_count(const uint8_t *src, size_t len,
@ -166,7 +141,7 @@ ssize_t nghttp2_hd_huff_encode(uint8_t *dest, size_t destlen,
const uint8_t *src, size_t srclen, const uint8_t *src, size_t srclen,
nghttp2_hd_side side) nghttp2_hd_side side)
{ {
int bitoff = 0; int rembits = 8;
uint8_t *dest_first = dest; uint8_t *dest_first = dest;
size_t i; size_t i;
const nghttp2_huff_sym *huff_sym_table; const nghttp2_huff_sym *huff_sym_table;
@ -178,13 +153,18 @@ ssize_t nghttp2_hd_huff_encode(uint8_t *dest, size_t destlen,
} }
for(i = 0; i < srclen; ++i) { for(i = 0; i < srclen; ++i) {
const nghttp2_huff_sym *sym = &huff_sym_table[src[i]]; const nghttp2_huff_sym *sym = &huff_sym_table[src[i]];
bitoff = huff_encode_sym(&dest, bitoff, sym); if(rembits == 8) {
*dest = 0;
}
rembits = huff_encode_sym(&dest, rembits, sym);
} }
/* 256 is special terminal symbol, pad with its prefix */ /* 256 is special terminal symbol, pad with its prefix */
if(bitoff > 0) { if(rembits < 8) {
*dest |= huff_sym_table[256].code[0] >> bitoff; const nghttp2_huff_sym *sym = &huff_sym_table[256];
*dest |= sym->code >> (sym->nbits - rembits);
++dest;
} }
return dest - dest_first + (bitoff > 0); return dest - dest_first;
} }
static int check_last_byte(const uint8_t *src, size_t srclen, size_t idx, static int check_last_byte(const uint8_t *src, size_t srclen, size_t idx,

View File

@ -35,9 +35,9 @@ typedef int16_t huff_decode_table_type[256];
typedef struct { typedef struct {
/* The number of bits in this code */ /* The number of bits in this code */
size_t nbits; uint32_t nbits;
/* Code sequence padded with 0 */ /* Huffman code aligned to LSB */
uint8_t code[4]; uint32_t code;
} nghttp2_huff_sym; } nghttp2_huff_sym;
#endif /* NGHTTP2_HD_HUFFMAN_H */ #endif /* NGHTTP2_HD_HUFFMAN_H */

File diff suppressed because it is too large Load Diff

View File

@ -44,23 +44,25 @@ root = Node(0)
nodes.append(root) nodes.append(root)
for line in sys.stdin: for line in sys.stdin:
m = re.match(r'.*\(\s*(\d+)\) ([|01]+) \[(\d+)\] .*', line) m = re.match(r'.*\(\s*(\d+)\) ([|01]+) \[(\d+)\]\s+(\S+).*', line)
if m: if m:
#print m.group(1), m.group(2), m.group(3) #print m.group(1), m.group(2), m.group(3)
if len(m.group(4)) > 8:
raise Error('Code is more than 4 bytes long')
sym = int(m.group(1)) sym = int(m.group(1))
pat = re.sub(r'\|', '', m.group(2)) pat = re.sub(r'\|', '', m.group(2))
nbits = int(m.group(3)) nbits = int(m.group(3))
assert(len(pat) == nbits) assert(len(pat) == nbits)
binpat = to_bin(pat) binpat = to_bin(pat)
assert(len(binpat) == (nbits+7)/8) assert(len(binpat) == (nbits+7)/8)
symbol_tbl[sym] = (binpat, nbits) symbol_tbl[sym] = (binpat, nbits, m.group(4))
#print "Inserting", sym #print "Inserting", sym
insert(root, sym, binpat, nbits, 0) insert(root, sym, binpat, nbits, 0)
print '''\ print '''\
typedef struct { typedef struct {
size_t nbits; uint32_t nbits;
uint8_t code[4]; uint32_t code;
} nghttp2_huff_sym; } nghttp2_huff_sym;
''' '''
@ -70,9 +72,8 @@ for i in range(257):
pat = list(symbol_tbl[i][0]) pat = list(symbol_tbl[i][0])
pat += [0]*(4 - len(pat)) pat += [0]*(4 - len(pat))
print '''\ print '''\
{{ {}, {{ {} }} }}{}\ {{ {}, 0x{}u }}{}\
'''.format(symbol_tbl[i][1], ', '.join([str(k) for k in pat]), '''.format(symbol_tbl[i][1], symbol_tbl[i][2], ',' if i < 256 else '')
',' if i < 256 else '')
print '};' print '};'
print '' print ''