Skip to content

Commit 1225959

Browse files
committed
Add transfer learning functionality between TFKG models with .SetLayerWeights() on layer interface
Fix issue where the first row of single_file_dataset was being ignored Add transfer_learning example Save order of model weights in json file in save dir to make transfer learning nicer
1 parent 183e0f1 commit 1225959

File tree

102 files changed

+2409
-426
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

102 files changed

+2409
-426
lines changed

Makefile

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,20 @@ examples-sign-raw:
8787
go generate ./...
8888
cd examples/sign && go run main.go
8989

90+
examples-transfer:
91+
go generate ./...
92+
docker-compose up -d tf-jupyter-golang
93+
docker-compose exec tf-jupyter-golang sh -c "cd /go/src/tfkg/examples/transfer_learning && go run main.go"
94+
95+
examples-transfer-gpu:
96+
go generate ./...
97+
docker-compose up -d tf-jupyter-golang-gpu
98+
docker-compose exec tf-jupyter-golang-gpu sh -c "cd /go/src/tfkg/examples/transfer_learning && go run main.go"
99+
100+
examples-transfer-raw:
101+
go generate ./...
102+
cd examples/transfer_learning && go run main.go
103+
90104
test-python:
91105
docker-compose up -d tf-jupyter-golang
92106
docker-compose exec tf-jupyter-golang sh -c "cd /go/src/tfkg && python test.py"

data/img_folder_dataset.go

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -128,24 +128,26 @@ func NewImgFolderDataset(
128128
},
129129
}
130130

131-
e = os.MkdirAll(config.CacheDir, os.ModePerm)
132-
if e != nil && e != os.ErrExist {
133-
errorHandler.Error(e)
134-
return nil, e
135-
}
136-
137-
if _, e := os.Stat(filepath.Join(config.CacheDir, "category-tokenizer.json")); e == nil {
138-
d.categoryTokenizer = preprocessor.NewTokenizer(
139-
errorHandler,
140-
1,
141-
-1,
142-
preprocessor.TokenizerConfig{IsCategoryTokenizer: true, DisableFiltering: true},
143-
)
144-
e = d.categoryTokenizer.Load(filepath.Join(config.CacheDir, "category-tokenizer.json"))
145-
if e != nil {
131+
if config.CacheDir != "" {
132+
e = os.MkdirAll(config.CacheDir, os.ModePerm)
133+
if e != nil && e != os.ErrExist {
146134
errorHandler.Error(e)
147135
return nil, e
148136
}
137+
138+
if _, e := os.Stat(filepath.Join(config.CacheDir, "category-tokenizer.json")); e == nil {
139+
d.categoryTokenizer = preprocessor.NewTokenizer(
140+
errorHandler,
141+
1,
142+
-1,
143+
preprocessor.TokenizerConfig{IsCategoryTokenizer: true, DisableFiltering: true},
144+
)
145+
e = d.categoryTokenizer.Load(filepath.Join(config.CacheDir, "category-tokenizer.json"))
146+
if e != nil {
147+
errorHandler.Error(e)
148+
return nil, e
149+
}
150+
}
149151
}
150152

151153
e = d.readFileNames()
@@ -169,11 +171,16 @@ type imgStatsCache struct {
169171

170172
func (d *ImgFolderDataset) readFileNames() error {
171173
cacheFileName := "file-stats.json"
172-
cacheFileBytes, e := ioutil.ReadFile(filepath.Join(d.cacheDir, cacheFileName))
173-
if e != nil && !errors.Is(e, os.ErrNotExist) {
174-
d.errorHandler.Error(e)
175-
return e
176-
} else if e == nil {
174+
var cacheFileBytes []byte
175+
var e error
176+
if d.cacheDir != "" {
177+
cacheFileBytes, e = ioutil.ReadFile(filepath.Join(d.cacheDir, cacheFileName))
178+
if e != nil && !errors.Is(e, os.ErrNotExist) {
179+
d.errorHandler.Error(e)
180+
return e
181+
}
182+
}
183+
if len(cacheFileBytes) > 0 {
177184
var cache imgStatsCache
178185
e = json.Unmarshal(cacheFileBytes, &cache)
179186
if e != nil {

data/single_file_dataset.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ func (d *SingleFileDataset) readLineOffsets() error {
225225

226226
lastPrint := time.Now().Unix()
227227
progress, lastProgress := 0, 0
228-
skippedHeaders := false
228+
skippedHeaders, zeroAdded := false, false
229229
swg := sizedwaitgroup.New(128)
230230
var errs []error
231231
for true {
@@ -238,6 +238,9 @@ func (d *SingleFileDataset) readLineOffsets() error {
238238
if !skippedHeaders && d.skipHeaders {
239239
skippedHeaders = true
240240
continue
241+
} else if !d.skipHeaders && !zeroAdded {
242+
zeroAdded = true
243+
d.lineOffsets = append(d.lineOffsets, 0)
241244
}
242245

243246
if len(errs) > 0 {

0 commit comments

Comments
 (0)