From f5add51fa7eb620a6edd1b27f02d38618f144480 Mon Sep 17 00:00:00 2001 From: LisoUseInAIKyrios <118716522+LisoUseInAIKyrios@users.noreply.github.com> Date: Tue, 17 Oct 2023 13:08:35 +0300 Subject: [PATCH] perf(YouTube): Reduce memory requirement for prefix tree searching (#501) --- .../patches/components/LithoFilterPatch.java | 8 +- .../integrations/utils/ByteTrieSearch.java | 12 +- .../integrations/utils/StringTrieSearch.java | 12 +- .../integrations/utils/TrieSearch.java | 116 +++++++++++++++--- 4 files changed, 125 insertions(+), 23 deletions(-) diff --git a/app/src/main/java/app/revanced/integrations/patches/components/LithoFilterPatch.java b/app/src/main/java/app/revanced/integrations/patches/components/LithoFilterPatch.java index 5099a4df..a4051c55 100644 --- a/app/src/main/java/app/revanced/integrations/patches/components/LithoFilterPatch.java +++ b/app/src/main/java/app/revanced/integrations/patches/components/LithoFilterPatch.java @@ -425,15 +425,15 @@ public final class LithoFilterPatch { static { for (Filter filter : filters) { - filterGroupLists(pathSearchTree, filter, filter.pathFilterGroupList); filterGroupLists(identifierSearchTree, filter, filter.identifierFilterGroupList); + filterGroupLists(pathSearchTree, filter, filter.pathFilterGroupList); } LogHelper.printDebug(() -> "Using: " - + pathSearchTree.numberOfPatterns() + " path filters" - + " (" + pathSearchTree.getEstimatedMemorySize() + " KB), " + identifierSearchTree.numberOfPatterns() + " identifier filters" - + " (" + identifierSearchTree.getEstimatedMemorySize() + " KB)"); + + " (" + identifierSearchTree.getEstimatedMemorySize() + " KB), " + + pathSearchTree.numberOfPatterns() + " path filters" + + " (" + pathSearchTree.getEstimatedMemorySize() + " KB)"); } private static void filterGroupLists(TrieSearch pathSearchTree, diff --git a/app/src/main/java/app/revanced/integrations/utils/ByteTrieSearch.java b/app/src/main/java/app/revanced/integrations/utils/ByteTrieSearch.java index 807f3f8a..02a1ff70 100644 --- a/app/src/main/java/app/revanced/integrations/utils/ByteTrieSearch.java +++ b/app/src/main/java/app/revanced/integrations/utils/ByteTrieSearch.java @@ -8,9 +8,17 @@ import java.util.Objects; public final class ByteTrieSearch extends TrieSearch { private static final class ByteTrieNode extends TrieNode { - TrieNode createNode() { - return new ByteTrieNode(); + ByteTrieNode() { + super(); } + ByteTrieNode(char nodeCharacterValue) { + super(nodeCharacterValue); + } + @Override + TrieNode createNode(char nodeCharacterValue) { + return new ByteTrieNode(nodeCharacterValue); + } + @Override char getCharValue(byte[] text, int index) { return (char) text[index]; } diff --git a/app/src/main/java/app/revanced/integrations/utils/StringTrieSearch.java b/app/src/main/java/app/revanced/integrations/utils/StringTrieSearch.java index 1a1a0a9e..28d960cf 100644 --- a/app/src/main/java/app/revanced/integrations/utils/StringTrieSearch.java +++ b/app/src/main/java/app/revanced/integrations/utils/StringTrieSearch.java @@ -11,9 +11,17 @@ import java.util.Objects; public final class StringTrieSearch extends TrieSearch { private static final class StringTrieNode extends TrieNode { - TrieNode createNode() { - return new StringTrieNode(); + StringTrieNode() { + super(); } + StringTrieNode(char nodeCharacterValue) { + super(nodeCharacterValue); + } + @Override + TrieNode createNode(char nodeValue) { + return new StringTrieNode(nodeValue); + } + @Override char getCharValue(String text, int index) { return text.charAt(index); } diff --git a/app/src/main/java/app/revanced/integrations/utils/TrieSearch.java b/app/src/main/java/app/revanced/integrations/utils/TrieSearch.java index 8b9fecb8..d42c305c 100644 --- a/app/src/main/java/app/revanced/integrations/utils/TrieSearch.java +++ b/app/src/main/java/app/revanced/integrations/utils/TrieSearch.java @@ -71,15 +71,31 @@ public abstract class TrieSearch { } static abstract class TrieNode { + /** + * Dummy value used for root node. Value can be anything as it's never referenced. + */ + private static final char ROOT_NODE_CHARACTER_VALUE = 0; // ASCII null character. + // Support only ASCII letters/numbers/symbols and filter out all control characters. private static final char MIN_VALID_CHAR = 32; // Space character. private static final char MAX_VALID_CHAR = 126; // 127 = delete character. - private static final int NUMBER_OF_CHILDREN = MAX_VALID_CHAR - MIN_VALID_CHAR + 1; + + /** + * How much to expand the children array when resizing. + */ + private static final int CHILDREN_ARRAY_INCREASE_SIZE_INCREMENT = 2; + private static final int CHILDREN_ARRAY_MAX_SIZE = MAX_VALID_CHAR - MIN_VALID_CHAR + 1; private static boolean isInvalidRange(char character) { return character < MIN_VALID_CHAR || character > MAX_VALID_CHAR; } + /** + * Character this node represents. + * This field is ignored for the root node (which does not represent any character). + */ + private final char nodeValue; + /** * A compressed graph path that represents the remaining pattern characters of a single child node. * @@ -91,6 +107,24 @@ public abstract class TrieSearch { /** * All child nodes. Only present if no compressed leaf exist. + * + * Array is dynamically increased in size as needed, + * and uses perfect hashing for the elements it contains. + * + * So if the array contains a given character, + * the character will always map to the node with index: (character % arraySize). + * + * Elements not contained can collide with elements the array does contain, + * so must compare the nodes character value. + * + * Alternatively this array could be a sorted and densely packed array, + * and lookup is done using binary search. + * That would save a small amount of memory because there's no null children entries, + * but would give a worst case search of O(nlog(m)) where n is the number of + * characters in the searched text and m is the maximum size of the sorted character arrays. + * Using a hash table array always gives O(n) search time. + * The memory usage here is very small (all Litho filters use ~10KB of memory), + * so the more performant hash implementation is chosen. */ @Nullable private TrieNode[] children; @@ -101,6 +135,13 @@ public abstract class TrieSearch { @Nullable private List> endOfPatternCallback; + TrieNode() { + this.nodeValue = ROOT_NODE_CHARACTER_VALUE; + } + TrieNode(char nodeCharacterValue) { + this.nodeValue = nodeCharacterValue; + } + /** * @param pattern Pattern to add. * @param patternLength Length of the pattern. @@ -121,7 +162,7 @@ public abstract class TrieSearch { // Recursively call back into this method and push the existing leaf down 1 level. if (children != null) throw new IllegalStateException(); //noinspection unchecked - children = new TrieNode[NUMBER_OF_CHILDREN]; + children = new TrieNode[1]; TrieCompressedPath temp = leaf; leaf = null; addPattern(temp.pattern, temp.patternLength, temp.patternStartIndex, temp.callback); @@ -130,19 +171,65 @@ public abstract class TrieSearch { leaf = new TrieCompressedPath<>(pattern, patternLength, patternIndex, callback); return; } - char character = getCharValue(pattern, patternIndex); + final char character = getCharValue(pattern, patternIndex); if (isInvalidRange(character)) { throw new IllegalArgumentException("invalid character at index " + patternIndex + ": " + pattern); } - character -= MIN_VALID_CHAR; // Adjust to the array range. - TrieNode child = children[character]; + final int arrayIndex = hashIndexForTableSize(children.length, character); + TrieNode child = children[arrayIndex]; if (child == null) { - child = createNode(); - children[character] = child; + child = createNode(character); + children[arrayIndex] = child; + } else if (child.nodeValue != character) { + // Hash collision. Resize the table until perfect hashing is found. + child = createNode(character); + expandChildArray(child); } child.addPattern(pattern, patternLength, patternIndex + 1, callback); } + /** + * Resizes the children table until all nodes hash to exactly one array index. + * Worse case, this will resize the array to {@link #CHILDREN_ARRAY_MAX_SIZE} elements. + */ + private void expandChildArray(TrieNode child) { + int replacementArraySize = Objects.requireNonNull(children).length; + while (true) { + replacementArraySize += CHILDREN_ARRAY_INCREASE_SIZE_INCREMENT; + //noinspection unchecked + TrieNode[] replacement = new TrieNode[replacementArraySize]; + addNodeToArray(replacement, child); + boolean collision = false; + for (TrieNode existingChild : children) { + if (existingChild != null) { + if (!addNodeToArray(replacement, existingChild)) { + collision = true; + break; + } + } + } + if (collision) { + if (replacementArraySize > CHILDREN_ARRAY_MAX_SIZE) throw new IllegalStateException(); + continue; + } + children = replacement; + return; + } + } + + private static boolean addNodeToArray(TrieNode[] array, TrieNode childToAdd) { + final int insertIndex = hashIndexForTableSize(array.length, childToAdd.nodeValue); + if (array[insertIndex] != null ) { + return false; // Collision. + } + array[insertIndex] = childToAdd; + return true; + } + + private static int hashIndexForTableSize(int arraySize, char nodeValue) { + return (nodeValue - MIN_VALID_CHAR) % arraySize; + } + /** * @param searchText Text to search for patterns in. * @param searchTextLength Length of the search text. @@ -170,18 +257,17 @@ public abstract class TrieSearch { if (children == null) { return false; // Reached a graph end point and there's no further patterns to search. } - if (searchTextIndex == searchTextLength) { return false; // Reached end of the search text and found no matches. } - char character = getCharValue(searchText, searchTextIndex); + final char character = getCharValue(searchText, searchTextIndex); if (isInvalidRange(character)) { return false; // Not an ASCII letter/number/symbol. } - character -= MIN_VALID_CHAR; // Adjust to the array range. - TrieNode child = children[character]; - if (child == null) { + final int arrayIndex = hashIndexForTableSize(children.length, character); + TrieNode child = children[arrayIndex]; + if (child == null || child.nodeValue != character) { return false; } return child.matches(searchText, searchTextLength, searchTextIndex + 1, @@ -194,7 +280,7 @@ public abstract class TrieSearch { * @return Estimated number of memory pointers used, starting from this node and including all children. */ private int estimatedNumberOfPointersUsed() { - int numberOfPointers = 3; // Number of fields in this class. + int numberOfPointers = 4; // Number of fields in this class. if (leaf != null) { numberOfPointers += 4; // Number of fields in leaf node. } @@ -202,7 +288,7 @@ public abstract class TrieSearch { numberOfPointers += endOfPatternCallback.size(); } if (children != null) { - numberOfPointers += NUMBER_OF_CHILDREN; + numberOfPointers += children.length; for (TrieNode child : children) { if (child != null) { numberOfPointers += child.estimatedNumberOfPointersUsed(); @@ -212,7 +298,7 @@ public abstract class TrieSearch { return numberOfPointers; } - abstract TrieNode createNode(); + abstract TrieNode createNode(char nodeValue); abstract char getCharValue(T text, int index); }