Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring #7

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
chore: running cargo fmt because I forgot to
MostWrong committed Apr 26, 2023
commit eae7a472797b3c73bdaff2908347fb276c272007
138 changes: 86 additions & 52 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,44 @@
#![allow(non_snake_case)]

extern crate random_choice;
extern crate fnv;
extern crate random_choice;

use std::collections::{HashSet, HashMap};
use std::cmp::max;
use fnv::FnvHashMap;
use self::random_choice::random_choice;
use fnv::FnvHashMap;
use std::cmp::max;
use std::collections::{HashMap, HashSet};

pub struct GSDMM {
alpha: f64,
beta: f64,
K:usize,
V:f64,
D:usize,
maxit:isize,
K: usize,
V: f64,
D: usize,
maxit: isize,
clusters: Vec<usize>,
pub doc_vectors:Vec<Vec<usize>>,
pub doc_vectors: Vec<Vec<usize>>,
pub labels: Vec<usize>,
pub cluster_counts: Vec<u32>,
pub cluster_word_counts:Vec<u32>,
pub word_index_map:HashMap<String, usize>,
pub index_word_map:HashMap<usize, String>,
pub cluster_word_distributions: Vec<FnvHashMap<usize,u32>>
pub cluster_word_counts: Vec<u32>,
pub word_index_map: HashMap<String, usize>,
pub index_word_map: HashMap<usize, String>,
pub cluster_word_distributions: Vec<FnvHashMap<usize, u32>>,
}

impl GSDMM {
pub fn new(alpha:f64, beta:f64, K: usize, maxit:isize, vocab:&HashSet<String>, docs:&Vec<Vec<String>>) -> GSDMM {
pub fn new(
alpha: f64,
beta: f64,
K: usize,
maxit: isize,
vocab: &HashSet<String>,
docs: &Vec<Vec<String>>,
) -> GSDMM {
let D = docs.len();

// compute utilized vocabulary size.
let mut word_index_map = HashMap::<String, usize>::with_capacity(vocab.len()/2);
let mut index_word_map = HashMap::<usize, String>::with_capacity(vocab.len()/2);
let mut word_index_map = HashMap::<String, usize>::with_capacity(vocab.len() / 2);
let mut index_word_map = HashMap::<usize, String>::with_capacity(vocab.len() / 2);
let mut index = 0_usize;
let mut doc_vectors = Vec::<Vec<usize>>::with_capacity(D);
for doc in docs {
@@ -40,7 +47,7 @@ impl GSDMM {
if !word_index_map.contains_key(word) {
word_index_map.insert(word.clone(), index);
index_word_map.insert(index, word.clone());
index+=1;
index += 1;
}
doc_vector.push(*word_index_map.get(word).unwrap());
}
@@ -54,22 +61,28 @@ impl GSDMM {
doc_vectors.push(doc_vector);
}
let V = index as f64;
println!("Fitting with alpha={}, beta={}, K={}, maxit={}, vocab size={}", alpha, beta, K, maxit, V as u32);
println!(
"Fitting with alpha={}, beta={}, K={}, maxit={}, vocab size={}",
alpha, beta, K, maxit, V as u32
);

let clusters = (0_usize..K).collect::<Vec<usize>>();
let mut d_z: Vec<usize> = (0_usize..D).map(|_| 0_usize).collect::<Vec<usize>>(); // doc labels
let mut m_z: Vec<u32> = GSDMM::zero_vector(K); // cluster sizes
let mut n_z: Vec<u32> = GSDMM::zero_vector(K); // cluster word counts
let mut n_z_w = Vec::<FnvHashMap<usize, u32>>::with_capacity(K); // container for cluster word distributions
let mut m_z: Vec<u32> = GSDMM::zero_vector(K); // cluster sizes
let mut n_z: Vec<u32> = GSDMM::zero_vector(K); // cluster word counts
let mut n_z_w = Vec::<FnvHashMap<usize, u32>>::with_capacity(K); // container for cluster word distributions
for _ in 0_usize..K {
let m = FnvHashMap::<usize, u32>::with_capacity_and_hasher(max(vocab.len() / 10, 100), Default::default());
let m = FnvHashMap::<usize, u32>::with_capacity_and_hasher(
max(vocab.len() / 10, 100),
Default::default(),
);
let _ = &n_z_w.push(m);
}

// randomly initialize cluster assignment
let p = (0..K).map(|_| 1_f64 / (K as f64)).collect::<Vec<f64>>();

let choices = random_choice().random_choice_f64(&clusters, &p, D) ;
let choices = random_choice().random_choice_f64(&clusters, &p, D);
for i in 0..D {
let z = *choices[i];
let doc = &doc_vectors[i];
@@ -81,7 +94,7 @@ impl GSDMM {
if !clust_words.contains_key(word) {
clust_words.insert(*word, 0_u32);
}
* clust_words.get_mut(word).unwrap() += 1_u32;
*clust_words.get_mut(word).unwrap() += 1_u32;
}
}

@@ -93,13 +106,13 @@ impl GSDMM {
D,
maxit,
doc_vectors,
clusters: clusters.clone(), // Don't totally get why we need the clone here!
clusters: clusters.clone(), // Don't totally get why we need the clone here!
labels: d_z,
cluster_counts: m_z,
cluster_word_counts: n_z,
word_index_map,
index_word_map,
cluster_word_distributions: n_z_w
cluster_word_distributions: n_z_w,
}
}

@@ -118,7 +131,8 @@ impl GSDMM {

// modify the map: enclose it in a block so we can borrow views again.
{
let old_clust_words: &mut FnvHashMap<usize, u32> = &mut self.cluster_word_distributions[z_old];
let old_clust_words: &mut FnvHashMap<usize, u32> =
&mut self.cluster_word_distributions[z_old];
for word in doc {
*old_clust_words.get_mut(word).unwrap() -= 1_u32;

@@ -144,37 +158,48 @@ impl GSDMM {
self.cluster_word_counts[z_new] += doc_size;

{
let new_clust_words: &mut FnvHashMap<usize, u32> = &mut self.cluster_word_distributions[z_new];
let new_clust_words: &mut FnvHashMap<usize, u32> =
&mut self.cluster_word_distributions[z_new];
for word in doc {
if !new_clust_words.contains_key(word) {
new_clust_words.insert(*word, 0_u32);
}
*new_clust_words.get_mut(word).unwrap() += 1_u32;
*new_clust_words.get_mut(word).unwrap() += 1_u32;
}
}
}
let new_number_clusters = self.cluster_word_distributions.iter().map(|c| if !c.is_empty() {1} else {0} ).sum();
println!("Iteration {}: {} docs transferred with {} clusters populated.", it, total_transfers, new_number_clusters);
let new_number_clusters = self
.cluster_word_distributions
.iter()
.map(|c| if !c.is_empty() { 1 } else { 0 })
.sum();
println!(
"Iteration {}: {} docs transferred with {} clusters populated.",
it, total_transfers, new_number_clusters
);

// apply ad-hoc convergence test
if total_transfers==0 && new_number_clusters==number_clusters {
println!("Converged after {} iterations. Solution has {} clusters.", it, new_number_clusters);
break
if total_transfers == 0 && new_number_clusters == number_clusters {
println!(
"Converged after {} iterations. Solution has {} clusters.",
it, new_number_clusters
);
break;
}
number_clusters = new_number_clusters;
}
}

pub fn score(&self, doc:&Vec<usize>) -> Vec<f64> {
pub fn score(&self, doc: &Vec<usize>) -> Vec<f64> {
// Score an input document using the formula of Yin and Wang 2014 (equation 3)
// http://dbgroup.cs.tsinghua.edu.cn/wangjy/papers/KDD14-GSDMM.pdf
//
//
// # Arguments
//
//
// * `doc` - A vector of unique index tokens characterizing the document
//
//
// # Value
//
//
// Vec<f64> - A length K probability vector where each component represents the probability
// of the doc belonging to a particular cluster.
//
@@ -193,36 +218,34 @@ impl GSDMM {
let mut lN2 = 0_f64;
let mut lD2 = 0_f64;
let cluster: &FnvHashMap<usize, u32> = &self.cluster_word_distributions[label];

for word in doc {
lN2 += (*cluster.get(word).unwrap_or(&0_u32) as f64 + self.beta).ln();
}
for j in 1_u32..(doc_size + 1) {
lD2 += ((self.cluster_word_counts[label] + j) as f64 - 1_f64 + self.V * self.beta).ln();
lD2 += ((self.cluster_word_counts[label] + j) as f64 - 1_f64 + self.V * self.beta)
.ln();
}
*item = (lN1 - lD1 + lN2 - lD2).exp();
}


// normalize the probability
let pnorm: f64 = p.iter().sum();
if pnorm>0_f64 {
if pnorm > 0_f64 {
for item in p.iter_mut().take(self.K) {
*item /= pnorm;
}
}
}
p
}

fn zero_vector(size:usize) -> Vec<u32>
{
fn zero_vector(size: usize) -> Vec<u32> {
let mut v = Vec::<u32>::with_capacity(size);
for _ in 0_usize..size {
v.push(0_u32)
}
v
}

}

#[test]
@@ -253,17 +276,29 @@ fn simple_run() {
assert_eq!(18, model.cluster_counts.iter().sum::<u32>());

// check that we get three clusters
assert_eq!(3, model.cluster_counts.into_iter().filter(|x| x>&0_u32 ).collect::<Vec<u32>>().len());
assert_eq!(
3,
model
.cluster_counts
.into_iter()
.filter(|x| x > &0_u32)
.collect::<Vec<u32>>()
.len()
);

// check that the clusters are pure
let mut check_map = HashMap::<usize,String>::new();
for (i,label) in vec!("A","A","B","B","B","B","B","B","B","B","C","C","C","C","C","C","C","C").into_iter().enumerate() {
let mut check_map = HashMap::<usize, String>::new();
for (i, label) in vec![
"A", "A", "B", "B", "B", "B", "B", "B", "B", "B", "C", "C", "C", "C", "C", "C", "C", "C",
]
.into_iter()
.enumerate()
{
if let std::collections::hash_map::Entry::Vacant(e) = check_map.entry(model.labels[i]) {
e.insert(label.to_string());
} else {
assert_eq!(check_map[&model.labels[i]], label);
}

}
}

@@ -292,5 +327,4 @@ fn indexing() {
assert_eq!(1_usize, *model.word_index_map.get("B").unwrap());
assert_eq!(2_usize, *model.word_index_map.get("D").unwrap());
assert_eq!(3_usize, *model.word_index_map.get("C").unwrap());

}
Loading