20
20
# This poc is based on source code taken from:
21
21
# https://github.com/guillaumegenthial/sequence_tagging
22
22
23
- import sys
24
23
from math import floor
25
24
import tensorflow as tf
26
25
import re
27
26
import numpy as np
28
27
import zipfile
29
28
import os
30
29
from tempfile import TemporaryDirectory
30
+ import argparse
31
31
32
32
# global variables for unknown word and numbers
33
33
__UNK__ = '__UNK__'
@@ -68,12 +68,16 @@ def __str__(self):
68
68
class NameFinder :
69
69
label_dict = {}
70
70
71
- def __init__ (self , use_lower_case_embeddings = False , vector_size = 100 ):
71
+ def __init__ (self , use_lower_case_embeddings , allow_unk , allow_num , digit_pattern , encoding , vector_size = 100 ):
72
72
self .__vector_size = vector_size
73
73
self .__use_lower_case_embeddings = use_lower_case_embeddings
74
+ self .__allow_unk = allow_unk
75
+ self .__allow_num = allow_num
76
+ self .__digit_pattern = re .compile (digit_pattern )
77
+ self .__encoding = encoding
74
78
75
79
def load_data (self , word_dict , file ):
76
- with open (file ) as f :
80
+ with open (file , encoding = self . __encoding ) as f :
77
81
raw_data = f .readlines ()
78
82
79
83
sentences = []
@@ -96,7 +100,8 @@ def load_data(self, word_dict, file):
96
100
if self .__use_lower_case_embeddings :
97
101
token = token .lower ()
98
102
99
- # TODO: implement NUM encoding
103
+ if self .__allow_num and self .__digit_pattern .match (token ):
104
+ token = __NUM__
100
105
101
106
if word_dict .get (token ) is not None :
102
107
vector = word_dict [token ]
@@ -340,8 +345,8 @@ def write_mapping(tags, output_filename):
340
345
f .write ('{}\n ' .format (tag ))
341
346
342
347
343
- def load_glove (glove_file ):
344
- with open (glove_file ) as f :
348
+ def load_glove (glove_file , encoding = 'utf-8' ):
349
+ with open (glove_file , encoding = encoding ) as f :
345
350
346
351
word_dict = {}
347
352
embeddings = []
@@ -381,16 +386,28 @@ def load_glove(glove_file):
381
386
382
387
383
388
def main ():
384
- if len (sys .argv ) != 5 :
385
- print ("Usage namefinder.py embedding_file train_file dev_file test_file" )
386
- return
387
-
388
- word_dict , rev_word_dict , embeddings , vector_size = load_glove (sys .argv [1 ])
389
-
390
- name_finder = NameFinder (vector_size )
391
-
392
- sentences , labels , char_set = name_finder .load_data (word_dict , sys .argv [2 ])
393
- sentences_dev , labels_dev , char_set_dev = name_finder .load_data (word_dict , sys .argv [3 ])
389
+ parser = argparse .ArgumentParser ()
390
+ parser .add_argument ("embedding_file" , help = "path to the embeddings file." )
391
+ parser .add_argument ("train_file" , help = "path to the training file." )
392
+ parser .add_argument ("dev_file" , help = "path to the dev file." )
393
+ parser .add_argument ("--allow_unk" , help = "use general UNK vector for unknown tokens." , default = True )
394
+ parser .add_argument ("--allow_num" , help = "use general NUM vector for all numeric tokens." , default = False )
395
+ parser .add_argument ("--lower_case_embeddings" , help = "convert tokens to lowercase for embeddings lookup." ,
396
+ default = False )
397
+ parser .add_argument ("--digit_pattern" , help = "regex to use for identifying numeric tokens." ,
398
+ default = '^\\ d+(,\\ d+)*(\\ .\\ d+)?$' )
399
+ parser .add_argument ("--data_encoding" , help = "set encoding of train and dev data." , default = 'utf-8' )
400
+ parser .add_argument ("--embeddings_encoding" , help = "set encoding of the embeddings." , default = 'utf-8' )
401
+ args = parser .parse_args ()
402
+
403
+ word_dict , rev_word_dict , embeddings , vector_size = load_glove (args .embedding_file , args .embeddings_encoding )
404
+
405
+ name_finder = NameFinder (use_lower_case_embeddings = args .lower_case_embeddings , allow_unk = args .allow_unk ,
406
+ allow_num = args .allow_num , digit_pattern = args .digit_pattern ,
407
+ encoding = args .data_encoding , vector_size = vector_size )
408
+
409
+ sentences , labels , char_set = name_finder .load_data (word_dict , args .train_file )
410
+ sentences_dev , labels_dev , char_set_dev = name_finder .load_data (word_dict , args .dev_file )
394
411
395
412
char_dict = {k : v for v , k in enumerate (char_set | char_set_dev )}
396
413
@@ -472,6 +489,14 @@ def main():
472
489
write_mapping (name_finder .label_dict , temp_model_dir + "/label_dict.txt" )
473
490
write_mapping (char_dict , temp_model_dir + "/char_dict.txt" )
474
491
492
+ write_mapping ({'lower_case_embeddings=' + str (args .lower_case_embeddings ).lower (): 0 ,
493
+ 'allow_unk=' + str (args .allow_unk ).lower (): 1 ,
494
+ 'allow_num=' + str (args .allow_num ).lower (): 2 ,
495
+ 'digit_pattern=' + re .escape (args .digit_pattern ): 3 ,
496
+ 'data_encoding=' + args .data_encoding : 4 ,
497
+ 'embeddings_encoding=' + args .embeddings_encoding : 5 },
498
+ temp_model_dir + "/config.properties" )
499
+
475
500
zipf = zipfile .ZipFile ("namefinder-" + str (epoch ) + ".zip" , 'w' , zipfile .ZIP_DEFLATED )
476
501
477
502
for root , dirs , files in os .walk (temp_model_dir ):
0 commit comments