diff --git a/src/UserCredentials.cs b/src/UserCredentials.cs index f33b416..29c9606 100644 --- a/src/UserCredentials.cs +++ b/src/UserCredentials.cs @@ -1,5 +1,6 @@ using System; using System.ComponentModel; +using System.Linq; using System.Runtime.InteropServices; using System.Security; using System.Security.Principal; @@ -190,8 +191,23 @@ private static void ValidateDomainAndUser(string domain, string username) if (domain.IndexOfAny(new[] { '\\', '@' }) != -1) throw new ArgumentException("Domain cannot contain \\ or @ characters.", nameof(domain)); - if (username.IndexOfAny(new[] { '\\', '@' }) != -1) - throw new ArgumentException("Username cannot contain \\ or @ characters when domain is provided separately.", nameof(username)); + if (string.Equals(domain, "AzureAD", StringComparison.OrdinalIgnoreCase)) + { + if (username.IndexOf('\\') != -1) + throw new ArgumentException("Username cannot contain \\ when the domain is AzureAD.", nameof(username)); + + int i = username.IndexOf('@'); + if (i == -1) + throw new ArgumentException("Username must contain @ when the domain is AzureAD.", nameof(username)); + + if (username.IndexOf('@', i+1) != -1) + throw new ArgumentException("Username cannot contain more than one @ when the domain is AzureAD.", nameof(username)); + } + else + { + if (username.IndexOfAny(new[] { '\\', '@' }) != -1) + throw new ArgumentException("Username cannot contain \\ or @ characters when domain is provided separately.", nameof(username)); + } } private static void ValidateUserWithoutDomain(string username) @@ -202,10 +218,27 @@ private static void ValidateUserWithoutDomain(string username) if (username.Trim() == string.Empty) throw new ArgumentException("Username cannot be empty or consist solely of whitespace characters.", nameof(username)); + char[] validSeparators; + if (username.StartsWith(@"AzureAD\", StringComparison.OrdinalIgnoreCase)) + { + username = username.Substring(8); + if (username.IndexOf('@') == -1) + throw new ArgumentException("Username must contain @ when the domain is AzureAD.", nameof(username)); + + if (username.IndexOf('\\') != -1) + throw new ArgumentException("Username cannot contain another \\ when the domain is AzureAD.", nameof(username)); + + validSeparators = new[] { '@' }; + } + else + { + validSeparators = new[] { '@', '\\' }; + } + int separatorCount = 0; foreach (var c in username) { - if (c == '\\' || c == '@') + if (validSeparators.Contains(c)) separatorCount++; } @@ -213,7 +246,7 @@ private static void ValidateUserWithoutDomain(string username) return; if (separatorCount > 1) - throw new ArgumentException("Username cannot contain more than one \\ or @ character.", nameof(username)); + throw new ArgumentException("Username cannot contain more than one separator.", nameof(username)); var firstChar = username[0]; var lastChar = username[username.Length - 1]; @@ -258,7 +291,7 @@ private static void SplitDomainFromUsername(ref string username, out string doma /// public override string ToString() { - return _domain == null ? _username : _username + "@" + _domain; + return _domain == null ? _username : _username.IndexOf('@') != -1 ? $@"{_domain}\{_username}" : $"{_username}@{_domain}"; } } } diff --git a/test/UserCredentialsTests.cs b/test/UserCredentialsTests.cs index 1d3afc3..fa30cf1 100644 --- a/test/UserCredentialsTests.cs +++ b/test/UserCredentialsTests.cs @@ -266,6 +266,56 @@ public void UserCredentials_LocalService_Valid() Assert.Equal("LOCAL SERVICE@NT AUTHORITY", UserCredentials.LocalService.ToString(), ignoreCase: true); } + [Fact] + public void UserCredentials_Constructor_Valid_AzureAD_1() + { + var creds = new UserCredentials(@"AzureAD\user@domain", "password"); + Assert.Equal(@"AzureAD\user@domain", creds.ToString()); + } + + [Fact] + public void UserCredentials_Constructor_Valid_AzureAD_2() + { + var creds = new UserCredentials("AzureAD", "user@domain", "password"); + Assert.Equal(@"AzureAD\user@domain", creds.ToString()); + } + + [Fact] + public void UserCredentials_Constructor_Invalid_AzureAD_1() + { + Assert.Throws(() => + { + var _ = new UserCredentials(@"AzureAD\user", "password"); + }); + } + + [Fact] + public void UserCredentials_Constructor_Invalid_AzureAD_2() + { + Assert.Throws(() => + { + var _ = new UserCredentials("AzureAD", "user", "password"); + }); + } + + [Fact] + public void UserCredentials_Constructor_Invalid_AzureAD_3() + { + Assert.Throws(() => + { + var _ = new UserCredentials("AzureAD", @"domain\user", "password"); + }); + } + + [Fact] + public void UserCredentials_Constructor_Invalid_AzureAD_4() + { + Assert.Throws(() => + { + var _ = new UserCredentials("AzureAD", "user@foo@bar", "password"); + }); + } + private static SecureString CreateSecureStringPasswordForTesting() { // Note: This is obviously not really a secure password. We just need something to test the API with.