from tensorflow.keras import Sequential
from tensorflow.keras.layers import Embedding, Dense, GlobalAveragePooling1D, Flatten, Reshape, Concatenate
from tensorflow.keras.layers.experimental.preprocessing import StringLookup, Normalization, TextVectorization
# Preprocessing layers
gender_lookup = StringLookup()
gender_lookup.adapt(users_df['gender'])
sexuality_lookup = StringLookup()
sexuality_lookup.adapt(users_df['sexuality'])
nd_condition_lookup = StringLookup()
nd_condition_lookup.adapt(np.unique(np.concatenate(users_df['neurodivergent_conditions'].values)))
hobby_lookup = StringLookup()
hobby_lookup.adapt(np.unique(np.concatenate(users_df['hobbies'].values)))
age_normalizer = Normalization()
age_normalizer.adapt(users_df['age'])
score_normalizer = Normalization()
score_normalizer.adapt(users_df['stimming_essentiality_score'])
# TextVectorization layer for stim names
max_tokens = 2000 # Maximum number of unique tokens, adjust this value according to your dataset
output_sequence_length = 16 # Adjust this value based on the expected number of tokens per stim_name
embedding_dim = 32 # Dimension of the dense vectors
text_vectorization = TextVectorization(max_tokens=max_tokens, output_sequence_length=output_sequence_length)
text_vectorization.adapt(list(stim_id_mapping.keys()))
description_vectorization = TextVectorization(max_tokens=max_tokens, output_sequence_length=output_sequence_length)
description_vectorization.adapt(stims_df['description'])
# Create a Normalization layer for harmfulness_score
harmfulness_score_normalizer = Normalization()
harmfulness_score_normalizer.adapt(stims_df['harmfulness_score'])
def create_user_input_layers():
age_input = tf.keras.Input(shape=(1,), name="age", dtype=tf.int32)
gender_input = tf.keras.Input(shape=(1,), name="gender", dtype=tf.string)
sexuality_input = tf.keras.Input(shape=(1,), name="sexuality", dtype=tf.string)
neurodivergent_conditions_input = tf.keras.Input(shape=(None,), ragged=True, name="neurodivergent_conditions", dtype=tf.string)
hobbies_input = tf.keras.Input(shape=(None,), ragged=True , name="hobbies", dtype=tf.string)
stimming_essentiality_score_input = tf.keras.Input(shape=(1,), name="stimming_essentiality_score", dtype=tf.int32)
return {
"age": age_input,
"gender": gender_input,
"sexuality": sexuality_input,
"neurodivergent_conditions": neurodivergent_conditions_input,
"hobbies": hobbies_input,
"stimming_essentiality_score": stimming_essentiality_score_input,
}
def create_stim_input_layers():
name_input = tf.keras.Input(shape=(1,), name="name", dtype=tf.string)
description_input = tf.keras.Input(shape=(1,), name="description", dtype=tf.string)
harmfulness_score_input = tf.keras.Input(shape=(1,), name="harmfulness_score", dtype=tf.int32)
return {"name": name_input, "description": description_input, "harmfulness_score": harmfulness_score_input}
def create_user_model(user_input_layers):
# Preprocessing layers
age_normalized = tf.expand_dims(age_normalizer(user_input_layers["age"]), -1)
gender_embedded = gender_lookup(user_input_layers["gender"])
sexuality_embedded = sexuality_lookup(user_input_layers["sexuality"])
neurodivergent_conditions_embedded = nd_condition_lookup(user_input_layers["neurodivergent_conditions"])
hobbies_embedded = hobby_lookup(user_input_layers["hobbies"])
stimming_essentiality_score_normalized = tf.expand_dims(score_normalizer(user_input_layers["stimming_essentiality_score"]),-1)
# Embedding layers
embedding_dim = 8
gender_embedding = Embedding(input_dim=len(gender_lookup.get_vocabulary()), output_dim=embedding_dim)(gender_embedded)
sexuality_embedding = Embedding(input_dim=len(sexuality_lookup.get_vocabulary()), output_dim=embedding_dim)(sexuality_embedded)
neurodivergent_conditions_dense = Embedding(input_dim=len(nd_condition_lookup.get_vocabulary()), output_dim=embedding_dim)(neurodivergent_conditions_embedded)
hobbies_dense = Embedding(input_dim=len(hobby_lookup.get_vocabulary()), output_dim=embedding_dim)(hobbies_embedded)
# Dense layers
age_dense = Dense(4, activation="relu")(tf.reshape(age_normalized, (-1, 1)))
stimming_essentiality_dense = Dense(6, activation="relu")(tf.reshape(stimming_essentiality_score_normalized, (-1, 1)))
gender_dense = Dense(2, activation='relu')(Flatten()(gender_embedding))
sexuality_dense = Dense(2, activation='relu')(Flatten()(sexuality_embedded))
neurodivergent_conditions_pooled = GlobalAveragePooling1D()(neurodivergent_conditions_dense)
hobbies_pooled = GlobalAveragePooling1D()(hobbies_dense)
nd_condition_dense = Dense(9, activation="relu")(neurodivergent_conditions_pooled)
hobbies_dense = Dense(9, activation="relu")(hobbies_pooled)
# Concatenate dense layers
concatenated = Concatenate(axis=-1)([
age_dense,
gender_dense,
sexuality_dense,
nd_condition_dense,
hobbies_dense,
stimming_essentiality_dense
])
flat_embeddings = Flatten()(concatenated)
# User Dense layers
dense_1 = Dense(64, activation='relu')(flat_embeddings)
dense_2 = Dense(32, activation='relu')(dense_1)
model = tf.keras.Model(inputs=user_input_layers, outputs=dense_2, name="user_model")
return model
def create_stim_model(stim_input_layers):
# TextVectorization layer
text_vectorized = text_vectorization(stim_input_layers["name"])
description_vectorized = description_vectorization(stim_input_layers["description"])
# Embedding layer
embedding_layer = Embedding(input_dim=max_tokens, output_dim=embedding_dim, input_length=output_sequence_length)
embedded = embedding_layer(text_vectorized)
description_embedding_layer = Embedding(input_dim=max_tokens, output_dim=embedding_dim, input_length=output_sequence_length)
description_embedded = description_embedding_layer(description_vectorized)
name_flattened = Flatten()(embedded)
description_flattened = Flatten()(description_embedded)
harmfulness_score_normalized = tf.expand_dims(harmfulness_score_normalizer(stim_input_layers["harmfulness_score"]), -1)
name_dense=Dense(8, activation="relu")(name_flattened)
description_dense=Dense(16, activation="relu")(description_flattened)
harmfulness_score_dense=Dense(8, activation="relu")(tf.reshape(harmfulness_score_normalized,(-1,1)))
concatenated = Concatenate(axis=-1)([
name_dense,
description_dense,
harmfulness_score_dense
])
flat_embeddings = Flatten()(concatenated)
# Flatten and Dense layers
dense_1 = Dense(64, activation='relu')(flat_embeddings)
dense_2 = Dense(32, activation='relu')(dense_1)
return tf.keras.Model(inputs=stim_input_layers, outputs=dense_2, name="stim_model")
user_input_layers = create_user_input_layers()
stim_input_layers = create_stim_input_layers()
user_model = create_user_model( user_input_layers)
stim_model = create_stim_model( stim_input_layers)