Skip to content

Commit d816ee9

Browse files
authored
Merge pull request #626 from projectdiscovery/add_jsonutil
add `FilterStruct` and `GetStructFields` funcs
2 parents 36bda9d + b03223f commit d816ee9

File tree

2 files changed

+189
-1
lines changed

2 files changed

+189
-1
lines changed

structs/structs.go

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package structs
22

3-
import "reflect"
3+
import (
4+
"errors"
5+
"reflect"
6+
)
47

58
// CallbackFunc on the struct field
69
// example:
@@ -35,3 +38,65 @@ func Walk(s interface{}, callback CallbackFunc) {
3538
}
3639
}
3740
}
41+
42+
// FilterStruct filters the struct based on include and exclude fields and returns a new struct.
43+
// - input: the original struct.
44+
// - includeFields: list of fields to include (if empty, includes all).
45+
// - excludeFields: list of fields to exclude (processed after include).
46+
func FilterStruct(input interface{}, includeFields, excludeFields []string) (interface{}, error) {
47+
val := reflect.ValueOf(input)
48+
if val.Kind() == reflect.Ptr {
49+
val = val.Elem()
50+
}
51+
52+
if val.Kind() != reflect.Struct {
53+
return nil, errors.New("input must be a struct")
54+
}
55+
56+
includeMap := make(map[string]bool)
57+
excludeMap := make(map[string]bool)
58+
59+
for _, field := range includeFields {
60+
includeMap[field] = true
61+
}
62+
for _, field := range excludeFields {
63+
excludeMap[field] = true
64+
}
65+
66+
typeOfStruct := val.Type()
67+
filteredStruct := reflect.New(typeOfStruct).Elem()
68+
69+
for i := 0; i < val.NumField(); i++ {
70+
field := typeOfStruct.Field(i)
71+
fieldName := field.Name
72+
fieldValue := val.Field(i)
73+
74+
if (len(includeMap) == 0 || includeMap[fieldName]) && !excludeMap[fieldName] {
75+
filteredStruct.Field(i).Set(fieldValue)
76+
}
77+
}
78+
79+
return filteredStruct.Interface(), nil
80+
}
81+
82+
// GetStructFields returns all the top-level field names from the given struct.
83+
// - input: the original struct.
84+
// Returns a slice of field names or an error if the input is not a struct.
85+
func GetStructFields(input interface{}) ([]string, error) {
86+
val := reflect.ValueOf(input)
87+
if val.Kind() == reflect.Ptr {
88+
val = val.Elem()
89+
}
90+
91+
if val.Kind() != reflect.Struct {
92+
return nil, errors.New("input must be a struct")
93+
}
94+
95+
fields := make([]string, 0, val.NumField())
96+
typeOfStruct := val.Type()
97+
for i := 0; i < val.NumField(); i++ {
98+
fields = append(fields, typeOfStruct.Field(i).Name)
99+
}
100+
101+
return fields, nil
102+
}

structs/structs_test.go

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
package structs
2+
3+
import (
4+
"reflect"
5+
"testing"
6+
)
7+
8+
type TestStruct struct {
9+
Name string
10+
Age int
11+
Address string
12+
}
13+
14+
type NestedStruct struct {
15+
Basic TestStruct
16+
PtrField *TestStruct
17+
}
18+
19+
func TestFilterStruct(t *testing.T) {
20+
s := TestStruct{
21+
Name: "John",
22+
Age: 30,
23+
Address: "New York",
24+
}
25+
26+
tests := []struct {
27+
name string
28+
input interface{}
29+
includeFields []string
30+
excludeFields []string
31+
want TestStruct
32+
wantErr bool
33+
}{
34+
{
35+
name: "include specific fields",
36+
input: s,
37+
includeFields: []string{"Name", "Age"},
38+
excludeFields: []string{},
39+
want: TestStruct{
40+
Name: "John",
41+
Age: 30,
42+
},
43+
wantErr: false,
44+
},
45+
{
46+
name: "exclude specific fields",
47+
input: s,
48+
includeFields: []string{},
49+
excludeFields: []string{"Address"},
50+
want: TestStruct{
51+
Name: "John",
52+
Age: 30,
53+
},
54+
wantErr: false,
55+
},
56+
{
57+
name: "non-struct input",
58+
input: "not a struct",
59+
includeFields: []string{},
60+
excludeFields: []string{},
61+
want: TestStruct{},
62+
wantErr: true,
63+
},
64+
}
65+
66+
for _, tt := range tests {
67+
t.Run(tt.name, func(t *testing.T) {
68+
got, err := FilterStruct(tt.input, tt.includeFields, tt.excludeFields)
69+
if (err != nil) != tt.wantErr {
70+
t.Errorf("FilterStruct() error = %v, wantErr %v", err, tt.wantErr)
71+
return
72+
}
73+
if !tt.wantErr {
74+
if !reflect.DeepEqual(got, tt.want) {
75+
t.Errorf("FilterStruct() = %v, want %v", got, tt.want)
76+
}
77+
}
78+
})
79+
}
80+
}
81+
82+
func TestGetStructFields(t *testing.T) {
83+
s := TestStruct{
84+
Name: "John",
85+
Age: 30,
86+
Address: "New York",
87+
}
88+
89+
tests := []struct {
90+
name string
91+
input interface{}
92+
want []string
93+
wantErr bool
94+
}{
95+
{
96+
name: "valid struct",
97+
input: s,
98+
want: []string{"Name", "Age", "Address"},
99+
wantErr: false,
100+
},
101+
{
102+
name: "non-struct input",
103+
input: "not a struct",
104+
want: nil,
105+
wantErr: true,
106+
},
107+
}
108+
109+
for _, tt := range tests {
110+
t.Run(tt.name, func(t *testing.T) {
111+
got, err := GetStructFields(tt.input)
112+
if (err != nil) != tt.wantErr {
113+
t.Errorf("GetStructFields() error = %v, wantErr %v", err, tt.wantErr)
114+
return
115+
}
116+
if !tt.wantErr {
117+
if !reflect.DeepEqual(got, tt.want) {
118+
t.Errorf("GetStructFields() = %v, want %v", got, tt.want)
119+
}
120+
}
121+
})
122+
}
123+
}

0 commit comments

Comments
 (0)