Skip to content

Commit 58dce32

Browse files
committed
predict command factored out of serve to run batch stdin->stdout prediction
1 parent 8486cb2 commit 58dce32

File tree

4 files changed

+246
-50
lines changed

4 files changed

+246
-50
lines changed

cluster.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,3 +435,28 @@ def select_best_variant(variant_max_k_prec_loss_reps, log_top_k=1):
435435
])
436436
)
437437
return prec_loss, k, vn, reps
438+
439+
440+
def cluster_gps_to_reduce_queries(
441+
gps, max_queries, gtp_scores, clustering_variant=None):
442+
if 0 < max_queries < len(gps):
443+
logger.info(
444+
'reducing amount of queries from %d down to %d ...',
445+
len(gps), max_queries
446+
)
447+
gtps = gtp_scores.ground_truth_pairs
448+
var_max_k_prec_loss_reps = expected_precision_loss_by_query_reduction(
449+
gps, gtps, [max_queries], gtp_scores,
450+
variants=[clustering_variant] if clustering_variant else None,
451+
)
452+
prec_loss, k, vn, reps = select_best_variant(var_max_k_prec_loss_reps)
453+
454+
logger.info(
455+
'reduced number of queries from %d to %d\n'
456+
'used variant: %s\n'
457+
'expected precision sum loss ratio: %0.3f '
458+
'(precision sum loss: %.2f)',
459+
len(gps), len(reps), vn, prec_loss, prec_loss * gtp_scores.score
460+
)
461+
gps = reps
462+
return gps

gp_learner.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
import logging_config
4343
from cluster import expected_precision_loss_by_query_reduction
44-
from cluster import select_best_variant
44+
from cluster import cluster_gps_to_reduce_queries
4545
import config
4646
from exception import GPLearnerAbortException
4747
from fusion import fuse_prediction_results
@@ -94,6 +94,10 @@
9494
signal.signal(signal.SIGUSR1, log_mem_usage)
9595

9696

97+
def init_workers():
98+
parallel_map(_init_workers, range(1000))
99+
100+
97101
def _init_workers(_):
98102
# dummy method that makes workers load all import and config
99103
pass
@@ -1631,7 +1635,7 @@ def main(
16311635
print(u'encoding check: äöüß\U0001F385') # printing unicode string
16321636

16331637
# init workers
1634-
parallel_map(_init_workers, range(1000))
1638+
init_workers()
16351639

16361640
timer_start = datetime.utcnow()
16371641
main_start = timer_start
@@ -1738,30 +1742,9 @@ def main(
17381742
sys.stdout.flush()
17391743
sys.stderr.flush()
17401744

1741-
1742-
if 0 < max_queries < len(gps):
1743-
print(
1744-
'reducing amount of queries from %d down to %d ...' % (
1745-
len(gps), max_queries)
1746-
)
1747-
sys.stdout.flush()
1748-
var_max_k_prec_loss_reps = expected_precision_loss_by_query_reduction(
1749-
gps, semantic_associations, [max_queries], gtp_scores,
1750-
variants=[clustering_variant] if clustering_variant else None,
1751-
)
1752-
prec_loss, k, vn, reps = select_best_variant(var_max_k_prec_loss_reps)
1753-
sys.stderr.flush()
1754-
print('reduced number of queries from %d to %d' % (len(gps), len(reps)))
1755-
print('used variant: %s' % vn)
1756-
print(
1757-
'expected precision sum loss ratio: %0.3f '
1758-
'(precision sum loss: %.2f)' % (
1759-
prec_loss, prec_loss * gtp_scores.score)
1760-
)
1761-
gps = reps
1762-
1763-
sys.stdout.flush()
1764-
sys.stderr.flush()
1745+
# reduce gps by clustering if mandated by max_queries
1746+
gps = cluster_gps_to_reduce_queries(
1747+
gps, max_queries, gtp_scores, clustering_variant)
17651748

17661749
if print_query_patterns:
17671750
print(

predict.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import absolute_import
3+
from __future__ import division
4+
from __future__ import print_function
5+
6+
"""Script to predict with a fully trained model.
7+
8+
Reads one source (TTL syntax) per line from stdin and writes one JSON line to
9+
stdout.
10+
"""
11+
12+
import json
13+
import logging
14+
import sys
15+
16+
import SPARQLWrapper
17+
from rdflib.util import from_n3
18+
19+
20+
# noinspection PyUnresolvedReferences
21+
import logging_config
22+
23+
# not all import on top due to scoop and init...
24+
25+
logger = logging.getLogger(__name__)
26+
27+
28+
def predict(sparql, timeout, gps, source,
29+
fusion_methods=None, max_results=0, max_target_candidates_per_gp=0):
30+
from fusion import fuse_prediction_results
31+
from gp_learner import predict_target_candidates
32+
33+
gp_tcs = predict_target_candidates(sparql, timeout, gps, source)
34+
fused_results = fuse_prediction_results(
35+
gps,
36+
gp_tcs,
37+
fusion_methods
38+
)
39+
orig_length = max([len(v) for k, v in fused_results.items()])
40+
if max_results > 0:
41+
for k, v in fused_results.items():
42+
del v[max_results:]
43+
mt = max_target_candidates_per_gp
44+
if mt < 1:
45+
mt = None
46+
# logger.info(gp_tcs)
47+
res = {
48+
'source': source,
49+
'orig_result_length': orig_length,
50+
'graph_pattern_target_candidates': [sorted(tcs)[:mt] for tcs in gp_tcs],
51+
'fused_results': fused_results,
52+
}
53+
return res
54+
55+
56+
def parse_args():
57+
import argparse
58+
59+
parser = argparse.ArgumentParser(
60+
description='gp learner prediction',
61+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
62+
)
63+
64+
parser.add_argument(
65+
"--sparql_endpoint",
66+
help="the SPARQL endpoint to query",
67+
action="store",
68+
default=config.SPARQL_ENDPOINT,
69+
)
70+
parser.add_argument(
71+
"--max_queries",
72+
help="limits the amount of queries per prediction (0: no limit). "
73+
"You want to use the same limit as in training for late fusion "
74+
"models.",
75+
action="store",
76+
type=int,
77+
default=100,
78+
)
79+
parser.add_argument(
80+
"--clustering_variant",
81+
help="if specified use this clustering variant for query reduction, "
82+
"otherwise select the best from various.",
83+
action="store",
84+
type=str,
85+
default=None,
86+
)
87+
parser.add_argument(
88+
"--fusion_methods",
89+
help="Which fusion methods to use. During prediction, each of "
90+
"the learned patterns can generate a list of target candidates. "
91+
"Fusion re-combines these into a single ranked list of "
92+
"predicted targets. By default this will use all "
93+
"implemented fusion methods. Any of them, or a ',' delimited list "
94+
"can be used to reduce the output (just make sure you ran "
95+
"--predict=train_set on them before). Also supports 'basic' and "
96+
"'classifier' as shorthands. Make sure to only select methods the "
97+
"selected model was also trained on!",
98+
action="store",
99+
type=str,
100+
default=None,
101+
)
102+
103+
parser.add_argument(
104+
"--timeout",
105+
help="sets the timeout in seconds for each query (0: auto calibrate)",
106+
action="store",
107+
type=float,
108+
default=2.,
109+
)
110+
parser.add_argument(
111+
"--max_results",
112+
help="limits the result list lengths to save bandwidth (0: no limit)",
113+
action="store",
114+
type=int,
115+
default=100,
116+
)
117+
parser.add_argument(
118+
"--max_target_candidates_per_gp",
119+
help="limits the target candidate list lengths to save bandwidth "
120+
"(0: no limit)",
121+
action="store",
122+
type=int,
123+
default=100,
124+
)
125+
126+
parser.add_argument(
127+
"resdir",
128+
help="result directory of the trained model (overrides --RESDIR)",
129+
action="store",
130+
)
131+
132+
133+
cfg_group = parser.add_argument_group(
134+
'Advanced config overrides',
135+
'The following allow overriding default values from config/defaults.py'
136+
)
137+
config.arg_parse_config_vars(cfg_group)
138+
139+
prog_args = vars(parser.parse_args())
140+
# the following were aliased above, make sure they're updated globally
141+
prog_args.update({
142+
'SPARQL_ENDPOINT': prog_args['sparql_endpoint'],
143+
'RESDIR': prog_args['resdir'],
144+
})
145+
config.finalize(prog_args)
146+
147+
return prog_args
148+
149+
150+
151+
def main(
152+
resdir,
153+
sparql_endpoint,
154+
max_queries,
155+
clustering_variant,
156+
fusion_methods,
157+
timeout,
158+
max_results,
159+
max_target_candidates_per_gp,
160+
**_ # gulp remaining kwargs
161+
):
162+
from gp_query import calibrate_query_timeout
163+
from serialization import load_results
164+
from serialization import find_last_result
165+
from cluster import cluster_gps_to_reduce_queries
166+
from gp_learner import init_workers
167+
168+
# init workers
169+
init_workers()
170+
171+
sparql = SPARQLWrapper.SPARQLWrapper(sparql_endpoint)
172+
timeout = timeout if timeout > 0 else calibrate_query_timeout(sparql)
173+
174+
# load model
175+
last_res = find_last_result()
176+
if not last_res:
177+
logger.error('cannot find fully trained model in %s', resdir)
178+
sys.exit(1)
179+
result_patterns, coverage_counts, gtp_scores = load_results(last_res)
180+
gps = [gp for gp, _ in result_patterns]
181+
gps = cluster_gps_to_reduce_queries(
182+
gps, max_queries, gtp_scores, clustering_variant)
183+
184+
# main loop
185+
for line in sys.stdin:
186+
line = line.strip()
187+
if not line:
188+
continue
189+
if line[0] not in '<"':
190+
logger.error(
191+
'expected inputs to start with < or ", but got: %s', line)
192+
sys.exit(1)
193+
source = from_n3(line)
194+
195+
res = predict(
196+
sparql, timeout, gps, source, fusion_methods,
197+
max_results, max_target_candidates_per_gp
198+
)
199+
print(json.dumps(res))
200+
201+
202+
if __name__ == "__main__":
203+
logger.info('init run: origin')
204+
import config
205+
prog_kwds = parse_args()
206+
main(**prog_kwds)
207+
else:
208+
logger.info('init run: worker')

serve.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -81,32 +81,12 @@ def predict():
8181

8282

8383
def _predict(source):
84-
from fusion import fuse_prediction_results
85-
from gp_learner import predict_target_candidates
8684
from gp_query import calibrate_query_timeout
87-
85+
from predict import predict
8886
timeout = TIMEOUT if TIMEOUT > 0 else calibrate_query_timeout(SPARQL)
89-
gp_tcs = predict_target_candidates(SPARQL, timeout, GPS, source)
90-
fused_results = fuse_prediction_results(
91-
GPS,
92-
gp_tcs,
93-
FUSION_METHODS
94-
)
95-
orig_length = max([len(v) for k, v in fused_results.items()])
96-
if MAX_RESULTS > 0:
97-
for k, v in fused_results.items():
98-
del v[MAX_RESULTS:]
99-
mt = MAX_TARGET_CANDIDATES_PER_GP
100-
if mt < 1:
101-
mt = None
102-
# logger.info(gp_tcs)
103-
res = {
104-
'source': source,
105-
'orig_result_length': orig_length,
106-
'graph_pattern_target_candidates': [sorted(tcs)[:mt] for tcs in gp_tcs],
107-
'fused_results': fused_results,
108-
}
109-
return res
87+
return predict(
88+
SPARQL, timeout, GPS, source,
89+
FUSION_METHODS, MAX_RESULTS, MAX_TARGET_CANDIDATES_PER_GP)
11090

11191

11292
@app.route("/api/feedback", methods=["POST"])

0 commit comments

Comments
 (0)