create_nn_module = function(nn_name = "DeepFFN", hd_neurons = c(20, 30, 20, 15), no_x = 10, no_y = 1, activations = NULL) {
box::use(
rlang[new_function, call2, caller_env, expr, exprs, sym, is_function, env_get_list],
purrr[map, map2, reduce, set_names, compact, map_if, keep, map_lgl],
glue[glue],
magrittr[`%>%`]
)
nodes = c(no_x, hd_neurons, no_y)
n_layers = length(nodes) - 1
call_args = match.call()
activation_arg = call_args$activations
if (is.null(activations)) {
activations = c(rep("nnf_relu", length(hd_neurons)), NA)
} else if (length(activations) == 1 || is.function(activations)) {
single_activation = activations
activations = c(rep(list(single_activation), length(hd_neurons)), list(NA))
}
activations = map2(activations, seq_along(activations), function(x, i) {
if (is.null(x)) {
NULL
} else if (is.function(x)) {
if(!is.null(activation_arg) && is.call(activation_arg) && activation_arg[[1]] == quote(c)) {
func_name = as.character(activation_arg[[i + 1]])
sym(func_name)
} else if(!is.null(activation_arg) && (is.symbol(activation_arg) || is.character(activation_arg))) {
func_name = as.character(activation_arg)
sym(func_name)
} else {
parent_env = parent.frame()
env_names = ls(envir = parent_env)
matching_names = env_names %>%
keep(~ {
obj = env_get_list(parent_env, .x)[[1]]
identical(obj, x)
})
if (length(matching_names) > 0) {
sym(matching_names[1])
} else {
stop("Could not determine function name for activation function")
}
}
} else if (is.character(x)) {
if (length(x) == 1 && is.na(x)) {
NULL
} else {
sym(x)
}
} else if (is.symbol(x)) {
x
} else if (is.logical(x) && length(x) == 1 && is.na(x)) {
NULL
} else {
stop("Activation must be a function, string, symbol, NA, or NULL")
}
})
init_body = map2(1:n_layers, map2(nodes[-length(nodes)], nodes[-1], c), function(i, dims) {
layer_name = if (i == n_layers) "out" else glue("fc{i}")
call2("=", call2("$", expr(self), sym(layer_name)), call2("nn_linear", !!!dims))
})
init = new_function(
args = list(),
body = call2("{", !!!init_body)
)
layer_calls = map(1:n_layers, function(i) {
layer_name = if (i == n_layers) "out" else glue("fc{i}")
activation_fn = if (i <= length(activations)) activations[[i]] else NULL
result = list(call2(call2("$", expr(self), sym(layer_name))))
if (!is.null(activation_fn)) {
result = append(result, list(call2(activation_fn)))
}
result
}) |>
unlist() |> # recursive = FALSE is also valid
compact()
forward_body = reduce(layer_calls, function(acc, call) {
expr(!!acc %>% !!call)
}, .init = expr(x))
forward = new_function(
args = list(x = expr()),
body = call2("{", forward_body)
)
call2("nn_module", nn_name, initialize = init, forward = forward)
}