diff --git a/allocation/constant.go b/allocation/constant.go new file mode 100644 index 0000000..3fb91fd --- /dev/null +++ b/allocation/constant.go @@ -0,0 +1,29 @@ +package allocation + +import ( + "context" + "errors" + "time" + + "github.com/portfoliotree/portfolio/returns" +) + +const ConstantWeightsAlgorithmName = "Constant Weights" + +type ConstantWeights struct { + weights []float64 +} + +func (cw *ConstantWeights) Name() string { return ConstantWeightsAlgorithmName } + +func (cw *ConstantWeights) PolicyWeights(_ context.Context, _ time.Time, _ returns.Table, ws []float64) ([]float64, error) { + if len(cw.weights) != len(ws) { + return nil, errors.New("expected the number of policy weights to be the same as the number of assets") + } + copy(ws, cw.weights) + return ws, nil +} + +func (cw *ConstantWeights) SetWeights(in []float64) { + cw.weights = in +} diff --git a/allocation/equal.go b/allocation/equal.go new file mode 100644 index 0000000..5ddebbf --- /dev/null +++ b/allocation/equal.go @@ -0,0 +1,21 @@ +package allocation + +import ( + "context" + "time" + + "github.com/portfoliotree/portfolio/returns" +) + +const EqualWeightsAlgorithmName = "Equal Weights" + +type EqualWeights struct{} + +func (*EqualWeights) Name() string { return EqualWeightsAlgorithmName } + +func (*EqualWeights) PolicyWeights(_ context.Context, _ time.Time, _ returns.Table, ws []float64) ([]float64, error) { + for i := range ws { + ws[i] = 1.0 / float64(len(ws)) + } + return ws, nil +} diff --git a/allocation/functions.go b/allocation/functions.go new file mode 100644 index 0000000..110d0d4 --- /dev/null +++ b/allocation/functions.go @@ -0,0 +1,42 @@ +package allocation + +import ( + "golang.org/x/exp/slices" + + "github.com/portfoliotree/portfolio/backtest" +) + +type Algorithm interface { + backtest.PolicyWeightCalculator + Name() string +} + +func NewDefaultAlgorithmsList() []Algorithm { + return []Algorithm{ + new(ConstantWeights), + new(EqualWeights), + new(EqualInverseVariance), + new(EqualRiskContribution), + new(EqualVolatility), + new(EqualInverseVolatility), + } +} + +func AlgorithmNames(algorithmOptions []Algorithm) []string { + names := make([]string, 0, len(algorithmOptions)) + for _, alg := range algorithmOptions { + names = append(names, alg.Name()) + } + slices.Sort(names) + names = slices.Compact(names) + return names +} + +type WeightSetter interface { + SetWeights([]float64) +} + +func AlgorithmRequiresWeights(alg Algorithm) bool { + _, ok := alg.(WeightSetter) + return ok +} diff --git a/allocation/inverse_variance.go b/allocation/inverse_variance.go new file mode 100644 index 0000000..ed8c85c --- /dev/null +++ b/allocation/inverse_variance.go @@ -0,0 +1,44 @@ +package allocation + +import ( + "context" + "math" + "time" + + "github.com/portfoliotree/portfolio/returns" +) + +type EqualInverseVariance struct{} + +func (cw *EqualInverseVariance) Name() string { return "Equal Inverse Variance" } + +func (*EqualInverseVariance) PolicyWeights(_ context.Context, _ time.Time, assetReturns returns.Table, ws []float64) ([]float64, error) { + if isOnlyZeros(ws) { + for i := range ws { + ws[i] = 1.0 + } + scaleToUnitRange(ws) + } + + err := ensureEnoughReturns(assetReturns) + if err != nil { + return ws, err + } + + assetRisks := assetReturns.RisksFromStdDev() + for i := range assetRisks { + assetRisks[i] = 1.0 / math.Pow(assetRisks[i], 2) + } + + sumOfAssetRisks := 0.0 + for i := range assetRisks { + sumOfAssetRisks += assetRisks[i] + } + + newWeights := make([]float64, len(assetRisks)) + for i := range assetRisks { + newWeights[i] = assetRisks[i] / sumOfAssetRisks + } + + return newWeights, nil +} diff --git a/allocation/optimizer_internal.go b/allocation/optimizer_internal.go new file mode 100644 index 0000000..7c952ee --- /dev/null +++ b/allocation/optimizer_internal.go @@ -0,0 +1,95 @@ +package allocation + +import ( + "context" + "errors" + + "gonum.org/v1/gonum/optimize" + + "github.com/portfoliotree/portfolio/returns" +) + +const ( + maxTries = 50_000 + skipContextCheckCount = 500 + preCancelCheckTries = 10_000 +) + +func checkTries(ctx context.Context, try int) error { + switch { + case try > preCancelCheckTries && try%skipContextCheckCount == 0: + return ctx.Err() + case try > maxTries: + return errors.New("reached max tries to calculate policy") + default: + return nil + } +} + +func optWeights(ctx context.Context, weights []float64, fn func(ws []float64) float64) error { + var ( + try = 0 + m = &optimize.NelderMead{} + s = &optimize.Settings{ + Converger: &optimize.FunctionConverge{ + Absolute: 1e-10, + Relative: 1, + Iterations: 1000, + }, + } + ws = make([]float64, len(weights)) + p = optimize.Problem{ + Func: func(x []float64) float64 { + copy(ws, x) + scaleToUnitRange(ws) + return fn(ws) + }, + Status: func() (optimize.Status, error) { + err := checkTries(ctx, try) + if err != nil { + return optimize.RuntimeLimit, err + } + try++ + return optimize.NotTerminated, nil + }, + } + ) + optResult, err := optimize.Minimize(p, weights, s, m) + if err != nil { + return err + } + + copy(weights, optResult.X) + scaleToUnitRange(weights) + + return nil +} + +func ensureEnoughReturns(assetReturns returns.Table) error { + if assetReturns.NumberOfColumns() == 0 || assetReturns.NumberOfRows() < 2 { + return errors.New("not enough data") + } + return nil +} + +func isOnlyZeros(a []float64) bool { + for _, v := range a { + if v != 0 { + return false + } + } + return true +} + +func scaleToUnitRange(list []float64) { + sum := 0.0 + for _, v := range list { + sum += v + } + if sum == 0 { + return + } + for i := range list { + list[i] /= sum + } +} diff --git a/allocation/risk.go b/allocation/risk.go new file mode 100644 index 0000000..5df61fd --- /dev/null +++ b/allocation/risk.go @@ -0,0 +1,46 @@ +package allocation + +import ( + "context" + "math" + "time" + + "github.com/portfoliotree/portfolio/calculations" + "github.com/portfoliotree/portfolio/returns" +) + +type EqualRiskContribution struct{} + +func (*EqualRiskContribution) Name() string { return "Equal Risk Contribution" } + +func (*EqualRiskContribution) PolicyWeights(ctx context.Context, _ time.Time, assetReturns returns.Table, ws []float64) ([]float64, error) { + if isOnlyZeros(ws) { + for i := range ws { + ws[i] = 1.0 + } + scaleToUnitRange(ws) + } + + err := ensureEnoughReturns(assetReturns) + if err != nil { + return ws, err + } + + assetRisks := assetReturns.RisksFromStdDev() + + target := 1.0 / float64(len(assetRisks)) + + cm := assetReturns.CorrelationMatrix() + + weights := make([]float64, len(ws)) + copy(weights, ws) + + return weights, optWeights(ctx, weights, func(ws []float64) float64 { + _, _, riskWeights := calculations.RiskFromRiskContribution(assetRisks, ws, cm) + var diff float64 + for i := range riskWeights { + diff += math.Abs(target - riskWeights[i]) + } + return diff + }) +} diff --git a/allocation/volatility.go b/allocation/volatility.go new file mode 100644 index 0000000..3326d35 --- /dev/null +++ b/allocation/volatility.go @@ -0,0 +1,75 @@ +package allocation + +import ( + "context" + "time" + + "github.com/portfoliotree/portfolio/returns" +) + +type EqualVolatility struct{} + +func (*EqualVolatility) Name() string { return "Equal Volatility" } + +func (*EqualVolatility) PolicyWeights(_ context.Context, _ time.Time, assetReturns returns.Table, ws []float64) ([]float64, error) { + if isOnlyZeros(ws) { + for i := range ws { + ws[i] = 1.0 + } + scaleToUnitRange(ws) + } + + err := ensureEnoughReturns(assetReturns) + if err != nil { + return ws, err + } + + assetRisks := assetReturns.RisksFromStdDev() + + sumOfAssetRisks := 0.0 + for i := range assetRisks { + sumOfAssetRisks += assetRisks[i] + } + + newWeights := make([]float64, len(assetRisks)) + for i := range assetRisks { + newWeights[i] = assetRisks[i] / sumOfAssetRisks + } + + return newWeights, nil +} + +type EqualInverseVolatility struct{} + +func (*EqualInverseVolatility) Name() string { return "Equal Inverse Volatility" } + +func (*EqualInverseVolatility) PolicyWeights(_ context.Context, _ time.Time, assetReturns returns.Table, ws []float64) ([]float64, error) { + if isOnlyZeros(ws) { + for i := range ws { + ws[i] = 1.0 + } + scaleToUnitRange(ws) + } + + err := ensureEnoughReturns(assetReturns) + if err != nil { + return ws, err + } + + assetRisks := assetReturns.RisksFromStdDev() + for i := range assetRisks { + assetRisks[i] = 1.0 / assetRisks[i] + } + + sumOfAssetRisks := 0.0 + for i := range assetRisks { + sumOfAssetRisks += assetRisks[i] + } + + newWeights := make([]float64, len(assetRisks)) + for i := range assetRisks { + newWeights[i] = assetRisks[i] / sumOfAssetRisks + } + + return newWeights, nil +} diff --git a/api.go b/api.go index 21a6432..eb08b5b 100644 --- a/api.go +++ b/api.go @@ -11,6 +11,8 @@ import ( "os" "strings" + "go.mongodb.org/mongo-driver/bson/primitive" + "github.com/portfoliotree/portfolio/returns" ) @@ -57,6 +59,22 @@ func (pf *Specification) AssetReturns(ctx context.Context) (returns.Table, error return doJSONRequest[returns.Table](http.DefaultClient.Do, req) } +func ParseComponentsFromURL(values url.Values, prefix string) ([]Component, error) { + assetValues, ok := values[prefix+"-id"] + if !ok { + return nil, errors.New("use asset-id parameters to specify asset returns") + } + components := make([]Component, 0, len(assetValues)) + for _, v := range assetValues { + if _, err := primitive.ObjectIDFromHex(v); err == nil { + components = append(components, Component{Type: "Portfolio", ID: v}) + continue + } + components = append(components, Component{Type: "Security", ID: v}) + } + return components, nil +} + func doJSONRequest[T any](do func(r *http.Request) (*http.Response, error), req *http.Request) (T, error) { var result T req.Header.Set("accept", "application/json") diff --git a/api_test.go b/api_test.go index 873e1fa..42136dd 100644 --- a/api_test.go +++ b/api_test.go @@ -2,6 +2,7 @@ package portfolio_test import ( "context" + "os" "testing" "github.com/stretchr/testify/assert" @@ -9,6 +10,29 @@ import ( "github.com/portfoliotree/portfolio" ) +func Test_APIEndpoints(t *testing.T) { + if value, found := os.LookupEnv("CI"); !found || value != "true" { + t.Skip("Skipping test in CI environment") + } + + t.Run("returns", func(t *testing.T) { + pf := portfolio.Specification{ + Assets: []portfolio.Component{ + {ID: "AAPL"}, + {ID: "GOOG"}, + }, + } + table, err := pf.AssetReturns(context.Background()) + assert.NoError(t, err) + if table.NumberOfColumns() != 2 { + t.Errorf("Expected 2 columns, got %d", table.NumberOfColumns()) + } + if table.NumberOfRows() < 10 { + t.Errorf("Expected at least 10 rows, got %d", table.NumberOfRows()) + } + }) +} + func TestSpecification_AssetReturns(t *testing.T) { for _, tt := range []struct { Name string diff --git a/backtest/backtestconfig/weight_functions.go b/backtest/backtestconfig/weight_functions.go deleted file mode 100644 index b3e631d..0000000 --- a/backtest/backtestconfig/weight_functions.go +++ /dev/null @@ -1,34 +0,0 @@ -package backtestconfig - -import ( - "context" - "time" - - "github.com/portfoliotree/portfolio/returns" -) - -// PolicyWeightCalculatorFunc can be used to wrap a function and pass it into Run as a PolicyWeightCalculator -type PolicyWeightCalculatorFunc func(ctx context.Context, today time.Time, assets returns.Table, currentWeights []float64) ([]float64, error) - -func (p PolicyWeightCalculatorFunc) PolicyWeights(ctx context.Context, today time.Time, assets returns.Table, currentWeights []float64) ([]float64, error) { - return p(ctx, today, assets, currentWeights) -} - -type ConstantWeights []float64 - -func (targetWeights ConstantWeights) PolicyWeights(_ context.Context, _ time.Time, _ returns.Table, ws []float64) ([]float64, error) { - copy(ws, targetWeights) - return ws, nil -} - -type EqualWeights struct{} - -func (EqualWeights) PolicyWeights(_ context.Context, _ time.Time, _ returns.Table, ws []float64) ([]float64, error) { - for i := range ws { - ws[i] = 1.0 / float64(len(ws)) - } - return ws, nil -} - -// Additional weight functions are maintained in portfoliotree.com proprietary code. -// If you'd like to read the code, feel free to ask us at support@portfoliotree.com, we are willing to share pseudocode. diff --git a/backtest/backtestconfig/weight_functions_test.go b/backtest/backtestconfig/weight_functions_test.go deleted file mode 100644 index 3824fad..0000000 --- a/backtest/backtestconfig/weight_functions_test.go +++ /dev/null @@ -1,12 +0,0 @@ -package backtestconfig_test - -import ( - "github.com/portfoliotree/portfolio/backtest" - "github.com/portfoliotree/portfolio/backtest/backtestconfig" -) - -var ( - _ backtest.PolicyWeightCalculator = backtestconfig.ConstantWeights{} - _ backtest.PolicyWeightCalculator = backtestconfig.EqualWeights{} - _ backtest.PolicyWeightCalculator = backtestconfig.PolicyWeightCalculatorFunc(nil) -) diff --git a/backtest/backtestconfig/window.go b/backtest/backtestconfig/window.go index 8f18a4c..949c8d9 100644 --- a/backtest/backtestconfig/window.go +++ b/backtest/backtestconfig/window.go @@ -97,5 +97,8 @@ func (dur Window) Sub(t time.Time) time.Time { } func (dur Window) Function(today time.Time, table returns.Table) returns.Table { + if !dur.IsSet() { + return table.Between(today, table.FirstTime()) + } return table.Between(today, dur.Sub(today)) } diff --git a/backtest/backtestconfig/window_test.go b/backtest/backtestconfig/window_test.go index ac7c0ea..1ee02ab 100644 --- a/backtest/backtestconfig/window_test.go +++ b/backtest/backtestconfig/window_test.go @@ -2,13 +2,16 @@ package backtestconfig_test import ( "testing" + "time" "github.com/stretchr/testify/assert" "github.com/portfoliotree/portfolio/backtest/backtestconfig" + "github.com/portfoliotree/portfolio/internal/fixtures" + "github.com/portfoliotree/portfolio/returns" ) -func TestDurations_Validate(t *testing.T) { +func TestWindows_Validate(t *testing.T) { for _, d := range backtestconfig.Windows() { t.Run(d.String(), func(t *testing.T) { err := d.Validate() @@ -26,3 +29,20 @@ func TestDurations_Validate(t *testing.T) { assert.Error(t, err) }) } + +func TestWindow_Function(t *testing.T) { + t.Run("not set", func(t *testing.T) { + var zero backtestconfig.Window + + today := fixtures.T(t, fixtures.Day2) + table := returns.NewTable([]returns.List{{ + returns.New(fixtures.T(t, fixtures.Day3), .1), + returns.New(today, .1), + returns.New(fixtures.T(t, fixtures.Day1), .1), + returns.New(fixtures.T(t, fixtures.Day0), .1), + }}) + + result := zero.Function(today, table) + assert.Equal(t, result.FirstTime().Format(time.DateOnly), fixtures.Day0) + }) +} diff --git a/backtest/run_benchmark_test.go b/backtest/run_benchmark_test.go index b2a80da..ada24e4 100644 --- a/backtest/run_benchmark_test.go +++ b/backtest/run_benchmark_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/portfoliotree/portfolio" + "github.com/portfoliotree/portfolio/allocation" "github.com/portfoliotree/portfolio/backtest" "github.com/portfoliotree/portfolio/backtest/backtestconfig" "github.com/portfoliotree/portfolio/portfoliotest" @@ -39,14 +40,14 @@ func benchmarkRun(b *testing.B, table returns.Table) { b.Helper() end := table.LastTime() start := table.FirstTime() - fn := backtestconfig.EqualWeights{} + alg := new(allocation.EqualWeights) lookback := backtestconfig.OneQuarterWindow.Function rebalance := backtestconfig.Daily() updatePolicyWeights := backtestconfig.Monthly() ctx := context.Background() b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := backtest.Run(ctx, end, start, table, fn, lookback, rebalance, updatePolicyWeights) + _, err := backtest.Run(ctx, end, start, table, alg, lookback, rebalance, updatePolicyWeights) if err != nil { b.Fatal(err) } diff --git a/backtest/run_test.go b/backtest/run_test.go index 806e26e..583337c 100644 --- a/backtest/run_test.go +++ b/backtest/run_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/portfoliotree/portfolio/allocation" "github.com/portfoliotree/portfolio/backtest" "github.com/portfoliotree/portfolio/backtest/backtestconfig" "github.com/portfoliotree/portfolio/returns" @@ -32,7 +33,7 @@ func TestSpec_Run(t *testing.T) { {Time: date("2021-01-02"), Value: 0.2}, {Time: date("2021-01-01"), Value: 0.1}, }}) - alg := backtestconfig.EqualWeights{} + alg := testAlgorithm() windowFunc := backtestconfig.WindowNotSet.Function rebalanceIntervalFunc := backtestconfig.IntervalDaily.CheckFunction() policyUpdateIntervalFunc := backtestconfig.IntervalDaily.CheckFunction() @@ -62,7 +63,7 @@ func TestSpec_Run(t *testing.T) { policyUpdateIntervalFunc := backtestconfig.IntervalDaily.CheckFunction() ws := []float64{.715, .315} - _, err := backtest.Run(context.Background(), date("2021-01-04"), date("2021-01-01"), assets, backtestconfig.PolicyWeightCalculatorFunc(func(_ context.Context, _ time.Time, _ returns.Table, currentWeights []float64) ([]float64, error) { + _, err := backtest.Run(context.Background(), date("2021-01-04"), date("2021-01-01"), assets, allocationFunction(func(_ context.Context, _ time.Time, _ returns.Table, currentWeights []float64) ([]float64, error) { return ws, nil }), windowFunc, rebalanceIntervalFunc, policyUpdateIntervalFunc) assert.NoError(t, err) @@ -78,7 +79,7 @@ func TestSpec_Run(t *testing.T) { assert.Error(t, err) }) t.Run("end date does not have a return", func(t *testing.T) { - alg := backtestconfig.EqualWeights{} + alg := testAlgorithm() windowFunc := backtestconfig.WindowNotSet.Function rebalanceIntervalFunc := backtestconfig.IntervalDaily.CheckFunction() policyUpdateIntervalFunc := backtestconfig.IntervalDaily.CheckFunction() @@ -96,7 +97,7 @@ func TestSpec_Run(t *testing.T) { t.Run("with no returns", func(t *testing.T) { assets := returns.Table{} - alg := backtestconfig.EqualWeights{} + alg := testAlgorithm() windowFunc := backtestconfig.WindowNotSet.Function rebalanceIntervalFunc := backtestconfig.IntervalDaily.CheckFunction() policyUpdateIntervalFunc := backtestconfig.IntervalDaily.CheckFunction() @@ -110,7 +111,7 @@ func TestSpec_Run(t *testing.T) { }) t.Run("when there is one asset", func(t *testing.T) { - alg := backtestconfig.EqualWeights{} + alg := testAlgorithm() windowFunc := backtestconfig.OneDayWindow.Function rebalanceIntervalFunc := backtestconfig.IntervalDaily.CheckFunction() policyUpdateIntervalFunc := backtestconfig.IntervalDaily.CheckFunction() @@ -132,7 +133,7 @@ func TestSpec_Run(t *testing.T) { }) t.Run("when called repeatedly", func(t *testing.T) { - alg := backtestconfig.EqualWeights{} + alg := testAlgorithm() windowFunc := backtestconfig.OneDayWindow.Function rebalanceIntervalFunc := backtestconfig.IntervalDaily.CheckFunction() policyUpdateIntervalFunc := backtestconfig.IntervalDaily.CheckFunction() @@ -178,7 +179,7 @@ func TestSpec_Run(t *testing.T) { <-c cancel() }() - alg := backtestconfig.PolicyWeightCalculatorFunc(func(ctx context.Context, _ time.Time, _ returns.Table, ws []float64) (targetWeights []float64, err error) { + alg := allocationFunction(func(ctx context.Context, _ time.Time, _ returns.Table, ws []float64) (targetWeights []float64, err error) { close(c) <-ctx.Done() return ws, ctx.Err() @@ -213,7 +214,7 @@ func TestSpec_Run(t *testing.T) { } assets := returns.NewTable([]returns.List{asset1, asset2}) - alg := backtestconfig.EqualWeights{} + alg := testAlgorithm() windowFunc := backtestconfig.OneDayWindow.Function rebalanceIntervalFunc := backtestconfig.IntervalDaily.CheckFunction() policyUpdateIntervalFunc := backtestconfig.IntervalDaily.CheckFunction() @@ -247,12 +248,12 @@ func TestSpec_Run(t *testing.T) { {Time: date("2021-04-15"), Value: 0}, } assets := returns.NewTable([]returns.List{asset1, asset2}) - - alg := backtestconfig.PolicyWeightCalculatorFunc(func(ctx context.Context, t time.Time, assetReturns returns.Table, currentWeights []float64) ([]float64, error) { + fallback := testAlgorithm() + alg := allocationFunction(func(ctx context.Context, t time.Time, assetReturns returns.Table, currentWeights []float64) ([]float64, error) { if t.Before(date("2021-04-20")) { return nil, backtest.ErrorNotEnoughData{} } - return backtestconfig.EqualWeights{}.PolicyWeights(ctx, t, assetReturns, currentWeights) + return fallback.PolicyWeights(ctx, t, assetReturns, currentWeights) }) windowFunc := backtestconfig.WindowNotSet.Function rebalanceIntervalFunc := backtestconfig.IntervalDaily.CheckFunction() @@ -290,7 +291,7 @@ func TestSpec_Run(t *testing.T) { } assets := returns.NewTable([]returns.List{asset1, asset2}) - alg := backtestconfig.EqualWeights{} + alg := testAlgorithm() windowFunc := backtestconfig.OneWeekWindow.Function rebalanceIntervalFunc := backtestconfig.IntervalDaily.CheckFunction() policyUpdateIntervalFunc := backtestconfig.IntervalDaily.CheckFunction() @@ -327,7 +328,7 @@ func TestSpec_Run(t *testing.T) { } callCount := 0 - alg := backtestconfig.PolicyWeightCalculatorFunc(func(_ context.Context, tm time.Time, assetReturns returns.Table, currentWeights []float64) ([]float64, error) { + alg := allocationFunction(func(_ context.Context, tm time.Time, assetReturns returns.Table, currentWeights []float64) ([]float64, error) { callCount++ assert.Equalf(t, assetReturns.NumberOfColumns(), 1, "call count %d", callCount) for c := 0; c < assetReturns.NumberOfColumns(); c++ { @@ -360,7 +361,7 @@ func TestSpec_Run(t *testing.T) { } assert.Lenf(t, rs, 5, "call count %d", callCount) } - return backtestconfig.EqualWeights{}.PolicyWeights(context.Background(), tm, assetReturns, currentWeights) + return (&allocation.EqualWeights{}).PolicyWeights(context.Background(), tm, assetReturns, currentWeights) }) windowFunc := backtestconfig.OneWeekWindow.Function @@ -403,7 +404,7 @@ func TestSpec_Run_weightHistory(t *testing.T) { assets := returns.NewTable([]returns.List{asset}) - alg := backtestconfig.PolicyWeightCalculatorFunc(randomWeights) + alg := allocationFunction(randomWeights) windowFunc := backtestconfig.WindowNotSet.Function rebalanceIntervalFunc := backtestconfig.IntervalDaily.CheckFunction() policyUpdateIntervalFunc := backtestconfig.IntervalDaily.CheckFunction() @@ -493,7 +494,7 @@ func TestSpec_Run_weightHistory(t *testing.T) { } assets := returns.NewTable([]returns.List{asset1, asset2}) - alg := backtestconfig.EqualWeights{} + alg := testAlgorithm() windowFunc := backtestconfig.WindowNotSet.Function rebalanceIntervalFunc := backtestconfig.IntervalWeekly.CheckFunction() policyUpdateIntervalFunc := backtestconfig.IntervalMonthly.CheckFunction() @@ -549,7 +550,7 @@ func TestSpec_Run_weightHistory(t *testing.T) { } assets := returns.NewTable([]returns.List{asset1, asset2}) - alg := backtestconfig.EqualWeights{} + alg := testAlgorithm() windowFunc := backtestconfig.WindowNotSet.Function rebalanceIntervalFunc := backtestconfig.IntervalDaily.CheckFunction() policyUpdateIntervalFunc := backtestconfig.IntervalWeekly.CheckFunction() @@ -574,3 +575,13 @@ func date(str string) time.Time { d, _ := time.Parse(time.DateOnly, str) return d } + +type allocationFunction func(_ context.Context, _ time.Time, _ returns.Table, currentWeights []float64) (targetWeights []float64, err error) + +func (function allocationFunction) PolicyWeights(ctx context.Context, today time.Time, assets returns.Table, ws []float64) (targetWeights []float64, err error) { + return function(ctx, today, assets, ws) +} + +func testAlgorithm() allocation.Algorithm { + return new(allocation.EqualWeights) +} diff --git a/component.go b/component.go index 4fd9863..de31c5f 100644 --- a/component.go +++ b/component.go @@ -10,8 +10,9 @@ import ( ) type Component struct { - Type string `yaml:"type,omitempty"` - ID string `yaml:"id,omitempty"` + Type string `yaml:"type,omitempty" json:"type,omitempty" bson:"type"` + ID string `yaml:"id,omitempty" json:"id,omitempty" bson:"id"` + Label string `yaml:"label,omitempty" json:"label,omitempty" bson:"label"` } var componentExpression = regexp.MustCompile(`^[a-zA-Z0-9.:s]{1,24}$`) diff --git a/fs_test.go b/fs_test.go index 37d4865..3102d13 100644 --- a/fs_test.go +++ b/fs_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/portfoliotree/portfolio" + "github.com/portfoliotree/portfolio/allocation" ) func TestParseSpecificationFile(t *testing.T) { @@ -36,7 +37,7 @@ func TestParseSpecificationFile(t *testing.T) { }, Policy: portfolio.Policy{ Weights: []float64{60, 40}, - WeightsAlgorithm: portfolio.PolicyAlgorithmConstantWeights, + WeightsAlgorithm: allocation.ConstantWeightsAlgorithmName, RebalancingInterval: "Quarterly", }, Filepath: "examples/60-40_portfolio.yml", @@ -59,7 +60,7 @@ func TestParseSpecificationFile(t *testing.T) { }, Policy: portfolio.Policy{ RebalancingInterval: "Quarterly", - WeightsAlgorithm: portfolio.PolicyAlgorithmEqualWeights, + WeightsAlgorithm: allocation.EqualWeightsAlgorithmName, }, Filepath: "examples/maang_portfolio.yml", }, diff --git a/go.mod b/go.mod index 10360d0..a35e991 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/portfoliotree/round v0.0.0-20230629094931-8afd986aa2f1 github.com/stretchr/testify v1.8.4 go.mongodb.org/mongo-driver v1.12.1 + golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb gonum.org/v1/gonum v0.13.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -15,6 +16,6 @@ require ( github.com/google/go-cmp v0.5.9 // indirect github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb // indirect + golang.org/x/tools v0.7.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect ) diff --git a/go.sum b/go.sum index b7b3951..41380c9 100644 --- a/go.sum +++ b/go.sum @@ -55,6 +55,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.7.0 h1:W4OVu8VVOaIO0yzWMNdepAulS7YfoS3Zabrm8DOXXU4= +golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.13.0 h1:a0T3bh+7fhRyqeNbiC3qVHYmkiQgit3wnNan/2c0HMM= diff --git a/portfolio.go b/portfolio.go index b065364..056b11e 100644 --- a/portfolio.go +++ b/portfolio.go @@ -5,19 +5,36 @@ import ( "errors" "fmt" "io" - "net/url" - "strconv" "strings" "time" + "go.mongodb.org/mongo-driver/bson/primitive" "golang.org/x/exp/slices" "gopkg.in/yaml.v3" + "github.com/portfoliotree/portfolio/allocation" "github.com/portfoliotree/portfolio/backtest" "github.com/portfoliotree/portfolio/backtest/backtestconfig" "github.com/portfoliotree/portfolio/returns" ) +type Identifier = primitive.ObjectID + +type Document struct { + ID Identifier `json:"_id" yaml:"_id" bson:"_id"` + Type string `json:"type" yaml:"type" bson:"type"` + Metadata Metadata `json:"metadata" yaml:"metadata" bson:"metadata"` + Spec Specification `json:"spec" yaml:"spec" bson:"spec"` +} + +type Metadata struct { + Name string `json:"name,omitempty" yaml:"name,omitempty" bson:"name,omitempty"` + Benchmark Component `json:"benchmark,omitempty" yaml:"benchmark,omitempty" bson:"benchmark,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty" bson:"description,omitempty"` + Privacy string `json:"privacy,omitempty" yaml:"privacy,omitempty" bson:"privacy,omitempty"` + Factors []Component `json:"factors,omitempty" yaml:"factors,omitempty" bson:"factors,omitempty"` +} + // Specification models a portfolio. type Specification struct { Name string `yaml:"name"` @@ -74,10 +91,13 @@ func ParseSpecifications(r io.Reader) ([]Specification, error) { default: return result, fmt.Errorf("incorrect specification type got %q but expected %q", spec.Type, portfolioTypeName) } + pf := spec.Spec pf.setDefaultPolicyWeightAlgorithm() - if err := pf.ensureEqualNumberOfWeightsAndAssets(); err != nil { - return result, err + if pf.Policy.WeightsAlgorithm == allocation.ConstantWeightsAlgorithmName { + if len(pf.Policy.Weights) != len(pf.Assets) { + return result, errAssetAndWeightsLenMismatch(&spec.Spec) + } } result = append(result, pf) } @@ -91,61 +111,27 @@ func (pf *Specification) RemoveAsset(index int) error { return nil } -func (pf *Specification) Backtest(ctx context.Context, assets returns.Table, weightsAlgorithm backtestconfig.PolicyWeightCalculatorFunc) (backtest.Result, error) { - return pf.BacktestWithStartAndEndTime(ctx, time.Time{}, time.Time{}, assets, weightsAlgorithm) +func (pf *Specification) Backtest(ctx context.Context, assets returns.Table, alg allocation.Algorithm) (backtest.Result, error) { + return pf.BacktestWithStartAndEndTime(ctx, time.Time{}, time.Time{}, assets, alg) } -const ( - PolicyAlgorithmEqualWeights = "EqualWeights" - PolicyAlgorithmConstantWeights = "ConstantWeights" -) - func (pf *Specification) setDefaultPolicyWeightAlgorithm() { - if pf.Policy.WeightsAlgorithm != "" { - return - } if len(pf.Policy.Weights) > 0 { - pf.Policy.WeightsAlgorithm = PolicyAlgorithmConstantWeights + pf.Policy.WeightsAlgorithm = (*allocation.ConstantWeights)(nil).Name() } else { - pf.Policy.WeightsAlgorithm = PolicyAlgorithmEqualWeights - } -} - -func (pf *Specification) ensureEqualNumberOfWeightsAndAssets() error { - switch pf.Policy.WeightsAlgorithm { - case PolicyAlgorithmConstantWeights: - if len(pf.Policy.Weights) != len(pf.Assets) { - return fmt.Errorf("the number of assets and number of weights must be equal: len(assets) is %d and len(weights) is %d", len(pf.Assets), len(pf.Policy.Weights)) - } + pf.Policy.WeightsAlgorithm = (*allocation.EqualWeights)(nil).Name() } - return nil } -func (pf *Specification) policyWeightFunction(weights backtestconfig.PolicyWeightCalculatorFunc) (backtestconfig.PolicyWeightCalculatorFunc, error) { - switch pf.Policy.WeightsAlgorithm { - case PolicyAlgorithmEqualWeights: - return backtestconfig.EqualWeights{}.PolicyWeights, nil - case PolicyAlgorithmConstantWeights: - return backtestconfig.ConstantWeights(pf.Policy.Weights).PolicyWeights, nil - default: - if weights == nil { - return nil, fmt.Errorf("policy %q not supported by the backtest runner", pf.Policy.WeightsAlgorithm) +func (pf *Specification) BacktestWithStartAndEndTime(ctx context.Context, start, end time.Time, assets returns.Table, alg allocation.Algorithm) (backtest.Result, error) { + if alg == nil { + var err error + alg, err = pf.Algorithm(nil) + if err != nil { + return backtest.Result{}, err } - return weights, nil - } -} - -func (pf *Specification) BacktestWithStartAndEndTime(ctx context.Context, start, end time.Time, assets returns.Table, weightsFn backtestconfig.PolicyWeightCalculatorFunc) (backtest.Result, error) { - if err := pf.ensureEqualNumberOfWeightsAndAssets(); err != nil { - return backtest.Result{}, err } - var err error - weightsFn, err = pf.policyWeightFunction(weightsFn) - if err != nil { - return backtest.Result{}, err - } - - return backtest.Run(ctx, end, start, assets, weightsFn, + return backtest.Run(ctx, end, start, assets, alg, pf.Policy.WeightsAlgorithmLookBack.Function, pf.Policy.WeightsUpdatingInterval.CheckFunction(), pf.Policy.RebalancingInterval.CheckFunction(), @@ -161,84 +147,6 @@ type Policy struct { WeightsUpdatingInterval backtestconfig.Interval `yaml:"weights_updating_interval,omitempty"` } -func (pf *Specification) ParseValues(q url.Values) error { - if q.Has("asset-id") { - pf.Assets = pf.Assets[:0] - for _, assetID := range q["asset-id"] { - pf.Assets = append(pf.Assets, Component{ID: assetID}) - } - } - if q.Has("benchmark-id") { - pf.Benchmark.ID = q.Get("benchmark-id") - } - if q.Has("name") { - pf.Name = q.Get("name") - } - if q.Has("filepath") { - pf.Filepath = q.Get("filepath") - } - if q.Has("policy-rebalance") { - pf.Policy.RebalancingInterval = backtestconfig.Interval(q.Get("policy-rebalance")) - } - if q.Has("policy-weights-algorithm") { - pf.Policy.WeightsAlgorithm = q.Get("policy-weights-algorithm") - } - if q.Has("policy-weight") { - pf.Policy.Weights = pf.Policy.Weights[:0] - for i, weight := range q["policy-weight"] { - f, err := strconv.ParseFloat(weight, 64) - if err != nil { - return fmt.Errorf("failed to parse policy weight at indx %d: %w", i, err) - } - pf.Policy.Weights = append(pf.Policy.Weights, f) - } - } - if q.Has("policy-update-weights") { - pf.Policy.WeightsUpdatingInterval = backtestconfig.Interval(q.Get("policy-update-weights")) - } - if q.Has("policy-weight-algorithm-look-back") { - pf.Policy.WeightsAlgorithmLookBack = backtestconfig.Window(q.Get("policy-weight-algorithm-look-back")) - } - pf.filterEmptyAssetIDs() - return pf.Validate() -} - -func (pf *Specification) Values() url.Values { - q := make(url.Values) - if pf.Name != "" { - q.Set("name", pf.Name) - } - if pf.Benchmark.ID != "" { - q.Set("benchmark-id", pf.Benchmark.ID) - } - if pf.Filepath != "" { - q.Set("filepath", pf.Filepath) - } - if pf.Assets != nil { - for _, asset := range pf.Assets { - q.Add("asset-id", asset.ID) - } - } - if pf.Policy.RebalancingInterval != "" { - q.Set("policy-rebalance", pf.Policy.RebalancingInterval.String()) - } - if pf.Policy.WeightsAlgorithm != "" { - q.Set("policy-weights-algorithm", pf.Policy.WeightsAlgorithm) - } - if pf.Policy.Weights != nil { - for _, w := range pf.Policy.Weights { - q.Add("policy-weight", strconv.FormatFloat(w, 'f', 4, 64)) - } - } - if pf.Policy.WeightsUpdatingInterval != "" { - q.Set("policy-update-weights", string(pf.Policy.WeightsUpdatingInterval)) - } - if pf.Policy.WeightsAlgorithmLookBack != "" { - q.Set("policy-weight-algorithm-look-back", pf.Policy.WeightsAlgorithmLookBack.String()) - } - return q -} - // Validate does some simple validations. // Server you should do additional validations. func (pf *Specification) Validate() error { @@ -263,3 +171,25 @@ func (pf *Specification) filterEmptyAssetIDs() { } pf.Assets = filtered } + +func (pf *Specification) Algorithm(algorithmOptions []allocation.Algorithm) (allocation.Algorithm, error) { + if len(algorithmOptions) == 0 { + algorithmOptions = allocation.NewDefaultAlgorithmsList() + } + + for _, alg := range algorithmOptions { + if alg.Name() != pf.Policy.WeightsAlgorithm { + continue + } + if se, ok := alg.(allocation.WeightSetter); ok { + se.SetWeights(slices.Clone(pf.Policy.Weights)) + } + return alg, nil // algorithm is known + } + + return nil, errors.New("unknown algorithm") +} + +func errAssetAndWeightsLenMismatch(spec *Specification) error { + return fmt.Errorf("expected the number of policy weights to be the same as the number of assets got %d but expected %d", len(spec.Policy.Weights), len(spec.Assets)) +} diff --git a/portfolio_test.go b/portfolio_test.go index 6925450..8e894fc 100644 --- a/portfolio_test.go +++ b/portfolio_test.go @@ -6,7 +6,6 @@ import ( "fmt" "net/http" "net/http/httptest" - "net/url" "os" "path/filepath" "testing" @@ -32,12 +31,12 @@ func TestMain(m *testing.M) { func testdataAssetReturns(crp portfolio.ComponentReturnsProvider) http.HandlerFunc { return func(res http.ResponseWriter, req *http.Request) { - var pf portfolio.Specification - if err := pf.ParseValues(req.URL.Query()); err != nil { - http.Error(res, err.Error(), http.StatusInternalServerError) + assets, err := portfolio.ParseComponentsFromURL(req.URL.Query(), "asset") + if err != nil { + http.Error(res, err.Error(), http.StatusBadRequest) return } - table, err := crp.ComponentReturnsTable(req.Context(), pf.Assets...) + table, err := crp.ComponentReturnsTable(req.Context(), assets...) if err != nil { http.Error(res, err.Error(), http.StatusInternalServerError) return @@ -67,7 +66,7 @@ spec: assets: [ACWI, AGG] policy: weights: [60, 40] - weights_algorithm: PolicyAlgorithmConstantWeights + weights_algorithm: Constant Weights rebalancing_interval: Quarterly ` @@ -80,7 +79,7 @@ spec: // Output: // Name: 60/40 - // Alg: PolicyAlgorithmConstantWeights + // Alg: Constant Weights } func ExampleOpen() { @@ -94,7 +93,7 @@ func ExampleOpen() { // Output: // Name: 60/40 - // Alg: ConstantWeights + // Alg: Constant Weights } func TestParse(t *testing.T) { @@ -119,7 +118,7 @@ func TestParse(t *testing.T) { Name: "the number of assets and policy weights do not match", // language=yaml SpecYAML: `{type: Portfolio, spec: {assets: ["a"], policy: {weights: [1, 2]}}}`, - ErrorStringContains: "the number of assets and number of weights must be equal:", + ErrorStringContains: "expected the number of policy weights to be the same as the number of assets", }, { Name: "component field is invalid", @@ -153,11 +152,10 @@ func TestParse(t *testing.T) { p, err := portfolio.ParseOneSpecification(tt.SpecYAML) if tt.ErrorStringContains == "" { assert.NoError(t, err) + assert.Equal(t, tt.Portfolio, p) } else { - assert.Error(t, err) assert.ErrorContains(t, err, tt.ErrorStringContains) } - assert.Equal(t, tt.Portfolio, p) }) } } @@ -219,11 +217,11 @@ func TestPortfolio_Backtest(t *testing.T) { Assets: []portfolio.Component{{ID: "AAPL"}}, Policy: portfolio.Policy{ Weights: []float64{50, 50}, - WeightsAlgorithm: portfolio.PolicyAlgorithmConstantWeights, + WeightsAlgorithm: "Constant Weights", }, }, ctx: context.Background(), - ErrorSubstring: "the number of assets and number of weights must be equal:", + ErrorSubstring: "expected the number of policy weights to be the same as the number of assets", }, { Name: "unknown policy algorithm", @@ -235,12 +233,12 @@ func TestPortfolio_Backtest(t *testing.T) { }, }, ctx: context.Background(), - ErrorSubstring: `policy "unknown" not supported by the backtest runner`, + ErrorSubstring: `unknown algorithm`, }, } { t.Run(tt.Name, func(t *testing.T) { pf := tt.Portfolio - _, err := pf.Backtest(tt.ctx, returns.Table{}, nil) + _, err := pf.Backtest(tt.ctx, returns.NewTable([]returns.List{{}}), nil) if tt.ErrorSubstring == "" { assert.NoError(t, err) } else { @@ -256,9 +254,7 @@ func TestPortfolio_Backtest_custom_function(t *testing.T) { {ID: "AAPL"}, {ID: "GOOG"}, }, - }).Backtest(context.Background(), returns.NewTable([]returns.List{{}}), func(ctx context.Context, today time.Time, assets returns.Table, currentWeights []float64) ([]float64, error) { - return nil, fmt.Errorf("lemon") - }) + }).Backtest(context.Background(), returns.NewTable([]returns.List{{}}), ErrorAlg{}) assert.EqualError(t, err, "lemon") } @@ -297,112 +293,6 @@ func Test_Portfolio_Validate(t *testing.T) { } } -func Test_Portfolio_ParseValues(t *testing.T) { - for _, tt := range []struct { - Name string - Values url.Values - In, Out portfolio.Specification - ExpectErr bool - }{ - { - Name: "set everything", - Values: url.Values{ - "name": []string{"X"}, - "asset-id": []string{"y", "z"}, - "benchmark-id": []string{"b"}, - "filepath": []string{"f"}, - "policy-weight": []string{".5", ".5"}, - "policy-rebalance": []string{"Daily"}, - "policy-weights-algorithm": []string{"Static"}, - "policy-update-weights": []string{"Daily"}, - "policy-weight-algorithm-look-back": []string{"1 Week"}, - }, - Out: portfolio.Specification{ - Name: "X", - Assets: []portfolio.Component{ - {ID: "y"}, - {ID: "z"}, - }, - Benchmark: portfolio.Component{ - ID: "b", - }, - Filepath: "f", - Policy: portfolio.Policy{ - RebalancingInterval: "Daily", - WeightsAlgorithm: "Static", - Weights: []float64{0.5, 0.5}, - WeightsUpdatingInterval: "Daily", - WeightsAlgorithmLookBack: "1 Week", - }, - }, - }, - { - Name: "empty values do not override", - Values: url.Values{}, - In: portfolio.Specification{ - Name: "no change", - Benchmark: portfolio.Component{ID: "b"}, - Assets: []portfolio.Component{{ID: "a1"}}, - Filepath: "f", - }, - Out: portfolio.Specification{ - Name: "no change", - Benchmark: portfolio.Component{ID: "b"}, - Assets: []portfolio.Component{{ID: "a1"}}, - Filepath: "f", - }, - }, - } { - t.Run(tt.Name, func(t *testing.T) { - pf := &tt.In - err := pf.ParseValues(tt.Values) - if tt.ExpectErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - assert.Equal(t, tt.Out, *pf) - }) - } -} - -func Test_Portfolio_Values(t *testing.T) { - t.Run("encode and decode", func(t *testing.T) { - pf := portfolio.Specification{ - Name: "X", - Assets: []portfolio.Component{ - {ID: "y"}, - {ID: "z"}, - }, - Benchmark: portfolio.Component{ - ID: "b", - }, - Filepath: "f", - Policy: portfolio.Policy{ - RebalancingInterval: "Daily", - WeightsAlgorithm: "Static", - Weights: []float64{0.5, 0.5}, - WeightsUpdatingInterval: "Daily", - WeightsAlgorithmLookBack: "1 Week", - }, - } - - var update portfolio.Specification - e := pf.Values().Encode() - q, err := url.ParseQuery(e) - require.NoError(t, err) - assert.NoError(t, update.ParseValues(q)) - assert.Equal(t, pf, update) - }) - - t.Run("fail to parse float", func(t *testing.T) { - values, err := url.ParseQuery(`policy-weight=x`) - require.NoError(t, err) - var pf portfolio.Specification - assert.Error(t, pf.ParseValues(values)) - }) -} - func TestPortfolio_RemoveAsset(t *testing.T) { t.Run("nil", func(t *testing.T) { var zero portfolio.Specification @@ -469,3 +359,11 @@ func TestPortfolio_RemoveAsset(t *testing.T) { require.Error(t, pf.RemoveAsset(-1)) }) } + +type ErrorAlg struct{} + +func (ErrorAlg) Name() string { return "" } + +func (ErrorAlg) PolicyWeights(ctx context.Context, today time.Time, assets returns.Table, currentWeights []float64) ([]float64, error) { + return nil, fmt.Errorf("lemon") +} diff --git a/returns/table.go b/returns/table.go index 6650ed4..b4b4130 100644 --- a/returns/table.go +++ b/returns/table.go @@ -48,10 +48,7 @@ func (table *Table) UnmarshalBSON(buf []byte) error { } func (table Table) MarshalBSON() ([]byte, error) { - return bson.Marshal(encodedTable{ - Times: table.times, - Values: table.values, - }) + return bson.Marshal(newEncodedTable(table.times, table.values)) } type encodedTable struct { @@ -59,6 +56,24 @@ type encodedTable struct { Values [][]float64 `json:"values" bson:"values"` } +func newEncodedTable(times []time.Time, values [][]float64) encodedTable { + if times == nil { + times = make([]time.Time, 0) + } + if values == nil { + values = make([][]float64, 0) + } + for i := range values { + if values[i] == nil { + values[i] = make([]float64, 0) + } + } + return encodedTable{ + Times: times, + Values: values, + } +} + func (table *Table) UnmarshalJSON(buf []byte) error { var enc encodedTable err := json.Unmarshal(buf, &enc) @@ -68,10 +83,7 @@ func (table *Table) UnmarshalJSON(buf []byte) error { } func (table Table) MarshalJSON() ([]byte, error) { - t := encodedTable{ - Times: table.times, - Values: table.values, - } + t := newEncodedTable(table.times, table.values) err := round.Recursive(t.Values, 6) if err != nil { return nil, err @@ -184,39 +196,23 @@ func (table Table) addAdditionalColumn(list List) Table { list = list.Between(table.LastTime(), table.FirstTime()) updated := table.Between(list.LastTime(), list.FirstTime()) + newValues := make([]float64, len(updated.times)) for _, r := range list { - _, updated = updated.ensureRowForTime(r.Time) + i, found := updated.rowForTime(r.Time) + if !found { + continue + } + newValues[i] = r.Value } - newValues := make([]float64, len(updated.times)) - for i, tm := range updated.times { - value, _ := list.Value(tm) - newValues[i] = value - } updated.values = append(updated.values, newValues) return updated } -func (table Table) ensureRowForTime(tm time.Time) (index int, updated Table) { - for i, et := range table.times { - if et.Equal(tm) { - return i, table - } - if tm.After(et) { - index, updated = i, table - updated.times = append(updated.times[:i], append([]time.Time{tm}, updated.times[i:]...)...) - for j, values := range updated.values { - updated.values[j] = append(values[:i], append([]float64{0}, values[i:]...)...) - } - break - - //// an early return makes the coverage dip below 100% because the - //// empty block outside the loop would never execute. This break - //// is essentially like the following line - // return index, updated - } - } - return index, updated +func (table Table) rowForTime(tm time.Time) (index int, exists bool) { + return slices.BinarySearchFunc(table.times, tm, func(et time.Time, t time.Time) int { + return et.Compare(t) * -1 + }) } func (table Table) FirstTime() time.Time { return indexOrEmpty(table.times, firstIndex(table.times)) } @@ -245,6 +241,12 @@ func (table Table) TimeBefore(tm time.Time) (time.Time, bool) { return next, !next.IsZero() } +func (table Table) ClosestTimeOnOrBefore(tm time.Time) (time.Time, bool) { + index := indexOfClosest(table.times, identity[time.Time], tm) + next := indexOrEmpty(table.times, index) + return next, !next.IsZero() +} + func identity[T any](t T) T { return t } func (table Table) Lists() []List { diff --git a/returns/table_test.go b/returns/table_test.go index 186a0b1..50416af 100644 --- a/returns/table_test.go +++ b/returns/table_test.go @@ -280,6 +280,9 @@ func TestTable_Between(t *testing.T) { func TestTable_AddColumn(t *testing.T) { t.Run("when adding list with an additional row", func(t *testing.T) { + t.Skip(` +AddColumn now does not add a column to the table if the table does not already have a row. +`) table := returns.NewTable([]returns.List{ {rtn(t, fixtures.Day3, .1), rtn(t, fixtures.Day1, .1), rtn(t, fixtures.Day0, .1)}, }) @@ -456,22 +459,84 @@ func TestTable_TimeBefore(t *testing.T) { assert.False(t, hasReturn) }) t.Run("on a Monday", func(t *testing.T) { + in := fixtures.T(t, fixtures.Day2) + require.Equal(t, time.Monday, in.Weekday()) table := returns.NewTable([]returns.List{ {rtn(t, fixtures.LastDay, 0), rtn(t, fixtures.Day2, 0), rtn(t, fixtures.Day1, 0), rtn(t, fixtures.FirstDay, 0)}, {rtn(t, fixtures.LastDay, 0), rtn(t, fixtures.Day2, 0), rtn(t, fixtures.Day1, 0), rtn(t, fixtures.FirstDay, 0)}, }) - after, hasReturn := table.TimeBefore(fixtures.T(t, fixtures.Day2)) + result, hasReturn := table.TimeBefore(in) assert.True(t, hasReturn) - assert.Equal(t, after, fixtures.T(t, fixtures.Day1)) + assert.Equal(t, fixtures.T(t, fixtures.Day1), result) }) t.Run("on a Friday", func(t *testing.T) { + in := fixtures.T(t, fixtures.Day1) + require.Equal(t, in.Weekday(), time.Friday) table := returns.NewTable([]returns.List{ {rtn(t, fixtures.LastDay, 0), rtn(t, fixtures.Day2, 0), rtn(t, fixtures.Day1, 0), rtn(t, fixtures.FirstDay, 0)}, {rtn(t, fixtures.LastDay, 0), rtn(t, fixtures.Day2, 0), rtn(t, fixtures.Day1, 0), rtn(t, fixtures.FirstDay, 0)}, }) - after, hasReturn := table.TimeBefore(fixtures.T(t, fixtures.Day3)) + result, hasReturn := table.TimeBefore(in) assert.True(t, hasReturn) - assert.Equal(t, after, fixtures.T(t, fixtures.Day2)) + assert.Equal(t, fixtures.T(t, fixtures.Day0), result) + }) +} + +func TestTable_ClosestTimeOnOrBefore(t *testing.T) { + t.Run("on a Friday", func(t *testing.T) { + in := fixtures.T(t, fixtures.Day1) + require.Equal(t, in.Weekday(), time.Friday) + table := returns.NewTable([]returns.List{ + {rtn(t, fixtures.LastDay, 0), rtn(t, fixtures.Day2, 0), rtn(t, fixtures.Day1, 0), rtn(t, fixtures.FirstDay, 0)}, + }) + result, hasReturn := table.ClosestTimeOnOrBefore(in) + assert.True(t, hasReturn) + assert.Equal(t, fixtures.T(t, fixtures.Day1), result) + }) + t.Run("exactly between", func(t *testing.T) { + in := fixtures.T(t, fixtures.Day2) + table := returns.NewTable([]returns.List{ + {rtn(t, fixtures.Day3, 0), rtn(t, fixtures.Day1, 0)}, + }) + result, hasReturn := table.ClosestTimeOnOrBefore(in) + assert.True(t, hasReturn) + assert.Equal(t, fixtures.T(t, fixtures.Day1), result) + }) + t.Run("between closer to final day", func(t *testing.T) { + in := fixtures.T(t, fixtures.Day2) + table := returns.NewTable([]returns.List{ + {rtn(t, fixtures.Day3, 0), rtn(t, fixtures.Day0, 0)}, + }) + result, hasReturn := table.ClosestTimeOnOrBefore(in) + assert.True(t, hasReturn) + assert.Equal(t, fixtures.T(t, fixtures.Day0), result) + }) + t.Run("between closer to first day", func(t *testing.T) { + in := fixtures.T(t, fixtures.Day1) + table := returns.NewTable([]returns.List{ + {rtn(t, fixtures.Day3, 0), rtn(t, fixtures.Day0, 0)}, + }) + result, hasReturn := table.ClosestTimeOnOrBefore(in) + assert.True(t, hasReturn) + assert.Equal(t, fixtures.T(t, fixtures.Day0), result) + }) + t.Run("exactly first", func(t *testing.T) { + in := fixtures.T(t, fixtures.Day0) + table := returns.NewTable([]returns.List{ + {rtn(t, fixtures.Day1, 0), rtn(t, fixtures.Day0, 0)}, + }) + result, hasReturn := table.ClosestTimeOnOrBefore(in) + assert.True(t, hasReturn) + assert.Equal(t, fixtures.T(t, fixtures.Day0), result) + }) + t.Run("exactly last", func(t *testing.T) { + in := fixtures.T(t, fixtures.Day1) + table := returns.NewTable([]returns.List{ + {rtn(t, fixtures.Day1, 0), rtn(t, fixtures.Day0, 0)}, + }) + result, hasReturn := table.ClosestTimeOnOrBefore(in) + assert.True(t, hasReturn) + assert.Equal(t, fixtures.T(t, fixtures.Day1), result) }) }