LCOV - code coverage report
Current view: top level - lib/compression - lzxpress_huffman.c (source / functions) Hit Total Coverage
Test: coverage report for master 2f515e9b Lines: 584 798 73.2 %
Date: 2024-04-21 15:09:00 Functions: 28 33 84.8 %

          Line data    Source code
       1             : /*
       2             :  * Samba compression library - LGPLv3
       3             :  *
       4             :  * Copyright © Catalyst IT 2022
       5             :  *
       6             :  * Written by Douglas Bagnall <douglas.bagnall@catalyst.net.nz>
       7             :  *        and Jo Sutton       <josutton@catalyst.net.nz>
       8             :  *
       9             :  *  ** NOTE! The following LGPL license applies to this file.
      10             :  *  ** It does NOT imply that all of Samba is released under the LGPL
      11             :  *
      12             :  *  This library is free software; you can redistribute it and/or
      13             :  *  modify it under the terms of the GNU Lesser General Public
      14             :  *  License as published by the Free Software Foundation; either
      15             :  *  version 3 of the License, or (at your option) any later version.
      16             :  *
      17             :  *  This library is distributed in the hope that it will be useful,
      18             :  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
      19             :  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
      20             :  *  Lesser General Public License for more details.
      21             :  *
      22             :  *  You should have received a copy of the GNU Lesser General Public
      23             :  *  License along with this library; if not, see <http://www.gnu.org/licenses/>.
      24             :  */
      25             : 
      26             : #include <talloc.h>
      27             : 
      28             : #include "replace.h"
      29             : #include "lzxpress_huffman.h"
      30             : #include "lib/util/stable_sort.h"
      31             : #include "lib/util/debug.h"
      32             : #include "lib/util/byteorder.h"
      33             : #include "lib/util/bytearray.h"
      34             : 
      35             : /*
      36             :  * DEBUG_NO_LZ77_MATCHES toggles the encoding of matches as matches. If it is
      37             :  * false the potential match is written as a series of literals, which is a
      38             :  * valid but usually inefficient encoding. This is useful for isolating a
      39             :  * problem to either the LZ77 or the Huffman stage.
      40             :  */
      41             : #ifndef DEBUG_NO_LZ77_MATCHES
      42             : #define DEBUG_NO_LZ77_MATCHES false
      43             : #endif
      44             : 
      45             : /*
      46             :  * DEBUG_HUFFMAN_TREE forces the drawing of ascii art huffman trees during
      47             :  * compression and decompression.
      48             :  *
      49             :  * These trees will also be drawn at DEBUG level 10, but that doesn't work
      50             :  * with cmocka tests.
      51             :  */
      52             : #ifndef DEBUG_HUFFMAN_TREE
      53             : #define DEBUG_HUFFMAN_TREE false
      54             : #endif
      55             : 
      56             : #if DEBUG_HUFFMAN_TREE
      57             : #define DBG(...) fprintf(stderr, __VA_ARGS__)
      58             : #else
      59             : #define DBG(...) DBG_INFO(__VA_ARGS__)
      60             : #endif
      61             : 
      62             : 
      63             : #define LZXPRESS_ERROR -1LL
      64             : 
      65             : /*
      66             :  * We won't encode a match length longer than MAX_MATCH_LENGTH.
      67             :  *
      68             :  * Reports are that Windows has a limit at 64M.
      69             :  */
      70             : #define MAX_MATCH_LENGTH (64 * 1024 * 1024)
      71             : 
      72             : 
      73             : struct bitstream {
      74             :         const uint8_t *bytes;
      75             :         size_t byte_pos;
      76             :         size_t byte_size;
      77             :         uint32_t bits;
      78             :         int remaining_bits;
      79             :         uint16_t *table;
      80             : };
      81             : 
      82             : 
      83             : #if ! defined __has_builtin
      84             : #define __has_builtin(x) 0
      85             : #endif
      86             : 
      87             : /*
      88             :  * bitlen_nonzero_16() returns the bit number of the most significant bit, or
      89             :  * put another way, the integer log base 2. Log(0) is undefined; the argument
      90             :  * has to be non-zero!
      91             :  * 1     -> 0
      92             :  * 2,3   -> 1
      93             :  * 4-7   -> 2
      94             :  * 1024  -> 10, etc
      95             :  *
      96             :  * Probably this is handled by a compiler intrinsic function that maps to a
      97             :  * dedicated machine instruction.
      98             :  */
      99             : 
     100    16980985 : static inline int bitlen_nonzero_16(uint16_t x)
     101             : {
     102             : #if  __has_builtin(__builtin_clz)
     103             : 
     104             :         /* __builtin_clz returns the number of leading zeros */
     105    15237141 :         return (sizeof(unsigned int) * CHAR_BIT) - 1
     106    16980985 :                 - __builtin_clz((unsigned int) x);
     107             : 
     108             : #else
     109             : 
     110             :         int count = -1;
     111             :         while(x) {
     112             :                 x >>= 1;
     113             :                 count++;
     114             :         }
     115             :         return count;
     116             : 
     117             : #endif
     118             : }
     119             : 
     120             : 
     121             : struct lzxhuff_compressor_context {
     122             :         const uint8_t *input_bytes;
     123             :         size_t input_size;
     124             :         size_t input_pos;
     125             :         size_t prev_block_pos;
     126             :         uint8_t *output;
     127             :         size_t available_size;
     128             :         size_t output_pos;
     129             : };
     130             : 
     131     1050199 : static int compare_huffman_node_count(struct huffman_node *a,
     132             :                                       struct huffman_node *b)
     133             : {
     134     1050199 :         return a->count - b->count;
     135             : }
     136             : 
     137      852945 : static int compare_huffman_node_depth(struct huffman_node *a,
     138             :                                       struct huffman_node *b)
     139             : {
     140      852945 :         int c = a->depth - b->depth;
     141      852945 :         if (c != 0) {
     142       22828 :                 return c;
     143             :         }
     144      687969 :         return (int)a->symbol - (int)b->symbol;
     145             : }
     146             : 
     147             : 
     148             : #define HASH_MASK ((1 << LZX_HUFF_COMP_HASH_BITS) - 1)
     149             : 
     150    10568868 : static inline uint16_t three_byte_hash(const uint8_t *bytes)
     151             : {
     152             :         /*
     153             :          * MS-XCA says "three byte hash", but does not specify it.
     154             :          *
     155             :          * This one is just cobbled together, but has quite good distribution
     156             :          * in the 12-14 bit forms, which is what we care about most.
     157             :          * e.g: 13 bit: median 2048, min 2022, max 2074, stddev 6.0
     158             :          */
     159    10568868 :         uint16_t a = bytes[0];
     160    10568868 :         uint16_t b = bytes[1] ^ 0x2e;
     161    10568868 :         uint16_t c = bytes[2] ^ 0x55;
     162    10568868 :         uint16_t ca = c - a;
     163    10568868 :         uint16_t d = ((a + b) << 8) ^ (ca << 5) ^ (c + b) ^ (0xcab + a);
     164    10568868 :         return d & HASH_MASK;
     165             : }
     166             : 
     167             : 
     168     3124551 : static inline uint16_t encode_match(size_t len, size_t offset)
     169             : {
     170     3124551 :         uint16_t code = 256;
     171     3124551 :         code |= MIN(len - 3, 15);
     172     3124551 :         code |= bitlen_nonzero_16(offset) << 4;
     173     3124551 :         return code;
     174             : }
     175             : 
     176             : /*
     177             :  * debug_huffman_tree() uses debug_huffman_tree_print() to draw the Huffman
     178             :  * tree in ascii art.
     179             :  *
     180             :  * Note that the Huffman tree is probably not the same as that implied by the
     181             :  * canonical Huffman encoding that is finally used. That tree would be the
     182             :  * same shape, but with the left and right toggled to sort the branches by
     183             :  * length, after which the symbols for each length sorted by value.
     184             :  */
     185             : 
     186           0 : static void debug_huffman_tree_print(struct huffman_node *node,
     187             :                                      int *trail, int depth)
     188             : {
     189           0 :         if (node->left == NULL) {
     190             :                 /* time to print a row */
     191           0 :                 int j;
     192           0 :                 bool branched = false;
     193           0 :                 int row[17];
     194           0 :                 char c[100];
     195           0 :                 int s = node->symbol;
     196           0 :                 char code[17];
     197           0 :                 if (depth > 15) {
     198           0 :                         fprintf(stderr,
     199             :                                 " \033[1;31m Max depth exceeded! (%d)\033[0m "
     200             :                                 " symbol %#3x claimed depth %d count %d\n",
     201           0 :                                 depth, node->symbol, node->depth, node->count);
     202           0 :                         return;
     203             :                 }
     204           0 :                 for (j = depth - 1; j >= 0; j--) {
     205           0 :                         if (branched) {
     206           0 :                                 if (trail[j] == -1) {
     207           0 :                                         row[j] = -3;
     208             :                                 } else {
     209           0 :                                         row[j] = -2;
     210             :                                 }
     211           0 :                         } else if (trail[j] == -1) {
     212           0 :                                 row[j] = -1;
     213           0 :                                 branched = true;
     214             :                         } else {
     215           0 :                                 row[j] = trail[j];
     216             :                         }
     217             :                 }
     218           0 :                 for (j = 0; j < depth; j++) {
     219           0 :                         switch (row[j]) {
     220           0 :                         case -3:
     221           0 :                                 code[j] = '1';
     222           0 :                                 fprintf(stderr, "        ");
     223           0 :                                 break;
     224           0 :                         case -2:
     225           0 :                                 code[j] = '0';
     226           0 :                                 fprintf(stderr, "      │ ");
     227           0 :                                 break;
     228           0 :                         case -1:
     229           0 :                                 code[j] = '1';
     230           0 :                                 fprintf(stderr, "      ╰─");
     231           0 :                                 break;
     232           0 :                         default:
     233           0 :                                 code[j] = '0';
     234           0 :                                 fprintf(stderr, "%5d─┬─", row[j]);
     235           0 :                                 break;
     236             :                         }
     237             :                 }
     238           0 :                 code[depth] = 0;
     239           0 :                 if (s < 32) {
     240           0 :                         snprintf(c, sizeof(c),
     241             :                                 "\033[1;32m%02x\033[0m \033[1;33m%c%c%c\033[0m",
     242             :                                  s,
     243             :                                  0xE2, 0x90, 0x80 + s); /* utf-8 for symbol */
     244           0 :                 }  else if (s < 127) {
     245           0 :                         snprintf(c, sizeof(c),
     246             :                                  "\033[1;32m%2x\033[0m '\033[10;32m%c\033[0m'",
     247             :                                  s, s);
     248           0 :                 } else if (s < 256) {
     249           0 :                         snprintf(c, sizeof(c), "\033[1;32m%2x\033[0m", s);
     250             :                 } else {
     251           0 :                         uint16_t len = (s & 15) + 3;
     252           0 :                         uint16_t dbits = ((s >> 4) & 15) + 1;
     253           0 :                         snprintf(c, sizeof(c),
     254             :                                  " \033[0;33mlen:%2d%s, "
     255             :                                  "dist:%d-%d \033[0m \033[1;32m%3x\033[0m%s",
     256             :                                  len,
     257             :                                  len == 18 ? "+" : "",
     258           0 :                                  1 << (dbits - 1),
     259           0 :                                  (1 << dbits) - 1,
     260             :                                  s,
     261             :                                  s == 256 ? " \033[1;31mEOF\033[0m" : "");
     262             : 
     263             :                 }
     264             : 
     265           0 :                 fprintf(stderr, "──%5d %s \033[2;37m%s\033[0m\n",
     266             :                         node->count, c, code);
     267           0 :                 return;
     268             :         }
     269           0 :         trail[depth] = node->count;
     270           0 :         debug_huffman_tree_print(node->left, trail, depth + 1);
     271           0 :         trail[depth] = -1;
     272           0 :         debug_huffman_tree_print(node->right, trail, depth + 1);
     273             : }
     274             : 
     275             : 
     276             : /*
     277             :  * If DEBUG_HUFFMAN_TREE is defined true, debug_huffman_tree()
     278             :  * will print a tree looking something like this:
     279             :  *
     280             :  *     7─┬───    3  len:18+, dist:1-1  10f 0
     281             :  *       ╰─    4─┬─    2─┬───    1 61 'a' 100
     282             :  *               │       ╰───    1 62 'b' 101
     283             :  *               ╰─    2─┬───    1 63 'c' 110
     284             :  *                       ╰───    1  len: 3, dist:1-1  100 EOF 111
     285             :  *
     286             :  * This is based off a Huffman root node, and the tree may not be the same as
     287             :  * the canonical tree.
     288             :  */
     289           0 : static void debug_huffman_tree(struct huffman_node *root)
     290             : {
     291           0 :         int trail[17];
     292           0 :         debug_huffman_tree_print(root, trail, 0);
     293           0 : }
     294             : 
     295             : 
     296             : /*
     297             :  * If DEBUG_HUFFMAN_TREE is defined true, debug_huffman_tree_from_table()
     298             :  * will print something like this based on a decoding symbol table.
     299             :  *
     300             :  *  Tree from decoding table 9 nodes → 5 codes
     301             :  * 10000─┬─── 5000  len:18+, dist:1-1  10f 0
     302             :  *       ╰─ 5000─┬─ 2500─┬─── 1250 61 'a' 100
     303             :  *               │       ╰─── 1250 62 'b' 101
     304             :  *               ╰─ 2500─┬─── 1250 63 'c' 110
     305             :  *                       ╰─── 1250  len: 3, dist:1-1  100 EOF 111
     306             :  *
     307             :  * This is the canonical form of the Huffman tree where the actual counts
     308             :  * aren't known (we use "10000" to help indicate relative frequencies).
     309             :  */
     310           0 : static void debug_huffman_tree_from_table(uint16_t *table)
     311             : {
     312           0 :         int trail[17];
     313           0 :         struct huffman_node nodes[1024] = {{0}};
     314           0 :         uint16_t codes[1024];
     315           0 :         size_t n = 1;
     316           0 :         size_t i = 0;
     317           0 :         codes[0] = 0;
     318           0 :         nodes[0].count = 10000;
     319             : 
     320           0 :         while (i < n) {
     321           0 :                 uint16_t index = codes[i];
     322           0 :                 struct huffman_node *node = &nodes[i];
     323           0 :                 if (table[index] == 0xffff) {
     324             :                         /* internal node */
     325           0 :                         index <<= 1;
     326             :                         /* left */
     327           0 :                         index++;
     328           0 :                         codes[n] = index;
     329           0 :                         node->left = nodes + n;
     330           0 :                         nodes[n].count = node->count >> 1;
     331           0 :                         n++;
     332             :                         /*right*/
     333           0 :                         index++;
     334           0 :                         codes[n] = index;
     335           0 :                         node->right = nodes + n;
     336           0 :                         nodes[n].count = node->count >> 1;
     337           0 :                         n++;
     338             :                 } else {
     339             :                         /* leaf node */
     340           0 :                         node->symbol = table[index] & 511;
     341             :                 }
     342           0 :                 i++;
     343             :         }
     344             : 
     345           0 :         fprintf(stderr,
     346             :                 "\033[1;34m Tree from decoding table\033[0m "
     347             :                 "%zu nodes → %zu codes\n",
     348           0 :                 n, (n + 1) / 2);
     349           0 :         debug_huffman_tree_print(nodes, trail, 0);
     350           0 : }
     351             : 
     352             : 
     353      325569 : static bool depth_walk(struct huffman_node *n, uint32_t depth)
     354             : {
     355      260029 :         bool ok;
     356      325569 :         if (n->left == NULL) {
     357             :                 /* this is a leaf, record the depth */
     358      163269 :                 n->depth = depth;
     359      163269 :                 return true;
     360             :         }
     361      162300 :         if (depth > 14) {
     362           0 :                 return false;
     363             :         }
     364      324442 :         ok = (depth_walk(n->left, depth + 1) &&
     365      162162 :               depth_walk(n->right, depth + 1));
     366             : 
     367      162280 :         return ok;
     368             : }
     369             : 
     370             : 
     371        1127 : static bool check_and_record_depths(struct huffman_node *root)
     372             : {
     373        1127 :         return depth_walk(root, 0);
     374             : }
     375             : 
     376             : 
     377        1107 : static bool encode_values(struct huffman_node *leaves,
     378             :                           size_t n_leaves,
     379             :                           uint16_t symbol_values[512])
     380             : {
     381         811 :         size_t i;
     382             :         /*
     383             :          * See, we have a leading 1 in our internal code representation, which
     384             :          * indicates the code length.
     385             :          */
     386        1107 :         uint32_t code = 1;
     387        1107 :         uint32_t code_len = 0;
     388        1107 :         memset(symbol_values, 0, sizeof(uint16_t) * 512);
     389      160192 :         for (i = 0; i < n_leaves; i++) {
     390      159085 :                 code <<= leaves[i].depth - code_len;
     391      159085 :                 code_len = leaves[i].depth;
     392             : 
     393      159085 :                 symbol_values[leaves[i].symbol] = code;
     394      159085 :                 code++;
     395             :         }
     396             :         /*
     397             :          * The last code should be 11111... with code_len + 1 ones. The final
     398             :          * code++ will wrap this round to 1000... with code_len + 1 zeroes.
     399             :          */
     400             : 
     401        1107 :         if (code != 2 << code_len) {
     402           0 :                 return false;
     403             :         }
     404         296 :         return true;
     405             : }
     406             : 
     407             : 
     408        1107 : static int generate_huffman_codes(struct huffman_node *leaf_nodes,
     409             :                                   struct huffman_node *internal_nodes,
     410             :                                   uint16_t symbol_values[512])
     411             : {
     412        1107 :         size_t head_leaf = 0;
     413        1107 :         size_t head_branch = 0;
     414        1107 :         size_t tail_branch = 0;
     415        1107 :         struct huffman_node *huffman_root = NULL;
     416         811 :         size_t i, j;
     417        1107 :         size_t n_leaves = 0;
     418             : 
     419             :         /*
     420             :          * Before we sort the nodes, we can eliminate the unused ones.
     421             :          */
     422      567891 :         for (i = 0; i < 512; i++) {
     423      566784 :                 if (leaf_nodes[i].count) {
     424      159082 :                         leaf_nodes[n_leaves] = leaf_nodes[i];
     425      159082 :                         n_leaves++;
     426             :                 }
     427             :         }
     428        1107 :         if (n_leaves == 0) {
     429           0 :                 return LZXPRESS_ERROR;
     430             :         }
     431        1107 :         if (n_leaves == 1) {
     432             :                 /*
     433             :                  * There is *almost* no way this should happen, and it would
     434             :                  * ruin the tree (because the shortest possible codes are 1
     435             :                  * bit long, and there are two of them).
     436             :                  *
     437             :                  * The only way to get here is in an internal block in a
     438             :                  * 3-or-more block message (i.e. > 128k), which consists
     439             :                  * entirely of a match starting in the previous block (if it
     440             :                  * was the end block, it would have the EOF symbol).
     441             :                  *
     442             :                  * What we do is add a dummy symbol which is this one XOR 256.
     443             :                  * It won't be used in the stream but will balance the tree.
     444             :                  */
     445           3 :                 leaf_nodes[1] = leaf_nodes[0];
     446           3 :                 leaf_nodes[1].symbol ^= 0x100;
     447           3 :                 n_leaves = 2;
     448             :         }
     449             : 
     450             :         /* note, in sort we're using internal_nodes as auxiliary space */
     451        1107 :         stable_sort(leaf_nodes,
     452             :                     internal_nodes,
     453             :                     n_leaves,
     454             :                     sizeof(struct huffman_node),
     455             :                     (samba_compare_fn_t)compare_huffman_node_count);
     456             : 
     457             :         /*
     458             :          * This outer loop is for re-quantizing the counts if the tree is too
     459             :          * tall (>15), which we need to do because the final encoding can't
     460             :          * express a tree that deep.
     461             :          *
     462             :          * In theory, this should be a 'while (true)' loop, but we chicken
     463             :          * out with 10 iterations, just in case.
     464             :          *
     465             :          * In practice it will almost always resolve in the first round; if
     466             :          * not then, in the second or third. Remember we'll looking at 64k or
     467             :          * less, so the rarest we can have is 1 in 64k; each round of
     468             :          * quantization effectively doubles its frequency to 1 in 32k, 1 in
     469             :          * 16k, etc, until we're treating the rare symbol as actually quite
     470             :          * common.
     471             :          */
     472        1938 :         for (j = 0; j < 10; j++) {
     473      131521 :                 bool less_than_15_bits;
     474      164143 :                 while (true) {
     475      164439 :                         struct huffman_node *a = NULL;
     476      164439 :                         struct huffman_node *b = NULL;
     477      164439 :                         size_t leaf_len = n_leaves - head_leaf;
     478      164439 :                         size_t internal_len = tail_branch - head_branch;
     479             : 
     480      164439 :                         if (leaf_len + internal_len == 1) {
     481             :                                 /*
     482             :                                  * We have the complete tree. The root will be
     483             :                                  * an internal node unless there is just one
     484             :                                  * symbol, which is already impossible.
     485             :                                  */
     486        1127 :                                 if (unlikely(leaf_len == 1)) {
     487           0 :                                         return LZXPRESS_ERROR;
     488             :                                 } else {
     489        1127 :                                         huffman_root = \
     490        1127 :                                                 &internal_nodes[head_branch];
     491             :                                 }
     492        1127 :                                 break;
     493             :                         }
     494             :                         /*
     495             :                          * We know here we have at least two nodes, and we
     496             :                          * want to select the two lowest scoring ones. Those
     497             :                          * have to be either a) the head of each queue, or b)
     498             :                          * the first two nodes of either queue.
     499             :                          *
     500             :                          * The complicating factors are: a) we need to check
     501             :                          * the length of each queue, and b) in the case of
     502             :                          * ties, we prefer to pair leaves with leaves.
     503             :                          *
     504             :                          * Note a complication we don't have: the leaf node
     505             :                          * queue never grows, and the subtree queue starts
     506             :                          * empty and cannot grow beyond n - 1. It feeds on
     507             :                          * itself. We don't need to think about overflow.
     508             :                          */
     509      163312 :                         if (leaf_len == 0) {
     510             :                                 /* two from subtrees */
     511       20203 :                                 a = &internal_nodes[head_branch];
     512       20203 :                                 b = &internal_nodes[head_branch + 1];
     513       20203 :                                 head_branch += 2;
     514      143109 :                         } else if (internal_len == 0) {
     515             :                                 /* two from nodes */
     516        1127 :                                 a = &leaf_nodes[head_leaf];
     517        1127 :                                 b = &leaf_nodes[head_leaf + 1];
     518        1127 :                                 head_leaf += 2;
     519      141982 :                         } else if (leaf_len == 1 && internal_len == 1) {
     520             :                                 /* one of each */
     521         196 :                                 a = &leaf_nodes[head_leaf];
     522         196 :                                 b = &internal_nodes[head_branch];
     523         196 :                                 head_branch++;
     524         196 :                                 head_leaf++;
     525             :                         } else {
     526             :                                 /*
     527             :                                  * Take the lowest head, twice, checking for
     528             :                                  * length after taking the first one.
     529             :                                  */
     530      141786 :                                 if (leaf_nodes[head_leaf].count >
     531      141786 :                                     internal_nodes[head_branch].count) {
     532       60541 :                                         a = &internal_nodes[head_branch];
     533       60541 :                                         head_branch++;
     534       60541 :                                         if (internal_len == 1) {
     535          78 :                                                 b = &leaf_nodes[head_leaf];
     536          78 :                                                 head_leaf++;
     537          78 :                                                 goto done;
     538             :                                         }
     539             :                                 } else {
     540       81245 :                                         a = &leaf_nodes[head_leaf];
     541       81245 :                                         head_leaf++;
     542       81245 :                                         if (leaf_len == 1) {
     543         439 :                                                 b = &internal_nodes[head_branch];
     544         439 :                                                 head_branch++;
     545         439 :                                                 goto done;
     546             :                                         }
     547             :                                 }
     548             :                                 /* the other node */
     549      141269 :                                 if (leaf_nodes[head_leaf].count >
     550      141269 :                                     internal_nodes[head_branch].count) {
     551       60603 :                                         b = &internal_nodes[head_branch];
     552       60603 :                                         head_branch++;
     553             :                                 } else {
     554       80666 :                                         b = &leaf_nodes[head_leaf];
     555       80666 :                                         head_leaf++;
     556             :                                 }
     557             :                         }
     558      163312 :                 done:
     559             :                         /*
     560             :                          * Now we add a new node to the subtrees list that
     561             :                          * combines the score of node_a and node_b, and points
     562             :                          * to them as children.
     563             :                          */
     564      163312 :                         internal_nodes[tail_branch].count = a->count + b->count;
     565      163312 :                         internal_nodes[tail_branch].left = a;
     566      163312 :                         internal_nodes[tail_branch].right = b;
     567      163312 :                         tail_branch++;
     568      163312 :                         if (tail_branch == n_leaves) {
     569             :                                 /*
     570             :                                  * We're not getting here, no way, never ever.
     571             :                                  * Unless we made a terrible mistake.
     572             :                                  *
     573             :                                  * That is, in a binary tree with n leaves,
     574             :                                  * there are ALWAYS n-1 internal nodes.
     575             :                                  */
     576           0 :                                 return LZXPRESS_ERROR;
     577             :                         }
     578             :                 }
     579        1127 :                 if (CHECK_DEBUGLVL(10) || DEBUG_HUFFMAN_TREE) {
     580           0 :                         debug_huffman_tree(huffman_root);
     581             :                 }
     582             :                 /*
     583             :                  * We have a tree, and need to turn it into a lookup table,
     584             :                  * and see if it is shallow enough (<= 15).
     585             :                  */
     586        1127 :                 less_than_15_bits = check_and_record_depths(huffman_root);
     587        1127 :                 if (less_than_15_bits) {
     588             :                         /*
     589             :                          * Now the leaf nodes know how deep they are, and we
     590             :                          * no longer need the internal nodes.
     591             :                          *
     592             :                          * We need to sort the nodes of equal depth, so that
     593             :                          * they are sorted by depth first, and symbol value
     594             :                          * second. The internal_nodes can again be auxiliary
     595             :                          * memory.
     596             :                          */
     597        1107 :                         stable_sort(
     598             :                                 leaf_nodes,
     599             :                                 internal_nodes,
     600             :                                 n_leaves,
     601             :                                 sizeof(struct huffman_node),
     602             :                                 (samba_compare_fn_t)compare_huffman_node_depth);
     603             : 
     604        1107 :                         encode_values(leaf_nodes, n_leaves, symbol_values);
     605             : 
     606        1107 :                         return n_leaves;
     607             :                 }
     608             : 
     609             :                 /*
     610             :                  * requantize by halving and rounding up, so that small counts
     611             :                  * become relatively bigger. This will lead to a flatter tree.
     612             :                  */
     613        5374 :                 for (i = 0; i < n_leaves; i++) {
     614        5354 :                         leaf_nodes[i].count >>= 1;
     615        5354 :                         leaf_nodes[i].count += 1;
     616             :                 }
     617          20 :                 head_leaf = 0;
     618          20 :                 head_branch = 0;
     619          20 :                 tail_branch = 0;
     620             :         }
     621           0 :         return LZXPRESS_ERROR;
     622             : }
     623             : 
     624             : /*
     625             :  * LZX_HUFF_COMP_HASH_SEARCH_ATTEMPTS is how far ahead to search in the
     626             :  * circular hash table for a match, before we give up. A bigger number will
     627             :  * generally lead to better but slower compression, but a stupidly big number
     628             :  * will just be worse.
     629             :  *
     630             :  * If you're fiddling with this, consider also fiddling with
     631             :  * LZX_HUFF_COMP_HASH_BITS.
     632             :  */
     633             : #define LZX_HUFF_COMP_HASH_SEARCH_ATTEMPTS 5
     634             : 
     635    10568868 : static inline void store_match(uint16_t *hash_table,
     636             :                                uint16_t h,
     637             :                                uint16_t offset)
     638             : {
     639     9704610 :         int i;
     640    10568868 :         uint16_t o = hash_table[h];
     641     9704610 :         uint16_t h2;
     642     9704610 :         uint16_t worst_h;
     643     9704610 :         int worst_score;
     644             : 
     645    10568868 :         if (o == 0xffff) {
     646             :                 /* there is nothing there yet */
     647     2705647 :                 hash_table[h] = offset;
     648     2705647 :                 return;
     649             :         }
     650    33215454 :         for (i = 1; i < LZX_HUFF_COMP_HASH_SEARCH_ATTEMPTS; i++) {
     651    27420362 :                 h2 = (h + i) & HASH_MASK;
     652    27420362 :                 if (hash_table[h2] == 0xffff) {
     653     2068129 :                         hash_table[h2] = offset;
     654     2068129 :                         return;
     655             :                 }
     656             :         }
     657             :         /*
     658             :          * There are no slots, but we really want to store this, so we'll kick
     659             :          * out the one with the longest distance.
     660             :          */
     661     5795092 :         worst_h = h;
     662     5795092 :         worst_score = offset - o;
     663    28975460 :         for (i = 1; i < LZX_HUFF_COMP_HASH_SEARCH_ATTEMPTS; i++) {
     664    21528032 :                 int score;
     665    23180368 :                 h2 = (h + i) & HASH_MASK;
     666    23180368 :                 o = hash_table[h2];
     667    23180368 :                 score = offset - o;
     668    23180368 :                 if (score > worst_score) {
     669     6287992 :                         worst_score = score;
     670     6287992 :                         worst_h = h2;
     671             :                 }
     672             :         }
     673     5795092 :         hash_table[worst_h] = offset;
     674             : }
     675             : 
     676             : 
     677             : /*
     678             :  * Yes, struct match looks a lot like a DATA_BLOB.
     679             :  */
     680             : struct match {
     681             :         const uint8_t *there;
     682             :         size_t length;
     683             : };
     684             : 
     685             : 
     686    17347770 : static inline struct match lookup_match(uint16_t *hash_table,
     687             :                                         uint16_t h,
     688             :                                         const uint8_t *data,
     689             :                                         const uint8_t *here,
     690             :                                         size_t max_len)
     691             : {
     692    16116058 :         int i;
     693    17347770 :         uint16_t o = hash_table[h];
     694    16116058 :         uint16_t h2;
     695    16116058 :         size_t len;
     696    17347770 :         const uint8_t *there = NULL;
     697    17347770 :         struct match best = {0};
     698             : 
     699    81338522 :         for (i = 0; i < LZX_HUFF_COMP_HASH_SEARCH_ATTEMPTS; i++) {
     700    69569836 :                 h2 = (h + i) & HASH_MASK;
     701    69569836 :                 o = hash_table[h2];
     702    69569836 :                 if (o == 0xffff) {
     703             :                         /*
     704             :                          * in setting this, we would never have stepped over
     705             :                          * an 0xffff, so we won't now.
     706             :                          */
     707      798054 :                         break;
     708             :                 }
     709    63990752 :                 there = data + o;
     710    63990752 :                 if (here - there > 65534 || there > here) {
     711     7164359 :                         continue;
     712             :                 }
     713             : 
     714             :                 /*
     715             :                  * When we already have a long match, we can try to avoid
     716             :                  * measuring out another long, but shorter match.
     717             :                  */
     718    56826393 :                 if (best.length > 1000 &&
     719         597 :                     there[best.length - 1] != best.there[best.length - 1]) {
     720         142 :                         continue;
     721             :                 }
     722             : 
     723     2381872 :                 for (len = 0;
     724   192423482 :                      len < max_len && here[len] == there[len];
     725   135597231 :                      len++) {
     726             :                         /* counting */
     727   119202329 :                 }
     728    56826251 :                 if (len > 2) {
     729             :                         /*
     730             :                          * As a tiebreaker, we prefer the closer match which
     731             :                          * is likely to encode smaller (and certainly no worse).
     732             :                          */
     733     8038688 :                         if (len > best.length ||
     734     3439520 :                             (len == best.length && there > best.there)) {
     735     5362318 :                                 best.length = len;
     736     5362318 :                                 best.there = there;
     737             :                         }
     738             :                 }
     739             :         }
     740    17347770 :         return best;
     741             : }
     742             : 
     743             : 
     744             : 
     745        1107 : static ssize_t lz77_encode_block(struct lzxhuff_compressor_context *cmp_ctx,
     746             :                                  struct lzxhuff_compressor_mem *cmp_mem,
     747             :                                  uint16_t *hash_table,
     748             :                                  uint16_t *prev_hash_table)
     749             : {
     750        1107 :         uint16_t *intermediate = cmp_mem->intermediate;
     751        1107 :         struct huffman_node *leaf_nodes = cmp_mem->leaf_nodes;
     752        1107 :         uint16_t *symbol_values = cmp_mem->symbol_values;
     753         811 :         size_t i, j, intermediate_len;
     754        1107 :         const uint8_t *data = cmp_ctx->input_bytes + cmp_ctx->input_pos;
     755        1107 :         const uint8_t *prev_block = NULL;
     756        1107 :         size_t remaining_size = cmp_ctx->input_size - cmp_ctx->input_pos;
     757        1107 :         size_t block_end = MIN(65536, remaining_size);
     758         811 :         struct match match;
     759         811 :         int n_symbols;
     760             : 
     761        1107 :         if (cmp_ctx->input_size < cmp_ctx->input_pos) {
     762           0 :                 return LZXPRESS_ERROR;
     763             :         }
     764             : 
     765        1107 :         if (cmp_ctx->prev_block_pos != cmp_ctx->input_pos) {
     766         553 :                 prev_block = cmp_ctx->input_bytes + cmp_ctx->prev_block_pos;
     767         554 :         } else if (prev_hash_table != NULL) {
     768             :                 /* we've got confused! hash and block should go together */
     769           0 :                 return LZXPRESS_ERROR;
     770             :         }
     771             : 
     772             :         /*
     773             :          * leaf_nodes is used to count the symbols seen, for later Huffman
     774             :          * encoding.
     775             :          */
     776      567891 :         for (i = 0; i < 512; i++) {
     777      566784 :                 leaf_nodes[i] = (struct huffman_node) {
     778             :                         .symbol = i
     779             :                 };
     780             :         }
     781             : 
     782        1107 :         j = 0;
     783             : 
     784        1107 :         if (remaining_size < 41 || DEBUG_NO_LZ77_MATCHES) {
     785             :                 /*
     786             :                  * There is no point doing a hash table and looking for
     787             :                  * matches in this tiny block (remembering we are committed to
     788             :                  * using 32 bits, so there's a good chance we wouldn't even
     789             :                  * save a byte). The threshold of 41 matches Windows.
     790             :                  * If remaining_size < 3, we *can't* do the hash.
     791             :                  */
     792           6 :                 i = 0;
     793             :         } else {
     794             :                 /*
     795             :                  * We use 0xffff as the unset value for table, because it is
     796             :                  * not a valid match offset (and 0x0 is).
     797             :                  */
     798         994 :                 memset(hash_table, 0xff, sizeof(cmp_mem->hash_table1));
     799             : 
     800    10569862 :                 for (i = 0; i <= block_end - 3; i++) {
     801     9704610 :                         uint16_t code;
     802    10568868 :                         const uint8_t *here = data + i;
     803    10568868 :                         uint16_t h = three_byte_hash(here);
     804    10568868 :                         size_t max_len = MIN(remaining_size - i, MAX_MATCH_LENGTH);
     805    10568868 :                         match = lookup_match(hash_table,
     806             :                                              h,
     807             :                                              data,
     808             :                                              here,
     809             :                                              max_len);
     810             : 
     811    10568868 :                         if (match.there == NULL && prev_hash_table != NULL) {
     812             :                                 /*
     813             :                                  * If this is not the first block,
     814             :                                  * backreferences can look into the previous
     815             :                                  * block (but only as far as 65535 bytes, so
     816             :                                  * the end of this block cannot see the start
     817             :                                  * of the last one).
     818             :                                  */
     819     6778902 :                                 match = lookup_match(prev_hash_table,
     820             :                                                      h,
     821             :                                                      prev_block,
     822             :                                                      here,
     823             :                                                      remaining_size - i);
     824             :                         }
     825             : 
     826    10568868 :                         store_match(hash_table, h, i);
     827             : 
     828    10568868 :                         if (match.there == NULL) {
     829             :                                 /* add a literal and move on. */
     830     7444317 :                                 uint8_t c = data[i];
     831     7444317 :                                 leaf_nodes[c].count++;
     832     7444317 :                                 intermediate[j] = c;
     833     7444317 :                                 j++;
     834     7444317 :                                 continue;
     835             :                         }
     836             : 
     837             :                         /* a real match */
     838     3124551 :                         if (match.length <= 65538) {
     839     3124482 :                                 intermediate[j] = 0xffff;
     840     3124482 :                                 intermediate[j + 1] = match.length - 3;
     841     3124482 :                                 intermediate[j + 2] = here - match.there;
     842     3124482 :                                 j += 3;
     843             :                         } else {
     844          69 :                                 size_t m = match.length - 3;
     845          69 :                                 intermediate[j] = 0xfffe;
     846          69 :                                 intermediate[j + 1] = m & 0xffff;
     847          69 :                                 intermediate[j + 2] = m >> 16;
     848          69 :                                 intermediate[j + 3] = here - match.there;
     849          69 :                                 j += 4;
     850             :                         }
     851     3124551 :                         code = encode_match(match.length, here - match.there);
     852     3124551 :                         leaf_nodes[code].count++;
     853     3124551 :                         i += match.length - 1; /* `- 1` for the loop i++ */
     854             :                         /*
     855             :                          * A match can take us past the intended block length,
     856             :                          * extending the block. We don't need to do anything
     857             :                          * special for this case -- the loops will naturally
     858             :                          * do the right thing.
     859             :                          */
     860             :                 }
     861             :         }
     862             : 
     863             :         /*
     864             :          * There might be some bytes at the end.
     865             :          */
     866        3929 :         for (; i < block_end; i++) {
     867        2822 :                 leaf_nodes[data[i]].count++;
     868        2822 :                 intermediate[j] = data[i];
     869        2822 :                 j++;
     870             :         }
     871             : 
     872        1107 :         if (i == remaining_size) {
     873             :                 /* add a trailing EOF marker (256) */
     874         554 :                 intermediate[j] = 0xffff;
     875         554 :                 intermediate[j + 1] = 0;
     876         554 :                 intermediate[j + 2] = 1;
     877         554 :                 j += 3;
     878         554 :                 leaf_nodes[256].count++;
     879             :         }
     880             : 
     881        1107 :         intermediate_len = j;
     882             : 
     883        1107 :         cmp_ctx->prev_block_pos = cmp_ctx->input_pos;
     884        1107 :         cmp_ctx->input_pos += i;
     885             : 
     886             :         /* fill in the symbols table */
     887        1918 :         n_symbols = generate_huffman_codes(leaf_nodes,
     888        1107 :                                            cmp_mem->internal_nodes,
     889             :                                            symbol_values);
     890        1107 :         if (n_symbols < 0) {
     891           0 :                 return n_symbols;
     892             :         }
     893             : 
     894        1107 :         return intermediate_len;
     895             : }
     896             : 
     897             : 
     898             : 
     899        1107 : static ssize_t write_huffman_table(uint16_t symbol_values[512],
     900             :                                    uint8_t *output,
     901             :                                    size_t available_size)
     902             : {
     903         811 :         size_t i;
     904             : 
     905        1107 :         if (available_size < 256) {
     906           0 :                 return LZXPRESS_ERROR;
     907             :         }
     908             : 
     909      284499 :         for (i = 0; i < 256; i++) {
     910      283392 :                 uint8_t b = 0;
     911      283392 :                 uint16_t even = symbol_values[i * 2];
     912      283392 :                 uint16_t odd = symbol_values[i * 2 + 1];
     913      283392 :                 if (even != 0) {
     914       83010 :                         b = bitlen_nonzero_16(even);
     915             :                 }
     916      283392 :                 if (odd != 0) {
     917       76075 :                         b |= bitlen_nonzero_16(odd) << 4;
     918             :                 }
     919      283392 :                 output[i] = b;
     920             :         }
     921         296 :         return i;
     922             : }
     923             : 
     924             : 
     925             : struct write_context {
     926             :         uint8_t *dest;
     927             :         size_t dest_len;
     928             :         size_t head;                 /* where lengths go */
     929             :         size_t next_code;            /* where symbol stream goes */
     930             :         size_t pending_next_code;    /* will be next_code */
     931             :         unsigned bit_len;
     932             :         uint32_t bits;
     933             : };
     934             : 
     935             : /*
     936             :  * Write out 16 bits, little-endian, for write_huffman_codes()
     937             :  *
     938             :  * As you'll notice, there's a bit to do.
     939             :  *
     940             :  * We are collecting up bits in a uint32_t, then when there are 16 of them we
     941             :  * write out a word into the stream, using a trio of offsets (wc->next_code,
     942             :  * wc->pending_next_code, and wc->head) which dance around ensuring that the
     943             :  * bitstream and the interspersed lengths are in the right places relative to
     944             :  * each other.
     945             :  */
     946             : 
     947    13698851 : static inline bool write_bits(struct write_context *wc,
     948             :                               uint16_t code, uint16_t length)
     949             : {
     950    13698851 :         wc->bits <<= length;
     951    13698851 :         wc->bits |= code;
     952    13698851 :         wc->bit_len += length;
     953    13698851 :         if (wc->bit_len > 16) {
     954     7138226 :                 uint32_t w = wc->bits >> (wc->bit_len - 16);
     955     7138226 :                 wc->bit_len -= 16;
     956     7138226 :                 if (wc->next_code + 2 > wc->dest_len ||
     957     7138226 :                     unlikely(wc->bit_len > 16)) {
     958           0 :                         return false;
     959             :                 }
     960     7138226 :                 wc->dest[wc->next_code] = w & 0xff;
     961     7138226 :                 wc->dest[wc->next_code + 1] = (w >> 8) & 0xff;
     962     7138226 :                 wc->next_code = wc->pending_next_code;
     963     7138226 :                 wc->pending_next_code = wc->head;
     964     7138226 :                 wc->head += 2;
     965             :         }
     966     1287994 :         return true;
     967             : }
     968             : 
     969             : 
     970    10572244 : static inline bool write_code(struct write_context *wc, uint16_t code)
     971             : {
     972    10572244 :         int code_bit_len = bitlen_nonzero_16(code);
     973    10572244 :         if (unlikely(code == 0)) {
     974           0 :                 return false;
     975             :         }
     976    10572244 :         code &= (1 << code_bit_len) - 1;
     977    10572244 :         return  write_bits(wc, code, code_bit_len);
     978             : }
     979             : 
     980       12939 : static inline bool write_byte(struct write_context *wc, uint8_t byte)
     981             : {
     982       12939 :         if (wc->head + 1 > wc->dest_len) {
     983           0 :                 return false;
     984             :         }
     985       12939 :         wc->dest[wc->head] = byte;
     986       12939 :         wc->head++;
     987       12939 :         return true;
     988             : }
     989             : 
     990             : 
     991        1625 : static inline bool write_long_len(struct write_context *wc, size_t len)
     992             : {
     993        1625 :         if (len < 65535) {
     994        1556 :                 if (wc->head + 3 > wc->dest_len) {
     995           0 :                         return false;
     996             :                 }
     997        1556 :                 wc->dest[wc->head] = 255;
     998        1556 :                 wc->dest[wc->head + 1] = len & 255;
     999        1556 :                 wc->dest[wc->head + 2] = len >> 8;
    1000        1556 :                 wc->head += 3;
    1001             :         } else {
    1002          69 :                 if (wc->head + 7 > wc->dest_len) {
    1003           0 :                         return false;
    1004             :                 }
    1005          69 :                 wc->dest[wc->head] = 255;
    1006          69 :                 wc->dest[wc->head + 1] = 0;
    1007          69 :                 wc->dest[wc->head + 2] = 0;
    1008          69 :                 wc->dest[wc->head + 3] = len & 255;
    1009          69 :                 wc->dest[wc->head + 4] = (len >> 8) & 255;
    1010          69 :                 wc->dest[wc->head + 5] = (len >> 16) & 255;
    1011          69 :                 wc->dest[wc->head + 6] = (len >> 24) & 255;
    1012          69 :                 wc->head += 7;
    1013             :         }
    1014          48 :         return true;
    1015             : }
    1016             : 
    1017        1107 : static ssize_t write_compressed_bytes(uint16_t symbol_values[512],
    1018             :                                       uint16_t *intermediate,
    1019             :                                       size_t intermediate_len,
    1020             :                                       uint8_t *dest,
    1021             :                                       size_t dest_len)
    1022             : {
    1023         811 :         bool ok;
    1024         811 :         size_t i;
    1025         811 :         size_t end;
    1026        1107 :         struct write_context wc = {
    1027             :                 .head = 4,
    1028             :                 .pending_next_code = 2,
    1029             :                 .dest = dest,
    1030             :                 .dest_len = dest_len
    1031             :         };
    1032    10573351 :         for (i = 0; i < intermediate_len; i++) {
    1033    10572244 :                 uint16_t c = intermediate[i];
    1034     9707680 :                 size_t len;
    1035     9707680 :                 uint16_t distance;
    1036    10572244 :                 uint16_t code_len = 0;
    1037    10572244 :                 uint16_t code_dist = 0;
    1038    10572244 :                 if (c < 256) {
    1039     7447139 :                         ok = write_code(&wc, symbol_values[c]);
    1040     7447139 :                         if (!ok) {
    1041           0 :                                 return LZXPRESS_ERROR;
    1042             :                         }
    1043     7447139 :                         continue;
    1044             :                 }
    1045             : 
    1046     3125105 :                 if (c == 0xfffe) {
    1047          69 :                         if (i > intermediate_len - 4) {
    1048           0 :                                 return LZXPRESS_ERROR;
    1049             :                         }
    1050             : 
    1051          69 :                         len = intermediate[i + 1];
    1052          69 :                         len |= (uint32_t)intermediate[i + 2] << 16;
    1053          69 :                         distance = intermediate[i + 3];
    1054          69 :                         i += 3;
    1055     3125036 :                 } else if (c == 0xffff) {
    1056     3125036 :                         if (i > intermediate_len - 3) {
    1057           0 :                                 return LZXPRESS_ERROR;
    1058             :                         }
    1059     3125036 :                         len = intermediate[i + 1];
    1060     3125036 :                         distance = intermediate[i + 2];
    1061     3125036 :                         i += 2;
    1062             :                 } else {
    1063           0 :                         return LZXPRESS_ERROR;
    1064             :                 }
    1065     3125105 :                 if (unlikely(distance == 0)) {
    1066           0 :                         return LZXPRESS_ERROR;
    1067             :                 }
    1068             :                 /* len has already had 3 subtracted */
    1069     3125105 :                 if (len >= 15) {
    1070             :                         /*
    1071             :                          * We are going to need to write extra length
    1072             :                          * bytes into the stream, but we don't do it
    1073             :                          * now, we do it after the code has been
    1074             :                          * written (and before the distance bits).
    1075             :                          */
    1076        1550 :                         code_len = 15;
    1077             :                 } else {
    1078      421750 :                         code_len = len;
    1079             :                 }
    1080     3125105 :                 code_dist = bitlen_nonzero_16(distance);
    1081     3125105 :                 c = 256 | (code_dist << 4) | code_len;
    1082     3125105 :                 if (c > 511) {
    1083           0 :                         return LZXPRESS_ERROR;
    1084             :                 }
    1085             : 
    1086     3125105 :                 ok = write_code(&wc, symbol_values[c]);
    1087     3125105 :                 if (!ok) {
    1088           0 :                         return LZXPRESS_ERROR;
    1089             :                 }
    1090             : 
    1091     3125105 :                 if (code_len == 15) {
    1092       14564 :                         if (len >= 270) {
    1093        1625 :                                 ok = write_long_len(&wc, len);
    1094             :                         } else {
    1095       12939 :                                 ok = write_byte(&wc, len - 15);
    1096             :                         }
    1097       14564 :                         if (! ok) {
    1098           0 :                                 return LZXPRESS_ERROR;
    1099             :                         }
    1100             :                 }
    1101     3125105 :                 if (code_dist != 0) {
    1102     3123286 :                         uint16_t dist_bits = distance - (1 << code_dist);
    1103     3123286 :                         ok = write_bits(&wc, dist_bits, code_dist);
    1104     3123286 :                         if (!ok) {
    1105           0 :                                 return LZXPRESS_ERROR;
    1106             :                         }
    1107             :                 }
    1108             :         }
    1109             :         /*
    1110             :          * There are some intricacies around flushing the bits and returning
    1111             :          * the length.
    1112             :          *
    1113             :          * If the returned length is not exactly right and there is another
    1114             :          * block, that block will read its huffman table from the wrong place,
    1115             :          * and have all the symbol codes out by a multiple of 4.
    1116             :          */
    1117        1107 :         end = wc.head;
    1118        1107 :         if (wc.bit_len == 0) {
    1119           0 :                 end -= 2;
    1120             :         }
    1121        1107 :         ok = write_bits(&wc, 0, 16 - wc.bit_len);
    1122        1107 :         if (!ok) {
    1123           0 :                 return LZXPRESS_ERROR;
    1124             :         }
    1125        3321 :         for (i = 0; i < 2; i++) {
    1126             :                 /*
    1127             :                  * Flush out the bits with zeroes. It doesn't matter if we do
    1128             :                  * a round too many, as we have buffer space, and have already
    1129             :                  * determined the returned length (end).
    1130             :                  */
    1131        2214 :                 ok = write_bits(&wc, 0, 16);
    1132        2214 :                 if (!ok) {
    1133           0 :                         return LZXPRESS_ERROR;
    1134             :                 }
    1135             :         }
    1136        1107 :         return end;
    1137             : }
    1138             : 
    1139             : 
    1140        1107 : static ssize_t lzx_huffman_compress_block(struct lzxhuff_compressor_context *cmp_ctx,
    1141             :                                           struct lzxhuff_compressor_mem *cmp_mem,
    1142             :                                           size_t block_no)
    1143             : {
    1144         811 :         ssize_t intermediate_size;
    1145        1107 :         uint16_t *hash_table = NULL;
    1146        1107 :         uint16_t *back_window_hash_table = NULL;
    1147         811 :         ssize_t bytes_written;
    1148             : 
    1149        1107 :         if (cmp_ctx->available_size - cmp_ctx->output_pos < 260) {
    1150             :                 /* huffman block + 4 bytes */
    1151           0 :                 return LZXPRESS_ERROR;
    1152             :         }
    1153             : 
    1154             :         /*
    1155             :          * For LZ77 compression, we keep a hash table for the previous block,
    1156             :          * via alternation after the first block.
    1157             :          *
    1158             :          * LZ77 writes into the intermediate buffer in the cmp_mem context.
    1159             :          */
    1160        1107 :         if (block_no == 0) {
    1161         554 :                 hash_table = cmp_mem->hash_table1;
    1162         554 :                 back_window_hash_table = NULL;
    1163         553 :         } else if (block_no & 1) {
    1164         343 :                 hash_table = cmp_mem->hash_table2;
    1165         343 :                 back_window_hash_table = cmp_mem->hash_table1;
    1166             :         } else {
    1167         210 :                 hash_table = cmp_mem->hash_table1;
    1168         210 :                 back_window_hash_table = cmp_mem->hash_table2;
    1169             :         }
    1170             : 
    1171        1107 :         intermediate_size = lz77_encode_block(cmp_ctx,
    1172             :                                               cmp_mem,
    1173             :                                               hash_table,
    1174             :                                               back_window_hash_table);
    1175             : 
    1176        1107 :         if (intermediate_size < 0) {
    1177           0 :                 return intermediate_size;
    1178             :         }
    1179             : 
    1180             :         /*
    1181             :          * Write the 256 byte Huffman table, based on the counts gained in
    1182             :          * LZ77 phase.
    1183             :          */
    1184        1918 :         bytes_written = write_huffman_table(
    1185        1107 :                 cmp_mem->symbol_values,
    1186        1107 :                 cmp_ctx->output + cmp_ctx->output_pos,
    1187        1107 :                 cmp_ctx->available_size - cmp_ctx->output_pos);
    1188             : 
    1189        1107 :         if (bytes_written != 256) {
    1190           0 :                 return LZXPRESS_ERROR;
    1191             :         }
    1192        1107 :         cmp_ctx->output_pos += 256;
    1193             : 
    1194             :         /*
    1195             :          * Write the compressed bytes using the LZ77 matches and Huffman codes
    1196             :          * worked out in the previous steps.
    1197             :          */
    1198        1918 :         bytes_written = write_compressed_bytes(
    1199         296 :                 cmp_mem->symbol_values,
    1200        1107 :                 cmp_mem->intermediate,
    1201             :                 intermediate_size,
    1202        1107 :                 cmp_ctx->output + cmp_ctx->output_pos,
    1203        1107 :                 cmp_ctx->available_size - cmp_ctx->output_pos);
    1204             : 
    1205        1107 :         if (bytes_written < 0) {
    1206           0 :                 return bytes_written;
    1207             :         }
    1208             : 
    1209        1107 :         cmp_ctx->output_pos += bytes_written;
    1210        1107 :         return bytes_written;
    1211             : }
    1212             : 
    1213             : /*
    1214             :  * lzxpress_huffman_max_compressed_size()
    1215             :  *
    1216             :  * Return the most bytes the compression can take, to allow
    1217             :  * pre-allocation.
    1218             :  */
    1219         547 : size_t lzxpress_huffman_max_compressed_size(size_t input_size)
    1220             : {
    1221             :         /*
    1222             :          * In the worst case, the output size should be about the same as the
    1223             :          * input size, plus the 256 byte header per 64k block. We aim for
    1224             :          * ample, but within the order of magnitude.
    1225             :          */
    1226         547 :         return input_size + (input_size / 8) + 270;
    1227             : }
    1228             : 
    1229             : /*
    1230             :  * lzxpress_huffman_compress_talloc()
    1231             :  *
    1232             :  * This is the convenience function that allocates the compressor context and
    1233             :  * output memory for you. The return value is the number of bytes written to
    1234             :  * the location indicated by the output pointer.
    1235             :  *
    1236             :  * The maximum input_size is effectively around 227MB due to the need to guess
    1237             :  * an upper bound on the output size that hits an internal limitation in
    1238             :  * talloc.
    1239             :  *
    1240             :  * @param mem_ctx      TALLOC_CTX parent for the compressed buffer.
    1241             :  * @param input_bytes  memory to be compressed.
    1242             :  * @param input_size   length of the input buffer.
    1243             :  * @param output       destination pointer for the compressed data.
    1244             :  *
    1245             :  * @return the number of bytes written or -1 on error.
    1246             :  */
    1247             : 
    1248         424 : ssize_t lzxpress_huffman_compress_talloc(TALLOC_CTX *mem_ctx,
    1249             :                                          const uint8_t *input_bytes,
    1250             :                                          size_t input_size,
    1251             :                                          uint8_t **output)
    1252             : {
    1253         424 :         struct lzxhuff_compressor_mem *cmp = NULL;
    1254         424 :         size_t alloc_size = lzxpress_huffman_max_compressed_size(input_size);
    1255             : 
    1256         305 :         ssize_t output_size;
    1257             : 
    1258         424 :         *output = talloc_array(mem_ctx, uint8_t, alloc_size);
    1259         424 :         if (*output == NULL) {
    1260           0 :                 return LZXPRESS_ERROR;
    1261             :         }
    1262             : 
    1263         424 :         cmp = talloc(mem_ctx, struct lzxhuff_compressor_mem);
    1264         424 :         if (cmp == NULL) {
    1265           0 :                 TALLOC_FREE(*output);
    1266           0 :                 return LZXPRESS_ERROR;
    1267             :         }
    1268             : 
    1269         424 :         output_size = lzxpress_huffman_compress(cmp,
    1270             :                                                 input_bytes,
    1271             :                                                 input_size,
    1272             :                                                 *output,
    1273             :                                                 alloc_size);
    1274             : 
    1275         424 :         talloc_free(cmp);
    1276             : 
    1277         424 :         if (output_size < 0) {
    1278           0 :                 TALLOC_FREE(*output);
    1279           0 :                 return LZXPRESS_ERROR;
    1280             :         }
    1281             : 
    1282         424 :         *output = talloc_realloc(mem_ctx, *output, uint8_t, output_size);
    1283         424 :         if (*output == NULL) {
    1284           0 :                 return LZXPRESS_ERROR;
    1285             :         }
    1286             : 
    1287         119 :         return output_size;
    1288             : }
    1289             : 
    1290             : /*
    1291             :  * lzxpress_huffman_compress()
    1292             :  *
    1293             :  * This is the inconvenience function, slightly faster and fiddlier than
    1294             :  * lzxpress_huffman_compress_talloc().
    1295             :  *
    1296             :  * To use this, you need to have allocated (but not initialised) a `struct
    1297             :  * lzxhuff_compressor_mem`, and an output buffer. If the buffer is not big
    1298             :  * enough (per `output_size`), you'll get a negative return value, otherwise
    1299             :  * the number of bytes actually consumed, which will always be at least 260.
    1300             :  *
    1301             :  * The `struct lzxhuff_compressor_mem` is reusable -- it is basically a
    1302             :  * collection of uninitialised memory buffers. The total size is less than
    1303             :  * 150k, so stack allocation is plausible.
    1304             :  *
    1305             :  * input_size and available_size are limited to the minimum of UINT32_MAX and
    1306             :  * SSIZE_MAX. On 64 bit machines that will be UINT32_MAX, or 4GB.
    1307             :  *
    1308             :  * @param cmp_mem         a struct lzxhuff_compressor_mem.
    1309             :  * @param input_bytes     memory to be compressed.
    1310             :  * @param input_size      length of the input buffer.
    1311             :  * @param output          destination for the compressed data.
    1312             :  * @param available_size  allocated output bytes.
    1313             :  *
    1314             :  * @return the number of bytes written or -1 on error.
    1315             :  */
    1316         560 : ssize_t lzxpress_huffman_compress(struct lzxhuff_compressor_mem *cmp_mem,
    1317             :                                   const uint8_t *input_bytes,
    1318             :                                   size_t input_size,
    1319             :                                   uint8_t *output,
    1320             :                                   size_t available_size)
    1321             : {
    1322         560 :         size_t i = 0;
    1323         560 :         struct lzxhuff_compressor_context cmp_ctx = {
    1324             :                 .input_bytes = input_bytes,
    1325             :                 .input_size = input_size,
    1326             :                 .input_pos = 0,
    1327             :                 .prev_block_pos = 0,
    1328             :                 .output = output,
    1329             :                 .available_size = available_size,
    1330             :                 .output_pos = 0
    1331             :         };
    1332             : 
    1333         560 :         if (input_size == 0) {
    1334             :                 /*
    1335             :                  * We can't deal with this for a number of reasons (e.g. it
    1336             :                  * breaks the Huffman tree), and the output will be infinitely
    1337             :                  * bigger than the input. The caller needs to go and think
    1338             :                  * about what they're trying to do here.
    1339             :                  */
    1340           0 :                 return LZXPRESS_ERROR;
    1341             :         }
    1342             : 
    1343         558 :         if (input_size > SSIZE_MAX ||
    1344         558 :             input_size > UINT32_MAX ||
    1345         558 :             available_size > SSIZE_MAX ||
    1346         558 :             available_size > UINT32_MAX ||
    1347             :             available_size == 0) {
    1348             :                 /*
    1349             :                  * We use negative ssize_t to return errors, which is limiting
    1350             :                  * on 32 bit machines; otherwise we adhere to Microsoft's 4GB
    1351             :                  * limit.
    1352             :                  *
    1353             :                  * lzxpress_huffman_compress_talloc() will not get this far,
    1354             :                  * having already have failed on talloc's 256 MB limit.
    1355             :                  */
    1356           0 :                 return LZXPRESS_ERROR;
    1357             :         }
    1358             : 
    1359         557 :         if (cmp_mem == NULL ||
    1360         557 :             output == NULL ||
    1361             :             input_bytes == NULL) {
    1362           0 :                 return LZXPRESS_ERROR;
    1363             :         }
    1364             : 
    1365        1661 :         while (cmp_ctx.input_pos < cmp_ctx.input_size) {
    1366         811 :                 ssize_t ret;
    1367        1107 :                 ret = lzx_huffman_compress_block(&cmp_ctx,
    1368             :                                                  cmp_mem,
    1369             :                                                  i);
    1370        1107 :                 if (ret < 0) {
    1371           0 :                         return ret;
    1372             :                 }
    1373        1107 :                 i++;
    1374             :         }
    1375             : 
    1376         554 :         return cmp_ctx.output_pos;
    1377             : }
    1378             : 
    1379           0 : static void debug_tree_codes(struct bitstream *input)
    1380             : {
    1381             :         /*
    1382             :          */
    1383           0 :         size_t head = 0;
    1384           0 :         size_t tail = 2;
    1385           0 :         size_t ffff_count = 0;
    1386           0 :         struct q {
    1387             :                 uint16_t tree_code;
    1388             :                 uint16_t code_code;
    1389             :         };
    1390           0 :         struct q queue[65536];
    1391           0 :         char bits[17];
    1392           0 :         uint16_t *t = input->table;
    1393           0 :         queue[0].tree_code = 1;
    1394           0 :         queue[0].code_code = 2;
    1395           0 :         queue[1].tree_code = 2;
    1396           0 :         queue[1].code_code = 3;
    1397           0 :         while (head < tail) {
    1398           0 :                 struct q q = queue[head];
    1399           0 :                 uint16_t x = t[q.tree_code];
    1400           0 :                 if (x != 0xffff) {
    1401           0 :                         int k;
    1402           0 :                         uint16_t j = q.code_code;
    1403           0 :                         size_t offset = bitlen_nonzero_16(j) - 1;
    1404           0 :                         if (unlikely(j == 0)) {
    1405           0 :                                 DBG("BROKEN code is 0!\n");
    1406           0 :                                 return;
    1407             :                         }
    1408             : 
    1409           0 :                         for (k = 0; k <= offset; k++) {
    1410           0 :                                 bool b = (j >> (offset - k)) & 1;
    1411           0 :                                 bits[k] = b ? '1' : '0';
    1412             :                         }
    1413           0 :                         bits[k] = 0;
    1414           0 :                         DBG("%03x   %s\n", x & 511, bits);
    1415           0 :                         head++;
    1416           0 :                         continue;
    1417             :                 }
    1418           0 :                 ffff_count++;
    1419           0 :                 queue[tail].tree_code = q.tree_code * 2 + 1;
    1420           0 :                 queue[tail].code_code = q.code_code * 2;
    1421           0 :                 tail++;
    1422           0 :                 queue[tail].tree_code = q.tree_code * 2 + 1 + 1;
    1423           0 :                 queue[tail].code_code = q.code_code * 2 + 1;
    1424           0 :                 tail++;
    1425           0 :                 head++;
    1426             :         }
    1427           0 :         DBG("0xffff count: %zu\n", ffff_count);
    1428             : }
    1429             : 
    1430             : /**
    1431             :  * Determines the sort order of one prefix_code_symbol relative to another
    1432             :  */
    1433      975144 : static int compare_uint16(const uint16_t *a, const uint16_t *b)
    1434             : {
    1435      975144 :         if (*a < *b) {
    1436       51960 :                 return -1;
    1437             :         }
    1438      386118 :         if (*a > *b) {
    1439      386118 :                 return 1;
    1440             :         }
    1441           0 :         return 0;
    1442             : }
    1443             : 
    1444             : 
    1445        1250 : static bool fill_decomp_table(struct bitstream *input)
    1446             : {
    1447             :         /*
    1448             :          * There are 512 symbols, each encoded in 4 bits, which indicates
    1449             :          * their depth in the Huffman tree. The even numbers get the lower
    1450             :          * nibble of each byte, so that the byte hex values look backwards
    1451             :          * (i.e. 0xab encodes b then a). These are allocated Huffman codes in
    1452             :          * order of appearance, per depth.
    1453             :          *
    1454             :          * For example, if the first two bytes were:
    1455             :          *
    1456             :          * 0x23 0x53
    1457             :          *
    1458             :          * the first four codes have the lengths 3, 2, 3, 5.
    1459             :          * Let's call them A, B, C, D.
    1460             :          *
    1461             :          * Suppose there is no other codeword with length 1 (which is
    1462             :          * necessarily true in this example) or 2, but there might be others
    1463             :          * of length 3 or 4. Then we can say this about the codes:
    1464             :          *
    1465             :          *        _ --*--_
    1466             :          *      /          \
    1467             :          *     0           1
    1468             :          *    / \         / \
    1469             :          *   0   1       0   1
    1470             :          *  B    |\     / \  |\
    1471             :          *       0 1   0   1 0 1
    1472             :          *       A C   |\ /| | |\
    1473             :          *
    1474             :          * pos bits  code
    1475             :          * A    3    010
    1476             :          * B    2    00
    1477             :          * C    3    011
    1478             :          * D    5    1????
    1479             :          *
    1480             :          * B has the shortest code, so takes the leftmost branch, 00. That
    1481             :          * ends the branch -- nothing else can start with 00. There are no
    1482             :          * more 2s, so we look at the 3s, starting as far left as possible. So
    1483             :          * A takes 010 and C takes 011. That means everything else has to
    1484             :          * start with 1xx. We don't know how many codewords of length 3 or 4
    1485             :          * there are; if there are none, D would end up with 10000, the
    1486             :          * leftmost available code of length 5. If the compressor is any good,
    1487             :          * there should be no unused leaf nodes left dangling at the end.
    1488             :          *
    1489             :          * (this is "Canonical Huffman Coding").
    1490             :          *
    1491             :          *
    1492             :          * But what symbols do these codes actually stand for?
    1493             :          * --------------------------------------------------
    1494             :          *
    1495             :          * Good question. The first 256 codes stand for the corresponding
    1496             :          * literal bytes. The codes from 256 to 511 stand for LZ77 matches,
    1497             :          * which have a distance and a length, encoded in a strange way that
    1498             :          * isn't entirely the purview of this function.
    1499             :          *
    1500             :          * What does the value 0 mean?
    1501             :          * ---------------------------
    1502             :          *
    1503             :          * The code does not occur. For example, if the next byte in the
    1504             :          * example above was 0x07, that would give the byte 0x04 a 7-long
    1505             :          * code, and no code to the 0x05 byte, which means we there is no way
    1506             :          * we going to see a 5 in the decoded stream.
    1507             :          *
    1508             :          * Isn't LZ77 + Huffman what zip/gzip/zlib do?
    1509             :          * -------------------------------------------
    1510             :          *
    1511             :          * Yes, DEFLATE is LZ77 + Huffman, but the details are quite different.
    1512             :          */
    1513         983 :         uint16_t symbols[512];
    1514         983 :         uint16_t sort_mem[512];
    1515         983 :         size_t i, n_symbols;
    1516         983 :         ssize_t code;
    1517        1250 :         uint16_t len = 0, prev_len;
    1518        1250 :         const uint8_t *table_bytes = input->bytes + input->byte_pos;
    1519             : 
    1520        1250 :         if (input->byte_pos + 260 > input->byte_size) {
    1521           0 :                 return false;
    1522             :         }
    1523             : 
    1524         267 :         n_symbols = 0;
    1525      321250 :         for (i = 0; i < 256; i++) {
    1526      320000 :                 uint16_t even = table_bytes[i] & 15;
    1527      320000 :                 uint16_t odd = table_bytes[i] >> 4;
    1528      320000 :                 if (even != 0) {
    1529       94392 :                         symbols[n_symbols] = (even << 9) + i * 2;
    1530       94392 :                         n_symbols++;
    1531             :                 }
    1532      320000 :                 if (odd != 0) {
    1533       87241 :                         symbols[n_symbols] = (odd << 9) + i * 2 + 1;
    1534       87241 :                         n_symbols++;
    1535             :                 }
    1536             :         }
    1537        1250 :         input->byte_pos += 256;
    1538        1250 :         if (n_symbols == 0) {
    1539           0 :                 return false;
    1540             :         }
    1541             : 
    1542        1250 :         stable_sort(symbols, sort_mem, n_symbols, sizeof(uint16_t),
    1543             :                     (samba_compare_fn_t)compare_uint16);
    1544             : 
    1545             :         /*
    1546             :          * we're using an implicit binary tree, as you'd see in a heap.
    1547             :          * table[0] = unused
    1548             :          * table[1] = '0'
    1549             :          * table[2] = '1'
    1550             :          * table[3] = '00'     <-- '00' and '01' are children of '0'
    1551             :          * table[4] = '01'     <-- '0' is [0], children are [0 * 2 + {1,2}]
    1552             :          * table[5] = '10'
    1553             :          * table[6] = '11'
    1554             :          * table[7] = '000'
    1555             :          * table[8] = '001'
    1556             :          * table[9] = '010'
    1557             :          * table[10]= '011'
    1558             :          * table[11]= '100
    1559             :          *'
    1560             :          * table[1 << n - 1] = '0' * n
    1561             :          * table[1 << n - 1 + x] = n-bit wide x (left padded with '0')
    1562             :          * table[1 << n - 2] = '1' * (n - 1)
    1563             :          *
    1564             :          * table[i]->left =  table[i*2 + 1]
    1565             :          * table[i]->right = table[i*2 + 2]
    1566             :          * table[0xffff] = unused (16 '0's, max len is 15)
    1567             :          *
    1568             :          * therefore e.g. table[70] = table[64     - 1 + 7]
    1569             :          *                          = table[1 << 6 - 1 + 7]
    1570             :          *                          = '000111' (binary 7, widened to 6 bits)
    1571             :          *
    1572             :          *   and if '000111' is a code,
    1573             :          *   '00011', '0001', '000', '00', '0' are unavailable prefixes.
    1574             :          *       34      16      7     3    1  are their indices
    1575             :          *   and (i - 1) >> 1 is the path back from 70 through these.
    1576             :          *
    1577             :          * the lookup is
    1578             :          *
    1579             :          * 1 start with i = 0
    1580             :          * 2 extract a symbol bit (i = (i << 1) + bit + 1)
    1581             :          * 3 is table[i] == 0xffff?
    1582             :          * 4  yes -- goto 2
    1583             :          * 4  table[i] & 511 is the symbol, stop
    1584             :          *
    1585             :          * and the construction (here) is sort of the reverse.
    1586             :          *
    1587             :          * Most of this table is free space that can never be reached, and
    1588             :          * most of the activity is at the beginning (since all codes start
    1589             :          * there, and by design the shortest codes are the most common).
    1590             :          */
    1591       42233 :         for (i = 0; i < 32; i++) {
    1592             :                 /* prefill the table head */
    1593       40000 :                 input->table[i] = 0xffff;
    1594             :         }
    1595         267 :         code = -1;
    1596         267 :         prev_len = 0;
    1597      182883 :         for (i = 0; i < n_symbols; i++) {
    1598      181633 :                 uint16_t s = symbols[i];
    1599      159916 :                 uint16_t prefix;
    1600      181633 :                 len = (s >> 9) & 15;
    1601      181633 :                 s &= 511;
    1602      181633 :                 code++;
    1603      191921 :                 while (len != prev_len) {
    1604       10288 :                         code <<= 1;
    1605       10288 :                         code++;
    1606       10288 :                         prev_len++;
    1607             :                 }
    1608             : 
    1609      181633 :                 if (code >= 65535) {
    1610           0 :                         return false;
    1611             :                 }
    1612      181633 :                 input->table[code] = s;
    1613      181633 :                 for(prefix = (code - 1) >> 1;
    1614      894684 :                     prefix > 31;
    1615      713051 :                     prefix = (prefix - 1) >> 1) {
    1616      713051 :                         input->table[prefix] = 0xffff;
    1617             :                 }
    1618             :         }
    1619        1250 :         if (CHECK_DEBUGLVL(10)) {
    1620           0 :                 debug_tree_codes(input);
    1621             :         }
    1622             : 
    1623             :         /*
    1624             :          * check that the last code encodes 11111..., with right number of
    1625             :          * ones, pointing to the right symbol -- otherwise we have a dangling
    1626             :          * uninitialised symbol.
    1627             :          */
    1628        1250 :         if (code != (1 << (len + 1)) - 2) {
    1629           0 :                 return false;
    1630             :         }
    1631         267 :         return true;
    1632             : }
    1633             : 
    1634             : 
    1635             : #define CHECK_READ_32(dest)                                       \
    1636             :         do {                                                      \
    1637             :                 if (input->byte_pos + 4 > input->byte_size) {     \
    1638             :                         return LZXPRESS_ERROR;                     \
    1639             :                 }                                                  \
    1640             :                 dest = PULL_LE_U32(input->bytes, input->byte_pos); \
    1641             :                 input->byte_pos += 4;                                   \
    1642             :         } while (0)
    1643             : 
    1644             : #define CHECK_READ_16(dest)                                       \
    1645             :         do {                                                      \
    1646             :                 if (input->byte_pos + 2 > input->byte_size) {     \
    1647             :                         return LZXPRESS_ERROR;                     \
    1648             :                 }                                                  \
    1649             :                 dest = PULL_LE_U16(input->bytes, input->byte_pos); \
    1650             :                 input->byte_pos += 2;                                   \
    1651             :         } while (0)
    1652             : 
    1653             : #define CHECK_READ_8(dest) \
    1654             :         do {                                                            \
    1655             :                 if (input->byte_pos >= input->byte_size) {             \
    1656             :                         return LZXPRESS_ERROR;                          \
    1657             :                 }                                                       \
    1658             :                 dest = PULL_LE_U8(input->bytes, input->byte_pos); \
    1659             :                 input->byte_pos++;                                   \
    1660             :         } while(0)
    1661             : 
    1662             : 
    1663     8809423 : static inline ssize_t pull_bits(struct bitstream *input)
    1664             : {
    1665     8809423 :         if (input->byte_pos + 1 < input->byte_size) {
    1666     8706401 :                 uint16_t tmp;
    1667     8809423 :                 CHECK_READ_16(tmp);
    1668     8809423 :                 input->remaining_bits += 16;
    1669     8809423 :                 input->bits <<= 16;
    1670     8809423 :                 input->bits |= tmp;
    1671           0 :         } else if (input->byte_pos < input->byte_size) {
    1672           0 :                 uint8_t tmp;
    1673           0 :                 CHECK_READ_8(tmp);
    1674           0 :                 input->remaining_bits += 8;
    1675           0 :                 input->bits <<= 8;
    1676           0 :                 input->bits |= tmp;
    1677             :         } else {
    1678           0 :                 return LZXPRESS_ERROR;
    1679             :         }
    1680      103022 :         return 0;
    1681             : }
    1682             : 
    1683             : 
    1684             : /*
    1685             :  * Decompress a block. The actual decompressed size is returned (or -1 on
    1686             :  * error). The putative block length is 64k (or shorter, if the message ends
    1687             :  * first), but a match can run over the end, extending the block. That's why
    1688             :  * we need the overall output size as well as the block size. A match encoded
    1689             :  * in this block can point back to previous blocks, but not before the
    1690             :  * beginning of the message, so we also need the previously decoded size.
    1691             :  *
    1692             :  * The compressed block will have 256 bytes for the Huffman table, and at
    1693             :  * least 4 bytes of (possibly padded) encoded values.
    1694             :  */
    1695        1250 : static ssize_t lzx_huffman_decompress_block(struct bitstream *input,
    1696             :                                             uint8_t *output,
    1697             :                                             size_t block_size,
    1698             :                                             size_t output_size,
    1699             :                                             size_t previous_size)
    1700             : {
    1701        1250 :         size_t output_pos = 0;
    1702         983 :         uint16_t symbol;
    1703         983 :         size_t index;
    1704        1250 :         uint16_t distance_bits_wanted = 0;
    1705        1250 :         size_t distance = 0;
    1706        1250 :         size_t length = 0;
    1707         983 :         bool ok;
    1708         983 :         uint32_t tmp;
    1709        1250 :         bool seen_eof_marker = false;
    1710             : 
    1711        1250 :         ok = fill_decomp_table(input);
    1712        1250 :         if (! ok) {
    1713           0 :                 return LZXPRESS_ERROR;
    1714             :         }
    1715        1250 :         if (CHECK_DEBUGLVL(10) || DEBUG_HUFFMAN_TREE) {
    1716           0 :                 debug_huffman_tree_from_table(input->table);
    1717             :         }
    1718             :         /*
    1719             :          * Always read 32 bits at the start, even if we don't need them.
    1720             :          */
    1721        1250 :         CHECK_READ_16(tmp);
    1722        1250 :         CHECK_READ_16(input->bits);
    1723        1250 :         input->bits |= tmp << 16;
    1724        1250 :         input->remaining_bits = 32;
    1725             : 
    1726             :         /*
    1727             :          * This loop iterates over individual *bits*. These are read from
    1728             :          * little-endian 16 bit words, most significant bit first.
    1729             :          *
    1730             :          * At points in the bitstream, the following are possible:
    1731             :          *
    1732             :          * # the source word is empty and needs to be refilled from the input
    1733             :          *    stream.
    1734             :          * # an incomplete codeword is being extended.
    1735             :          * # a codeword is resolved, either as a literal or a match.
    1736             :          * # a literal is written.
    1737             :          * # a match is collecting distance bits.
    1738             :          * # the output stream is copied, as specified by a match.
    1739             :          * # input bytes are read for match lengths.
    1740             :          *
    1741             :          * Note that we *don't* specifically check for the EOF marker (symbol
    1742             :          * 256) in this loop, because the precondition for stopping for the
    1743             :          * EOF marker is that the output buffer is full (otherwise, you
    1744             :          * wouldn't know which 256 is EOF, rather than an actual symbol), and
    1745             :          * we *always* want to stop when the buffer is full. So we work out if
    1746             :          * there is an EOF in another loop after we stop writing.
    1747             :          */
    1748             : 
    1749        1250 :         index = 0;
    1750   140959549 :         while (output_pos < block_size) {
    1751   139309195 :                 uint16_t b;
    1752   140958300 :                 if (input->remaining_bits == 16) {
    1753     8809246 :                         ssize_t ret = pull_bits(input);
    1754     8809246 :                         if (ret) {
    1755           0 :                                 return ret;
    1756             :                         }
    1757             :                 }
    1758   140958300 :                 input->remaining_bits--;
    1759             : 
    1760   140958300 :                 b = (input->bits >> input->remaining_bits) & 1;
    1761   140958300 :                 if (length == 0) {
    1762             :                         /* not in a match; pulling a codeword */
    1763   103567597 :                         index <<= 1;
    1764   103567597 :                         index += b + 1;
    1765   103567597 :                         if (input->table[index] == 0xffff) {
    1766             :                                 /* incomplete codeword, the common case */
    1767    89823391 :                                 continue;
    1768             :                         }
    1769             :                         /* found the symbol, reset the code string */
    1770    13744206 :                         symbol = input->table[index] & 511;
    1771    13744206 :                         index = 0;
    1772    13744206 :                         if (symbol < 256) {
    1773             :                                 /* a literal, the easy case */
    1774    10547072 :                                 output[output_pos] = symbol;
    1775    10547072 :                                 output_pos++;
    1776    10547072 :                                 continue;
    1777             :                         }
    1778             : 
    1779             :                         /* the beginning of a match */
    1780     3197134 :                         distance_bits_wanted = (symbol >> 4) & 15;
    1781     3197134 :                         distance = 1 << distance_bits_wanted;
    1782     3197134 :                         length = symbol & 15;
    1783     3197134 :                         if (length == 15) {
    1784       21043 :                                 CHECK_READ_8(tmp);
    1785       21043 :                                 length += tmp;
    1786       21043 :                                 if (length == 255 + 15) {
    1787             :                                         /*
    1788             :                                          * note, we discard (don't add) the
    1789             :                                          * length so far.
    1790             :                                          */
    1791        2513 :                                         CHECK_READ_16(length);
    1792        2513 :                                         if (length == 0) {
    1793          70 :                                                 CHECK_READ_32(length);
    1794             :                                         }
    1795             :                                 }
    1796             :                         }
    1797     3197134 :                         length += 3;
    1798             :                 } else {
    1799             :                         /* we are pulling extra distance bits */
    1800    37390703 :                         distance_bits_wanted--;
    1801    37390703 :                         distance |= b << distance_bits_wanted;
    1802             :                 }
    1803             : 
    1804    40587837 :                 if (distance_bits_wanted == 0) {
    1805             :                         /*
    1806             :                          * We have a complete match, and it is time to do the
    1807             :                          * copy (byte by byte, because the ranges can overlap,
    1808             :                          * and we might need to copy bytes we just copied in).
    1809             :                          *
    1810             :                          * It is possible that this match will extend beyond
    1811             :                          * the end of the expected block. That's fine, so long
    1812             :                          * as it doesn't extend past the total output size.
    1813             :                          */
    1814     3079095 :                         size_t i;
    1815     3197134 :                         size_t end = output_pos + length;
    1816     3197134 :                         uint8_t *here = output + output_pos;
    1817     3197134 :                         uint8_t *there = here - distance;
    1818     3197134 :                         if (end > output_size ||
    1819     3197133 :                             previous_size + output_pos < distance ||
    1820     3197133 :                             unlikely(end < output_pos || there > here)) {
    1821           0 :                                 return LZXPRESS_ERROR;
    1822             :                         }
    1823   114753169 :                         for (i = 0; i < length; i++) {
    1824   111556036 :                                 here[i] = there[i];
    1825             :                         }
    1826      118039 :                         output_pos += length;
    1827      118039 :                         distance = 0;
    1828      118039 :                         length = 0;
    1829             :                 }
    1830             :         }
    1831             : 
    1832        1249 :         if (length != 0 || index != 0) {
    1833             :                 /* it seems like we've hit an early end, mid-code */
    1834           0 :                 return LZXPRESS_ERROR;
    1835             :         }
    1836             : 
    1837        1249 :         if (input->byte_pos + 256 < input->byte_size) {
    1838             :                 /*
    1839             :                  * This block is over, but it clearly isn't the last block, so
    1840             :                  * we don't want to look for the EOF.
    1841             :                  */
    1842         585 :                 return output_pos;
    1843             :         }
    1844             :         /*
    1845             :          * We won't write any more, but we try to read some more to make sure
    1846             :          * we're finishing in a good place. That means we want to see a 256
    1847             :          * symbol and then some number of zeroes, possibly zero, but as many
    1848             :          * as 32.
    1849             :          *
    1850             :          * In this we are perhaps a bit stricter than Windows, which
    1851             :          * apparently does not insist on the EOF marker, nor on a lack of
    1852             :          * trailing bytes.
    1853             :          */
    1854        8825 :         while (true) {
    1855        4995 :                 uint16_t b;
    1856        8660 :                 if (input->remaining_bits == 16) {
    1857         495 :                         ssize_t ret;
    1858         840 :                         if (input->byte_pos == input->byte_size) {
    1859             :                                 /* FIN */
    1860         249 :                                 break;
    1861             :                         }
    1862         177 :                         ret = pull_bits(input);
    1863         177 :                         if (ret) {
    1864           0 :                                 return ret;
    1865             :                         }
    1866             :                 }
    1867        7997 :                 input->remaining_bits--;
    1868        7997 :                 b = (input->bits >> input->remaining_bits) & 1;
    1869        7997 :                 if (seen_eof_marker) {
    1870             :                         /*
    1871             :                          * we have read an EOF symbols. Now we just want to
    1872             :                          * see zeroes.
    1873             :                          */
    1874        4621 :                         if (b != 0) {
    1875           0 :                                 return LZXPRESS_ERROR;
    1876             :                         }
    1877        4621 :                         continue;
    1878             :                 }
    1879             : 
    1880             :                 /* we're pulling in a symbol, which had better be 256 */
    1881        3376 :                 index <<= 1;
    1882        3376 :                 index += b + 1;
    1883        3376 :                 if (input->table[index] == 0xffff) {
    1884        2712 :                         continue;
    1885             :                 }
    1886             : 
    1887         664 :                 symbol = input->table[index] & 511;
    1888         664 :                 if (symbol != 256) {
    1889           0 :                         return LZXPRESS_ERROR;
    1890             :                 }
    1891         663 :                 seen_eof_marker = true;
    1892         663 :                 continue;
    1893             :         }
    1894             : 
    1895         663 :         if (! seen_eof_marker) {
    1896           0 :                 return LZXPRESS_ERROR;
    1897             :         }
    1898             : 
    1899         663 :         return output_pos;
    1900             : }
    1901             : 
    1902         665 : static ssize_t lzxpress_huffman_decompress_internal(struct bitstream *input,
    1903             :                                                     uint8_t *output,
    1904             :                                                     size_t output_size)
    1905             : {
    1906         665 :         size_t output_pos = 0;
    1907             : 
    1908         665 :         if (input->byte_size < 260) {
    1909           0 :                 return LZXPRESS_ERROR;
    1910             :         }
    1911             : 
    1912        1913 :         while (input->byte_pos < input->byte_size) {
    1913         983 :                 ssize_t block_output_pos;
    1914         983 :                 ssize_t block_output_size;
    1915        1250 :                 size_t remaining_output_size = output_size - output_pos;
    1916             : 
    1917        1250 :                 block_output_size = MIN(65536, remaining_output_size);
    1918             : 
    1919        1250 :                 block_output_pos = lzx_huffman_decompress_block(
    1920             :                         input,
    1921             :                         output + output_pos,
    1922             :                         block_output_size,
    1923             :                         remaining_output_size,
    1924             :                         output_pos);
    1925             : 
    1926        1250 :                 if (block_output_pos < block_output_size) {
    1927           0 :                         return LZXPRESS_ERROR;
    1928             :                 }
    1929        1248 :                 output_pos += block_output_pos;
    1930        1248 :                 if (output_pos > output_size) {
    1931             :                         /* not expecting to get here. */
    1932           0 :                         return LZXPRESS_ERROR;
    1933             :                 }
    1934             :         }
    1935             : 
    1936         663 :         if (input->byte_pos != input->byte_size) {
    1937           0 :                 return LZXPRESS_ERROR;
    1938             :         }
    1939             : 
    1940         663 :         return output_pos;
    1941             : }
    1942             : 
    1943             : 
    1944             : /*
    1945             :  * lzxpress_huffman_decompress()
    1946             :  *
    1947             :  * output_size must be the expected length of the decompressed data.
    1948             :  * input_size and output_size are limited to the minimum of UINT32_MAX and
    1949             :  * SSIZE_MAX. On 64 bit machines that will be UINT32_MAX, or 4GB.
    1950             :  *
    1951             :  * @param input_bytes  memory to be decompressed.
    1952             :  * @param input_size   length of the compressed buffer.
    1953             :  * @param output       destination for the decompressed data.
    1954             :  * @param output_size  exact expected length of the decompressed data.
    1955             :  *
    1956             :  * @return the number of bytes written or -1 on error.
    1957             :  */
    1958             : 
    1959         669 : ssize_t lzxpress_huffman_decompress(const uint8_t *input_bytes,
    1960             :                                     size_t input_size,
    1961             :                                     uint8_t *output,
    1962             :                                     size_t output_size)
    1963             : {
    1964         420 :         uint16_t table[65536];
    1965         669 :         struct bitstream input = {
    1966             :                 .bytes = input_bytes,
    1967             :                 .byte_size = input_size,
    1968             :                 .byte_pos = 0,
    1969             :                 .bits = 0,
    1970             :                 .remaining_bits = 0,
    1971             :                 .table = table
    1972             :         };
    1973             : 
    1974         669 :         if (input_size > SSIZE_MAX ||
    1975         669 :             input_size > UINT32_MAX ||
    1976         669 :             output_size > SSIZE_MAX ||
    1977         669 :             output_size > UINT32_MAX ||
    1978         669 :             input_size == 0 ||
    1979         668 :             output_size == 0 ||
    1980         668 :             input_bytes == NULL ||
    1981             :             output == NULL) {
    1982             :                 /*
    1983             :                  * We use negative ssize_t to return errors, which is limiting
    1984             :                  * on 32 bit machines, and the 4GB limit exists on Windows.
    1985             :                  */
    1986           0 :                 return  LZXPRESS_ERROR;
    1987             :         }
    1988             : 
    1989         665 :         return lzxpress_huffman_decompress_internal(&input,
    1990             :                                                     output,
    1991             :                                                     output_size);
    1992             : }
    1993             : 
    1994             : 
    1995             : /**
    1996             :  * lzxpress_huffman_decompress_talloc()
    1997             :  *
    1998             :  * The caller must provide the exact size of the expected output.
    1999             :  *
    2000             :  * The input_size is limited to the minimum of UINT32_MAX and SSIZE_MAX, but
    2001             :  * output_size is limited to 256MB due to a limit in talloc. This effectively
    2002             :  * limits input_size too, as non-crafted compressed data will not exceed the
    2003             :  * decompressed size by very much.
    2004             :  *
    2005             :  * @param mem_ctx      TALLOC_CTX parent for the decompressed buffer.
    2006             :  * @param input_bytes  memory to be decompressed.
    2007             :  * @param input_size   length of the compressed buffer.
    2008             :  * @param output_size  expected decompressed size.
    2009             :  *
    2010             :  * @return a talloc'ed buffer exactly output_size in length, or NULL.
    2011             :  */
    2012             : 
    2013           0 : uint8_t *lzxpress_huffman_decompress_talloc(TALLOC_CTX *mem_ctx,
    2014             :                                             const uint8_t *input_bytes,
    2015             :                                             size_t input_size,
    2016             :                                             size_t output_size)
    2017             : {
    2018           0 :         ssize_t result;
    2019           0 :         uint8_t *output = NULL;
    2020           0 :         struct bitstream input = {
    2021             :                 .bytes = input_bytes,
    2022             :                 .byte_size = input_size
    2023             :         };
    2024             : 
    2025           0 :         output = talloc_array(mem_ctx, uint8_t, output_size);
    2026           0 :         if (output == NULL) {
    2027           0 :                 return NULL;
    2028             :         }
    2029             : 
    2030           0 :         input.table = talloc_array(mem_ctx, uint16_t, 65536);
    2031           0 :         if (input.table == NULL) {
    2032           0 :                 talloc_free(output);
    2033           0 :                 return NULL;
    2034             :         }
    2035           0 :         result = lzxpress_huffman_decompress_internal(&input,
    2036             :                                                       output,
    2037             :                                                       output_size);
    2038           0 :         talloc_free(input.table);
    2039             : 
    2040           0 :         if (result != output_size) {
    2041           0 :                 talloc_free(output);
    2042           0 :                 return NULL;
    2043             :         }
    2044           0 :         return output;
    2045             : }

Generated by: LCOV version 1.14