Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/Core/Resolvers/SqlMutationEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,12 @@ await queryExecutor.ExecuteQueryAsync(
case EntityActionOperation.Insert:

HttpContext httpContext = GetHttpContext();
// Use scheme/host from X-Forwarded-* headers if present, else fallback to request values
string scheme = SqlPaginationUtil.ResolveRequestScheme(httpContext.Request);
string host = SqlPaginationUtil.ResolveRequestHost(httpContext.Request);
string locationHeaderURL = UriHelper.BuildAbsolute(
scheme: httpContext.Request.Scheme,
host: httpContext.Request.Host,
scheme: scheme,
host: new HostString(host),
pathBase: GetBaseRouteFromConfig(_runtimeConfigProvider.GetConfig()),
path: httpContext.Request.Path);

Expand Down
10 changes: 5 additions & 5 deletions src/Core/Resolvers/SqlPaginationUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -751,12 +751,12 @@ public static string FormatQueryString(NameValueCollection? queryStringParameter
}

/// <summary>
/// Extracts and request scheme from "X-Forwarded-Proto" or falls back to the request scheme.
/// Extracts the request scheme from "X-Forwarded-Proto" or falls back to the request scheme.
/// Invalid forwarded values are ignored.
/// </summary>
/// <param name="req">The HTTP request.</param>
/// <returns>The scheme string ("http" or "https").</returns>
/// <exception cref="DataApiBuilderException">Thrown when client explicitly sets an invalid scheme.</exception>
private static string ResolveRequestScheme(HttpRequest req)
internal static string ResolveRequestScheme(HttpRequest req)
{
string? rawScheme = req.Headers["X-Forwarded-Proto"].FirstOrDefault();
string? normalized = rawScheme?.Trim().ToLowerInvariant();
Expand All @@ -776,11 +776,11 @@ private static string ResolveRequestScheme(HttpRequest req)

/// <summary>
/// Extracts the request host from "X-Forwarded-Host" or falls back to the request host.
/// Invalid forwarded values are ignored.
/// </summary>
/// <param name="req">The HTTP request.</param>
/// <returns>The host string.</returns>
/// <exception cref="DataApiBuilderException">Thrown when client explicitly sets an invalid host.</exception>
private static string ResolveRequestHost(HttpRequest req)
internal static string ResolveRequestHost(HttpRequest req)
{
string? rawHost = req.Headers["X-Forwarded-Host"].FirstOrDefault();
string? trimmed = rawHost?.Trim();
Expand Down
7 changes: 5 additions & 2 deletions src/Core/Resolvers/SqlResponseHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,12 @@ HttpContext httpContext
// The third part is the computed primary key route.
if (operationType is EntityActionOperation.Insert && !string.IsNullOrEmpty(primaryKeyRoute))
{
// Use scheme/host from X-Forwarded-* headers if present, else fallback to request values
string scheme = SqlPaginationUtil.ResolveRequestScheme(httpContext.Request);
string host = SqlPaginationUtil.ResolveRequestHost(httpContext.Request);
locationHeaderURL = UriHelper.BuildAbsolute(
scheme: httpContext.Request.Scheme,
host: httpContext.Request.Host,
scheme: scheme,
host: new HostString(host),
pathBase: baseRoute,
path: httpContext.Request.Path);

Expand Down
130 changes: 130 additions & 0 deletions src/Service.Tests/Configuration/ConfigurationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@
using Azure.DataApiBuilder.Service.Tests.SqlTests;
using HotChocolate;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Hosting.Server.Features;
using Microsoft.AspNetCore.TestHost;
using Microsoft.Data.SqlClient;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.IdentityModel.Tokens;
Expand Down Expand Up @@ -3448,6 +3451,133 @@ public async Task ValidateLocationHeaderWhenBaseRouteIsConfigured(
}
}

/// <summary>
/// Validates that the Location header returned for POST requests respects X-Forwarded-Host and X-Forwarded-Proto.
/// This covers both table and stored procedure insert endpoints.
/// </summary>
/// <param name="entityType">Type of entity under test.</param>
/// <param name="requestPath">REST endpoint path for POST request.</param>
/// <param name="forwardedHost">Value for X-Forwarded-Host header.</param>
/// <param name="forwardedProto">Value for X-Forwarded-Proto header.</param>
/// <param name="expectedScheme">Expected scheme in Location header.</param>
[DataTestMethod]
[TestCategory(TestCategory.MSSQL)]
[DataRow(EntitySourceType.Table, "/api/Book", null, null, "http", DisplayName = "Location header uses local http scheme when no forwarded headers are present for table POST")]
[DataRow(EntitySourceType.StoredProcedure, "/api/GetBooks", null, null, "http", DisplayName = "Location header uses local http scheme when no forwarded headers are present for stored procedure POST")]
[DataRow(EntitySourceType.Table, "/api/Book", "api.contoso.com", "http", "http", DisplayName = "Location header uses forwarded http scheme for table POST")]
[DataRow(EntitySourceType.StoredProcedure, "/api/GetBooks", "api.contoso.com", "http", "http", DisplayName = "Location header uses forwarded http scheme for stored procedure POST")]
[DataRow(EntitySourceType.Table, "/api/Book", "api.contoso.com", "https", "https", DisplayName = "Location header uses forwarded https scheme/host for table POST")]
[DataRow(EntitySourceType.StoredProcedure, "/api/GetBooks", "api.contoso.com", "https", "https", DisplayName = "Location header uses forwarded https scheme/host for stored procedure POST")]
public async Task ValidateLocationHeaderRespectsXForwardedHostAndProto(
EntitySourceType entityType,
string requestPath,
string forwardedHost,
string forwardedProto,
string expectedScheme)
{
TestHelper.SetupDatabaseEnvironment(MSSQL_ENVIRONMENT);

GraphQLRuntimeOptions graphqlOptions = new(Enabled: false);
RestRuntimeOptions restRuntimeOptions = new(Enabled: true);
McpRuntimeOptions mcpRuntimeOptions = new(Enabled: false);

SqlConnectionStringBuilder connectionStringBuilder = new(GetConnectionStringFromEnvironmentConfig(environment: TestCategory.MSSQL))
{
TrustServerCertificate = true
};

DataSource dataSource = new(DatabaseType.MSSQL,
connectionStringBuilder.ConnectionString, Options: null);

RuntimeConfig configuration;
if (entityType is EntitySourceType.StoredProcedure)
{
Entity entity = new(Source: new("get_books", EntitySourceType.StoredProcedure, null, null),
Fields: null,
Rest: new(new SupportedHttpVerb[] { SupportedHttpVerb.Get, SupportedHttpVerb.Post }),
GraphQL: null,
Permissions: new[] { GetMinimalPermissionConfig(AuthorizationResolver.ROLE_ANONYMOUS) },
Relationships: null,
Mappings: null
);

configuration = InitMinimalRuntimeConfig(dataSource, graphqlOptions, restRuntimeOptions, mcpRuntimeOptions, entity, entityName: "GetBooks");
}
else
{
configuration = InitMinimalRuntimeConfig(dataSource, graphqlOptions, restRuntimeOptions, mcpRuntimeOptions);
}

const string CUSTOM_CONFIG = "custom-config.json";
File.WriteAllText(CUSTOM_CONFIG, configuration.ToJson());
string[] args = new[] { $"--ConfigFileName={CUSTOM_CONFIG}" };

// Intentionally bind HTTP to simulate the proxy-to-app internal hop.
using IWebHost host = Program.CreateWebHostBuilder(args)
.UseUrls("http://127.0.0.1:0")
.Build();
await host.StartAsync();

IServerAddressesFeature addresses = host.ServerFeatures.Get<IServerAddressesFeature>();
Assert.IsNotNull(addresses);

string baseAddress = addresses.Addresses.FirstOrDefault();
Assert.IsFalse(string.IsNullOrEmpty(baseAddress));

using HttpClient client = new()
{
BaseAddress = new Uri(baseAddress)
};

HttpRequestMessage request = new(HttpMethod.Post, requestPath);
if (!string.IsNullOrEmpty(forwardedHost))
{
request.Headers.Add("X-Forwarded-Host", forwardedHost);
}

if (!string.IsNullOrEmpty(forwardedProto))
{
request.Headers.Add("X-Forwarded-Proto", forwardedProto);
}

if (entityType is EntitySourceType.Table)
{
JsonElement requestBodyElement = JsonDocument.Parse(@"{
""title"": ""Forwarded Header Location Test"",
""publisher_id"": 1234
}").RootElement.Clone();

request.Content = JsonContent.Create(requestBodyElement);
}

HttpResponseMessage response = await client.SendAsync(request);

Assert.AreEqual(HttpStatusCode.Created, response.StatusCode);
Assert.IsNotNull(response.Headers.Location, "Location header should be present for successful POST create.");

Uri location = response.Headers.Location;
Assert.AreEqual(expectedScheme, location.Scheme, $"Expected Location scheme '{expectedScheme}', got '{location.Scheme}'.");

if (!string.IsNullOrEmpty(forwardedHost))
{
Assert.AreEqual(forwardedHost, location.Host, $"Expected Location host '{forwardedHost}', got '{location.Host}'.");
}

// Since forwarded host is external, validate follow-up using local path only.
string localPathAndQuery = string.IsNullOrEmpty(location.Query) ? location.AbsolutePath : location.AbsolutePath + location.Query;
HttpRequestMessage followUpRequest = new(HttpMethod.Get, localPathAndQuery);
HttpResponseMessage followUpResponse = await client.SendAsync(followUpRequest);
Assert.AreEqual(HttpStatusCode.OK, followUpResponse.StatusCode);

if (entityType is EntitySourceType.Table)
{
HttpRequestMessage cleanupRequest = new(HttpMethod.Delete, localPathAndQuery);
await client.SendAsync(cleanupRequest);
}

await host.StopAsync();
}

/// <summary>
/// Test to validate that when the property rest.request-body-strict is absent from the rest runtime section in config file, DAB runs in strict mode.
/// In strict mode, presence of extra fields in the request body is not permitted and leads to HTTP 400 - BadRequest error.
Expand Down
Loading