Crypto++
|
00001 // zinflate.cpp - written and placed in the public domain by Wei Dai 00002 00003 // This is a complete reimplementation of the DEFLATE decompression algorithm. 00004 // It should not be affected by any security vulnerabilities in the zlib 00005 // compression library. In particular it is not affected by the double free bug 00006 // (http://www.kb.cert.org/vuls/id/368819). 00007 00008 #include "pch.h" 00009 #include "zinflate.h" 00010 00011 NAMESPACE_BEGIN(CryptoPP) 00012 00013 struct CodeLessThan 00014 { 00015 inline bool operator()(CryptoPP::HuffmanDecoder::code_t lhs, const CryptoPP::HuffmanDecoder::CodeInfo &rhs) 00016 {return lhs < rhs.code;} 00017 // needed for MSVC .NET 2005 00018 inline bool operator()(const CryptoPP::HuffmanDecoder::CodeInfo &lhs, const CryptoPP::HuffmanDecoder::CodeInfo &rhs) 00019 {return lhs.code < rhs.code;} 00020 }; 00021 00022 inline bool LowFirstBitReader::FillBuffer(unsigned int length) 00023 { 00024 while (m_bitsBuffered < length) 00025 { 00026 byte b; 00027 if (!m_store.Get(b)) 00028 return false; 00029 m_buffer |= (unsigned long)b << m_bitsBuffered; 00030 m_bitsBuffered += 8; 00031 } 00032 assert(m_bitsBuffered <= sizeof(unsigned long)*8); 00033 return true; 00034 } 00035 00036 inline unsigned long LowFirstBitReader::PeekBits(unsigned int length) 00037 { 00038 bool result = FillBuffer(length); 00039 assert(result); 00040 return m_buffer & (((unsigned long)1 << length) - 1); 00041 } 00042 00043 inline void LowFirstBitReader::SkipBits(unsigned int length) 00044 { 00045 assert(m_bitsBuffered >= length); 00046 m_buffer >>= length; 00047 m_bitsBuffered -= length; 00048 } 00049 00050 inline unsigned long LowFirstBitReader::GetBits(unsigned int length) 00051 { 00052 unsigned long result = PeekBits(length); 00053 SkipBits(length); 00054 return result; 00055 } 00056 00057 inline HuffmanDecoder::code_t HuffmanDecoder::NormalizeCode(HuffmanDecoder::code_t code, unsigned int codeBits) 00058 { 00059 return code << (MAX_CODE_BITS - codeBits); 00060 } 00061 00062 void HuffmanDecoder::Initialize(const unsigned int *codeBits, unsigned int nCodes) 00063 { 00064 // the Huffman codes are represented in 3 ways in this code: 00065 // 00066 // 1. most significant code bit (i.e. top of code tree) in the least significant bit position 00067 // 2. most significant code bit (i.e. top of code tree) in the most significant bit position 00068 // 3. most significant code bit (i.e. top of code tree) in n-th least significant bit position, 00069 // where n is the maximum code length for this code tree 00070 // 00071 // (1) is the way the codes come in from the deflate stream 00072 // (2) is used to sort codes so they can be binary searched 00073 // (3) is used in this function to compute codes from code lengths 00074 // 00075 // a code in representation (2) is called "normalized" here 00076 // The BitReverse() function is used to convert between (1) and (2) 00077 // The NormalizeCode() function is used to convert from (3) to (2) 00078 00079 if (nCodes == 0) 00080 throw Err("null code"); 00081 00082 m_maxCodeBits = *std::max_element(codeBits, codeBits+nCodes); 00083 00084 if (m_maxCodeBits > MAX_CODE_BITS) 00085 throw Err("code length exceeds maximum"); 00086 00087 if (m_maxCodeBits == 0) 00088 throw Err("null code"); 00089 00090 // count number of codes of each length 00091 SecBlockWithHint<unsigned int, 15+1> blCount(m_maxCodeBits+1); 00092 std::fill(blCount.begin(), blCount.end(), 0); 00093 unsigned int i; 00094 for (i=0; i<nCodes; i++) 00095 blCount[codeBits[i]]++; 00096 00097 // compute the starting code of each length 00098 code_t code = 0; 00099 SecBlockWithHint<code_t, 15+1> nextCode(m_maxCodeBits+1); 00100 nextCode[1] = 0; 00101 for (i=2; i<=m_maxCodeBits; i++) 00102 { 00103 // compute this while checking for overflow: code = (code + blCount[i-1]) << 1 00104 if (code > code + blCount[i-1]) 00105 throw Err("codes oversubscribed"); 00106 code += blCount[i-1]; 00107 if (code > (code << 1)) 00108 throw Err("codes oversubscribed"); 00109 code <<= 1; 00110 nextCode[i] = code; 00111 } 00112 00113 if (code > (1 << m_maxCodeBits) - blCount[m_maxCodeBits]) 00114 throw Err("codes oversubscribed"); 00115 else if (m_maxCodeBits != 1 && code < (1 << m_maxCodeBits) - blCount[m_maxCodeBits]) 00116 throw Err("codes incomplete"); 00117 00118 // compute a vector of <code, length, value> triples sorted by code 00119 m_codeToValue.resize(nCodes - blCount[0]); 00120 unsigned int j=0; 00121 for (i=0; i<nCodes; i++) 00122 { 00123 unsigned int len = codeBits[i]; 00124 if (len != 0) 00125 { 00126 code = NormalizeCode(nextCode[len]++, len); 00127 m_codeToValue[j].code = code; 00128 m_codeToValue[j].len = len; 00129 m_codeToValue[j].value = i; 00130 j++; 00131 } 00132 } 00133 std::sort(m_codeToValue.begin(), m_codeToValue.end()); 00134 00135 // initialize the decoding cache 00136 m_cacheBits = STDMIN(9U, m_maxCodeBits); 00137 m_cacheMask = (1 << m_cacheBits) - 1; 00138 m_normalizedCacheMask = NormalizeCode(m_cacheMask, m_cacheBits); 00139 assert(m_normalizedCacheMask == BitReverse(m_cacheMask)); 00140 00141 if (m_cache.size() != size_t(1) << m_cacheBits) 00142 m_cache.resize(1 << m_cacheBits); 00143 00144 for (i=0; i<m_cache.size(); i++) 00145 m_cache[i].type = 0; 00146 } 00147 00148 void HuffmanDecoder::FillCacheEntry(LookupEntry &entry, code_t normalizedCode) const 00149 { 00150 normalizedCode &= m_normalizedCacheMask; 00151 const CodeInfo &codeInfo = *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode, CodeLessThan())-1); 00152 if (codeInfo.len <= m_cacheBits) 00153 { 00154 entry.type = 1; 00155 entry.value = codeInfo.value; 00156 entry.len = codeInfo.len; 00157 } 00158 else 00159 { 00160 entry.begin = &codeInfo; 00161 const CodeInfo *last = & *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode + ~m_normalizedCacheMask, CodeLessThan())-1); 00162 if (codeInfo.len == last->len) 00163 { 00164 entry.type = 2; 00165 entry.len = codeInfo.len; 00166 } 00167 else 00168 { 00169 entry.type = 3; 00170 entry.end = last+1; 00171 } 00172 } 00173 } 00174 00175 inline unsigned int HuffmanDecoder::Decode(code_t code, /* out */ value_t &value) const 00176 { 00177 assert(m_codeToValue.size() > 0); 00178 LookupEntry &entry = m_cache[code & m_cacheMask]; 00179 00180 code_t normalizedCode; 00181 if (entry.type != 1) 00182 normalizedCode = BitReverse(code); 00183 00184 if (entry.type == 0) 00185 FillCacheEntry(entry, normalizedCode); 00186 00187 if (entry.type == 1) 00188 { 00189 value = entry.value; 00190 return entry.len; 00191 } 00192 else 00193 { 00194 const CodeInfo &codeInfo = (entry.type == 2) 00195 ? entry.begin[(normalizedCode << m_cacheBits) >> (MAX_CODE_BITS - (entry.len - m_cacheBits))] 00196 : *(std::upper_bound(entry.begin, entry.end, normalizedCode, CodeLessThan())-1); 00197 value = codeInfo.value; 00198 return codeInfo.len; 00199 } 00200 } 00201 00202 bool HuffmanDecoder::Decode(LowFirstBitReader &reader, value_t &value) const 00203 { 00204 reader.FillBuffer(m_maxCodeBits); 00205 unsigned int codeBits = Decode(reader.PeekBuffer(), value); 00206 if (codeBits > reader.BitsBuffered()) 00207 return false; 00208 reader.SkipBits(codeBits); 00209 return true; 00210 } 00211 00212 // ************************************************************* 00213 00214 Inflator::Inflator(BufferedTransformation *attachment, bool repeat, int propagation) 00215 : AutoSignaling<Filter>(propagation) 00216 , m_state(PRE_STREAM), m_repeat(repeat), m_reader(m_inQueue) 00217 { 00218 Detach(attachment); 00219 } 00220 00221 void Inflator::IsolatedInitialize(const NameValuePairs ¶meters) 00222 { 00223 m_state = PRE_STREAM; 00224 parameters.GetValue("Repeat", m_repeat); 00225 m_inQueue.Clear(); 00226 m_reader.SkipBits(m_reader.BitsBuffered()); 00227 } 00228 00229 void Inflator::OutputByte(byte b) 00230 { 00231 m_window[m_current++] = b; 00232 if (m_current == m_window.size()) 00233 { 00234 ProcessDecompressedData(m_window + m_lastFlush, m_window.size() - m_lastFlush); 00235 m_lastFlush = 0; 00236 m_current = 0; 00237 m_wrappedAround = true; 00238 } 00239 } 00240 00241 void Inflator::OutputString(const byte *string, size_t length) 00242 { 00243 while (length) 00244 { 00245 size_t len = UnsignedMin(length, m_window.size() - m_current); 00246 memcpy(m_window + m_current, string, len); 00247 m_current += len; 00248 if (m_current == m_window.size()) 00249 { 00250 ProcessDecompressedData(m_window + m_lastFlush, m_window.size() - m_lastFlush); 00251 m_lastFlush = 0; 00252 m_current = 0; 00253 m_wrappedAround = true; 00254 } 00255 string += len; 00256 length -= len; 00257 } 00258 } 00259 00260 void Inflator::OutputPast(unsigned int length, unsigned int distance) 00261 { 00262 size_t start; 00263 if (distance <= m_current) 00264 start = m_current - distance; 00265 else if (m_wrappedAround && distance <= m_window.size()) 00266 start = m_current + m_window.size() - distance; 00267 else 00268 throw BadBlockErr(); 00269 00270 if (start + length > m_window.size()) 00271 { 00272 for (; start < m_window.size(); start++, length--) 00273 OutputByte(m_window[start]); 00274 start = 0; 00275 } 00276 00277 if (start + length > m_current || m_current + length >= m_window.size()) 00278 { 00279 while (length--) 00280 OutputByte(m_window[start++]); 00281 } 00282 else 00283 { 00284 memcpy(m_window + m_current, m_window + start, length); 00285 m_current += length; 00286 } 00287 } 00288 00289 size_t Inflator::Put2(const byte *inString, size_t length, int messageEnd, bool blocking) 00290 { 00291 if (!blocking) 00292 throw BlockingInputOnly("Inflator"); 00293 00294 LazyPutter lp(m_inQueue, inString, length); 00295 ProcessInput(messageEnd != 0); 00296 00297 if (messageEnd) 00298 if (!(m_state == PRE_STREAM || m_state == AFTER_END)) 00299 throw UnexpectedEndErr(); 00300 00301 Output(0, NULL, 0, messageEnd, blocking); 00302 return 0; 00303 } 00304 00305 bool Inflator::IsolatedFlush(bool hardFlush, bool blocking) 00306 { 00307 if (!blocking) 00308 throw BlockingInputOnly("Inflator"); 00309 00310 if (hardFlush) 00311 ProcessInput(true); 00312 FlushOutput(); 00313 00314 return false; 00315 } 00316 00317 void Inflator::ProcessInput(bool flush) 00318 { 00319 while (true) 00320 { 00321 switch (m_state) 00322 { 00323 case PRE_STREAM: 00324 if (!flush && m_inQueue.CurrentSize() < MaxPrestreamHeaderSize()) 00325 return; 00326 ProcessPrestreamHeader(); 00327 m_state = WAIT_HEADER; 00328 m_wrappedAround = false; 00329 m_current = 0; 00330 m_lastFlush = 0; 00331 m_window.New(1 << GetLog2WindowSize()); 00332 break; 00333 case WAIT_HEADER: 00334 { 00335 // maximum number of bytes before actual compressed data starts 00336 const size_t MAX_HEADER_SIZE = BitsToBytes(3+5+5+4+19*7+286*15+19*15); 00337 if (m_inQueue.CurrentSize() < (flush ? 1 : MAX_HEADER_SIZE)) 00338 return; 00339 DecodeHeader(); 00340 break; 00341 } 00342 case DECODING_BODY: 00343 if (!DecodeBody()) 00344 return; 00345 break; 00346 case POST_STREAM: 00347 if (!flush && m_inQueue.CurrentSize() < MaxPoststreamTailSize()) 00348 return; 00349 ProcessPoststreamTail(); 00350 m_state = m_repeat ? PRE_STREAM : AFTER_END; 00351 Output(0, NULL, 0, GetAutoSignalPropagation(), true); // TODO: non-blocking 00352 if (m_inQueue.IsEmpty()) 00353 return; 00354 break; 00355 case AFTER_END: 00356 m_inQueue.TransferTo(*AttachedTransformation()); 00357 return; 00358 } 00359 } 00360 } 00361 00362 void Inflator::DecodeHeader() 00363 { 00364 if (!m_reader.FillBuffer(3)) 00365 throw UnexpectedEndErr(); 00366 m_eof = m_reader.GetBits(1) != 0; 00367 m_blockType = (byte)m_reader.GetBits(2); 00368 switch (m_blockType) 00369 { 00370 case 0: // stored 00371 { 00372 m_reader.SkipBits(m_reader.BitsBuffered() % 8); 00373 if (!m_reader.FillBuffer(32)) 00374 throw UnexpectedEndErr(); 00375 m_storedLen = (word16)m_reader.GetBits(16); 00376 word16 nlen = (word16)m_reader.GetBits(16); 00377 if (nlen != (word16)~m_storedLen) 00378 throw BadBlockErr(); 00379 break; 00380 } 00381 case 1: // fixed codes 00382 m_nextDecode = LITERAL; 00383 break; 00384 case 2: // dynamic codes 00385 { 00386 if (!m_reader.FillBuffer(5+5+4)) 00387 throw UnexpectedEndErr(); 00388 unsigned int hlit = m_reader.GetBits(5); 00389 unsigned int hdist = m_reader.GetBits(5); 00390 unsigned int hclen = m_reader.GetBits(4); 00391 00392 FixedSizeSecBlock<unsigned int, 286+32> codeLengths; 00393 unsigned int i; 00394 static const unsigned int border[] = { // Order of the bit length code lengths 00395 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}; 00396 std::fill(codeLengths.begin(), codeLengths+19, 0); 00397 for (i=0; i<hclen+4; i++) 00398 codeLengths[border[i]] = m_reader.GetBits(3); 00399 00400 try 00401 { 00402 HuffmanDecoder codeLengthDecoder(codeLengths, 19); 00403 for (i = 0; i < hlit+257+hdist+1; ) 00404 { 00405 unsigned int k, count, repeater; 00406 bool result = codeLengthDecoder.Decode(m_reader, k); 00407 if (!result) 00408 throw UnexpectedEndErr(); 00409 if (k <= 15) 00410 { 00411 count = 1; 00412 repeater = k; 00413 } 00414 else switch (k) 00415 { 00416 case 16: 00417 if (!m_reader.FillBuffer(2)) 00418 throw UnexpectedEndErr(); 00419 count = 3 + m_reader.GetBits(2); 00420 if (i == 0) 00421 throw BadBlockErr(); 00422 repeater = codeLengths[i-1]; 00423 break; 00424 case 17: 00425 if (!m_reader.FillBuffer(3)) 00426 throw UnexpectedEndErr(); 00427 count = 3 + m_reader.GetBits(3); 00428 repeater = 0; 00429 break; 00430 case 18: 00431 if (!m_reader.FillBuffer(7)) 00432 throw UnexpectedEndErr(); 00433 count = 11 + m_reader.GetBits(7); 00434 repeater = 0; 00435 break; 00436 } 00437 if (i + count > hlit+257+hdist+1) 00438 throw BadBlockErr(); 00439 std::fill(codeLengths + i, codeLengths + i + count, repeater); 00440 i += count; 00441 } 00442 m_dynamicLiteralDecoder.Initialize(codeLengths, hlit+257); 00443 if (hdist == 0 && codeLengths[hlit+257] == 0) 00444 { 00445 if (hlit != 0) // a single zero distance code length means all literals 00446 throw BadBlockErr(); 00447 } 00448 else 00449 m_dynamicDistanceDecoder.Initialize(codeLengths+hlit+257, hdist+1); 00450 m_nextDecode = LITERAL; 00451 } 00452 catch (HuffmanDecoder::Err &) 00453 { 00454 throw BadBlockErr(); 00455 } 00456 break; 00457 } 00458 default: 00459 throw BadBlockErr(); // reserved block type 00460 } 00461 m_state = DECODING_BODY; 00462 } 00463 00464 bool Inflator::DecodeBody() 00465 { 00466 bool blockEnd = false; 00467 switch (m_blockType) 00468 { 00469 case 0: // stored 00470 assert(m_reader.BitsBuffered() == 0); 00471 while (!m_inQueue.IsEmpty() && !blockEnd) 00472 { 00473 size_t size; 00474 const byte *block = m_inQueue.Spy(size); 00475 size = UnsignedMin(m_storedLen, size); 00476 OutputString(block, size); 00477 m_inQueue.Skip(size); 00478 m_storedLen -= (word16)size; 00479 if (m_storedLen == 0) 00480 blockEnd = true; 00481 } 00482 break; 00483 case 1: // fixed codes 00484 case 2: // dynamic codes 00485 static const unsigned int lengthStarts[] = { 00486 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 00487 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258}; 00488 static const unsigned int lengthExtraBits[] = { 00489 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 00490 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0}; 00491 static const unsigned int distanceStarts[] = { 00492 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 00493 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 00494 8193, 12289, 16385, 24577}; 00495 static const unsigned int distanceExtraBits[] = { 00496 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 00497 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 00498 12, 12, 13, 13}; 00499 00500 const HuffmanDecoder& literalDecoder = GetLiteralDecoder(); 00501 const HuffmanDecoder& distanceDecoder = GetDistanceDecoder(); 00502 00503 switch (m_nextDecode) 00504 { 00505 case LITERAL: 00506 while (true) 00507 { 00508 if (!literalDecoder.Decode(m_reader, m_literal)) 00509 { 00510 m_nextDecode = LITERAL; 00511 break; 00512 } 00513 if (m_literal < 256) 00514 OutputByte((byte)m_literal); 00515 else if (m_literal == 256) // end of block 00516 { 00517 blockEnd = true; 00518 break; 00519 } 00520 else 00521 { 00522 if (m_literal > 285) 00523 throw BadBlockErr(); 00524 unsigned int bits; 00525 case LENGTH_BITS: 00526 bits = lengthExtraBits[m_literal-257]; 00527 if (!m_reader.FillBuffer(bits)) 00528 { 00529 m_nextDecode = LENGTH_BITS; 00530 break; 00531 } 00532 m_literal = m_reader.GetBits(bits) + lengthStarts[m_literal-257]; 00533 case DISTANCE: 00534 if (!distanceDecoder.Decode(m_reader, m_distance)) 00535 { 00536 m_nextDecode = DISTANCE; 00537 break; 00538 } 00539 case DISTANCE_BITS: 00540 bits = distanceExtraBits[m_distance]; 00541 if (!m_reader.FillBuffer(bits)) 00542 { 00543 m_nextDecode = DISTANCE_BITS; 00544 break; 00545 } 00546 m_distance = m_reader.GetBits(bits) + distanceStarts[m_distance]; 00547 OutputPast(m_literal, m_distance); 00548 } 00549 } 00550 } 00551 } 00552 if (blockEnd) 00553 { 00554 if (m_eof) 00555 { 00556 FlushOutput(); 00557 m_reader.SkipBits(m_reader.BitsBuffered()%8); 00558 if (m_reader.BitsBuffered()) 00559 { 00560 // undo too much lookahead 00561 SecBlockWithHint<byte, 4> buffer(m_reader.BitsBuffered() / 8); 00562 for (unsigned int i=0; i<buffer.size(); i++) 00563 buffer[i] = (byte)m_reader.GetBits(8); 00564 m_inQueue.Unget(buffer, buffer.size()); 00565 } 00566 m_state = POST_STREAM; 00567 } 00568 else 00569 m_state = WAIT_HEADER; 00570 } 00571 return blockEnd; 00572 } 00573 00574 void Inflator::FlushOutput() 00575 { 00576 if (m_state != PRE_STREAM) 00577 { 00578 assert(m_current >= m_lastFlush); 00579 ProcessDecompressedData(m_window + m_lastFlush, m_current - m_lastFlush); 00580 m_lastFlush = m_current; 00581 } 00582 } 00583 00584 struct NewFixedLiteralDecoder 00585 { 00586 HuffmanDecoder * operator()() const 00587 { 00588 unsigned int codeLengths[288]; 00589 std::fill(codeLengths + 0, codeLengths + 144, 8); 00590 std::fill(codeLengths + 144, codeLengths + 256, 9); 00591 std::fill(codeLengths + 256, codeLengths + 280, 7); 00592 std::fill(codeLengths + 280, codeLengths + 288, 8); 00593 std::auto_ptr<HuffmanDecoder> pDecoder(new HuffmanDecoder); 00594 pDecoder->Initialize(codeLengths, 288); 00595 return pDecoder.release(); 00596 } 00597 }; 00598 00599 struct NewFixedDistanceDecoder 00600 { 00601 HuffmanDecoder * operator()() const 00602 { 00603 unsigned int codeLengths[32]; 00604 std::fill(codeLengths + 0, codeLengths + 32, 5); 00605 std::auto_ptr<HuffmanDecoder> pDecoder(new HuffmanDecoder); 00606 pDecoder->Initialize(codeLengths, 32); 00607 return pDecoder.release(); 00608 } 00609 }; 00610 00611 const HuffmanDecoder& Inflator::GetLiteralDecoder() const 00612 { 00613 return m_blockType == 1 ? Singleton<HuffmanDecoder, NewFixedLiteralDecoder>().Ref() : m_dynamicLiteralDecoder; 00614 } 00615 00616 const HuffmanDecoder& Inflator::GetDistanceDecoder() const 00617 { 00618 return m_blockType == 1 ? Singleton<HuffmanDecoder, NewFixedDistanceDecoder>().Ref() : m_dynamicDistanceDecoder; 00619 } 00620 00621 NAMESPACE_END