diff --git a/fs/parseduration.go b/fs/parseduration.go index 3838d811c..7f8ba99d3 100644 --- a/fs/parseduration.go +++ b/fs/parseduration.go @@ -1,6 +1,7 @@ package fs import ( + "encoding/json" "fmt" "math" "strconv" @@ -231,10 +232,26 @@ func (d Duration) Type() string { // UnmarshalJSON makes sure the value can be parsed as a string or integer in JSON func (d *Duration) UnmarshalJSON(in []byte) error { - return UnmarshalJSONFlag(in, d, func(i int64) error { - *d = Duration(i) + // Check if the input is a string value. + if len(in) >= 2 && in[0] == '"' && in[len(in)-1] == '"' { + strVal := string(in[1 : len(in)-1]) // Remove the quotes + + // Attempt to parse the string as a duration. + parsedDuration, err := ParseDuration(strVal) + if err != nil { + return err + } + *d = Duration(parsedDuration) return nil - }) + } + // Handle numeric values. + var i int64 + err := json.Unmarshal(in, &i) + if err != nil { + return err + } + *d = Duration(i) + return nil } // Scan implements the fmt.Scanner interface diff --git a/fs/parseduration_test.go b/fs/parseduration_test.go index b95e7da8d..93340026c 100644 --- a/fs/parseduration_test.go +++ b/fs/parseduration_test.go @@ -214,3 +214,32 @@ func TestParseUnmarshalJSON(t *testing.T) { assert.Equal(t, Duration(test.want), duration, test.in) } } + +func TestUnmarshalJSON(t *testing.T) { + tests := []struct { + name string + input string + want Duration + wantErr bool + }{ + {"off string", `"off"`, DurationOff, false}, + {"max int64", `9223372036854775807`, DurationOff, false}, + {"duration string", `"1h"`, Duration(time.Hour), false}, + {"invalid string", `"invalid"`, 0, true}, + {"negative int", `-1`, Duration(-1), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var d Duration + err := json.Unmarshal([]byte(tt.input), &d) + if (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if d != tt.want { + t.Errorf("UnmarshalJSON() got = %v, want %v", d, tt.want) + } + }) + } +}