[NLP] Word Embedding - GloVe [practice]
๐ง๐ป๐ป ์ฃผ์ ์ ๋ฆฌ
NLP
Word Embedding
GloVe
์๋์ ๊ฐ์ด, GloVe์ ๋ํ ์ฝ๋๋ฅผ ์ดํด๋ด ๋๋ค.
GloVe
์๋ ์๋ฒ ๋ฉ์ ํ๋์ ์-ํซ ์ธ์ฝ๋ฉ ๋ฒกํฐ(ํ ์์๋ง 1์ด๊ณ ๋๋จธ์ง๋ 0์ธ ๋ฒกํฐ)๋ฅผ ํจ์ฌ ์์ ์ค์ ๊ฐ์ ๋ฒกํฐ๋ก ๋ณํํฉ๋๋ค. ์-ํซ ์ธ์ฝ๋ฉ ๋ฒกํฐ๋ ํฌ์ ๋ฒกํฐ์ด๋ฉฐ, ์ค์ ๊ฐ ๋ฒกํฐ๋ ๋ฐ์ง ๋ฒกํฐ์ ๋๋ค.
์ด ์๋ ์๋ฒ ๋ฉ์์ ๊ฐ์ฅ ์ค์ํ ๊ฐ๋ ์ ๋น์ทํ ๋ฌธ๋งฅ์ ๋ํ๋๋ ๋จ์ด๋ค์ ๋ฒกํฐ ๊ณต๊ฐ์์ ๊ฐ๊น์ด ์์นํ๋ค๋ ๊ฒ์ ๋๋ค. ์ฌ๊ธฐ์ ๋ฌธ๋งฅ์ด๋ ์ฃผ๋ณ ๋จ์ด๋ฅผ ๋งํฉ๋๋ค. ์๋ฅผ ๋ค์ด "I purchased some items at the shop"๊ณผ "I purchased some items at the store" ๋ ๋ฌธ์ฅ์์ 'shop'๊ณผ 'store' ๋จ์ด๋ ๊ฐ์ ๋ฌธ๋งฅ์ ๋ํ๋๊ธฐ ๋๋ฌธ์ ๋ฒกํฐ ๊ณต๊ฐ์์ ์๋ก ๊ฐ๊น์ด ์์ด์ผ ํฉ๋๋ค.
์ฌ๊ธฐ์ ์ด๋ฏธ GloVe ๋ฒกํฐ๋ก ๋ฏธ๋ฆฌ ํ์ต๋ ๋ฒกํฐ๋ฅผ ์ฌ์ฉํ ๊ฒ์ ๋๋ค. GloVe๋ word2vec๊ณผ ์ ์ฌํ์ง๋ง ๋ค๋ฅธ ์๊ณ ๋ฆฌ์ฆ์ ๋๋ค. ์ด๋ฌํ ๋ฏธ๋ฆฌ ํ์ต๋ ์๋ฒ ๋ฉ์ ๊ฑฐ๋ํ ๋ง๋ญ์น์์ ํ์ต๋์์ผ๋ฉฐ, ๋ชจ๋ธ ๋ด์์ ์ด๋ฅผ ์ฌ์ฉํ์ฌ ๋จ์ด์ ๋ฌธ๋งฅ์ ์ด๋ฏธ ํ์ตํ ์ํ์์ ์์์ ์ผ๋ก ์ฌ์ฉํ ์ ์์ต๋๋ค. ์ด๋ ์ผ๋ฐ์ ์ผ๋ก ๋ ๋น ๋ฅธ ํ์ต ์๊ฐ๊ณผ/๋๋ ํฅ์๋ ์ ํ๋๋ฅผ ์ ๊ณตํฉ๋๋ค.
ํ์ดํ ์น์์๋ ๋จ์ด ๋ฒกํฐ๋ฅผ nn.Embedding ๋ ์ด์ด๋ฅผ ์ฌ์ฉํ์ฌ ์ฌ์ฉํฉ๋๋ค. ์ด ๋ ์ด์ด๋ [๋ฌธ์ฅ ๊ธธ์ด, ๋ฐฐ์น ํฌ๊ธฐ] ํ ์๋ฅผ ๊ฐ์ ธ์ [๋ฌธ์ฅ ๊ธธ์ด, ๋ฐฐ์น ํฌ๊ธฐ, ์๋ฒ ๋ฉ ์ฐจ์] ํ ์๋ก ๋ณํํฉ๋๋ค. nn.Embedding ๋ ์ด์ด๋ ์ฒ์๋ถํฐ ํ๋ จํ ์๋ ์๊ณ , ๋ฏธ๋ฆฌ ํ์ต๋ ์๋ฒ ๋ฉ ๋ฐ์ดํฐ๋ก ์ด๊ธฐํํ๊ณ (์ ํ์ ์ผ๋ก ๊ณ ์ ์ํฌ ์๋ ์์) ์ฌ์ฉํ ์๋ ์์ต๋๋ค. nn.Embedding์ ํต์ฌ์ ๋ช ์์ ์ผ๋ก ์-ํซ ๋ฒกํฐ ํํ์ ์ฌ์ฉํ์ง ์์๋ ๋๋ค๋ ๊ฒ์ ๋๋ค. ๋จ์ํ ์ธ๋ฑ์ค๋ฅผ ๋ฒกํฐ์ ๋งคํํ๋ ๊ฒ์ ๋๋ค. ์ด๊ฒ์ ๊ณ์ฐ์ ์ธ ์ธก๋ฉด์์ ๋งค์ฐ ์ค์ํฉ๋๋ค.
๊ตฌ์ฒด์ ์ผ๋ก๋, nn.Embedding์ ์-ํซ sparse-๋ฒกํฐ๋ฅผ ๋ํ๋ด๋ ์ธ๋ฑ์ค์ ํด๋นํ๋ ๊ฐ์ค์น ํ๋ ฌ์ ์ด์ ์ ํํ์ฌ ๋ฎ์ ์ฐจ์ (dense) ์ถ๋ ฅ์ ์์ฑํ๋ ์ ํ ๋งต์ ๋๋ค. ์ด๋ฒ ํํธ์์๋ ๋ชจ๋ธ์ ํ๋ จํ์ง ์๊ณ , ๋จ์ด ์๋ฒ ๋ฉ์ ์ดํด๋ณด๊ณ , ๊ทธ๊ฒ๋ค๋ก ํ ์ ์๋ ๋ช ๊ฐ์ง ํฅ๋ฏธ๋ก์ด ๊ฒ๋ค์ ์กฐ์ฌํด ๋ณผ ๊ฒ์ ๋๋ค.
๋จผ์ , ๋ฏธ๋ฆฌ ํ์ต๋ GloVe ๋ฒกํฐ๋ฅผ ๋ก๋ํ ๊ฒ์ ๋๋ค. name ํ๋๋ ๋ฒกํฐ๊ฐ ํ์ต๋ ๋ฐ์ดํฐ๋ฅผ ์ง์ ํ๋ฉฐ, ์ฌ๊ธฐ์ 6B๋ 60์ต ๊ฐ์ ๋จ์ด ์ฝํผ์ค๋ฅผ ์๋ฏธํฉ๋๋ค. dim ์ธ์๋ ๋จ์ด ๋ฒกํฐ์ ์ฐจ์์ ์ง์ ํฉ๋๋ค. GloVe ๋ฒกํฐ๋ 50, 100, 200 ๋ฐ 300 ์ฐจ์์ผ๋ก ์ ๊ณต๋ฉ๋๋ค. ๋ํ 42B ๋ฐ 840B glove ๋ฒกํฐ๋ ์์ง๋ง, ์ด๋ค์ 300 ์ฐจ์์์๋ง ์ฌ์ฉํ ์ ์์ต๋๋ค. ์ด๊ฒ์ ์ฒ์ ์คํํ ๋, ๋ฒกํฐ๋ฅผ ๋ค์ด๋ก๋ํ๋ ๋ฐ ์๊ฐ์ด ๊ฑธ๋ฆฝ๋๋ค.
import torchtext.vocab
glove = torchtext.vocab.GloVe(name='6B', dim=100) # ์ด๋ฏธ ํ์ต๋ ๊ฒ์ ๋ถ๋ฌ์ค๊ธฐ๋ง ํ๋ ๊ฒ์. torchtext์ vocab์ ์กด์ฌ.
# hidden size๋ฅผ ๋ช์ผ๋ก ์ ํ ๋์ด๋ค. V -> hidden (100 / 300 / 512 / 256)
print(f'There are {len(glove.itos)} words in the vocabulary')
# itos -> dictionary์ ์ฌ์ด์ฆ๊ฐ ๋ฌด์์ด๋๋ ๊ฒ.
# ์ฐ๋ฆฌ๋ ์ง๊ธ ์๋์ ๊ฐ์ด 400000๊ฐ์ words๋ฅผ ์ฌ์ฉ ๊ฐ๋ฅํจ.
.vector_cache/glove.6B.zip: 862MB [02:39, 5.40MB/s]
100%|โโโโโโโโโโ| 399999/400000 [00:19<00:00, 20515.34it/s]
There are 400000 words in the vocabulary
์ฆ, ์ฐ๋ฆฌ๋ 40000๊ฐ์ words๊ฐ ๋ด๊ธด vocabulary๊ฐ ์๊ธด ๊ฒ์ ๋๋ค.
glove.vectors.shape
์ ๊ฒฐ๊ณผ๋ก
torch.Size([400000, 100])
์ด๊ฒ๊ณผ ๊ฐ์ด 40000 x 100 ์ฐจ์ ์ง๋ฆฌ๋ฅผ ๊ฐ์ ธ์ต๋๋ค.
๊ฐ ํ์ด ์ด๋ค ๋จ์ด์ ๊ด๋ จ์ด ์๋์ง๋ itos(int to string) ๋ฆฌ์คํธ๋ฅผ ํ์ธํ์ฌ ์ ์ ์์ต๋๋ค.
์๋์ ์์๋ 0๋ฒ ํ์ด 'the'์ ๊ด๋ จ๋ ๋ฒกํฐ, 1๋ฒ ํ์ด ','(์ผํ)์ ๊ด๋ จ๋ ๋ฒกํฐ, 2๋ฒ ํ์ด '.'(๋ง์นจํ)์ ๊ด๋ จ๋ ๋ฒกํฐ ๋ฑ์ผ๋ก ์ดํดํ ์ ์์ต๋๋ค.
glove.itos[:100]
์ด๊ฒ์ ์ฐ์ผ๋ฉด ์๋์ ๊ฐ์ด ๊ฒฐ๊ณผ๊ฐ ๋์ต๋๋ค.
['the',
',',
'.',
'of',
'to',
'and',
'in',
'a',
'"',
"'s",
'for',
'-',
'that',
'on',
'is',
'was',
'said',
'with',
'he',
.
.
.
'']
์ผ๋ถ ์๋ตํ์ต๋๋ค.
์์ ๊ฐ์ด ๋์จ ๊ฒ์ ๋ณด์์,
glove๋ ์ฌ์ ๊ณผ๋ ๊ฐ๋ค๋ ๊ฒ์ ์ ์ ์์์ต๋๋ค. ์ด๋ค ๋จ์ด์ ๋ํ embedding ๊ฐ์ ์ญ ๊ฐ์ง๊ณ ์๋ ๊ฒ์ ๋๋ค.
glove.stoi['the'] # ๋จ์ด๊ฐ ๋ช ๋ฒ์งธ์ ์กด์ฌํ๋. keyerror ๋ ์๋ ๊ฒ.
# ๊ฐ์ฅ ๋ง์ด ๋ฑ์ฅํ๋ ๋จ์ด๋ฅผ ์์์๋ถํฐ ๋ฃ๋๋ค.
0
๋จ์ด๊ฐ ์ด๋์ ์กด์ฌํ๋๋ ์ ์ ์์ต๋๋ค.
print(glove.stoi['the'])
glove.vectors[glove.stoi['the']]
์ฐ๋ฆฌ์ ๋ชฉ์ ์, ๋จ์ด์ ์ธํ ์ค์์ ๋ฒกํฐ๋ฅผ ๊ฐ์ ธ์ค๋ ๊ฒ์ ๋๋ค.
์ด ๋ฒกํฐ๋ 400000 x 100 ์ฐจ์ ์ง๋ฆฌ์ธ๋ฐ, the ๋ผ๋ ๋จ์ด๋ฅผ ๊ฐ์ ธ์ค๋ ค๋ฉด, the ๋ผ๋ ๋จ์ด์ ์ธ๋ฑ์ค๋ฅผ 0์ผ๋ก ๋ถ์ฌํจ.
์ปดํจํฐ๋ ๋จ์ด๋ฅผ ์์ง๋ ๋ชป ํ๋ค. ์ซ์๋ก ํํํ๋ค. ๊ทธ๋์ man - woman ์ ๊ฐ์ ์ฐ์ฐ์ด ๊ฐ๋ฅํ ๊ฒ์ด๋ค.
์๋ ๊ตฌํ๋ ํจ์๋ฅผ ์ดํด๋ณด์.
def get_vector(embeddings, word):
return embeddings.vectors[embeddings.stoi[word]]
- Glove๋ผ๋ embedding์ vectors๊ฐ์ ๋ฃ์.
- ํน์ embedding์ ๋ํ ๋จ์ด๋ฅผ ๋ฑ์ด๋ผ.
import torch
def closest_words(embeddings, vector, n=10):
distances = [(w, torch.dist(vector, get_vector(embeddings, w)).item()) for w in embeddings.itos]
return sorted(distances, key = lambda w: w[1])[:n]
- ๋ชจ๋ ๋จ์ด์ ์ฃผ์ด์ง ๋จ์ด์์ ๋ชจ๋ distance๋ฅผ ๊ตฌํด์ ๊ฐ์ฅ ๊ฑฐ๋ฆฌ๊ฐ ์งง์ ์ n๊ฐ๋ฅผ ๋ฆฌํดํด์ค.
- ๊ฐ์ฅ ๊ฐ๊น์ด ๊ฒ์ ์ฐพ์.
- embedding๊ณผ vector๊ฐ ๋ค์ด์ค๋ฉด ์ฐพ์.
- ์ ํด๋ฆฌ๋์ distance๋ก distance๋ฅผ ๊ตฌํด์์ ๊ณ์ฐํจ.
closest_words(glove, get_vector(glove, 'korea'))
- glove์ korea๋ผ๋ ๋จ์ด์ ๊ฐ์ฅ ๊ฐ๊น์ด ๋จ์ด 10๊ฐ๋ ์๋์ ๊ฐ๋ค.
[('korea', 0.0),
('pyongyang', 3.9039547443389893),
('korean', 4.068886756896973),
('dprk', 4.2631049156188965),
('seoul', 4.340494632720947),
('japan', 4.551243305206299),
('koreans', 4.615607738494873),
('south', 4.65822696685791),
('china', 4.8395185470581055),
('north', 4.986356735229492)]
def print_tuples(tuples):
for w, d in tuples:
print(f'({d:02.04f}) {w}')
์ tuple์ ์ด์๊ฒ ์ถ๋ ฅํ๋ ๊ฒ.
print_tuples(closest_words(glove, get_vector(glove, 'ai')))
์ฐ๋ฆฌ๋ ai๋ผ๋ ๋จ์ด๋ฅผ ์ํ์ง๋ง ,์๋์ ๊ฐ์ด ๋ค๋ฅธ ๋จ์ด๊ฐ ๋์ฌ ์ ์๋ค.
glove๋ผ๋ embedding์ ํ์ตํ ๋ฐ์ดํฐ๊ฐ generalํ๊ณ , ๋ํ ํ์์ ๋ฐ์ดํฐ๋ฅผ ํ์ตํ๊ธฐ ๋๋ฌธ์ ์ด๋ ๊ฒ ๋์ด.
(0.0000) ai
(4.5332) hey
(4.5842) ok
(4.6785) fukuhara
(4.8145) fortunately
(4.8299) cause
(4.8935) yeah
(4.9061) hi
(4.9083) luckily
(4.9333) …
def analogy(embeddings, word1, word2, word3, n=5):
candidate_words = closest_words(embeddings, get_vector(embeddings, word2) - get_vector(embeddings, word1) + get_vector(embeddings, word3), n+3)
candidate_words = [x for x in candidate_words if x[0] not in [word1, word2, word3]][:n]
print(f'{word1} is to {word2} as {word3} is to...')
return candidate_words
- vector๋ก ํํ ํ์ผ๋ ๋นผ๊ธฐ ๋ํ๊ธฐ ์ด๋ฐ ๊ฒ์ด ๋์ง ์์๊น ?
- king + woman - man =====> queen?์ด ๋์ค์ง ์์๊น?
- ์ด๋ค embedding ๊ฐ์ด ๋์ฌ ํ ๋ฐ, ๊ทธ๊ฒ์ด 400000๊ฐ ์ค ๊ฐ์ ๊ฒ ๋์ฌ ๊ฐ๋ฅ์ฑ์ ๊ฑฐ์ ์๋ค.
- ๊ทธ ๊ฒฐ๊ณผ๋ก ๋์จ vector์ 40000 ๋จ์ด ์ค์์ ๊ทธ๋๋ง distance๊ฐ ์์ ๊ฒ์ ๊ณ ๋ฅธ๋ค.
print_tuples(analogy(glove, 'man', 'king', 'woman', n = 10))
print_tuples(analogy(glove, 'seoul', 'korea', 'india', n = 10))
์ ํ๊ธด ํ์ง๋ง ์ฐ๋ฆฌ๊ฐ ์๋ฒฝํ ์ํ๋ ๊ฒ์ ์ป๋ ๋ฐ์๋ ํ๊ณ๊ฐ ์๋ค.
man is to king as woman is to...
(4.0811) queen
(4.6429) monarch
(4.9055) throne
(4.9216) elizabeth
(4.9811) prince
(4.9857) daughter
(5.0641) mother
(5.0775) cousin
(5.0787) princess
(5.1283) widow
seoul is to korea as india is to...
(5.8343) pakistan
(6.2924) lanka
(6.5571) australia
(6.5643) bangladesh
(6.5883) africa
(6.6894) sri
(6.7463) indonesia
(6.7763) indian
(6.9396) japan
(6.9865) zealand
๋ฐ๋ก ํ์ตํ์ง ์๊ณ LSTM์ด๋ ๋ค๋ฅธ ๊ฒ์ ๊ปด์ ํ์ต์ ์งํํฉ๋๋ค.
'Artificial Intelligence > Natural Language Processing' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[NLP] RNN - LSTM, GRU (0) | 2023.04.04 |
---|---|
[NLP] RNN (0) | 2023.04.04 |
[NLP] Word Embedding - GloVe (0) | 2023.03.31 |
[NLP] Word Embedding - CBOW and Skip-Gram (2) | 2023.03.27 |
[NLP] Word Embedding - Word2Vec (0) | 2023.03.27 |