Skip to content

Commit

Permalink
Merge pull request #42 from nortti/token-mem-leak
Browse files Browse the repository at this point in the history
Address memory leak when using 'Token' authentication (#41)
  • Loading branch information
EkiH authored Sep 14, 2021
2 parents 0a84fbe + b45a25c commit fbc80f1
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 86 deletions.
58 changes: 58 additions & 0 deletions Frends.Web.Tests/UnitTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,33 @@ public async Task RequestShouldAddOAuthBearerHeader(
_mockHttpMessageHandler.VerifyNoOutstandingExpectation();
}

[Theory]
[MemberData(nameof(TestDataSource.TestCases), MemberType = typeof(TestDataSource))]
public async Task AuthorizationHeaderShouldOverrideOption(
Func<Input, Options, CancellationToken, Task<object>> requestFunc)
{
const string expectedReturn = @"'FooBar'";

var input = new Input
{
Method = Method.GET,
Url = "http://localhost:9191/endpoint",
Headers = new[] { new Header() { Name = "Authorization", Value = "Basic fooToken" } },
Message = ""
};
var options = new Options
{
ConnectionTimeoutSeconds = 60,
Authentication = Authentication.OAuth,
Token = "barToken"
};

_mockHttpMessageHandler.Expect($"{BasePath}/endpoint").WithHeaders("Authorization", "Basic fooToken")
.Respond("application/json", expectedReturn);
await requestFunc(input, options, CancellationToken.None);
_mockHttpMessageHandler.VerifyNoOutstandingExpectation();
}

[Fact]
public async Task RequestShouldAddClientCertificate()
{
Expand Down Expand Up @@ -502,5 +529,36 @@ public async Task HttpSendAndReceiveBytesShouldBeAbleToReturnBinary()

Assert.Equal(actualFileBytes, result.BodyBytes);
}

[Fact]
public async Task ShouldUseSameClientEvenIfOAuthTokenChanges()
{
var requestBytes = Encoding.UTF8.GetBytes("some request data");

var input = new ByteInput
{
Method = SendMethod.POST,
Url = "http://localhost:9191/endpoint",
Headers = new[] {
new Header { Name = "Content-Type", Value = "text/plain; charset=utf-8" },
new Header { Name ="Content-Length", Value = requestBytes.Length.ToString() }
},
ContentBytes = requestBytes
};
var options = new Options { ConnectionTimeoutSeconds = 60, Authentication = Authentication.OAuth };

_mockHttpMessageHandler.When(input.Url)
.Respond("application/octet-stream", string.Empty);

for (var i = 0; i < 2; i++)
{
options.Token = i + "";
var result = (HttpByteResponse)await Web.HttpSendAndReceiveBytes(input, options, CancellationToken.None);
Assert.Equal(0, result.BodySizeInMegaBytes);
Assert.Empty(result.BodyBytes);
}

Assert.Equal(1, Web.ClientCache.GetCount());
}
}
}
1 change: 1 addition & 0 deletions Frends.Web/Frends.Web.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
<ItemGroup>
<PackageReference Include="Newtonsoft.Json" Version="12.0.1" />
<PackageReference Include="System.ComponentModel.Annotations" Version="4.3.0" />
<PackageReference Include="System.Runtime.Caching" Version="5.0.0" />
</ItemGroup>
<ItemGroup>
<None Include="FrendsTaskMetadata.json" Pack="true" PackagePath="/">
Expand Down
141 changes: 55 additions & 86 deletions Frends.Web/Web.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using System.Threading;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Runtime.Caching;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Text.RegularExpressions;
Expand Down Expand Up @@ -101,7 +102,7 @@ public class ByteInput
public Header[] Headers { get; set; }
}

public class Options : IEquatable<Options>
public class Options
{
/// <summary>
/// Method of authenticating request
Expand Down Expand Up @@ -197,60 +198,6 @@ public class Options : IEquatable<Options>
/// </summary>
[DefaultValue(true)]
public bool AutomaticCookieHandling { get; set; } = true;



public bool Equals(Options other)
{
if (ReferenceEquals(null, other)) return false;
if (ReferenceEquals(this, other)) return true;
return Authentication == other.Authentication &&
string.Equals(Username, other.Username) &&
string.Equals(Password, other.Password) &&
string.Equals(Token, other.Token) &&
string.Equals(CertificateThumbprint, other.CertificateThumbprint) &&
ClientCertificateSource == other.ClientCertificateSource &&
ClientCertificateInBase64 == other.ClientCertificateInBase64 &&
ClientCertificateFilePath == other.ClientCertificateFilePath &&
LoadEntireChainForCertificate == other.LoadEntireChainForCertificate &&
ConnectionTimeoutSeconds == other.ConnectionTimeoutSeconds &&
FollowRedirects == other.FollowRedirects &&
AllowInvalidCertificate == other.AllowInvalidCertificate &&
AllowInvalidResponseContentTypeCharSet == other.AllowInvalidResponseContentTypeCharSet &&
ThrowExceptionOnErrorResponse == other.ThrowExceptionOnErrorResponse &&
AutomaticCookieHandling == other.AutomaticCookieHandling;
}

public override bool Equals(object obj)
{
if (ReferenceEquals(null, obj)) return false;
if (ReferenceEquals(this, obj)) return true;
if (obj.GetType() != this.GetType()) return false;
return Equals((Options)obj);
}

public override int GetHashCode()
{
unchecked
{
var hashCode = (int)Authentication;
hashCode = (hashCode * 397) ^ (Username != null ? Username.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (Password != null ? Password.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (Token != null ? Token.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ ClientCertificateSource.GetHashCode();
hashCode = (hashCode * 397) ^ (CertificateThumbprint != null ? CertificateThumbprint.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (ClientCertificateInBase64 != null ? ClientCertificateInBase64.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ (ClientCertificateFilePath != null ? ClientCertificateFilePath.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ LoadEntireChainForCertificate.GetHashCode();
hashCode = (hashCode * 397) ^ ConnectionTimeoutSeconds;
hashCode = (hashCode * 397) ^ FollowRedirects.GetHashCode();
hashCode = (hashCode * 397) ^ AllowInvalidCertificate.GetHashCode();
hashCode = (hashCode * 397) ^ AllowInvalidResponseContentTypeCharSet.GetHashCode();
hashCode = (hashCode * 397) ^ ThrowExceptionOnErrorResponse.GetHashCode();
hashCode = (hashCode * 397) ^ AutomaticCookieHandling.GetHashCode();
return hashCode;
}
}
}

public enum CertificateSource
Expand Down Expand Up @@ -300,11 +247,19 @@ public HttpClient CreateClient(Options options)

public class Web
{
private static readonly ConcurrentDictionary<Options, HttpClient> ClientCache = new ConcurrentDictionary<Options, HttpClient>();
// For tests
public static readonly ObjectCache ClientCache = MemoryCache.Default;

private static readonly CacheItemPolicy _cachePolicy = new CacheItemPolicy() { SlidingExpiration = TimeSpan.FromHours(1) };


public static void ClearClientCache()
{
ClientCache.Clear();
var cacheKeys = ClientCache.Select(kvp => kvp.Key).ToList();
foreach (var cacheKey in cacheKeys)
{
ClientCache.Remove(cacheKey);
}
}
// For tests
public static IHttpClientFactory ClientFactory = new HttpClientFactory();
Expand All @@ -318,7 +273,7 @@ public static void ClearClientCache()
public static async Task<object> RestRequest([PropertyTab] Input input, [PropertyTab] Options options, CancellationToken cancellationToken)
{
var httpClient = GetHttpClientForOptions(options);
var headers = GetHeaderDictionary(input.Headers);
var headers = GetHeaderDictionary(input.Headers, options);
using (var content = GetContent(input, headers))
{
using (var responseMessage = await GetHttpRequestResponseAsync(
Expand Down Expand Up @@ -356,16 +311,29 @@ public static async Task<object> RestRequest([PropertyTab] Input input, [Propert

private static HttpClient GetHttpClientForOptions(Options options)
{
var cacheKey = GetHttpClientCacheKey(options);

return ClientCache.GetOrAdd(options, (opts) =>
if (ClientCache.Get(cacheKey) is HttpClient httpClient)
{
// might get called more than once if e.g. many process instances execute at once,
// but that should not matter much, as only one client will get cached
var httpClient = ClientFactory.CreateClient(options);
httpClient.SetDefaultRequestHeadersBasedOnOptions(opts);
return httpClient;
});
}

httpClient = ClientFactory.CreateClient(options);
httpClient.SetDefaultRequestHeadersBasedOnOptions(options);

ClientCache.Add(cacheKey, httpClient, _cachePolicy);

return httpClient;
}

private static string GetHttpClientCacheKey(Options options)
{
// Includes everything except for options.Token, which is used on request level, not http client level
return $"{options.Authentication}:{options.Username}:{options.Password}:{options.ClientCertificateSource}"
+ $":{options.ClientCertificateFilePath}:{options.ClientCertificateInBase64}:{options.ClientCertificateKeyPhrase}"
+ $":{options.CertificateThumbprint}:{options.LoadEntireChainForCertificate}:{options.ConnectionTimeoutSeconds}"
+ $":{options.FollowRedirects}:{options.AllowInvalidCertificate}:{options.AllowInvalidResponseContentTypeCharSet}"
+ $":{options.ThrowExceptionOnErrorResponse}:{options.AutomaticCookieHandling}";
}

/// <summary>
Expand All @@ -377,7 +345,7 @@ private static HttpClient GetHttpClientForOptions(Options options)
public static async Task<object> HttpRequest([PropertyTab] Input input, [PropertyTab] Options options, CancellationToken cancellationToken)
{
var httpClient = GetHttpClientForOptions(options);
var headers = GetHeaderDictionary(input.Headers);
var headers = GetHeaderDictionary(input.Headers, options);

using (var content = GetContent(input, headers))
{
Expand Down Expand Up @@ -422,7 +390,7 @@ public static async Task<object> HttpRequest([PropertyTab] Input input, [Propert
public static async Task<object> HttpRequestBytes([PropertyTab]Input input, [PropertyTab] Options options, CancellationToken cancellationToken)
{
var httpClient = GetHttpClientForOptions(options);
var headers = GetHeaderDictionary(input.Headers);
var headers = GetHeaderDictionary(input.Headers, options);

using (var content = GetContent(input, headers))
{
Expand Down Expand Up @@ -467,7 +435,7 @@ public static async Task<object> HttpRequestBytes([PropertyTab]Input input, [Pro
public static async Task<object> HttpSendBytes([PropertyTab]ByteInput input, [PropertyTab] Options options, CancellationToken cancellationToken)
{
var httpClient = GetHttpClientForOptions(options);
var headers = GetHeaderDictionary(input.Headers);
var headers = GetHeaderDictionary(input.Headers, options);

using (var content = GetContent(input))
{
Expand Down Expand Up @@ -511,7 +479,7 @@ public static async Task<object> HttpSendBytes([PropertyTab]ByteInput input, [Pr
public static async Task<object> HttpSendAndReceiveBytes([PropertyTab]ByteInput input, [PropertyTab] Options options, CancellationToken cancellationToken)
{
var httpClient = GetHttpClientForOptions(options);
var headers = GetHeaderDictionary(input.Headers);
var headers = GetHeaderDictionary(input.Headers, options);

using (var content = GetContent(input))
{
Expand Down Expand Up @@ -556,8 +524,25 @@ private static Dictionary<string, string> GetResponseHeaderDictionary(HttpRespon
return allHeaders;
}

private static IDictionary<string, string> GetHeaderDictionary(IEnumerable<Header> headers)
private static IDictionary<string, string> GetHeaderDictionary(Header[] headers, Options options)
{
if (!headers.Any(header => header.Name.ToLower().Equals("authorization")))
{

var authHeader = new Header { Name = "Authorization" };
switch (options.Authentication)
{
case Authentication.Basic:
authHeader.Value = $"Basic {Convert.ToBase64String(Encoding.ASCII.GetBytes($"{options.Username}:{options.Password}"))}";
headers = headers.Concat(new[] { authHeader }).ToArray();
break;
case Authentication.OAuth:
authHeader.Value = $"Bearer {options.Token}";
headers = headers.Concat(new[] { authHeader }).ToArray();
break;
}
}

//Ignore case for headers and key comparison
return headers.ToDictionary(key => key.Name, value => value.Value, StringComparer.InvariantCultureIgnoreCase);
}
Expand Down Expand Up @@ -692,22 +677,6 @@ internal static void SetHandlerSettingsBasedOnOptions(this HttpClientHandler han

internal static void SetDefaultRequestHeadersBasedOnOptions(this HttpClient httpClient, Options options)
{
if (options.Authentication == Authentication.Basic || options.Authentication == Authentication.OAuth)
{
switch (options.Authentication)
{
case Authentication.Basic:
httpClient.DefaultRequestHeaders.Authorization =
new AuthenticationHeaderValue(
"Basic",
Convert.ToBase64String(Encoding.ASCII.GetBytes($"{options.Username}:{options.Password}")));
break;
case Authentication.OAuth:
httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", options.Token);
break;
}
}

//Do not automatically set expect 100-continue response header
httpClient.DefaultRequestHeaders.ExpectContinue = false;
httpClient.DefaultRequestHeaders.TryAddWithoutValidation("content-type", "application/json");
Expand Down

0 comments on commit fbc80f1

Please sign in to comment.