perf(YouTube): Reduce memory requirement for prefix tree searching (#501)

This commit is contained in:
LisoUseInAIKyrios 2023-10-17 13:08:35 +03:00 committed by GitHub
parent bd307e475f
commit f5add51fa7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 125 additions and 23 deletions

View File

@ -425,15 +425,15 @@ public final class LithoFilterPatch {
static {
for (Filter filter : filters) {
filterGroupLists(pathSearchTree, filter, filter.pathFilterGroupList);
filterGroupLists(identifierSearchTree, filter, filter.identifierFilterGroupList);
filterGroupLists(pathSearchTree, filter, filter.pathFilterGroupList);
}
LogHelper.printDebug(() -> "Using: "
+ pathSearchTree.numberOfPatterns() + " path filters"
+ " (" + pathSearchTree.getEstimatedMemorySize() + " KB), "
+ identifierSearchTree.numberOfPatterns() + " identifier filters"
+ " (" + identifierSearchTree.getEstimatedMemorySize() + " KB)");
+ " (" + identifierSearchTree.getEstimatedMemorySize() + " KB), "
+ pathSearchTree.numberOfPatterns() + " path filters"
+ " (" + pathSearchTree.getEstimatedMemorySize() + " KB)");
}
private static <T> void filterGroupLists(TrieSearch<T> pathSearchTree,

View File

@ -8,9 +8,17 @@ import java.util.Objects;
public final class ByteTrieSearch extends TrieSearch<byte[]> {
private static final class ByteTrieNode extends TrieNode<byte[]> {
TrieNode<byte[]> createNode() {
return new ByteTrieNode();
ByteTrieNode() {
super();
}
ByteTrieNode(char nodeCharacterValue) {
super(nodeCharacterValue);
}
@Override
TrieNode<byte[]> createNode(char nodeCharacterValue) {
return new ByteTrieNode(nodeCharacterValue);
}
@Override
char getCharValue(byte[] text, int index) {
return (char) text[index];
}

View File

@ -11,9 +11,17 @@ import java.util.Objects;
public final class StringTrieSearch extends TrieSearch<String> {
private static final class StringTrieNode extends TrieNode<String> {
TrieNode<String> createNode() {
return new StringTrieNode();
StringTrieNode() {
super();
}
StringTrieNode(char nodeCharacterValue) {
super(nodeCharacterValue);
}
@Override
TrieNode<String> createNode(char nodeValue) {
return new StringTrieNode(nodeValue);
}
@Override
char getCharValue(String text, int index) {
return text.charAt(index);
}

View File

@ -71,15 +71,31 @@ public abstract class TrieSearch<T> {
}
static abstract class TrieNode<T> {
/**
* Dummy value used for root node. Value can be anything as it's never referenced.
*/
private static final char ROOT_NODE_CHARACTER_VALUE = 0; // ASCII null character.
// Support only ASCII letters/numbers/symbols and filter out all control characters.
private static final char MIN_VALID_CHAR = 32; // Space character.
private static final char MAX_VALID_CHAR = 126; // 127 = delete character.
private static final int NUMBER_OF_CHILDREN = MAX_VALID_CHAR - MIN_VALID_CHAR + 1;
/**
* How much to expand the children array when resizing.
*/
private static final int CHILDREN_ARRAY_INCREASE_SIZE_INCREMENT = 2;
private static final int CHILDREN_ARRAY_MAX_SIZE = MAX_VALID_CHAR - MIN_VALID_CHAR + 1;
private static boolean isInvalidRange(char character) {
return character < MIN_VALID_CHAR || character > MAX_VALID_CHAR;
}
/**
* Character this node represents.
* This field is ignored for the root node (which does not represent any character).
*/
private final char nodeValue;
/**
* A compressed graph path that represents the remaining pattern characters of a single child node.
*
@ -91,6 +107,24 @@ public abstract class TrieSearch<T> {
/**
* All child nodes. Only present if no compressed leaf exist.
*
* Array is dynamically increased in size as needed,
* and uses perfect hashing for the elements it contains.
*
* So if the array contains a given character,
* the character will always map to the node with index: (character % arraySize).
*
* Elements not contained can collide with elements the array does contain,
* so must compare the nodes character value.
*
* Alternatively this array could be a sorted and densely packed array,
* and lookup is done using binary search.
* That would save a small amount of memory because there's no null children entries,
* but would give a worst case search of O(nlog(m)) where n is the number of
* characters in the searched text and m is the maximum size of the sorted character arrays.
* Using a hash table array always gives O(n) search time.
* The memory usage here is very small (all Litho filters use ~10KB of memory),
* so the more performant hash implementation is chosen.
*/
@Nullable
private TrieNode<T>[] children;
@ -101,6 +135,13 @@ public abstract class TrieSearch<T> {
@Nullable
private List<TriePatternMatchedCallback<T>> endOfPatternCallback;
TrieNode() {
this.nodeValue = ROOT_NODE_CHARACTER_VALUE;
}
TrieNode(char nodeCharacterValue) {
this.nodeValue = nodeCharacterValue;
}
/**
* @param pattern Pattern to add.
* @param patternLength Length of the pattern.
@ -121,7 +162,7 @@ public abstract class TrieSearch<T> {
// Recursively call back into this method and push the existing leaf down 1 level.
if (children != null) throw new IllegalStateException();
//noinspection unchecked
children = new TrieNode[NUMBER_OF_CHILDREN];
children = new TrieNode[1];
TrieCompressedPath<T> temp = leaf;
leaf = null;
addPattern(temp.pattern, temp.patternLength, temp.patternStartIndex, temp.callback);
@ -130,19 +171,65 @@ public abstract class TrieSearch<T> {
leaf = new TrieCompressedPath<>(pattern, patternLength, patternIndex, callback);
return;
}
char character = getCharValue(pattern, patternIndex);
final char character = getCharValue(pattern, patternIndex);
if (isInvalidRange(character)) {
throw new IllegalArgumentException("invalid character at index " + patternIndex + ": " + pattern);
}
character -= MIN_VALID_CHAR; // Adjust to the array range.
TrieNode<T> child = children[character];
final int arrayIndex = hashIndexForTableSize(children.length, character);
TrieNode<T> child = children[arrayIndex];
if (child == null) {
child = createNode();
children[character] = child;
child = createNode(character);
children[arrayIndex] = child;
} else if (child.nodeValue != character) {
// Hash collision. Resize the table until perfect hashing is found.
child = createNode(character);
expandChildArray(child);
}
child.addPattern(pattern, patternLength, patternIndex + 1, callback);
}
/**
* Resizes the children table until all nodes hash to exactly one array index.
* Worse case, this will resize the array to {@link #CHILDREN_ARRAY_MAX_SIZE} elements.
*/
private void expandChildArray(TrieNode<T> child) {
int replacementArraySize = Objects.requireNonNull(children).length;
while (true) {
replacementArraySize += CHILDREN_ARRAY_INCREASE_SIZE_INCREMENT;
//noinspection unchecked
TrieNode<T>[] replacement = new TrieNode[replacementArraySize];
addNodeToArray(replacement, child);
boolean collision = false;
for (TrieNode<T> existingChild : children) {
if (existingChild != null) {
if (!addNodeToArray(replacement, existingChild)) {
collision = true;
break;
}
}
}
if (collision) {
if (replacementArraySize > CHILDREN_ARRAY_MAX_SIZE) throw new IllegalStateException();
continue;
}
children = replacement;
return;
}
}
private static <T> boolean addNodeToArray(TrieNode<T>[] array, TrieNode<T> childToAdd) {
final int insertIndex = hashIndexForTableSize(array.length, childToAdd.nodeValue);
if (array[insertIndex] != null ) {
return false; // Collision.
}
array[insertIndex] = childToAdd;
return true;
}
private static int hashIndexForTableSize(int arraySize, char nodeValue) {
return (nodeValue - MIN_VALID_CHAR) % arraySize;
}
/**
* @param searchText Text to search for patterns in.
* @param searchTextLength Length of the search text.
@ -170,18 +257,17 @@ public abstract class TrieSearch<T> {
if (children == null) {
return false; // Reached a graph end point and there's no further patterns to search.
}
if (searchTextIndex == searchTextLength) {
return false; // Reached end of the search text and found no matches.
}
char character = getCharValue(searchText, searchTextIndex);
final char character = getCharValue(searchText, searchTextIndex);
if (isInvalidRange(character)) {
return false; // Not an ASCII letter/number/symbol.
}
character -= MIN_VALID_CHAR; // Adjust to the array range.
TrieNode<T> child = children[character];
if (child == null) {
final int arrayIndex = hashIndexForTableSize(children.length, character);
TrieNode<T> child = children[arrayIndex];
if (child == null || child.nodeValue != character) {
return false;
}
return child.matches(searchText, searchTextLength, searchTextIndex + 1,
@ -194,7 +280,7 @@ public abstract class TrieSearch<T> {
* @return Estimated number of memory pointers used, starting from this node and including all children.
*/
private int estimatedNumberOfPointersUsed() {
int numberOfPointers = 3; // Number of fields in this class.
int numberOfPointers = 4; // Number of fields in this class.
if (leaf != null) {
numberOfPointers += 4; // Number of fields in leaf node.
}
@ -202,7 +288,7 @@ public abstract class TrieSearch<T> {
numberOfPointers += endOfPatternCallback.size();
}
if (children != null) {
numberOfPointers += NUMBER_OF_CHILDREN;
numberOfPointers += children.length;
for (TrieNode<T> child : children) {
if (child != null) {
numberOfPointers += child.estimatedNumberOfPointersUsed();
@ -212,7 +298,7 @@ public abstract class TrieSearch<T> {
return numberOfPointers;
}
abstract TrieNode<T> createNode();
abstract TrieNode<T> createNode(char nodeValue);
abstract char getCharValue(T text, int index);
}