@@ -174,8 +174,9 @@ def main(args):
174
174
rows = l [1 :]
175
175
assert header [0 ] == "tag_id" and header [1 ] == "name" and header [2 ] == "category" , f"unexpected csv format: { header } "
176
176
177
- general_tags = [row [1 ] for row in rows [1 :] if row [2 ] == "0" ]
178
- character_tags = [row [1 ] for row in rows [1 :] if row [2 ] == "4" ]
177
+ rating_tags = [row [1 ] for row in rows [0 :] if row [2 ] == "9" ]
178
+ general_tags = [row [1 ] for row in rows [0 :] if row [2 ] == "0" ]
179
+ character_tags = [row [1 ] for row in rows [0 :] if row [2 ] == "4" ]
179
180
180
181
# 画像を読み込む
181
182
@@ -202,17 +203,13 @@ def run_batch(path_imgs):
202
203
probs = probs .numpy ()
203
204
204
205
for (image_path , _ ), prob in zip (path_imgs , probs ):
205
- # 最初の4つはratingなので無視する
206
- # # First 4 labels are actually ratings: pick one with argmax
207
- # ratings_names = label_names[:4]
208
- # rating_index = ratings_names["probs"].argmax()
209
- # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
206
+ combined_tags = []
207
+ rating_tag_text = ""
208
+ character_tag_text = ""
209
+ general_tag_text = ""
210
210
211
211
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
212
212
# Everything else is tags: pick any where prediction confidence > threshold
213
- combined_tags = []
214
- general_tag_text = ""
215
- character_tag_text = ""
216
213
for i , p in enumerate (prob [4 :]):
217
214
if i < len (general_tags ) and p >= args .general_threshold :
218
215
tag_name = general_tags [i ]
@@ -231,7 +228,20 @@ def run_batch(path_imgs):
231
228
if tag_name not in undesired_tags :
232
229
tag_freq [tag_name ] = tag_freq .get (tag_name , 0 ) + 1
233
230
character_tag_text += caption_separator + tag_name
234
- combined_tags .append (tag_name )
231
+ combined_tags .insert (0 ,tag_name ) # insert to the beggining
232
+
233
+ #最初の4つはratingなので無視する
234
+ # First 4 labels are actually ratings: pick one with argmax
235
+ ratings_names = prob [:4 ]
236
+ rating_index = ratings_names .argmax ()
237
+ found_rating = rating_tags [rating_index ]
238
+ if args .remove_underscore and len (found_rating ) > 3 :
239
+ found_rating = found_rating .replace ("_" , " " )
240
+
241
+ if found_rating not in undesired_tags :
242
+ tag_freq [found_rating ] = tag_freq .get (found_rating , 0 ) + 1
243
+ rating_tag_text = found_rating
244
+ combined_tags .insert (0 ,found_rating ) # insert to the beggining
235
245
236
246
# 先頭のカンマを取る
237
247
if len (general_tag_text ) > 0 :
@@ -264,6 +274,7 @@ def run_batch(path_imgs):
264
274
if args .debug :
265
275
logger .info ("" )
266
276
logger .info (f"{ image_path } :" )
277
+ logger .info (f"\t Rating tags: { rating_tag_text } " )
267
278
logger .info (f"\t Character tags: { character_tag_text } " )
268
279
logger .info (f"\t General tags: { general_tag_text } " )
269
280
0 commit comments