Skip to content

Commit

Permalink
Forward OPTIONS request
Browse files Browse the repository at this point in the history
  • Loading branch information
guoye-zhang committed Nov 3, 2024
1 parent 1de2667 commit 33e1bc1
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 18 deletions.
19 changes: 16 additions & 3 deletions Sources/NIOResumableUpload/HTTPResumableUpload.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ final class HTTPResumableUpload {
private var requestIsCreation: Bool = false
/// The end of the current request is the end of the upload.
private var requestIsComplete: Bool = true
/// Whether the request is OPTIONS
private var requestIsOptions: Bool = false
/// The interop version of the current request
private var interopVersion: HTTPResumableUploadProtocol.InteropVersion = .latest
/// Whether you have received the entire upload.
Expand Down Expand Up @@ -307,8 +309,9 @@ extension HTTPResumableUpload {
self.respondAndDetach(response, handler: handler)
}
case .options:
let response = HTTPResumableUploadProtocol.optionsResponse(version: version)
self.respondAndDetach(response, handler: handler)
self.requestIsOptions = true
let channel = self.createChannel(handler: handler, parent: channel)
channel.receive(.head(request))
}
} catch {
let response = HTTPResumableUploadProtocol.badRequestResponse()
Expand Down Expand Up @@ -425,7 +428,17 @@ extension HTTPResumableUpload {
uploadHandler.write(part, promise: promise)
}
} else {
uploadHandler.write(part, promise: promise)
if self.requestIsOptions {
switch part {
case .head(let head):
let response = HTTPResumableUploadProtocol.processOptionsResponse(head)
uploadHandler.write(.head(response), promise: promise)
case .body, .end:
uploadHandler.write(part, promise: promise)
}
} else {
uploadHandler.write(part, promise: promise)
}
}
}

Expand Down
9 changes: 9 additions & 0 deletions Sources/NIOResumableUpload/HTTPResumableUploadProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,15 @@ enum HTTPResumableUploadProtocol {
finalResponse.headerFields.uploadOffset = offset
return finalResponse
}

static func processOptionsResponse(_ response: HTTPResponse) -> HTTPResponse {
var response = response
if response.status == .notImplemented {
response = HTTPResponse(status: .ok)
}
response.headerFields.uploadLimit = .init(minSize: 0)
return response
}
}

extension HTTPField.Name {
Expand Down
50 changes: 35 additions & 15 deletions Tests/NIOResumableUploadTests/NIOResumableUploadTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,33 @@ import NIOResumableUpload
import XCTest

/// A handler that keeps track of all reads made on a channel.
private final class InboundRecorder<Frame>: ChannelInboundHandler {
typealias InboundIn = Frame
private final class InboundRecorder<FrameIn, FrameOut>: ChannelDuplexHandler {
typealias InboundIn = FrameIn
typealias OutboundIn = Never
typealias OutboundOut = FrameOut

var receivedFrames: [Frame] = []
private var context: ChannelHandlerContext? = nil

var receivedFrames: [FrameIn] = []

func channelActive(context: ChannelHandlerContext) {
self.context = context
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
self.receivedFrames.append(self.unwrapInboundIn(data))
}

func write(_ frame: FrameOut) {
self.write(context: self.context!, data: self.wrapOutboundOut(frame), promise: nil)
self.flush(context: self.context!)
}
}

final class NIOResumableUploadTests: XCTestCase {
func testNonUpload() throws {
let channel = EmbeddedChannel()
let recorder = InboundRecorder<HTTPRequestPart>()
let recorder = InboundRecorder<HTTPRequestPart, Never>()

let context = HTTPResumableUploadContext(origin: "https://example.com")
try channel.pipeline.addHandler(HTTPResumableUploadHandler(context: context, handlers: [recorder])).wait()
Expand All @@ -50,7 +63,7 @@ final class NIOResumableUploadTests: XCTestCase {

func testNotResumableUpload() throws {
let channel = EmbeddedChannel()
let recorder = InboundRecorder<HTTPRequestPart>()
let recorder = InboundRecorder<HTTPRequestPart, Never>()

let context = HTTPResumableUploadContext(origin: "https://example.com")
try channel.pipeline.addHandler(HTTPResumableUploadHandler(context: context, handlers: [recorder])).wait()
Expand All @@ -69,7 +82,7 @@ final class NIOResumableUploadTests: XCTestCase {

func testOptions() throws {
let channel = EmbeddedChannel()
let recorder = InboundRecorder<HTTPRequestPart>()
let recorder = InboundRecorder<HTTPRequestPart, HTTPResponsePart>()

let context = HTTPResumableUploadContext(origin: "https://example.com")
try channel.pipeline.addHandler(HTTPResumableUploadHandler(context: context, handlers: [recorder])).wait()
Expand All @@ -79,6 +92,13 @@ final class NIOResumableUploadTests: XCTestCase {
try channel.writeInbound(HTTPRequestPart.head(request))
try channel.writeInbound(HTTPRequestPart.end(nil))

XCTAssertEqual(recorder.receivedFrames.count, 2)
XCTAssertEqual(recorder.receivedFrames[0], HTTPRequestPart.head(request))
XCTAssertEqual(recorder.receivedFrames[1], HTTPRequestPart.end(nil))

recorder.write(HTTPResponsePart.head(HTTPResponse(status: .notImplemented)))
recorder.write(HTTPResponsePart.end(nil))

let responsePart = try channel.readOutbound(as: HTTPResponsePart.self)
guard case .head(let response) = responsePart else {
XCTFail("Part is not response headers")
Expand All @@ -95,7 +115,7 @@ final class NIOResumableUploadTests: XCTestCase {

func testResumableUploadUninterruptedV3() throws {
let channel = EmbeddedChannel()
let recorder = InboundRecorder<HTTPRequestPart>()
let recorder = InboundRecorder<HTTPRequestPart, Never>()

let context = HTTPResumableUploadContext(origin: "https://example.com")
try channel.pipeline.addHandler(HTTPResumableUploadHandler(context: context, handlers: [recorder])).wait()
Expand Down Expand Up @@ -127,7 +147,7 @@ final class NIOResumableUploadTests: XCTestCase {

func testResumableUploadUninterruptedV5() throws {
let channel = EmbeddedChannel()
let recorder = InboundRecorder<HTTPRequestPart>()
let recorder = InboundRecorder<HTTPRequestPart, Never>()

let context = HTTPResumableUploadContext(origin: "https://example.com")
try channel.pipeline.addHandler(HTTPResumableUploadHandler(context: context, handlers: [recorder])).wait()
Expand Down Expand Up @@ -159,7 +179,7 @@ final class NIOResumableUploadTests: XCTestCase {

func testResumableUploadUninterruptedV6() throws {
let channel = EmbeddedChannel()
let recorder = InboundRecorder<HTTPRequestPart>()
let recorder = InboundRecorder<HTTPRequestPart, Never>()

let context = HTTPResumableUploadContext(origin: "https://example.com")
try channel.pipeline.addHandler(HTTPResumableUploadHandler(context: context, handlers: [recorder])).wait()
Expand Down Expand Up @@ -192,7 +212,7 @@ final class NIOResumableUploadTests: XCTestCase {

func testResumableUploadInterruptedV3() throws {
let channel = EmbeddedChannel()
let recorder = InboundRecorder<HTTPRequestPart>()
let recorder = InboundRecorder<HTTPRequestPart, Never>()

let context = HTTPResumableUploadContext(origin: "https://example.com")
try channel.pipeline.addHandler(HTTPResumableUploadHandler(context: context, handlers: [recorder])).wait()
Expand Down Expand Up @@ -254,7 +274,7 @@ final class NIOResumableUploadTests: XCTestCase {

func testResumableUploadInterruptedV5() throws {
let channel = EmbeddedChannel()
let recorder = InboundRecorder<HTTPRequestPart>()
let recorder = InboundRecorder<HTTPRequestPart, Never>()

let context = HTTPResumableUploadContext(origin: "https://example.com")
try channel.pipeline.addHandler(HTTPResumableUploadHandler(context: context, handlers: [recorder])).wait()
Expand Down Expand Up @@ -318,7 +338,7 @@ final class NIOResumableUploadTests: XCTestCase {

func testResumableUploadInterruptedV6() throws {
let channel = EmbeddedChannel()
let recorder = InboundRecorder<HTTPRequestPart>()
let recorder = InboundRecorder<HTTPRequestPart, Never>()

let context = HTTPResumableUploadContext(origin: "https://example.com")
try channel.pipeline.addHandler(HTTPResumableUploadHandler(context: context, handlers: [recorder])).wait()
Expand Down Expand Up @@ -383,7 +403,7 @@ final class NIOResumableUploadTests: XCTestCase {

func testResumableUploadChunkedV3() throws {
let channel = EmbeddedChannel()
let recorder = InboundRecorder<HTTPRequestPart>()
let recorder = InboundRecorder<HTTPRequestPart, Never>()

let context = HTTPResumableUploadContext(origin: "https://example.com")
try channel.pipeline.addHandler(HTTPResumableUploadHandler(context: context, handlers: [recorder])).wait()
Expand Down Expand Up @@ -453,7 +473,7 @@ final class NIOResumableUploadTests: XCTestCase {

func testResumableUploadChunkedV5() throws {
let channel = EmbeddedChannel()
let recorder = InboundRecorder<HTTPRequestPart>()
let recorder = InboundRecorder<HTTPRequestPart, Never>()

let context = HTTPResumableUploadContext(origin: "https://example.com")
try channel.pipeline.addHandler(HTTPResumableUploadHandler(context: context, handlers: [recorder])).wait()
Expand Down Expand Up @@ -525,7 +545,7 @@ final class NIOResumableUploadTests: XCTestCase {

func testResumableUploadChunkedV6() throws {
let channel = EmbeddedChannel()
let recorder = InboundRecorder<HTTPRequestPart>()
let recorder = InboundRecorder<HTTPRequestPart, Never>()

let context = HTTPResumableUploadContext(origin: "https://example.com")
try channel.pipeline.addHandler(HTTPResumableUploadHandler(context: context, handlers: [recorder])).wait()
Expand Down

0 comments on commit 33e1bc1

Please sign in to comment.