Inference Model API
This documents the basic API for the system. The methods here just a small subset of the various functions in the system but should represent the majority of what's needed for run-time (aka inference) operation.
For training and test examples see Training or refer directly to the code.
For additional examples, see the scripts/xx
directories or the tests/xx
directories.
Sequence to Graph Functions (Parsing)
load_stog_model()
stog = load_stog_model(model_dir=None, **kwargs)
This method loads the sequence to graph model (aka parser).
If no model_dir
is not supplied the default of amrlib/data/model_stog
is used.
kwargs
can be used to pass parameters such as device
, batch_size
, beam_size
, etc to the
inference routine.
See specific model descriptions for additional parameters and their use.
The function returns a STOGInferenceBase
type object which is a simple abstract base class for the underlying model.
Inference.parse_sents()
graphs = parse_sents(sents, add_metadata=True)
This method takes a list of sentence strings and converts them into a list of AMR graphs.
The optional parameter add_metadata
tells the system if metadata such as "id", "snt", etc..
should appear at the top of the graph string.
Example
import amrlib
stog = amrlib.load_stog_model()
graphs = stog.parse_sents(['This is a test of the system.', 'This is a second sentence.'])
for graph in graphs:
print(graph)
Graph to Sequence Functions (Generation)
load_gtos_model()
gtos = load_gtos_model(model_dir=None, **kwargs)
This method loads the graph to sequence model (aka generator).
If no model_dir
is specified the default amrlib/data/model_gtos
is used.
kwargs
can be used to pass parameters such as device
, batch_size
, num_beams
, num_ret_seq
, etc.
See specific [model descriptions]((https://amrlib.readthedocs.io/en/latest/models/) for additional parameters and their use.
device
is automatically selected but you can pass cpu
or cuda:0
if needed.
The function returns a GTOSInferenceBase
type object which is a simple abstract base class for the underlying model.
Inference.generate()
sents, clips = generate(graphs, disable_progress=False)
This method takes a list of AMR graph strings and returns a list of sentence strings and a list of booleans.
The boolean list clips
tells if any of the returned sentences were clipped as a result of
the tokenized graph being too long for the model.
disable_progress
can be used to turn off the default tqdm
progress bar.
Inference.get_ans_group()
sents = get_ans_group(answers, group_num)
This is a simple slicing function that returns all the sentences associated with the input graph number.
answers
is the returned list from generate()
and group_num
is the input graph number. The method
will return num_req_seq
sentence strings.
Example
import amrlib
gtos = amrlib.load_gtos_model()
sents, _ = gtos.generate(graphs, disable_progress=True)
for sent in sents:
print(sent)
for gnum in range(len(graphs)):
print('graph number', gnum)
for sent in gtos.get_ans_group(sents, gnum):
print(sent)