Skip to content

Track bytes used by in-memory postings #129969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
19 changes: 15 additions & 4 deletions server/src/main/java/org/elasticsearch/common/lucene/Lucene.java
Original file line number Diff line number Diff line change
Expand Up @@ -739,15 +739,26 @@ public static Version parseVersionLenient(String toParse, Version defaultValue)
* If no SegmentReader can be extracted an {@link IllegalStateException} is thrown.
*/
public static SegmentReader segmentReader(LeafReader reader) {
SegmentReader segmentReader = tryUnwrapSegmentReader(reader);
if (segmentReader == null) {
throw new IllegalStateException("Can not extract segment reader from given index reader [" + reader + "]");
}
return segmentReader;
}

/**
* Tries to extract a segment reader from the given index reader. Unlike {@link #segmentReader(LeafReader)} this method returns
* null if no SegmentReader can be unwrapped instead of throwing an exception.
*/
public static SegmentReader tryUnwrapSegmentReader(LeafReader reader) {
if (reader instanceof SegmentReader) {
return (SegmentReader) reader;
} else if (reader instanceof final FilterLeafReader fReader) {
return segmentReader(FilterLeafReader.unwrap(fReader));
return tryUnwrapSegmentReader(FilterLeafReader.unwrap(fReader));
} else if (reader instanceof final FilterCodecReader fReader) {
return segmentReader(FilterCodecReader.unwrap(fReader));
return tryUnwrapSegmentReader(FilterCodecReader.unwrap(fReader));
}
// hard fail - we can't get a SegmentReader
throw new IllegalStateException("Can not extract segment reader from given index reader [" + reader + "]");
return null;
}

@SuppressForbidden(reason = "Version#parseLeniently() used in a central place")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.codec;

import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.FieldsConsumer;
import org.apache.lucene.codecs.FieldsProducer;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.PostingsFormat;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.Fields;
import org.apache.lucene.index.FilterLeafReader;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.internal.hppc.IntIntHashMap;
import org.apache.lucene.util.BytesRef;

import java.io.IOException;
import java.util.function.IntConsumer;

public class TrackingPostingsInMemoryBytesCodec extends FilterCodec {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add class level javadocs explain the purpose of this class?

public static final String IN_MEMORY_POSTINGS_BYTES_KEY = "es.postings.in_memory_bytes";

public TrackingPostingsInMemoryBytesCodec(Codec delegate) {
super(delegate.getName(), delegate);
}

@Override
public PostingsFormat postingsFormat() {
PostingsFormat format = super.postingsFormat();

return new PostingsFormat(format.getName()) {
@Override
public FieldsConsumer fieldsConsumer(SegmentWriteState state) throws IOException {
FieldsConsumer consumer = format.fieldsConsumer(state);
return new TrackingLengthFieldsConsumer(state, consumer);
}

@Override
public FieldsProducer fieldsProducer(SegmentReadState state) throws IOException {
return format.fieldsProducer(state);
}
};
}

static final class TrackingLengthFieldsConsumer extends FieldsConsumer {
final SegmentWriteState state;
final FieldsConsumer in;
final IntIntHashMap maxLengths;

TrackingLengthFieldsConsumer(SegmentWriteState state, FieldsConsumer in) {
this.state = state;
this.in = in;
this.maxLengths = new IntIntHashMap(state.fieldInfos.size());
}

@Override
public void write(Fields fields, NormsProducer norms) throws IOException {
in.write(new TrackingLengthFields(fields, maxLengths, state.fieldInfos), norms);
long totalLength = 0;
for (int len : maxLengths.values) {
totalLength += len; // minTerm
totalLength += len; // maxTerm
}
state.segmentInfo.putAttribute(IN_MEMORY_POSTINGS_BYTES_KEY, Long.toString(totalLength));
}

@Override
public void close() throws IOException {
in.close();
}
}

static final class TrackingLengthFields extends FilterLeafReader.FilterFields {
final IntIntHashMap maxLengths;
final FieldInfos fieldInfos;

TrackingLengthFields(Fields in, IntIntHashMap maxLengths, FieldInfos fieldInfos) {
super(in);
this.maxLengths = maxLengths;
this.fieldInfos = fieldInfos;
}

@Override
public Terms terms(String field) throws IOException {
Terms terms = super.terms(field);
if (terms == null) {
return terms;
}
int fieldNum = fieldInfos.fieldInfo(field).number;
return new TrackingLengthTerms(terms, len -> maxLengths.put(fieldNum, Math.max(maxLengths.getOrDefault(fieldNum, 0), len)));
Comment on lines +96 to +101
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether we can do this instead:

Suggested change
Terms terms = super.terms(field);
if (terms == null) {
return terms;
}
int fieldNum = fieldInfos.fieldInfo(field).number;
return new TrackingLengthTerms(terms, len -> maxLengths.put(fieldNum, Math.max(maxLengths.getOrDefault(fieldNum, 0), len)));
Terms terms = super.terms(field);
// Only org.apache.lucene.codecs.lucene90.blocktree.FieldReader keeps min and max term in jvm heap,
// so only account for these cases:
if (terms instanceof FieldReader fieldReader) {
int fieldNum = fieldInfos.fieldInfo(field).number;
int length = fieldReader.getMin().length;
length += fieldReader.getMax().length;
maxLengths.put(fieldNum, length);
}
return terms;

This way there is way less wrapping. We only care about min and max term, given that this is loaded in jvm heap.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scratch that idea. The implementation provided here different. This gets invoked during indexing / merging. During indexing this implementation of terms is FreqProxTermsWriterPerField. Invoking getMax() is potentially expensive as it causes reading ahead to figure out which is the max term, these terms get later read via terms enum.

}
}

static final class TrackingLengthTerms extends FilterLeafReader.FilterTerms {
final IntConsumer onFinish;

TrackingLengthTerms(Terms in, IntConsumer onFinish) {
super(in);
this.onFinish = onFinish;
}

@Override
public TermsEnum iterator() throws IOException {
return new TrackingLengthTermsEnum(super.iterator(), onFinish);
}
}

static final class TrackingLengthTermsEnum extends FilterLeafReader.FilterTermsEnum {
int maxTermLength = 0;
final IntConsumer onFinish;

TrackingLengthTermsEnum(TermsEnum in, IntConsumer onFinish) {
super(in);
this.onFinish = onFinish;
}

@Override
public BytesRef next() throws IOException {
final BytesRef term = super.next();
if (term != null) {
maxTermLength = Math.max(maxTermLength, term.length);
} else {
onFinish.accept(maxTermLength);
}
return term;
}
Comment on lines +129 to +137
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that we need to estimate the terms that get loaded in jvm heap would the following be more accurate?

Suggested change
public BytesRef next() throws IOException {
final BytesRef term = super.next();
if (term != null) {
maxTermLength = Math.max(maxTermLength, term.length);
} else {
onFinish.accept(maxTermLength);
}
return term;
}
int prevTermLength = 0;
@Override
public BytesRef next() throws IOException {
final BytesRef term = super.next();
if (term == null) {
maxTermLength += prevTermLength;
onFinish.accept(maxTermLength);
return term;
}
if (maxTermLength == 0) {
maxTermLength = term.length;
}
prevTermLength = term.length;
return term;
}

In the org.apache.lucene.codecs.lucene90.blocktree.FieldReader class, the lowest and highest lexicographically term is kept around in jvm heap. The current code just keeps track what the longest term is and report that, which doesn't map with the minTerm and maxTerm in FieldReader?

}
}
13 changes: 12 additions & 1 deletion server/src/main/java/org/elasticsearch/index/engine/Engine.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.VersionType;
import org.elasticsearch.index.codec.FieldInfosWithUsages;
import org.elasticsearch.index.codec.TrackingPostingsInMemoryBytesCodec;
import org.elasticsearch.index.codec.vectors.reflect.OffHeapByteSizeUtils;
import org.elasticsearch.index.mapper.DocumentParser;
import org.elasticsearch.index.mapper.LuceneDocument;
Expand Down Expand Up @@ -275,6 +276,7 @@ protected static ShardFieldStats shardFieldStats(List<LeafReaderContext> leaves)
int numSegments = 0;
int totalFields = 0;
long usages = 0;
long totalPostingBytes = 0;
for (LeafReaderContext leaf : leaves) {
numSegments++;
var fieldInfos = leaf.reader().getFieldInfos();
Expand All @@ -286,8 +288,17 @@ protected static ShardFieldStats shardFieldStats(List<LeafReaderContext> leaves)
} else {
usages = -1;
}
SegmentReader segmentReader = Lucene.tryUnwrapSegmentReader(leaf.reader());
if (segmentReader != null) {
String postingBytes = segmentReader.getSegmentInfo().info.getAttribute(
TrackingPostingsInMemoryBytesCodec.IN_MEMORY_POSTINGS_BYTES_KEY
);
if (postingBytes != null) {
totalPostingBytes += Long.parseLong(postingBytes);
}
}
}
return new ShardFieldStats(numSegments, totalFields, usages);
return new ShardFieldStats(numSegments, totalFields, usages, totalPostingBytes);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
import org.elasticsearch.index.IndexVersions;
import org.elasticsearch.index.VersionType;
import org.elasticsearch.index.cache.query.TrivialQueryCachingPolicy;
import org.elasticsearch.index.codec.TrackingPostingsInMemoryBytesCodec;
import org.elasticsearch.index.mapper.DocumentParser;
import org.elasticsearch.index.mapper.IdFieldMapper;
import org.elasticsearch.index.mapper.LuceneDocument;
Expand Down Expand Up @@ -2778,7 +2779,7 @@ private IndexWriterConfig getIndexWriterConfig() {
iwc.setMaxFullFlushMergeWaitMillis(-1);
iwc.setSimilarity(engineConfig.getSimilarity());
iwc.setRAMBufferSizeMB(engineConfig.getIndexingBufferSize().getMbFrac());
iwc.setCodec(engineConfig.getCodec());
iwc.setCodec(new TrackingPostingsInMemoryBytesCodec(engineConfig.getCodec()));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder what the overhead is of always wrapping the codec in TrackingPostingsInMemoryBytesCodec. Maybe let's quickly run benchmark? (elastic/logs?)

Additionally I wonder whether this should only be done for stateless only.

boolean useCompoundFile = engineConfig.getUseCompoundFile();
iwc.setUseCompoundFile(useCompoundFile);
if (useCompoundFile == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
* @param totalFields the total number of fields across the segments
* @param fieldUsages the number of usages for segment-level fields (e.g., doc_values, postings, norms, points)
* -1 if unavailable
* @param postingsInMemoryBytes the total bytes in memory used for postings across all fields
*/
public record ShardFieldStats(int numSegments, int totalFields, long fieldUsages) {
public record ShardFieldStats(int numSegments, int totalFields, long fieldUsages, long postingsInMemoryBytes) {

}
Original file line number Diff line number Diff line change
Expand Up @@ -1882,8 +1882,9 @@ public void testShardFieldStats() throws IOException {
assertThat(stats.numSegments(), equalTo(0));
assertThat(stats.totalFields(), equalTo(0));
assertThat(stats.fieldUsages(), equalTo(0L));
assertThat(stats.postingsInMemoryBytes(), equalTo(0L));
// index some documents
int numDocs = between(1, 10);
int numDocs = between(2, 10);
for (int i = 0; i < numDocs; i++) {
indexDoc(shard, "_doc", "first_" + i, """
{
Expand All @@ -1901,6 +1902,9 @@ public void testShardFieldStats() throws IOException {
// _id(term), _source(0), _version(dv), _primary_term(dv), _seq_no(point,dv), f1(postings,norms),
// f1.keyword(term,dv), f2(postings,norms), f2.keyword(term,dv),
assertThat(stats.fieldUsages(), equalTo(13L));
// _id: 8, f1: 3, f1.keyword: 3, f2: 3, f2.keyword: 3
// (8 + 3 + 3 + 3 + 3) * 2 = 40
assertThat(stats.postingsInMemoryBytes(), equalTo(40L));
// don't re-compute on refresh without change
if (randomBoolean()) {
shard.refresh("test");
Expand All @@ -1918,7 +1922,7 @@ public void testShardFieldStats() throws IOException {
}
assertThat(shard.getShardFieldStats(), sameInstance(stats));
// index more docs
numDocs = between(1, 10);
numDocs = between(2, 10);
for (int i = 0; i < numDocs; i++) {
indexDoc(shard, "_doc", "first_" + i, """
{
Expand Down Expand Up @@ -1948,13 +1952,20 @@ public void testShardFieldStats() throws IOException {
assertThat(stats.totalFields(), equalTo(21));
// first segment: 13, second segment: 13 + f3(postings,norms) + f3.keyword(term,dv), and __soft_deletes to previous segment
assertThat(stats.fieldUsages(), equalTo(31L));
// segment 1: 40 (see above)
// segment 2: _id: 8, f1: 3, f1.keyword: 3, f2: 3, f2.keyword: 3, f3: 6, f3.keyword: 6
// (8 + 3 + 3 + 3 + 3 + 6 + 6) * 2 q= 64
// 40 + 64 = 104
assertThat(stats.postingsInMemoryBytes(), equalTo(104L));
shard.forceMerge(new ForceMergeRequest().maxNumSegments(1).flush(true));
stats = shard.getShardFieldStats();
assertThat(stats.numSegments(), equalTo(1));
assertThat(stats.totalFields(), equalTo(12));
// _id(term), _source(0), _version(dv), _primary_term(dv), _seq_no(point,dv), f1(postings,norms),
// f1.keyword(term,dv), f2(postings,norms), f2.keyword(term,dv), f3(postings,norms), f3.keyword(term,dv), __soft_deletes
assertThat(stats.fieldUsages(), equalTo(18L));
// max(segment1: 40, segment2: 64) = 64
assertThat(stats.postingsInMemoryBytes(), equalTo(64L));
closeShards(shard);
}

Expand Down