Skip to content

CSHARP-3458: Extend IAsyncCursor and IAsyncCursorSource to support IAsyncEnumerable #1708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/MongoDB.Driver/Core/IAsyncCursor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,17 @@ public static class IAsyncCursorExtensions
return new AsyncCursorEnumerableOneTimeAdapter<TDocument>(cursor, cancellationToken);
}

/// <summary>
/// Wraps a cursor in an IAsyncEnumerable that can be enumerated one time.
/// </summary>
/// <typeparam name="TDocument">The type of the document.</typeparam>
/// <param name="cursor">The cursor.</param>
/// <returns>An IAsyncEnumerable.</returns>
public static IAsyncEnumerable<TDocument> ToAsyncEnumerable<TDocument>(this IAsyncCursor<TDocument> cursor)
{
return new AsyncCursorEnumerableOneTimeAdapter<TDocument>(cursor);
}

/// <summary>
/// Returns a list containing all the documents returned by a cursor.
/// </summary>
Expand Down
11 changes: 11 additions & 0 deletions src/MongoDB.Driver/Core/IAsyncCursorSource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,17 @@ public static class IAsyncCursorSourceExtensions
return new AsyncCursorSourceEnumerableAdapter<TDocument>(source, cancellationToken);
}

/// <summary>
/// Wraps a cursor source in an IAsyncEnumerable. Each time GetAsyncEnumerator is called a new cursor is fetched from the cursor source.
/// </summary>
/// <typeparam name="TDocument">The type of the document.</typeparam>
/// <param name="source">The source.</param>
/// <returns>An IAsyncEnumerable.</returns>
public static IAsyncEnumerable<TDocument> ToAsyncEnumerable<TDocument>(this IAsyncCursorSource<TDocument> source)
{
return new AsyncCursorSourceEnumerableAdapter<TDocument>(source);
}

/// <summary>
/// Returns a list containing all the documents returned by the cursor returned by a cursor source.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,33 @@

namespace MongoDB.Driver.Core.Operations
{
internal sealed class AsyncCursorEnumerableOneTimeAdapter<TDocument> : IEnumerable<TDocument>
internal sealed class AsyncCursorEnumerableOneTimeAdapter<TDocument> : IEnumerable<TDocument>, IAsyncEnumerable<TDocument>
{
private readonly CancellationToken _cancellationToken;
private readonly IAsyncCursor<TDocument> _cursor;
private bool _hasBeenEnumerated;

public AsyncCursorEnumerableOneTimeAdapter(IAsyncCursor<TDocument> cursor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not have added this constructor.

It makes it too easy for the caller to forget to pass in the cancellationToken.

Let the caller provide CancellationToken.None if it needs to.

: this(cursor, CancellationToken.None)
{
}

public AsyncCursorEnumerableOneTimeAdapter(IAsyncCursor<TDocument> cursor, CancellationToken cancellationToken)
{
_cursor = Ensure.IsNotNull(cursor, nameof(cursor));
_cancellationToken = cancellationToken;
}

public IAsyncEnumerator<TDocument> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
if (_hasBeenEnumerated)
{
throw new InvalidOperationException("An IAsyncCursor can only be enumerated once.");
}
_hasBeenEnumerated = true;
return new AsyncCursorEnumerator<TDocument>(_cursor, cancellationToken);
}

public IEnumerator<TDocument> GetEnumerator()
{
if (_hasBeenEnumerated)
Expand Down
55 changes: 42 additions & 13 deletions src/MongoDB.Driver/Core/Operations/AsyncCursorEnumerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
using System.Collections;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using MongoDB.Driver.Core.Misc;

namespace MongoDB.Driver.Core.Operations
{
internal class AsyncCursorEnumerator<TDocument> : IEnumerator<TDocument>
internal class AsyncCursorEnumerator<TDocument> : IEnumerator<TDocument>, IAsyncEnumerator<TDocument>
{
// private fields
private IEnumerator<TDocument> _batchEnumerator;
Expand Down Expand Up @@ -72,6 +73,12 @@ public void Dispose()
}
}

public ValueTask DisposeAsync()
{
Dispose();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a comment like:

 Dispose(); // TODO: implement true async disposal
return default;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My first reaction was that default was a bug, that we should be returning ValueTask.CompletedTask.

It took me some research to figure out why you were returning default.

I suggest the following:

#if NET6_0_OR_GREATER
            return ValueTask.CompletedTask;
#else
            return default; // prior to NET6_0 you have to fake ValueTask.CompletedTask using default
#endif

It's a little more verbose, but it documents why we sometimes have to use default instead of ValueTask.CompletedTask, and when we remove support for netstandard2.1 this will get cleaned up.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a bug definitely. This is documented behavior:

An instance created with the parameterless constructor or by the default(ValueTask) syntax (a zero-initialized structure) represents a synchronously, successfully completed operation.

https://learn.microsoft.com/en-us/dotnet/api/system.threading.tasks.valuetask?view=net-9.0#remarks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't say it was a bug. I meant it "looked" like a bug.

I'm saying that we should use ValueTask.CompletedTask when available.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, but it's technically the same.
image
ValueTask.CompletedTask is more readable though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't say it wasn't technically the same.

I'm saying that we should use ValueTask.CompletedTask when available. It's more readable than default (who wants to go hunt down the documentation to figure out what default means).

}

public bool MoveNext()
{
ThrowIfDisposed();
Expand All @@ -82,24 +89,46 @@ public bool MoveNext()
return true;
}

while (true)
while (_cursor.MoveNext(_cancellationToken))
{
if (_cursor.MoveNext(_cancellationToken))
_batchEnumerator?.Dispose();
_batchEnumerator = _cursor.Current.GetEnumerator();
if (_batchEnumerator.MoveNext())
{
_batchEnumerator?.Dispose();
_batchEnumerator = _cursor.Current.GetEnumerator();
if (_batchEnumerator.MoveNext())
{
return true;
}
return true;
}
else
}

_batchEnumerator?.Dispose();
_batchEnumerator = null;
_finished = true;
return false;
}

public async ValueTask<bool> MoveNextAsync()
{
ThrowIfDisposed();
_started = true;

if (_batchEnumerator != null && _batchEnumerator.MoveNext())
{
return true;
}

while (await _cursor.MoveNextAsync(_cancellationToken).ConfigureAwait(false))
{
_batchEnumerator?.Dispose();
_batchEnumerator = _cursor.Current.GetEnumerator();
if (_batchEnumerator.MoveNext())
{
_batchEnumerator = null;
_finished = true;
return false;
return true;
}
}

_batchEnumerator?.Dispose();
_batchEnumerator = null;
_finished = true;
return false;
}

public void Reset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,37 @@
* limitations under the License.
*/

using System;
using System.Collections;
using System.Collections.Generic;
using System.Threading;
using MongoDB.Driver.Core.Misc;

namespace MongoDB.Driver.Core.Operations
{
internal class AsyncCursorSourceEnumerableAdapter<TDocument> : IEnumerable<TDocument>
internal class AsyncCursorSourceEnumerableAdapter<TDocument> : IEnumerable<TDocument>, IAsyncEnumerable<TDocument>
{
// private fields
private readonly CancellationToken _cancellationToken;
private readonly IAsyncCursorSource<TDocument> _source;

// constructors
public AsyncCursorSourceEnumerableAdapter(IAsyncCursorSource<TDocument> source)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to call another constructor here, like:
: this(source, CancellationToken.None)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not have added this constructor.

It makes it too easy for the caller to forget to pass in the cancellationToken.

Let the caller provide CancellationToken.None if it needs to.

: this(source, CancellationToken.None)
{
}

public AsyncCursorSourceEnumerableAdapter(IAsyncCursorSource<TDocument> source, CancellationToken cancellationToken)
{
_source = Ensure.IsNotNull(source, nameof(source));
_cancellationToken = cancellationToken;
}

public IAsyncEnumerator<TDocument> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
var cursor = _source.ToCursor(cancellationToken);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling _source.ToCursor is a BLOCKING call that executes the query.

Presumable the caller did not want to execute the query synchronously.

Can we defer query execution until asyncEnumerator.MoveNextAsync() is called?

Thoughts?

return new AsyncCursorEnumerator<TDocument>(cursor, cancellationToken);
}

// public methods
public IEnumerator<TDocument> GetEnumerator()
{
Expand Down
12 changes: 12 additions & 0 deletions src/MongoDB.Driver/Linq/MongoQueryable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3385,6 +3385,18 @@ public static IQueryable<TSource> Take<TSource>(this IQueryable<TSource> source,
Expression.Constant(count)));
}

/// <summary>
/// Returns an <see cref="IAsyncEnumerable{T}" /> which can be enumerated asynchronously.
/// </summary>
/// <typeparam name="TSource">The type of the elements of <paramref name="source" />.</typeparam>
/// <param name="source">A sequence of values.</param>
/// <returns>An IAsyncEnumerable for the query results.</returns>
public static IAsyncEnumerable<TSource> ToAsyncEnumerable<TSource>(this IQueryable<TSource> source)
{
var cursorSource = GetCursorSource(source);
return cursorSource.ToAsyncEnumerable();
}

/// <summary>
/// Executes the LINQ query and returns a cursor to the results.
/// </summary>
Expand Down
51 changes: 51 additions & 0 deletions tests/MongoDB.Driver.Tests/Core/IAsyncCursorExtensionsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using FluentAssertions;
using MongoDB.Bson;
using MongoDB.Bson.Serialization.Serializers;
Expand Down Expand Up @@ -201,6 +203,55 @@ public void SingleOrDefault_should_throw_when_cursor_has_wrong_number_of_documen
action.ShouldThrow<InvalidOperationException>();
}

[Fact]
public void ToAsyncEnumerable_result_should_only_be_enumerable_one_time()
{
var cursor = CreateCursor(2);
var enumerable = cursor.ToAsyncEnumerable();
enumerable.GetAsyncEnumerator();

Action action = () => enumerable.GetAsyncEnumerator();

action.ShouldThrow<InvalidOperationException>();
}

[Fact]
public async Task ToAsyncEnumerable_should_respect_cancellation_token()
{
var source = CreateCursor(5);
using var cts = new CancellationTokenSource();

var count = 0;
await Assert.ThrowsAsync<OperationCanceledException>(async () =>
{
await foreach (var doc in source.ToAsyncEnumerable().WithCancellation(cts.Token))
{
count++;
if (count == 2)
cts.Cancel();
}
});
}

[Fact]
public async Task ToAsyncEnumerable_should_return_expected_result()
{
var cursor = CreateCursor(2);
var expectedDocuments = new[]
{
new BsonDocument("_id", 0),
new BsonDocument("_id", 1)
};

var result = new List<BsonDocument>();
await foreach (var doc in cursor.ToAsyncEnumerable())
{
result.Add(doc);
}

result.Should().Equal(expectedDocuments);
}

[Fact]
public void ToEnumerable_result_should_only_be_enumerable_one_time()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,31 @@ public void SingleOrDefault_should_throw_when_cursor_has_wrong_number_of_documen
action.ShouldThrow<InvalidOperationException>();
}

[Theory]
[ParameterAttributeData]
public async Task ToAsyncEnumerable_result_should_be_enumerable_multiple_times(
[Values(1, 2)] int times)
{
var source = CreateCursorSource(2);
var expectedDocuments = new[]
{
new BsonDocument("_id", 0),
new BsonDocument("_id", 1)
};

var result = new List<BsonDocument>();
for (var i = 0; i < times; i++)
{
await foreach (var doc in source.ToAsyncEnumerable())
{
result.Add(doc);
}

result.Should().Equal(expectedDocuments);
result.Clear();
}
}

[Theory]
[ParameterAttributeData]
public void ToEnumerable_result_should_be_enumerable_multiple_times(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
using System.Threading.Tasks;
using FluentAssertions;
using MongoDB.Bson;
using MongoDB.Driver;
using MongoDB.Driver.Core.Clusters;
using MongoDB.Driver.Core.Misc;
using MongoDB.Driver.Core.TestHelpers.XunitExtensions;
Expand Down Expand Up @@ -78,6 +77,21 @@ public async Task AnyAsync_with_predicate()
result.Should().BeTrue();
}

[Fact]
public async Task ToAsyncEnumerable()
{
var query = CreateQuery().Select(x => x.A);
var expectedResults = query.ToList();

var asyncResults = new List<string>();
await foreach (var item in query.ToAsyncEnumerable().ConfigureAwait(false))
{
asyncResults.Add(item);
}

asyncResults.Should().Equal(expectedResults);
}

[Fact]
public void Average()
{
Expand Down