Skip to content

Commit 6efde10

Browse files
committed
malformed URIs, and a bit of logging.
1 parent addf836 commit 6efde10

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

predict.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import json
1414
import logging
1515
import sys
16+
import utils
17+
import time
1618

1719
import SPARQLWrapper
1820
from splendid import chunker
@@ -118,6 +120,13 @@ def parse_args():
118120
type=str,
119121
default=None,
120122
)
123+
parser.add_argument(
124+
"--drop_bad_uris",
125+
help="URIs that cannot be curified are ignored",
126+
action="store",
127+
type=bool,
128+
default=False,
129+
)
121130
parser.add_argument(
122131
"--fusion_methods",
123132
help="Which fusion methods to use. During prediction, each of "
@@ -198,6 +207,7 @@ def main(
198207
max_results,
199208
max_target_candidates_per_gp,
200209
batch_predict,
210+
drop_bad_uris,
201211
**_ # gulp remaining kwargs
202212
):
203213
from gp_query import calibrate_query_timeout
@@ -222,6 +232,8 @@ def main(
222232
gps = cluster_gps_to_reduce_queries(
223233
gps, max_queries, gtp_scores, clustering_variant)
224234

235+
processed = 0
236+
start = time.time()
225237
batch_size = config.BATCH_SIZE if batch_predict else 1
226238
# main loop
227239
for lines in chunker(sys.stdin, batch_size):
@@ -230,6 +242,13 @@ def main(
230242
line = line.strip()
231243
if not line:
232244
continue
245+
if drop_bad_uris:
246+
try:
247+
source = from_n3(line)
248+
utils.curify(source)
249+
except:
250+
logger.warning('Warning: Could not curify URI %s! Skip.', line)
251+
continue
233252
if line[0] not in '<"':
234253
logger.error(
235254
'expected inputs to start with < or ", but got: %s', line)
@@ -238,7 +257,9 @@ def main(
238257
batch.append(source)
239258
batch = list(OrderedDict.fromkeys(batch))
240259

241-
if len(batch) == 1:
260+
if len(batch) == 0:
261+
pass
262+
elif len(batch) == 1:
242263
res = predict(
243264
sparql, timeout, gps, batch[0], fusion_methods,
244265
max_results, max_target_candidates_per_gp
@@ -252,6 +273,8 @@ def main(
252273
for r in res:
253274
print(json.dumps(r))
254275

276+
processed += len(batch)
277+
logger.info('Have processed %d URIs now. Took %s sec', processed, time.time()-start)
255278

256279
if __name__ == "__main__":
257280
logger.info('init run: origin')

0 commit comments

Comments
 (0)