Package pypln :: Module client
[hide private]

Source Code for Module pypln.client

  1  #!/usr/bin/env python 
  2  # coding: utf-8 
  3   
  4  from logging import Logger, NullHandler 
  5  from copy import deepcopy 
  6  import zmq 
  7   
  8   
9 -class ManagerClient(object):
10 - def __init__(self, logger=None, logger_name='ManagerClient'):
11 self.context = zmq.Context() 12 if logger is None: 13 self.logger = Logger(logger_name) 14 self.logger.addHandler(NullHandler()) 15 else: 16 self.logger = logger
17
18 - def connect(self, api_host_port, broadcast_host_port):
19 self.api_host_port = api_host_port 20 self.broadcast_host_port = broadcast_host_port 21 self.api_connection_string = 'tcp://{}:{}'.format(*api_host_port) 22 self.broadcast_connection_string = \ 23 'tcp://{}:{}'.format(*broadcast_host_port) 24 25 self.manager_api = self.context.socket(zmq.REQ) 26 self.manager_broadcast = self.context.socket(zmq.SUB) 27 28 self.manager_api.connect(self.api_connection_string) 29 self.manager_broadcast.connect(self.broadcast_connection_string)
30
31 - def __del__(self):
32 self.close_sockets()
33
34 - def close_sockets(self):
35 sockets = ['manager_api', 'manager_broadcast'] 36 for socket in sockets: 37 if hasattr(self, socket): 38 getattr(self, socket).close()
39
40 -class Worker(object):
41 - def __init__(self, worker_name):
42 self.name = worker_name 43 self.after = []
44
45 - def then(self, *after):
46 self.after = after 47 return self
48
49 -class Pipeline(object):
50 - def __init__(self, pipeline, api_host_port, broadcast_host_port, 51 logger=None, logger_name='Pipeline', time_to_wait=0.1):
52 self.client = ManagerClient(logger, logger_name) 53 self.client.connect(api_host_port, broadcast_host_port) 54 self.pipeline = pipeline 55 self.time_to_wait = time_to_wait 56 self.logger = self.client.logger
57
58 - def send_job(self, worker):
59 job = {'command': 'add job', 'worker': worker.name, 60 'document': worker.document} 61 self.client.manager_api.send_json(job) 62 self.logger.info('Sent job: {}'.format(job)) 63 message = self.client.manager_api.recv_json() 64 self.logger.info('Received from Manager API: {}'.format(message)) 65 self.waiting[message['job id']] = worker 66 subscribe_message = 'job finished: {}'.format(message['job id']) 67 self.client.manager_broadcast.setsockopt(zmq.SUBSCRIBE, 68 subscribe_message) 69 self.logger.info('Subscribed on Manager Broadcast to: {}'\ 70 .format(subscribe_message))
71
72 - def distribute(self):
73 self.waiting = {} 74 for document in self.documents: 75 worker = deepcopy(self.pipeline) 76 worker.document = document 77 self.send_job(worker)
78
79 - def run(self, documents):
80 self.documents = documents 81 self.distribute() 82 try: 83 while True: 84 if self.client.manager_broadcast.poll(self.time_to_wait): 85 message = self.client.manager_broadcast.recv() 86 self.logger.info('[Client] Received from broadcast: {}'\ 87 .format(message)) 88 if message.startswith('job finished: '): 89 #TODO: unsubscribe 90 job_id = message.split(': ')[1].split(' ')[0] 91 worker = self.waiting[job_id] 92 for next_worker in worker.after: 93 next_worker.document = worker.document 94 self.send_job(next_worker) 95 del self.waiting[job_id] 96 if not self.waiting.keys(): 97 break 98 except KeyboardInterrupt: 99 self.client.close_sockets()
100
101 -def main():
102 import os 103 from logging import Logger, StreamHandler, Formatter 104 from sys import stdout, argv 105 from pymongo import Connection 106 from gridfs import GridFS 107 108 109 if len(argv) == 1: 110 print 'ERROR: you need to pass files to import' 111 exit(1) 112 113 api_host_port = ('localhost', 5555) 114 broadcast_host_port = ('localhost', 5556) 115 #TODO: should get config from manager 116 config = {'db': {'host': 'localhost', 'port': 27017, 117 'database': 'pypln', 118 'collection': 'documents', 119 'gridfs collection': 'files', 120 'monitoring collection': 'monitoring'}, 121 'monitoring interval': 60,} 122 db_config = config['db'] 123 mongo_connection = Connection(db_config['host'], db_config['port']) 124 db = mongo_connection[db_config['database']] 125 if 'username' in db_config and 'password' in db_config and \ 126 db_config['username'] and db_config['password']: 127 db.authenticate(db_config['username'], db_config['password']) 128 gridfs = GridFS(db, db_config['gridfs collection']) 129 #TODO: connection/collection lines should be in pypln.stores.mongodb 130 131 logger = Logger('Pipeline') 132 handler = StreamHandler(stdout) 133 formatter = Formatter('%(asctime)s - %(name)s - %(levelname)s - ' 134 '%(message)s') 135 handler.setFormatter(formatter) 136 logger.addHandler(handler) 137 138 my_docs = [] 139 filenames = argv[1:] 140 logger.info('Inserting files...') 141 for filename in filenames: 142 if os.path.exists(filename): 143 logger.debug(' {}'.format(filename)) 144 doc_id = gridfs.put(open(filename).read(), filename=filename) 145 my_docs.append(str(doc_id)) 146 147 #TODO: use et2 to create the tree/pipeline image 148 W, W.__call__ = Worker, Worker.then 149 workers = W('extractor')(W('tokenizer')(W('pos'), 150 W('freqdist'))) 151 pipeline = Pipeline(workers, api_host_port, broadcast_host_port, logger) 152 pipeline.run(my_docs)
153 #TODO: should receive a 'job error' from manager when some job could not be 154 # processed (timeout, worker not found etc.) 155 156 157 if __name__ == '__main__': 158 main() 159