Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit f44f6cf

Browse files
daveliepmannkedarbellare
authored andcommitted
Extend Clojure BERT example (#15023)
* Clojure predictor example: add rich comment This provides an entry point for folks working on this example in their REPL rather than the command line. * Clojure BERT example: refactor prepare-data fn for purity * Clojure BERT example: test fitted model on samples * Clojure BERT example: namespace docstring & comment * Clojure BERT example: format intro, add references * Clojure BERT example: minor refactor * Clojure BERT example: trim sentence pair explorations * Clojure BERT example: port experiment to iPynb * Clojure BERT example: fix test Underlying fn was refactored * Clojure BERT example: add sentence-pair prediction test
1 parent 0340536 commit f44f6cf

File tree

4 files changed

+248
-40
lines changed

4 files changed

+248
-40
lines changed

contrib/clojure-package/examples/bert/fine-tune-bert.ipynb

Lines changed: 132 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@
1010
"\n",
1111
"Pre-trained language representations have been shown to improve many downstream NLP tasks such as question answering, and natural language inference. To apply pre-trained representations to these tasks, there are two strategies:\n",
1212
"\n",
13-
"feature-based approach, which uses the pre-trained representations as additional features to the downstream task.\n",
14-
"fine-tuning based approach, which trains the downstream tasks by fine-tuning pre-trained parameters.\n",
15-
"While feature-based approaches such as ELMo [3] (introduced in the previous tutorial) are effective in improving many downstream tasks, they require task-specific architectures. Devlin, Jacob, et al proposed BERT [1] (Bidirectional Encoder Representations from Transformers), which fine-tunes deep bidirectional representations on a wide range of tasks with minimal task-specific parameters, and obtained state- of-the-art results.\n",
13+
" - **feature-based approach**, which uses the pre-trained representations as additional features to the downstream task.\n",
14+
" - **fine-tuning based approach**, which trains the downstream tasks by fine-tuning pre-trained parameters.\n",
15+
" \n",
16+
"While feature-based approaches such as ELMo [1] are effective in improving many downstream tasks, they require task-specific architectures. Devlin, Jacob, et al proposed BERT [2] (Bidirectional Encoder Representations from Transformers), which fine-tunes deep bidirectional representations on a wide range of tasks with minimal task-specific parameters, and obtained state- of-the-art results.\n",
1617
"\n",
1718
"In this tutorial, we will focus on fine-tuning with the pre-trained BERT model to classify semantically equivalent sentence pairs. Specifically, we will:\n",
1819
"\n",
19-
"load the state-of-the-art pre-trained BERT model and attach an additional layer for classification,\n",
20-
"process and transform sentence pair data for the task at hand, and\n",
21-
"fine-tune BERT model for sentence classification.\n",
20+
" 1. load the state-of-the-art pre-trained BERT model and attach an additional layer for classification\n",
21+
" 2. process and transform sentence pair data for the task at hand, and \n",
22+
" 3. fine-tune BERT model for sentence classification.\n",
2223
"\n"
2324
]
2425
},
@@ -59,6 +60,7 @@
5960
" [org.apache.clojure-mxnet.callback :as callback]\n",
6061
" [org.apache.clojure-mxnet.context :as context]\n",
6162
" [org.apache.clojure-mxnet.dtype :as dtype]\n",
63+
" [org.apache.clojure-mxnet.infer :as infer]\n",
6264
" [org.apache.clojure-mxnet.eval-metric :as eval-metric]\n",
6365
" [org.apache.clojure-mxnet.io :as mx-io]\n",
6466
" [org.apache.clojure-mxnet.layout :as layout]\n",
@@ -89,7 +91,7 @@
8991
"\n",
9092
"![bert](https://gluon-nlp.mxnet.io/_images/bert-sentence-pair.png)\n",
9193
"\n",
92-
"where the model takes a pair of sequences and pools the representation of the first token in the sequence. Note that the original BERT model was trained for masked language model and next sentence prediction tasks, which includes layers for language model decoding and classification. These layers will not be used for fine-tuning sentence pair classification.\n",
94+
"where the model takes a pair of sequences and *pools* the representation of the first token in the sequence. Note that the original BERT model was trained for masked language model and next sentence prediction tasks, which includes layers for language model decoding and classification. These layers will not be used for fine-tuning sentence pair classification.\n",
9395
"\n",
9496
"Let's load the pre-trained BERT using the module API in MXNet."
9597
]
@@ -114,12 +116,15 @@
114116
],
115117
"source": [
116118
"(def model-path-prefix \"data/static_bert_base_net\")\n",
119+
"\n",
117120
";; the vocabulary used in the model\n",
118121
"(def vocab (bert-util/get-vocab))\n",
119-
";; the input question\n",
122+
"\n",
120123
";; the maximum length of the sequence\n",
121124
"(def seq-length 128)\n",
122125
"\n",
126+
"(def batch-size 32)\n",
127+
"\n",
123128
"(def bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0}))"
124129
]
125130
},
@@ -291,7 +296,7 @@
291296
"source": [
292297
"(defn pre-processing\n",
293298
" \"Preprocesses the sentences in the format that BERT is expecting\"\n",
294-
" [ctx idx->token token->idx train-item]\n",
299+
" [idx->token token->idx train-item]\n",
295300
" (let [[sentence-a sentence-b label] train-item\n",
296301
" ;;; pre-processing tokenize sentence\n",
297302
" token-1 (bert-util/tokenize (string/lower-case sentence-a))\n",
@@ -319,7 +324,7 @@
319324
"(def idx->token (:idx->token vocab))\n",
320325
"(def token->idx (:token->idx vocab))\n",
321326
"(def dev (context/default-context))\n",
322-
"(def processed-datas (mapv #(pre-processing dev idx->token token->idx %) data-train-raw))\n",
327+
"(def processed-datas (mapv #(pre-processing idx->token token->idx %) data-train-raw))\n",
323328
"(def train-count (count processed-datas))\n",
324329
"(println \"Train Count is = \" train-count)\n",
325330
"(println \"[PAD] token id = \" (get token->idx \"[PAD]\"))\n",
@@ -375,8 +380,6 @@
375380
" (into []))\n",
376381
" :train-num (count processed-datas)})\n",
377382
"\n",
378-
"(def batch-size 32)\n",
379-
"\n",
380383
"(def train-data\n",
381384
" (let [{:keys [data0s data1s data2s labels train-num]} prepared-data\n",
382385
" data-desc0 (mx-io/data-desc {:name \"data0\"\n",
@@ -480,7 +483,7 @@
480483
"(def num-epoch 3)\n",
481484
"\n",
482485
"(def fine-tune-model (m/module model-sym {:contexts [dev]\n",
483-
" :data-names [\"data0\" \"data1\" \"data2\"]}))\n",
486+
" :data-names [\"data0\" \"data1\" \"data2\"]}))\n",
484487
"\n",
485488
"(m/fit fine-tune-model {:train-data train-data :num-epoch num-epoch\n",
486489
" :fit-params (m/fit-params {:allow-missing true\n",
@@ -489,6 +492,122 @@
489492
" :optimizer (optimizer/adam {:learning-rate 5e-6 :episilon 1e-9})\n",
490493
" :batch-end-callback (callback/speedometer batch-size 1)})})\n"
491494
]
495+
},
496+
{
497+
"cell_type": "markdown",
498+
"metadata": {},
499+
"source": [
500+
"### Explore results from the fine-tuned model\n",
501+
"\n",
502+
"Now that our model is fitted, we can use it to infer semantic equivalence of arbitrary sentence pairs. Note that for demonstration purpose we skipped the warmup learning rate schedule and validation on dev dataset used in the original implementation. This means that our model's performance will be significantly less than optimal. Please visit [here](https://gluon-nlp.mxnet.io/model_zoo/bert/index.html) for the complete fine-tuning scripts (using Python and GluonNLP).\n",
503+
"\n",
504+
"To do inference with our model we need a predictor. It must have a batch size of 1 so we can feed the model a single sentence pair."
505+
]
506+
},
507+
{
508+
"cell_type": "code",
509+
"execution_count": 14,
510+
"metadata": {},
511+
"outputs": [
512+
{
513+
"data": {
514+
"text/plain": [
515+
"#'bert.bert-sentence-classification/fine-tuned-predictor"
516+
]
517+
},
518+
"execution_count": 14,
519+
"metadata": {},
520+
"output_type": "execute_result"
521+
}
522+
],
523+
"source": [
524+
"(def fine-tuned-prefix \"fine-tune-sentence-bert\")\n",
525+
"\n",
526+
"(m/save-checkpoint fine-tune-model {:prefix fine-tuned-prefix :epoch 3})\n",
527+
"\n",
528+
"(def fine-tuned-predictor\n",
529+
" (infer/create-predictor (infer/model-factory fine-tuned-prefix\n",
530+
" [{:name \"data0\" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT}\n",
531+
" {:name \"data1\" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT}\n",
532+
" {:name \"data2\" :shape [1] :dtype dtype/FLOAT32 :layout layout/N}])\n",
533+
" {:epoch 3}))"
534+
]
535+
},
536+
{
537+
"cell_type": "markdown",
538+
"metadata": {},
539+
"source": [
540+
"Now we can write a function that feeds a sentence pair to the fine-tuned model:"
541+
]
542+
},
543+
{
544+
"cell_type": "code",
545+
"execution_count": 15,
546+
"metadata": {},
547+
"outputs": [
548+
{
549+
"data": {
550+
"text/plain": [
551+
"#'bert.bert-sentence-classification/predict-equivalence"
552+
]
553+
},
554+
"execution_count": 15,
555+
"metadata": {},
556+
"output_type": "execute_result"
557+
}
558+
],
559+
"source": [
560+
"(defn predict-equivalence\n",
561+
" [predictor sentence1 sentence2]\n",
562+
" (let [vocab (bert.util/get-vocab)\n",
563+
" processed-test-data (mapv #(pre-processing (:idx->token vocab)\n",
564+
" (:token->idx vocab) %)\n",
565+
" [[sentence1 sentence2]])\n",
566+
" prediction (infer/predict-with-ndarray predictor\n",
567+
" [(ndarray/array (slice-inputs-data processed-test-data 0) [1 seq-length])\n",
568+
" (ndarray/array (slice-inputs-data processed-test-data 1) [1 seq-length])\n",
569+
" (ndarray/array (slice-inputs-data processed-test-data 2) [1])])]\n",
570+
" (ndarray/->vec (first prediction))))"
571+
]
572+
},
573+
{
574+
"cell_type": "code",
575+
"execution_count": 22,
576+
"metadata": {},
577+
"outputs": [
578+
{
579+
"data": {
580+
"text/plain": [
581+
"[0.2633881 0.7366119]"
582+
]
583+
},
584+
"execution_count": 22,
585+
"metadata": {},
586+
"output_type": "execute_result"
587+
}
588+
],
589+
"source": [
590+
";; Modify an existing sentence pair to test:\n",
591+
";; [\"1\"\n",
592+
";; \"69773\"\n",
593+
";; \"69792\"\n",
594+
";; \"Cisco pared spending to compensate for sluggish sales .\"\n",
595+
";; \"In response to sluggish sales , Cisco pared spending .\"]\n",
596+
"(predict-equivalence fine-tuned-predictor\n",
597+
" \"The company cut spending to compensate for weak sales .\"\n",
598+
" \"In response to poor sales results, the company cut spending .\")"
599+
]
600+
},
601+
{
602+
"cell_type": "markdown",
603+
"metadata": {},
604+
"source": [
605+
"## References\n",
606+
"\n",
607+
"[1] Peters, Matthew E., et al. “Deep contextualized word representations.” arXiv preprint arXiv:1802.05365 (2018).\n",
608+
"\n",
609+
"[2] Devlin, Jacob, et al. “Bert: Pre-training of deep bidirectional transformers for language understanding.” arXiv preprint arXiv:1810.04805 (2018)."
610+
]
492611
}
493612
],
494613
"metadata": {

contrib/clojure-package/examples/bert/src/bert/bert_sentence_classification.clj

Lines changed: 89 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,42 +16,67 @@
1616
;;
1717

1818
(ns bert.bert-sentence-classification
19+
"Fine-tuning Sentence Pair Classification with BERT
20+
This tutorial focuses on fine-tuning with the pre-trained BERT model to classify semantically equivalent sentence pairs.
21+
22+
Specifically, we will:
23+
1. load the state-of-the-art pre-trained BERT model
24+
2. attach an additional layer for classification
25+
3. process and transform sentence pair data for the task at hand
26+
4. fine-tune BERT model for sentence classification"
1927
(:require [bert.util :as bert-util]
2028
[clojure-csv.core :as csv]
2129
[clojure.string :as string]
2230
[org.apache.clojure-mxnet.callback :as callback]
2331
[org.apache.clojure-mxnet.context :as context]
2432
[org.apache.clojure-mxnet.dtype :as dtype]
33+
[org.apache.clojure-mxnet.infer :as infer]
2534
[org.apache.clojure-mxnet.io :as mx-io]
2635
[org.apache.clojure-mxnet.layout :as layout]
2736
[org.apache.clojure-mxnet.module :as m]
2837
[org.apache.clojure-mxnet.ndarray :as ndarray]
2938
[org.apache.clojure-mxnet.optimizer :as optimizer]
3039
[org.apache.clojure-mxnet.symbol :as sym]))
3140

41+
;; Pre-trained language representations have been shown to improve
42+
;; many downstream NLP tasks such as question answering, and natural
43+
;; language inference. To apply pre-trained representations to these
44+
;; tasks, there are two strategies:
45+
46+
;; * feature-based approach, which uses the pre-trained representations as additional features to the downstream task.
47+
;; * fine-tuning based approach, which trains the downstream tasks by fine-tuning pre-trained parameters.
48+
49+
;; While feature-based approaches such as ELMo are effective in
50+
;; improving many downstream tasks, they require task-specific
51+
;; architectures. Devlin, Jacob, et al proposed BERT (Bidirectional
52+
;; Encoder Representations from Transformers), which fine-tunes deep
53+
;; bidirectional representations on a wide range of tasks with minimal
54+
;; task-specific parameters, and obtained state-of-the-art results.
55+
3256
(def model-path-prefix "data/static_bert_base_net")
33-
;; epoch number of the model
57+
58+
(def fine-tuned-prefix "fine-tune-sentence-bert")
59+
3460
;; the maximum length of the sequence
3561
(def seq-length 128)
3662

3763
(defn pre-processing
3864
"Preprocesses the sentences in the format that BERT is expecting"
3965
[idx->token token->idx train-item]
4066
(let [[sentence-a sentence-b label] train-item
41-
;;; pre-processing tokenize sentence
67+
;; pre-processing tokenize sentence
4268
token-1 (bert-util/tokenize (string/lower-case sentence-a))
4369
token-2 (bert-util/tokenize (string/lower-case sentence-b))
4470
valid-length (+ (count token-1) (count token-2))
45-
;;; generate token types [0000...1111...0000]
71+
;; generate token types [0000...1111...0000]
4672
qa-embedded (into (bert-util/pad [] 0 (count token-1))
47-
4873
(bert-util/pad [] 1 (count token-2)))
4974
token-types (bert-util/pad qa-embedded 0 seq-length)
50-
;;; make BERT pre-processing standard
75+
;; make BERT pre-processing standard
5176
token-2 (conj token-2 "[SEP]")
5277
token-1 (into [] (concat ["[CLS]"] token-1 ["[SEP]"] token-2))
5378
tokens (bert-util/pad token-1 "[PAD]" seq-length)
54-
;;; pre-processing - token to index translation
79+
;; pre-processing - token to index translation
5580
indexes (bert-util/tokens->idxs token->idx tokens)]
5681
{:input-batch [indexes
5782
token-types
@@ -83,19 +108,18 @@
83108

84109
(defn get-raw-data []
85110
(csv/parse-csv (string/replace (slurp "data/dev.tsv") "\"" "")
86-
:delimiter \tab
87-
:strict true))
111+
:delimiter \tab
112+
:strict true))
88113

89114
(defn prepare-data
90-
"This prepares the senetence pairs into NDArrays for use in NDArrayIterator"
91-
[]
92-
(let [raw-file (get-raw-data)
93-
vocab (bert-util/get-vocab)
115+
"This prepares the sentence pairs into NDArrays for use in NDArrayIterator"
116+
[raw-data]
117+
(let [vocab (bert-util/get-vocab)
94118
idx->token (:idx->token vocab)
95119
token->idx (:token->idx vocab)
96-
data-train-raw (->> raw-file
120+
data-train-raw (->> raw-data
97121
(mapv #(vals (select-keys % [3 4 0])))
98-
(rest) ;;drop header
122+
(rest) ; drop header
99123
(into []))
100124
processed-datas (mapv #(pre-processing idx->token token->idx %) data-train-raw)]
101125
{:data0s (slice-inputs-data processed-datas 0)
@@ -111,7 +135,7 @@
111135
[dev num-epoch]
112136
(let [bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0})
113137
model-sym (fine-tune-model (m/symbol bert-base) {:num-classes 2 :dropout 0.1})
114-
{:keys [data0s data1s data2s labels train-num]} (prepare-data)
138+
{:keys [data0s data1s data2s labels train-num]} (prepare-data (get-raw-data))
115139
batch-size 32
116140
data-desc0 (mx-io/data-desc {:name "data0"
117141
:shape [train-num seq-length]
@@ -138,14 +162,16 @@
138162
{:label {label-desc (ndarray/array labels [train-num]
139163
{:ctx dev})}
140164
:data-batch-size batch-size})
141-
model (m/module model-sym {:contexts [dev]
142-
:data-names ["data0" "data1" "data2"]})]
143-
(m/fit model {:train-data train-data :num-epoch num-epoch
144-
:fit-params (m/fit-params {:allow-missing true
145-
:arg-params (m/arg-params bert-base)
146-
:aux-params (m/aux-params bert-base)
147-
:optimizer (optimizer/adam {:learning-rate 5e-6 :episilon 1e-9})
148-
:batch-end-callback (callback/speedometer batch-size 1)})})))
165+
fitted-model (m/fit (m/module model-sym {:contexts [dev]
166+
:data-names ["data0" "data1" "data2"]})
167+
{:train-data train-data :num-epoch num-epoch
168+
:fit-params (m/fit-params {:allow-missing true
169+
:arg-params (m/arg-params bert-base)
170+
:aux-params (m/aux-params bert-base)
171+
:optimizer (optimizer/adam {:learning-rate 5e-6 :epsilon 1e-9})
172+
:batch-end-callback (callback/speedometer batch-size 1)})})]
173+
(m/save-checkpoint fitted-model {:prefix fine-tuned-prefix :epoch num-epoch})
174+
fitted-model))
149175

150176
(defn -main [& args]
151177
(let [[dev-arg num-epoch-arg] args
@@ -154,7 +180,46 @@
154180
(println "Running example with " dev " and " num-epoch " epochs ")
155181
(train dev num-epoch)))
156182

183+
;; For evaluating the model
184+
(defn predict-equivalence
185+
"Get the fine-tuned model's opinion on whether two sentences are equivalent:"
186+
[predictor sentence1 sentence2]
187+
(let [vocab (bert.util/get-vocab)
188+
processed-test-data (mapv #(pre-processing (:idx->token vocab)
189+
(:token->idx vocab) %)
190+
[[sentence1 sentence2]])
191+
prediction (infer/predict-with-ndarray predictor
192+
[(ndarray/array (slice-inputs-data processed-test-data 0) [1 seq-length])
193+
(ndarray/array (slice-inputs-data processed-test-data 1) [1 seq-length])
194+
(ndarray/array (slice-inputs-data processed-test-data 2) [1])])]
195+
(ndarray/->vec (first prediction))))
196+
157197
(comment
158198

159199
(train (context/cpu 0) 3)
160-
(m/save-checkpoint model {:prefix "fine-tune-sentence-bert" :epoch 3}))
200+
201+
(m/save-checkpoint model {:prefix fine-tuned-prefix :epoch 3})
202+
203+
204+
;;;; Explore results from the fine-tuned model
205+
206+
;; We need a predictor with a batch size of 1, so we can feed the
207+
;; model a single sentence pair.
208+
(def fine-tuned-predictor
209+
(infer/create-predictor (infer/model-factory fine-tuned-prefix
210+
[{:name "data0" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT}
211+
{:name "data1" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT}
212+
{:name "data2" :shape [1] :dtype dtype/FLOAT32 :layout layout/N}])
213+
{:epoch 3}))
214+
215+
;; Modify an existing sentence pair to test:
216+
;; ["1"
217+
;; "69773"
218+
;; "69792"
219+
;; "Cisco pared spending to compensate for sluggish sales ."
220+
;; "In response to sluggish sales , Cisco pared spending ."]
221+
(predict-equivalence fine-tuned-predictor
222+
"The company cut spending to compensate for weak sales ."
223+
"In response to poor sales results, the company cut spending .")
224+
225+
)

0 commit comments

Comments
 (0)