Skip to content

Commit dd9763b

Browse files
committed
Rating support for WD Tagger
1 parent b86af67 commit dd9763b

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

finetune/tag_images_by_wd14_tagger.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,9 @@ def main(args):
174174
rows = l[1:]
175175
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
176176

177-
general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
178-
character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
177+
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
178+
general_tags = [row[1] for row in rows[0:] if row[2] == "0"]
179+
character_tags = [row[1] for row in rows[0:] if row[2] == "4"]
179180

180181
# 画像を読み込む
181182

@@ -202,17 +203,13 @@ def run_batch(path_imgs):
202203
probs = probs.numpy()
203204

204205
for (image_path, _), prob in zip(path_imgs, probs):
205-
# 最初の4つはratingなので無視する
206-
# # First 4 labels are actually ratings: pick one with argmax
207-
# ratings_names = label_names[:4]
208-
# rating_index = ratings_names["probs"].argmax()
209-
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
206+
combined_tags = []
207+
rating_tag_text = ""
208+
character_tag_text = ""
209+
general_tag_text = ""
210210

211211
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
212212
# Everything else is tags: pick any where prediction confidence > threshold
213-
combined_tags = []
214-
general_tag_text = ""
215-
character_tag_text = ""
216213
for i, p in enumerate(prob[4:]):
217214
if i < len(general_tags) and p >= args.general_threshold:
218215
tag_name = general_tags[i]
@@ -231,7 +228,20 @@ def run_batch(path_imgs):
231228
if tag_name not in undesired_tags:
232229
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
233230
character_tag_text += caption_separator + tag_name
234-
combined_tags.append(tag_name)
231+
combined_tags.insert(0,tag_name) # insert to the beggining
232+
233+
#最初の4つはratingなので無視する
234+
# First 4 labels are actually ratings: pick one with argmax
235+
ratings_names = prob[:4]
236+
rating_index = ratings_names.argmax()
237+
found_rating = rating_tags[rating_index]
238+
if args.remove_underscore and len(found_rating) > 3:
239+
found_rating = found_rating.replace("_", " ")
240+
241+
if found_rating not in undesired_tags:
242+
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
243+
rating_tag_text = found_rating
244+
combined_tags.insert(0,found_rating) # insert to the beggining
235245

236246
# 先頭のカンマを取る
237247
if len(general_tag_text) > 0:
@@ -264,6 +274,7 @@ def run_batch(path_imgs):
264274
if args.debug:
265275
logger.info("")
266276
logger.info(f"{image_path}:")
277+
logger.info(f"\tRating tags: {rating_tag_text}")
267278
logger.info(f"\tCharacter tags: {character_tag_text}")
268279
logger.info(f"\tGeneral tags: {general_tag_text}")
269280

0 commit comments

Comments
 (0)