Skip to content

Commit 75d2ea1

Browse files
committed
make python training and java inference configurable
1 parent 4fb494f commit 75d2ea1

File tree

4 files changed

+150
-26
lines changed

4 files changed

+150
-26
lines changed

tf-ner-poc/src/main/java/org/apache/opennlp/namefinder/PredictionConfiguration.java

+38
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.io.FileInputStream;
2121
import java.io.IOException;
2222
import java.io.InputStream;
23+
import java.util.regex.Pattern;
2324

2425
public class PredictionConfiguration {
2526

@@ -28,6 +29,11 @@ public class PredictionConfiguration {
2829
private String vocabTags;
2930
private String savedModel;
3031

32+
private boolean useLowerCaseEmbeddings;
33+
private boolean allowUNK;
34+
private boolean allowNUM;
35+
private Pattern digitPattern = Pattern.compile("\\d+(,\\d+)*(\\.\\d+)?");
36+
3137
public PredictionConfiguration(String vocabWords, String vocabChars, String vocabTags, String savedModel) {
3238
this.vocabWords = vocabWords;
3339
this.vocabChars = vocabChars;
@@ -51,6 +57,38 @@ public String getSavedModel() {
5157
return savedModel;
5258
}
5359

60+
public boolean isUseLowerCaseEmbeddings() {
61+
return useLowerCaseEmbeddings;
62+
}
63+
64+
public void setUseLowerCaseEmbeddings(boolean useLowerCaseEmbeddings) {
65+
this.useLowerCaseEmbeddings = useLowerCaseEmbeddings;
66+
}
67+
68+
public boolean isAllowUNK() {
69+
return allowUNK;
70+
}
71+
72+
public void setAllowUNK(boolean allowUNK) {
73+
this.allowUNK = allowUNK;
74+
}
75+
76+
public boolean isAllowNUM() {
77+
return allowNUM;
78+
}
79+
80+
public void setAllowNUM(boolean allowNUM) {
81+
this.allowNUM = allowNUM;
82+
}
83+
84+
public Pattern getDigitPattern() {
85+
return digitPattern;
86+
}
87+
88+
public void setDigitPattern(Pattern digitPattern) {
89+
this.digitPattern = digitPattern;
90+
}
91+
5492
public InputStream getVocabWordsInputStream() throws IOException{
5593
return new FileInputStream(getVocabWords());
5694
}

tf-ner-poc/src/main/java/org/apache/opennlp/namefinder/SequenceTagging.java

+11-3
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,23 @@ public SequenceTagging(PredictionConfiguration config) throws IOException {
4444
model = SavedModelBundle.load(config.getSavedModel(), "serve");
4545
session = model.session();
4646

47-
this.wordIndexer = new WordIndexer(new FileInputStream(config.getVocabWords()),
47+
this.wordIndexer = new WordIndexer(config.isUseLowerCaseEmbeddings(), config.isAllowNUM(), config.isAllowNUM(),
48+
new FileInputStream(config.getVocabWords()),
4849
new FileInputStream(config.getVocabChars()));
4950

51+
this.wordIndexer.setDigitPattern(config.getDigitPattern());
52+
5053
this.indexTagger = new IndexTagger((new FileInputStream(config.getVocabTags())));
5154
}
5255

5356
public SequenceTagging(InputStream modelZipPackage) throws IOException {
5457

5558
Path tmpDir = ModelUtil.writeModelToTmpDir(modelZipPackage);
5659

57-
try (InputStream wordsIn = Files.newInputStream(tmpDir.resolve("word_dict.txt"));
60+
try (InputStream configIn = Files.newInputStream(tmpDir.resolve("config.properties"));
61+
InputStream wordsIn = Files.newInputStream(tmpDir.resolve("word_dict.txt"));
5862
InputStream charsIn = Files.newInputStream(tmpDir.resolve("char_dict.txt"))) {
59-
wordIndexer = new WordIndexer(wordsIn, charsIn);
63+
wordIndexer = new WordIndexer(configIn, wordsIn, charsIn);
6064
}
6165

6266
try (InputStream in = Files.newInputStream(tmpDir.resolve("label_dict.txt"))) {
@@ -122,6 +126,10 @@ private String[][] predict(TokenIds tokenIds) {
122126
}
123127
}
124128

129+
public WordIndexer getWordIndexer() {
130+
return wordIndexer;
131+
}
132+
125133
@Override
126134
public void clearAdaptiveData() {
127135
}

tf-ner-poc/src/main/java/org/apache/opennlp/namefinder/WordIndexer.java

+60-7
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,83 @@
1717

1818
package org.apache.opennlp.namefinder;
1919

20+
import opennlp.tools.util.StringUtil;
21+
2022
import java.io.BufferedReader;
2123
import java.io.IOException;
2224
import java.io.InputStream;
2325
import java.io.InputStreamReader;
2426
import java.util.Arrays;
2527
import java.util.HashMap;
2628
import java.util.Map;
29+
import java.util.Properties;
2730
import java.util.regex.Pattern;
2831

29-
import opennlp.tools.util.StringUtil;
30-
3132
public class WordIndexer {
3233

3334
private final Map<Character, Integer> char2idx;
3435
private final Map<String, Integer> word2idx;
3536

36-
public static String UNK = "$UNK$";
37-
public static String NUM = "$NUM$";
37+
public static String UNK = "__UNK__";
38+
public static String NUM = "__NUM__";
3839

3940
private boolean lowerCase = false;
40-
private boolean allowUnk = false;
41+
private boolean allowUnk = true;
42+
private boolean allowNum = false;
4143

4244
private Pattern digitPattern = Pattern.compile("\\d+(,\\d+)*(\\.\\d+)?");
4345

46+
public boolean isLowerCase() {
47+
return lowerCase;
48+
}
49+
50+
public void setLowerCase(boolean lowerCase) {
51+
this.lowerCase = lowerCase;
52+
}
53+
54+
public boolean isAllowUnk() {
55+
return allowUnk;
56+
}
57+
58+
public void setAllowUnk(boolean allowUnk) {
59+
this.allowUnk = allowUnk;
60+
}
61+
62+
public boolean isAllowNum() {
63+
return allowNum;
64+
}
65+
66+
public void setAllowNum(boolean allowNum) {
67+
this.allowNum = allowNum;
68+
}
69+
70+
public Pattern getDigitPattern() {
71+
return digitPattern;
72+
}
73+
74+
public void setDigitPattern(Pattern digitPattern) {
75+
this.digitPattern = digitPattern;
76+
}
77+
78+
public WordIndexer(InputStream config, InputStream vocabWords, InputStream vocabChars) throws IOException {
79+
this(vocabWords, vocabChars);
80+
Properties props = new Properties();
81+
if (config != null) {
82+
props.load(new InputStreamReader(config, "UTF8"));
83+
this.setLowerCase(Boolean.valueOf(props.getProperty("lower_case_embeddings")));
84+
this.setAllowUnk(Boolean.valueOf(props.getProperty("allow_unk")));
85+
this.setAllowNum(Boolean.valueOf(props.getProperty("allow_num")));
86+
this.setDigitPattern(Pattern.compile(props.getProperty("digit_pattern")));
87+
}
88+
}
89+
90+
public WordIndexer(boolean lowerCaseTokens, boolean allowUnk, boolean allowNum, InputStream vocabWords, InputStream vocabChars) throws IOException {
91+
this(vocabWords, vocabChars);
92+
this.allowUnk = allowUnk;
93+
this.allowNum = allowNum;
94+
this.lowerCase = lowerCaseTokens;
95+
}
96+
4497
public WordIndexer(InputStream vocabWords, InputStream vocabChars) throws IOException {
4598
this.word2idx = new HashMap<>();
4699
try(BufferedReader in = new BufferedReader(new InputStreamReader(vocabWords, "UTF8"))) {
@@ -113,8 +166,8 @@ private Ids apply(String word) {
113166
word = StringUtil.toLowerCase(word);
114167
}
115168

116-
// if (digitPattern.matcher(word).find())
117-
// word = NUM;
169+
if (allowNum && digitPattern.matcher(word).find())
170+
word = NUM;
118171

119172
// 2. get id of word
120173
Integer wordId;

tf-ner-poc/src/main/python/namefinder/namefinder.py

+41-16
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
# This poc is based on source code taken from:
2121
# https://github.com/guillaumegenthial/sequence_tagging
2222

23-
import sys
2423
from math import floor
2524
import tensorflow as tf
2625
import re
2726
import numpy as np
2827
import zipfile
2928
import os
3029
from tempfile import TemporaryDirectory
30+
import argparse
3131

3232
# global variables for unknown word and numbers
3333
__UNK__ = '__UNK__'
@@ -68,12 +68,16 @@ def __str__(self):
6868
class NameFinder:
6969
label_dict = {}
7070

71-
def __init__(self, use_lower_case_embeddings=False, vector_size=100):
71+
def __init__(self, use_lower_case_embeddings, allow_unk, allow_num, digit_pattern, encoding, vector_size=100):
7272
self.__vector_size = vector_size
7373
self.__use_lower_case_embeddings = use_lower_case_embeddings
74+
self.__allow_unk = allow_unk
75+
self.__allow_num = allow_num
76+
self.__digit_pattern = re.compile(digit_pattern)
77+
self.__encoding = encoding
7478

7579
def load_data(self, word_dict, file):
76-
with open(file) as f:
80+
with open(file, encoding=self.__encoding) as f:
7781
raw_data = f.readlines()
7882

7983
sentences = []
@@ -96,7 +100,8 @@ def load_data(self, word_dict, file):
96100
if self.__use_lower_case_embeddings:
97101
token = token.lower()
98102

99-
# TODO: implement NUM encoding
103+
if self.__allow_num and self.__digit_pattern.match(token):
104+
token = __NUM__
100105

101106
if word_dict.get(token) is not None:
102107
vector = word_dict[token]
@@ -340,8 +345,8 @@ def write_mapping(tags, output_filename):
340345
f.write('{}\n'.format(tag))
341346

342347

343-
def load_glove(glove_file):
344-
with open(glove_file) as f:
348+
def load_glove(glove_file, encoding='utf-8'):
349+
with open(glove_file, encoding=encoding) as f:
345350

346351
word_dict = {}
347352
embeddings = []
@@ -381,16 +386,28 @@ def load_glove(glove_file):
381386

382387

383388
def main():
384-
if len(sys.argv) != 5:
385-
print("Usage namefinder.py embedding_file train_file dev_file test_file")
386-
return
387-
388-
word_dict, rev_word_dict, embeddings, vector_size = load_glove(sys.argv[1])
389-
390-
name_finder = NameFinder(vector_size)
391-
392-
sentences, labels, char_set = name_finder.load_data(word_dict, sys.argv[2])
393-
sentences_dev, labels_dev, char_set_dev = name_finder.load_data(word_dict, sys.argv[3])
389+
parser = argparse.ArgumentParser()
390+
parser.add_argument("embedding_file", help="path to the embeddings file.")
391+
parser.add_argument("train_file", help="path to the training file.")
392+
parser.add_argument("dev_file", help="path to the dev file.")
393+
parser.add_argument("--allow_unk", help="use general UNK vector for unknown tokens.", default=True)
394+
parser.add_argument("--allow_num", help="use general NUM vector for all numeric tokens.", default=False)
395+
parser.add_argument("--lower_case_embeddings", help="convert tokens to lowercase for embeddings lookup.",
396+
default=False)
397+
parser.add_argument("--digit_pattern", help="regex to use for identifying numeric tokens.",
398+
default='^\\d+(,\\d+)*(\\.\\d+)?$')
399+
parser.add_argument("--data_encoding", help="set encoding of train and dev data.", default='utf-8')
400+
parser.add_argument("--embeddings_encoding", help="set encoding of the embeddings.", default='utf-8')
401+
args = parser.parse_args()
402+
403+
word_dict, rev_word_dict, embeddings, vector_size = load_glove(args.embedding_file, args.embeddings_encoding)
404+
405+
name_finder = NameFinder(use_lower_case_embeddings=args.lower_case_embeddings, allow_unk=args.allow_unk,
406+
allow_num=args.allow_num, digit_pattern=args.digit_pattern,
407+
encoding=args.data_encoding, vector_size=vector_size)
408+
409+
sentences, labels, char_set = name_finder.load_data(word_dict, args.train_file)
410+
sentences_dev, labels_dev, char_set_dev = name_finder.load_data(word_dict, args.dev_file)
394411

395412
char_dict = {k: v for v, k in enumerate(char_set | char_set_dev)}
396413

@@ -472,6 +489,14 @@ def main():
472489
write_mapping(name_finder.label_dict, temp_model_dir + "/label_dict.txt")
473490
write_mapping(char_dict, temp_model_dir + "/char_dict.txt")
474491

492+
write_mapping({'lower_case_embeddings=' + str(args.lower_case_embeddings).lower(): 0,
493+
'allow_unk=' + str(args.allow_unk).lower(): 1,
494+
'allow_num=' + str(args.allow_num).lower(): 2,
495+
'digit_pattern=' + re.escape(args.digit_pattern): 3,
496+
'data_encoding=' + args.data_encoding: 4,
497+
'embeddings_encoding=' + args.embeddings_encoding: 5},
498+
temp_model_dir + "/config.properties")
499+
475500
zipf = zipfile.ZipFile("namefinder-" + str(epoch) + ".zip", 'w', zipfile.ZIP_DEFLATED)
476501

477502
for root, dirs, files in os.walk(temp_model_dir):

0 commit comments

Comments
 (0)